Repository: pyg-team/pytorch_geometric Branch: master Commit: 2eadeee0495f Files: 1487 Total size: 6.8 MB Directory structure: gitextract_yj2ugid6/ ├── .github/ │ ├── CODEOWNERS │ ├── CONTRIBUTING.md │ ├── ISSUE_TEMPLATE/ │ │ ├── bug-report.yml │ │ ├── config.yml │ │ ├── documentation.yml │ │ ├── feature-request.yml │ │ ├── installation.yml │ │ └── refactor.yml │ ├── actions/ │ │ └── setup/ │ │ └── action.yml │ ├── dependabot.yml │ ├── labeler.yml │ └── workflows/ │ ├── _testing.yml │ ├── auto-merge.yml │ ├── building_nightly.yml │ ├── changelog.yml │ ├── documentation.yml │ ├── examples.yml │ ├── labeler.yml │ ├── linting.yml │ ├── testing.yml │ └── testing_rag.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CITATION.cff ├── LICENSE ├── README.md ├── benchmark/ │ ├── README.md │ ├── citation/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── appnp.py │ │ ├── arma.py │ │ ├── cheb.py │ │ ├── datasets.py │ │ ├── gat.py │ │ ├── gcn.py │ │ ├── inference.sh │ │ ├── run.sh │ │ ├── sgc.py │ │ ├── statistics.py │ │ └── train_eval.py │ ├── inference/ │ │ ├── README.md │ │ └── inference_benchmark.py │ ├── kernel/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── asap.py │ │ ├── datasets.py │ │ ├── diff_pool.py │ │ ├── edge_pool.py │ │ ├── gcn.py │ │ ├── gin.py │ │ ├── global_attention.py │ │ ├── graclus.py │ │ ├── graph_sage.py │ │ ├── main.py │ │ ├── main_performance.py │ │ ├── sag_pool.py │ │ ├── set2set.py │ │ ├── sort_pool.py │ │ ├── statistics.py │ │ ├── top_k.py │ │ └── train_eval.py │ ├── loader/ │ │ └── neighbor_loader.py │ ├── multi_gpu/ │ │ └── training/ │ │ ├── README.md │ │ ├── common.py │ │ ├── training_benchmark_cuda.py │ │ └── training_benchmark_xpu.py │ ├── points/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── datasets.py │ │ ├── edge_cnn.py │ │ ├── mpnn.py │ │ ├── point_cnn.py │ │ ├── point_net.py │ │ ├── spline_cnn.py │ │ ├── statistics.py │ │ └── train_eval.py │ ├── runtime/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── dgl/ │ │ │ ├── gat.py │ │ │ ├── gcn.py │ │ │ ├── hidden.py │ │ │ ├── main.py │ │ │ ├── rgcn.py │ │ │ └── train.py │ │ ├── gat.py │ │ ├── gcn.py │ │ ├── main.py │ │ ├── rgcn.py │ │ └── train.py │ ├── setup.py │ ├── training/ │ │ ├── README.md │ │ └── training_benchmark.py │ └── utils/ │ ├── __init__.py │ ├── hetero_gat.py │ ├── hetero_sage.py │ └── utils.py ├── codecov.yml ├── docker/ │ ├── Dockerfile │ ├── Dockerfile.xpu │ ├── README.md │ └── singularity ├── docs/ │ ├── Makefile │ ├── README.md │ ├── requirements.txt │ └── source/ │ ├── .gitignore │ ├── _figures/ │ │ ├── .gitignore │ │ ├── build.sh │ │ ├── graph.tex │ │ ├── hg_example.tex │ │ ├── to_hetero.tex │ │ └── to_hetero_with_bases.tex │ ├── _static/ │ │ └── js/ │ │ └── version_alert.js │ ├── _templates/ │ │ └── autosummary/ │ │ ├── class.rst │ │ ├── inherited_class.rst │ │ ├── metrics.rst │ │ ├── nn.rst │ │ └── only_class.rst │ ├── advanced/ │ │ ├── batching.rst │ │ ├── compile.rst │ │ ├── cpu_affinity.rst │ │ ├── graphgym.rst │ │ ├── hgam.rst │ │ ├── jit.rst │ │ ├── remote.rst │ │ └── sparse_tensor.rst │ ├── cheatsheet/ │ │ ├── data_cheatsheet.rst │ │ └── gnn_cheatsheet.rst │ ├── conf.py │ ├── external/ │ │ └── resources.rst │ ├── get_started/ │ │ ├── colabs.rst │ │ └── introduction.rst │ ├── index.rst │ ├── install/ │ │ ├── installation.rst │ │ └── quick-start.html │ ├── modules/ │ │ ├── contrib.rst │ │ ├── data.rst │ │ ├── datasets.rst │ │ ├── distributed.rst │ │ ├── explain.rst │ │ ├── graphgym.rst │ │ ├── llm.rst │ │ ├── loader.rst │ │ ├── metrics.rst │ │ ├── nn.rst │ │ ├── profile.rst │ │ ├── root.rst │ │ ├── sampler.rst │ │ ├── transforms.rst │ │ └── utils.rst │ ├── notes/ │ │ ├── batching.rst │ │ ├── cheatsheet.rst │ │ ├── colabs.rst │ │ ├── create_dataset.rst │ │ ├── create_gnn.rst │ │ ├── data_cheatsheet.rst │ │ ├── explain.rst │ │ ├── graphgym.rst │ │ ├── heterogeneous.rst │ │ ├── installation.rst │ │ ├── introduction.rst │ │ ├── jit.rst │ │ ├── load_csv.rst │ │ ├── remote.rst │ │ ├── resources.rst │ │ └── sparse_tensor.rst │ └── tutorial/ │ ├── application.rst │ ├── compile.rst │ ├── create_dataset.rst │ ├── create_gnn.rst │ ├── dataset.rst │ ├── dataset_splitting.rst │ ├── distributed.rst │ ├── distributed_pyg.rst │ ├── explain.rst │ ├── gnn_design.rst │ ├── graph_transformer.rst │ ├── heterogeneous.rst │ ├── load_csv.rst │ ├── multi_gpu_vanilla.rst │ ├── multi_node_multi_gpu_vanilla.rst │ ├── neighbor_loader.rst │ ├── point_cloud.rst │ └── shallow_node_embeddings.rst ├── examples/ │ ├── README.md │ ├── agnn.py │ ├── ar_link_pred.py │ ├── argva_node_clustering.py │ ├── arma.py │ ├── attentive_fp.py │ ├── autoencoder.py │ ├── cluster_gcn_ppi.py │ ├── cluster_gcn_reddit.py │ ├── colors_topk_pool.py │ ├── compile/ │ │ ├── gcn.py │ │ └── gin.py │ ├── contrib/ │ │ ├── README.md │ │ ├── pgm_explainer_graph_classification.py │ │ ├── pgm_explainer_node_classification.py │ │ ├── rbcd_attack.py │ │ └── rbcd_attack_poisoning.py │ ├── cora.py │ ├── correct_and_smooth.py │ ├── cpp/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── main.cpp │ │ └── save_model.py │ ├── datapipe.py │ ├── dgcnn_classification.py │ ├── dgcnn_segmentation.py │ ├── dir_gnn.py │ ├── distributed/ │ │ ├── README.md │ │ ├── graphlearn_for_pytorch/ │ │ │ ├── README.md │ │ │ ├── dist_train_sage_sup_config.yml │ │ │ ├── dist_train_sage_supervised.py │ │ │ ├── launch.py │ │ │ └── partition_ogbn_dataset.py │ │ ├── kuzu/ │ │ │ ├── README.md │ │ │ └── papers_100M/ │ │ │ ├── README.md │ │ │ ├── prepare_data.py │ │ │ └── train.py │ │ └── pyg/ │ │ └── README.md │ ├── dna.py │ ├── egc.py │ ├── equilibrium_median.py │ ├── explain/ │ │ ├── README.md │ │ ├── captum_explainer.py │ │ ├── captum_explainer_hetero_link.py │ │ ├── gnn_explainer.py │ │ ├── gnn_explainer_ba_shapes.py │ │ ├── gnn_explainer_link_pred.py │ │ └── graphmask_explainer.py │ ├── faust.py │ ├── film.py │ ├── gat.py │ ├── gcn.py │ ├── gcn2_cora.py │ ├── gcn2_ppi.py │ ├── geniepath.py │ ├── glnn.py │ ├── gpse.py │ ├── graph_gps.py │ ├── graph_sage_unsup.py │ ├── graph_sage_unsup_ppi.py │ ├── graph_saint.py │ ├── graph_unet.py │ ├── hetero/ │ │ ├── README.md │ │ ├── bipartite_sage.py │ │ ├── bipartite_sage_unsup.py │ │ ├── dmgi_unsup.py │ │ ├── han_imdb.py │ │ ├── hetero_conv_dblp.py │ │ ├── hetero_link_pred.py │ │ ├── hgt_dblp.py │ │ ├── hierarchical_sage.py │ │ ├── load_csv.py │ │ ├── metapath2vec.py │ │ ├── recommender_system.py │ │ ├── temporal_link_pred.py │ │ └── to_hetero_mag.py │ ├── hierarchical_sampling.py │ ├── infomax_inductive.py │ ├── infomax_transductive.py │ ├── jit/ │ │ ├── README.md │ │ ├── film.py │ │ ├── gat.py │ │ ├── gcn.py │ │ └── gin.py │ ├── kge_fb15k_237.py │ ├── label_prop.py │ ├── lcm_aggr_2nd_min.py │ ├── lightgcn.py │ ├── link_pred.py │ ├── linkx.py │ ├── llm/ │ │ ├── README.md │ │ ├── g_retriever.py │ │ ├── git_mol.py │ │ ├── glem.py │ │ ├── molecule_gpt.py │ │ ├── protein_mpnn.py │ │ └── txt2kg_rag.py │ ├── lpformer.py │ ├── mem_pool.py │ ├── mixhop.py │ ├── mnist_graclus.py │ ├── mnist_nn_conv.py │ ├── mnist_voxel_grid.py │ ├── multi_gpu/ │ │ ├── README.md │ │ ├── distributed_batching.py │ │ ├── distributed_sampling.py │ │ ├── distributed_sampling_multinode.py │ │ ├── distributed_sampling_multinode.sbatch │ │ ├── distributed_sampling_xpu.py │ │ ├── mag240m_graphsage.py │ │ ├── model_parallel.py │ │ ├── papers100m_gcn.py │ │ ├── papers100m_gcn_multinode.py │ │ ├── pcqm4m_ogb.py │ │ └── taobao.py │ ├── mutag_gin.py │ ├── node2vec.py │ ├── ogbn_proteins_deepgcn.py │ ├── ogbn_train.py │ ├── ogc.py │ ├── pmlp.py │ ├── pna.py │ ├── point_transformer_classification.py │ ├── point_transformer_segmentation.py │ ├── pointnet2_classification.py │ ├── pointnet2_segmentation.py │ ├── ppi.py │ ├── proteins_diff_pool.py │ ├── proteins_dmon_pool.py │ ├── proteins_gmt.py │ ├── proteins_mincut_pool.py │ ├── proteins_topk_pool.py │ ├── pytorch_ignite/ │ │ ├── README.md │ │ └── gin.py │ ├── pytorch_lightning/ │ │ ├── README.md │ │ ├── gin.py │ │ ├── graph_sage.py │ │ └── relational_gnn.py │ ├── qm9_nn_conv.py │ ├── qm9_pretrained_dimenet.py │ ├── qm9_pretrained_schnet.py │ ├── quiver/ │ │ ├── README.md │ │ ├── multi_gpu_quiver.py │ │ └── single_gpu_quiver.py │ ├── randlanet_classification.py │ ├── randlanet_segmentation.py │ ├── rdl.py │ ├── rect.py │ ├── reddit.py │ ├── renet.py │ ├── rev_gnn.py │ ├── rgat.py │ ├── rgcn.py │ ├── rgcn_link_pred.py │ ├── seal_link_pred.py │ ├── sgc.py │ ├── shadow.py │ ├── sign.py │ ├── signed_gcn.py │ ├── super_gat.py │ ├── tagcn.py │ ├── tensorboard_logging.py │ ├── tgn.py │ ├── triangles_sag_pool.py │ ├── unimp_arxiv.py │ ├── upfd.py │ └── wl_kernel.py ├── graphgym/ │ ├── agg_batch.py │ ├── configs/ │ │ ├── example.yaml │ │ └── pyg/ │ │ ├── example_graph.yaml │ │ ├── example_link.yaml │ │ └── example_node.yaml │ ├── configs_gen.py │ ├── custom_graphgym/ │ │ ├── __init__.py │ │ ├── act/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── config/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── encoder/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── head/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── layer/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── loader/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── loss/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── network/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── optimizer/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── pooling/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── stage/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ ├── train/ │ │ │ ├── __init__.py │ │ │ └── example.py │ │ └── transform/ │ │ └── __init__.py │ ├── grids/ │ │ ├── example.txt │ │ └── pyg/ │ │ └── example.txt │ ├── main.py │ ├── parallel.sh │ ├── run_batch.sh │ ├── run_single.sh │ └── sample/ │ ├── dimensions.txt │ └── dimensionsatt.txt ├── pyproject.toml ├── readthedocs.yml ├── test/ │ ├── conftest.py │ ├── contrib/ │ │ ├── explain/ │ │ │ └── test_pgm_explainer.py │ │ └── nn/ │ │ └── models/ │ │ └── test_rbcd_attack.py │ ├── data/ │ │ ├── lightning/ │ │ │ └── test_datamodule.py │ │ ├── test_batch.py │ │ ├── test_data.py │ │ ├── test_database.py │ │ ├── test_datapipes.py │ │ ├── test_dataset.py │ │ ├── test_dataset_summary.py │ │ ├── test_feature_store.py │ │ ├── test_graph_store.py │ │ ├── test_hetero_data.py │ │ ├── test_hypergraph_data.py │ │ ├── test_inherit.py │ │ ├── test_on_disk_dataset.py │ │ ├── test_remote_backend_utils.py │ │ ├── test_storage.py │ │ ├── test_temporal.py │ │ └── test_view.py │ ├── datasets/ │ │ ├── graph_generator/ │ │ │ ├── test_ba_graph.py │ │ │ ├── test_er_graph.py │ │ │ ├── test_grid_graph.py │ │ │ └── test_tree_graph.py │ │ ├── motif_generator/ │ │ │ ├── test_custom_motif.py │ │ │ ├── test_cycle_motif.py │ │ │ ├── test_grid_motif.py │ │ │ └── test_house_motif.py │ │ ├── test_ba_shapes.py │ │ ├── test_bzr.py │ │ ├── test_elliptic.py │ │ ├── test_enzymes.py │ │ ├── test_explainer_dataset.py │ │ ├── test_fake.py │ │ ├── test_git_mol_dataset.py │ │ ├── test_imdb_binary.py │ │ ├── test_infection_dataset.py │ │ ├── test_karate.py │ │ ├── test_medshapenet.py │ │ ├── test_molecule_gpt_dataset.py │ │ ├── test_mutag.py │ │ ├── test_planetoid.py │ │ ├── test_protein_mpnn_dataset.py │ │ ├── test_snap_dataset.py │ │ ├── test_suite_sparse.py │ │ ├── test_tag_dataset.py │ │ ├── test_teeth3ds.py │ │ └── test_web_qsp_dataset.py │ ├── distributed/ │ │ ├── test_dist_link_neighbor_loader.py │ │ ├── test_dist_link_neighbor_sampler.py │ │ ├── test_dist_neighbor_loader.py │ │ ├── test_dist_neighbor_sampler.py │ │ ├── test_dist_utils.py │ │ ├── test_local_feature_store.py │ │ ├── test_local_graph_store.py │ │ ├── test_partition.py │ │ └── test_rpc.py │ ├── explain/ │ │ ├── algorithm/ │ │ │ ├── test_attention_explainer.py │ │ │ ├── test_captum.py │ │ │ ├── test_captum_explainer.py │ │ │ ├── test_captum_hetero.py │ │ │ ├── test_explain_algorithm_utils.py │ │ │ ├── test_gnn_explainer.py │ │ │ ├── test_graphmask_explainer.py │ │ │ └── test_pg_explainer.py │ │ ├── conftest.py │ │ ├── metric/ │ │ │ ├── test_basic_metric.py │ │ │ ├── test_faithfulness.py │ │ │ └── test_fidelity.py │ │ ├── test_explain_config.py │ │ ├── test_explainer.py │ │ ├── test_explanation.py │ │ ├── test_hetero_explainer.py │ │ └── test_hetero_explanation.py │ ├── graphgym/ │ │ ├── example_node.yml │ │ ├── test_config.py │ │ ├── test_graphgym.py │ │ ├── test_logger.py │ │ └── test_register.py │ ├── io/ │ │ ├── example1.off │ │ ├── example2.off │ │ ├── test_fs.py │ │ └── test_off.py │ ├── llm/ │ │ ├── conftest.py │ │ ├── models/ │ │ │ ├── test_g_retriever.py │ │ │ ├── test_git_mol.py │ │ │ ├── test_glem.py │ │ │ ├── test_llm.py │ │ │ ├── test_llm_judge.py │ │ │ ├── test_molecule_gpt.py │ │ │ ├── test_protein_mpnn.py │ │ │ ├── test_sentence_transformer.py │ │ │ ├── test_txt2kg.py │ │ │ └── test_vision_transformer.py │ │ ├── test_large_graph_indexer.py │ │ ├── test_rag_loader.py │ │ └── utils/ │ │ ├── test_rag_backend_utils.py │ │ ├── test_rag_feature_store.py │ │ ├── test_rag_graph_store.py │ │ └── test_vectorrag.py │ ├── loader/ │ │ ├── test_cache.py │ │ ├── test_cluster.py │ │ ├── test_dataloader.py │ │ ├── test_dynamic_batch_sampler.py │ │ ├── test_graph_saint.py │ │ ├── test_hgt_loader.py │ │ ├── test_ibmb_loader.py │ │ ├── test_imbalanced_sampler.py │ │ ├── test_link_neighbor_loader.py │ │ ├── test_mixin.py │ │ ├── test_neighbor_loader.py │ │ ├── test_neighbor_sampler.py │ │ ├── test_prefetch.py │ │ ├── test_random_node_loader.py │ │ ├── test_shadow.py │ │ ├── test_temporal_dataloader.py │ │ ├── test_utils.py │ │ └── test_zip_loader.py │ ├── metrics/ │ │ └── test_link_pred_metric.py │ ├── my_config.yaml │ ├── nn/ │ │ ├── aggr/ │ │ │ ├── test_aggr_utils.py │ │ │ ├── test_attention.py │ │ │ ├── test_basic.py │ │ │ ├── test_deep_sets.py │ │ │ ├── test_equilibrium.py │ │ │ ├── test_fused.py │ │ │ ├── test_gmt.py │ │ │ ├── test_gru.py │ │ │ ├── test_lcm.py │ │ │ ├── test_lstm.py │ │ │ ├── test_mlp_aggr.py │ │ │ ├── test_multi.py │ │ │ ├── test_patch_transformer.py │ │ │ ├── test_quantile.py │ │ │ ├── test_scaler.py │ │ │ ├── test_set2set.py │ │ │ ├── test_set_transformer.py │ │ │ ├── test_sort.py │ │ │ └── test_variance_preserving.py │ │ ├── attention/ │ │ │ ├── test_performer_attention.py │ │ │ ├── test_polynormer_attention.py │ │ │ └── test_qformer.py │ │ ├── conv/ │ │ │ ├── cugraph/ │ │ │ │ ├── test_cugraph_gat_conv.py │ │ │ │ ├── test_cugraph_rgcn_conv.py │ │ │ │ └── test_cugraph_sage_conv.py │ │ │ ├── test_agnn_conv.py │ │ │ ├── test_antisymmetric_conv.py │ │ │ ├── test_appnp.py │ │ │ ├── test_arma_conv.py │ │ │ ├── test_cg_conv.py │ │ │ ├── test_cheb_conv.py │ │ │ ├── test_cluster_gcn_conv.py │ │ │ ├── test_create_gnn.py │ │ │ ├── test_dir_gnn_conv.py │ │ │ ├── test_dna_conv.py │ │ │ ├── test_edge_conv.py │ │ │ ├── test_eg_conv.py │ │ │ ├── test_fa_conv.py │ │ │ ├── test_feast_conv.py │ │ │ ├── test_film_conv.py │ │ │ ├── test_fused_gat_conv.py │ │ │ ├── test_gat_conv.py │ │ │ ├── test_gated_graph_conv.py │ │ │ ├── test_gatv2_conv.py │ │ │ ├── test_gcn2_conv.py │ │ │ ├── test_gcn_conv.py │ │ │ ├── test_gen_conv.py │ │ │ ├── test_general_conv.py │ │ │ ├── test_gin_conv.py │ │ │ ├── test_gmm_conv.py │ │ │ ├── test_gps_conv.py │ │ │ ├── test_graph_conv.py │ │ │ ├── test_gravnet_conv.py │ │ │ ├── test_han_conv.py │ │ │ ├── test_heat_conv.py │ │ │ ├── test_hetero_conv.py │ │ │ ├── test_hgt_conv.py │ │ │ ├── test_hypergraph_conv.py │ │ │ ├── test_le_conv.py │ │ │ ├── test_lg_conv.py │ │ │ ├── test_meshcnn_conv.py │ │ │ ├── test_message_passing.py │ │ │ ├── test_mf_conv.py │ │ │ ├── test_mixhop_conv.py │ │ │ ├── test_nn_conv.py │ │ │ ├── test_pan_conv.py │ │ │ ├── test_pdn_conv.py │ │ │ ├── test_pna_conv.py │ │ │ ├── test_point_conv.py │ │ │ ├── test_point_gnn_conv.py │ │ │ ├── test_point_transformer_conv.py │ │ │ ├── test_ppf_conv.py │ │ │ ├── test_res_gated_graph_conv.py │ │ │ ├── test_rgat_conv.py │ │ │ ├── test_rgcn_conv.py │ │ │ ├── test_sage_conv.py │ │ │ ├── test_sg_conv.py │ │ │ ├── test_signed_conv.py │ │ │ ├── test_simple_conv.py │ │ │ ├── test_spline_conv.py │ │ │ ├── test_ssg_conv.py │ │ │ ├── test_static_graph.py │ │ │ ├── test_supergat_conv.py │ │ │ ├── test_tag_conv.py │ │ │ ├── test_transformer_conv.py │ │ │ ├── test_wl_conv.py │ │ │ ├── test_wl_conv_continuous.py │ │ │ ├── test_x_conv.py │ │ │ └── utils/ │ │ │ └── test_gnn_cheatsheet.py │ │ ├── dense/ │ │ │ ├── test_dense_gat_conv.py │ │ │ ├── test_dense_gcn_conv.py │ │ │ ├── test_dense_gin_conv.py │ │ │ ├── test_dense_graph_conv.py │ │ │ ├── test_dense_sage_conv.py │ │ │ ├── test_diff_pool.py │ │ │ ├── test_dmon_pool.py │ │ │ ├── test_linear.py │ │ │ └── test_mincut_pool.py │ │ ├── functional/ │ │ │ ├── test_bro.py │ │ │ └── test_gini.py │ │ ├── kge/ │ │ │ ├── test_complex.py │ │ │ ├── test_distmult.py │ │ │ ├── test_rotate.py │ │ │ └── test_transe.py │ │ ├── models/ │ │ │ ├── test_attentive_fp.py │ │ │ ├── test_attract_repel.py │ │ │ ├── test_autoencoder.py │ │ │ ├── test_basic_gnn.py │ │ │ ├── test_correct_and_smooth.py │ │ │ ├── test_deep_graph_infomax.py │ │ │ ├── test_deepgcn.py │ │ │ ├── test_dimenet.py │ │ │ ├── test_gnnff.py │ │ │ ├── test_gpse.py │ │ │ ├── test_graph_mixer.py │ │ │ ├── test_graph_unet.py │ │ │ ├── test_jumping_knowledge.py │ │ │ ├── test_label_prop.py │ │ │ ├── test_lightgcn.py │ │ │ ├── test_linkx.py │ │ │ ├── test_lpformer.py │ │ │ ├── test_mask_label.py │ │ │ ├── test_meta.py │ │ │ ├── test_metapath2vec.py │ │ │ ├── test_mlp.py │ │ │ ├── test_neural_fingerprint.py │ │ │ ├── test_node2vec.py │ │ │ ├── test_pmlp.py │ │ │ ├── test_polynormer.py │ │ │ ├── test_re_net.py │ │ │ ├── test_rect.py │ │ │ ├── test_rev_gnn.py │ │ │ ├── test_schnet.py │ │ │ ├── test_sgformer.py │ │ │ ├── test_signed_gcn.py │ │ │ ├── test_tgn.py │ │ │ └── test_visnet.py │ │ ├── norm/ │ │ │ ├── test_batch_norm.py │ │ │ ├── test_diff_group_norm.py │ │ │ ├── test_graph_norm.py │ │ │ ├── test_graph_size_norm.py │ │ │ ├── test_instance_norm.py │ │ │ ├── test_layer_norm.py │ │ │ ├── test_mean_subtraction_norm.py │ │ │ ├── test_msg_norm.py │ │ │ └── test_pair_norm.py │ │ ├── pool/ │ │ │ ├── connect/ │ │ │ │ └── test_filter_edges.py │ │ │ ├── select/ │ │ │ │ └── test_select_topk.py │ │ │ ├── test_approx_knn.py │ │ │ ├── test_asap.py │ │ │ ├── test_avg_pool.py │ │ │ ├── test_cluster_pool.py │ │ │ ├── test_consecutive.py │ │ │ ├── test_decimation.py │ │ │ ├── test_edge_pool.py │ │ │ ├── test_glob.py │ │ │ ├── test_graclus.py │ │ │ ├── test_knn.py │ │ │ ├── test_max_pool.py │ │ │ ├── test_mem_pool.py │ │ │ ├── test_pan_pool.py │ │ │ ├── test_pool.py │ │ │ ├── test_sag_pool.py │ │ │ ├── test_topk_pool.py │ │ │ └── test_voxel_grid.py │ │ ├── test_compile_basic.py │ │ ├── test_compile_conv.py │ │ ├── test_compile_dynamic.py │ │ ├── test_data_parallel.py │ │ ├── test_encoding.py │ │ ├── test_fvcore.py │ │ ├── test_fx.py │ │ ├── test_inits.py │ │ ├── test_model_hub.py │ │ ├── test_model_summary.py │ │ ├── test_module_dict.py │ │ ├── test_parameter_dict.py │ │ ├── test_reshape.py │ │ ├── test_resolver.py │ │ ├── test_sequential.py │ │ ├── test_to_fixed_size_transformer.py │ │ ├── test_to_hetero_module.py │ │ ├── test_to_hetero_transformer.py │ │ ├── test_to_hetero_with_bases_transformer.py │ │ └── unpool/ │ │ └── test_knn_interpolate.py │ ├── profile/ │ │ ├── test_benchmark.py │ │ ├── test_nvtx.py │ │ ├── test_profile.py │ │ ├── test_profile_utils.py │ │ └── test_profiler.py │ ├── sampler/ │ │ ├── test_sampler_base.py │ │ └── test_sampler_neighbor_sampler.py │ ├── test_config_mixin.py │ ├── test_config_store.py │ ├── test_debug.py │ ├── test_edge_index.py │ ├── test_experimental.py │ ├── test_hash_tensor.py │ ├── test_home.py │ ├── test_index.py │ ├── test_inspector.py │ ├── test_isinstance.py │ ├── test_onnx.py │ ├── test_seed.py │ ├── test_typing.py │ ├── test_warnings.py │ ├── testing/ │ │ └── test_decorators.py │ ├── transforms/ │ │ ├── test_add_gpse.py │ │ ├── test_add_metapaths.py │ │ ├── test_add_positional_encoding.py │ │ ├── test_add_remaining_self_loops.py │ │ ├── test_add_self_loops.py │ │ ├── test_cartesian.py │ │ ├── test_center.py │ │ ├── test_compose.py │ │ ├── test_constant.py │ │ ├── test_delaunay.py │ │ ├── test_distance.py │ │ ├── test_face_to_edge.py │ │ ├── test_feature_propagation.py │ │ ├── test_fixed_points.py │ │ ├── test_gcn_norm.py │ │ ├── test_gdc.py │ │ ├── test_generate_mesh_normals.py │ │ ├── test_grid_sampling.py │ │ ├── test_half_hop.py │ │ ├── test_knn_graph.py │ │ ├── test_laplacian_lambda_max.py │ │ ├── test_largest_connected_components.py │ │ ├── test_line_graph.py │ │ ├── test_linear_transformation.py │ │ ├── test_local_cartesian.py │ │ ├── test_local_degree_profile.py │ │ ├── test_mask_transform.py │ │ ├── test_node_property_split.py │ │ ├── test_normalize_features.py │ │ ├── test_normalize_rotation.py │ │ ├── test_normalize_scale.py │ │ ├── test_one_hot_degree.py │ │ ├── test_pad.py │ │ ├── test_point_pair_features.py │ │ ├── test_polar.py │ │ ├── test_radius_graph.py │ │ ├── test_random_flip.py │ │ ├── test_random_jitter.py │ │ ├── test_random_link_split.py │ │ ├── test_random_node_split.py │ │ ├── test_random_rotate.py │ │ ├── test_random_scale.py │ │ ├── test_random_shear.py │ │ ├── test_remove_duplicated_edges.py │ │ ├── test_remove_isolated_nodes.py │ │ ├── test_remove_self_loops.py │ │ ├── test_remove_training_classes.py │ │ ├── test_rooted_subgraph.py │ │ ├── test_sample_points.py │ │ ├── test_sign.py │ │ ├── test_spherical.py │ │ ├── test_svd_feature_reduction.py │ │ ├── test_target_indegree.py │ │ ├── test_to_dense.py │ │ ├── test_to_device.py │ │ ├── test_to_sparse_tensor.py │ │ ├── test_to_superpixels.py │ │ ├── test_to_undirected.py │ │ ├── test_two_hop.py │ │ └── test_virtual_node.py │ ├── utils/ │ │ ├── conftest.py │ │ ├── test_assortativity.py │ │ ├── test_augmentation.py │ │ ├── test_coalesce.py │ │ ├── test_convert.py │ │ ├── test_cross_entropy.py │ │ ├── test_degree.py │ │ ├── test_dropout.py │ │ ├── test_embedding.py │ │ ├── test_functions.py │ │ ├── test_geodesic.py │ │ ├── test_grid.py │ │ ├── test_hetero.py │ │ ├── test_homophily.py │ │ ├── test_index_sort.py │ │ ├── test_isolated.py │ │ ├── test_laplacian.py │ │ ├── test_lexsort.py │ │ ├── test_loop.py │ │ ├── test_map.py │ │ ├── test_mask.py │ │ ├── test_mesh_laplacian.py │ │ ├── test_negative_sampling.py │ │ ├── test_nested.py │ │ ├── test_noise_scheduler.py │ │ ├── test_normalize_edge_index.py │ │ ├── test_normalized_cut.py │ │ ├── test_num_nodes.py │ │ ├── test_one_hot.py │ │ ├── test_ppr.py │ │ ├── test_random.py │ │ ├── test_repeat.py │ │ ├── test_scatter.py │ │ ├── test_segment.py │ │ ├── test_select.py │ │ ├── test_smiles.py │ │ ├── test_softmax.py │ │ ├── test_sort_edge_index.py │ │ ├── test_sparse.py │ │ ├── test_spmm.py │ │ ├── test_subgraph.py │ │ ├── test_to_dense_adj.py │ │ ├── test_to_dense_batch.py │ │ ├── test_total_influence.py │ │ ├── test_train_test_split_edges.py │ │ ├── test_tree_decomposition.py │ │ ├── test_trim_to_layer.py │ │ ├── test_unbatch.py │ │ └── test_undirected.py │ └── visualization/ │ ├── test_graph_visualization.py │ └── test_influence.py └── torch_geometric/ ├── __init__.py ├── _compile.py ├── _onnx.py ├── backend.py ├── config_mixin.py ├── config_store.py ├── contrib/ │ ├── __init__.py │ ├── datasets/ │ │ └── __init__.py │ ├── explain/ │ │ ├── __init__.py │ │ └── pgm_explainer.py │ ├── nn/ │ │ ├── __init__.py │ │ ├── conv/ │ │ │ └── __init__.py │ │ └── models/ │ │ ├── __init__.py │ │ └── rbcd_attack.py │ └── transforms/ │ └── __init__.py ├── data/ │ ├── __init__.py │ ├── batch.py │ ├── collate.py │ ├── data.py │ ├── database.py │ ├── datapipes.py │ ├── dataset.py │ ├── download.py │ ├── extract.py │ ├── feature_store.py │ ├── graph_store.py │ ├── hetero_data.py │ ├── hypergraph_data.py │ ├── in_memory_dataset.py │ ├── lightning/ │ │ ├── __init__.py │ │ └── datamodule.py │ ├── makedirs.py │ ├── on_disk_dataset.py │ ├── remote_backend_utils.py │ ├── separate.py │ ├── storage.py │ ├── summary.py │ ├── temporal.py │ └── view.py ├── datasets/ │ ├── __init__.py │ ├── actor.py │ ├── airfrans.py │ ├── airports.py │ ├── amazon.py │ ├── amazon_book.py │ ├── amazon_products.py │ ├── aminer.py │ ├── aqsol.py │ ├── attributed_graph_dataset.py │ ├── ba2motif_dataset.py │ ├── ba_multi_shapes.py │ ├── ba_shapes.py │ ├── bitcoin_otc.py │ ├── brca_tgca.py │ ├── citation_full.py │ ├── city.py │ ├── coauthor.py │ ├── coma.py │ ├── cornell.py │ ├── dblp.py │ ├── dbp15k.py │ ├── deezer_europe.py │ ├── dgraph.py │ ├── dynamic_faust.py │ ├── elliptic.py │ ├── elliptic_temporal.py │ ├── email_eu_core.py │ ├── entities.py │ ├── explainer_dataset.py │ ├── facebook.py │ ├── fake.py │ ├── faust.py │ ├── flickr.py │ ├── freebase.py │ ├── gdelt.py │ ├── gdelt_lite.py │ ├── ged_dataset.py │ ├── gemsec.py │ ├── geometry.py │ ├── git_mol_dataset.py │ ├── github.py │ ├── gnn_benchmark_dataset.py │ ├── graph_generator/ │ │ ├── __init__.py │ │ ├── ba_graph.py │ │ ├── base.py │ │ ├── er_graph.py │ │ ├── grid_graph.py │ │ └── tree_graph.py │ ├── heterophilous_graph_dataset.py │ ├── hgb_dataset.py │ ├── hm.py │ ├── hydro_net.py │ ├── icews.py │ ├── igmc_dataset.py │ ├── imdb.py │ ├── infection_dataset.py │ ├── instruct_mol_dataset.py │ ├── jodie.py │ ├── karate.py │ ├── last_fm.py │ ├── lastfm_asia.py │ ├── linkx_dataset.py │ ├── lrgb.py │ ├── malnet_tiny.py │ ├── md17.py │ ├── medshapenet.py │ ├── mixhop_synthetic_dataset.py │ ├── mnist_superpixels.py │ ├── modelnet.py │ ├── molecule_gpt_dataset.py │ ├── molecule_net.py │ ├── motif_generator/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── custom.py │ │ ├── cycle.py │ │ ├── grid.py │ │ └── house.py │ ├── movie_lens.py │ ├── movie_lens_100k.py │ ├── movie_lens_1m.py │ ├── myket.py │ ├── nell.py │ ├── neurograph.py │ ├── ogb_mag.py │ ├── omdb.py │ ├── opf.py │ ├── ose_gvcs.py │ ├── particle.py │ ├── pascal.py │ ├── pascal_pf.py │ ├── pcpnet_dataset.py │ ├── pcqm4m.py │ ├── planetoid.py │ ├── polblogs.py │ ├── ppi.py │ ├── protein_mpnn_dataset.py │ ├── qm7.py │ ├── qm9.py │ ├── rcdd.py │ ├── reddit.py │ ├── reddit2.py │ ├── rel_link_pred_dataset.py │ ├── s3dis.py │ ├── sbm_dataset.py │ ├── shapenet.py │ ├── shrec2016.py │ ├── snap_dataset.py │ ├── suite_sparse.py │ ├── tag_dataset.py │ ├── taobao.py │ ├── teeth3ds.py │ ├── tosca.py │ ├── tu_dataset.py │ ├── twitch.py │ ├── upfd.py │ ├── utils/ │ │ ├── __init__.py │ │ └── cheatsheet.py │ ├── web_qsp_dataset.py │ ├── webkb.py │ ├── wikics.py │ ├── wikidata.py │ ├── wikipedia_network.py │ ├── willow_object_class.py │ ├── word_net.py │ ├── yelp.py │ └── zinc.py ├── debug.py ├── deprecation.py ├── device.py ├── distributed/ │ ├── __init__.py │ ├── dist_context.py │ ├── dist_link_neighbor_loader.py │ ├── dist_loader.py │ ├── dist_neighbor_loader.py │ ├── dist_neighbor_sampler.py │ ├── event_loop.py │ ├── local_feature_store.py │ ├── local_graph_store.py │ ├── partition.py │ ├── rpc.py │ └── utils.py ├── edge_index.py ├── experimental.py ├── explain/ │ ├── __init__.py │ ├── algorithm/ │ │ ├── __init__.py │ │ ├── attention_explainer.py │ │ ├── base.py │ │ ├── captum.py │ │ ├── captum_explainer.py │ │ ├── dummy_explainer.py │ │ ├── gnn_explainer.py │ │ ├── graphmask_explainer.py │ │ ├── pg_explainer.py │ │ └── utils.py │ ├── config.py │ ├── explainer.py │ ├── explanation.py │ └── metric/ │ ├── __init__.py │ ├── basic.py │ ├── faithfulness.py │ └── fidelity.py ├── graphgym/ │ ├── __init__.py │ ├── benchmark.py │ ├── checkpoint.py │ ├── cmd_args.py │ ├── config.py │ ├── contrib/ │ │ ├── __init__.py │ │ ├── act/ │ │ │ └── __init__.py │ │ ├── config/ │ │ │ └── __init__.py │ │ ├── encoder/ │ │ │ └── __init__.py │ │ ├── head/ │ │ │ └── __init__.py │ │ ├── layer/ │ │ │ ├── __init__.py │ │ │ └── generalconv.py │ │ ├── loader/ │ │ │ └── __init__.py │ │ ├── loss/ │ │ │ └── __init__.py │ │ ├── network/ │ │ │ └── __init__.py │ │ ├── optimizer/ │ │ │ └── __init__.py │ │ ├── pooling/ │ │ │ └── __init__.py │ │ ├── stage/ │ │ │ └── __init__.py │ │ ├── train/ │ │ │ └── __init__.py │ │ └── transform/ │ │ └── __init__.py │ ├── imports.py │ ├── init.py │ ├── loader.py │ ├── logger.py │ ├── loss.py │ ├── model_builder.py │ ├── models/ │ │ ├── __init__.py │ │ ├── act.py │ │ ├── encoder.py │ │ ├── gnn.py │ │ ├── head.py │ │ ├── layer.py │ │ ├── pooling.py │ │ └── transform.py │ ├── optim.py │ ├── register.py │ ├── train.py │ └── utils/ │ ├── LICENSE │ ├── __init__.py │ ├── agg_runs.py │ ├── comp_budget.py │ ├── device.py │ ├── epoch.py │ ├── io.py │ ├── plot.py │ └── tools.py ├── hash_tensor.py ├── home.py ├── index.py ├── inspector.py ├── io/ │ ├── __init__.py │ ├── fs.py │ ├── npz.py │ ├── obj.py │ ├── off.py │ ├── planetoid.py │ ├── ply.py │ ├── sdf.py │ ├── tu.py │ └── txt_array.py ├── isinstance.py ├── lazy_loader.py ├── llm/ │ ├── __init__.py │ ├── large_graph_indexer.py │ ├── models/ │ │ ├── __init__.py │ │ ├── g_retriever.py │ │ ├── git_mol.py │ │ ├── glem.py │ │ ├── llm.py │ │ ├── llm_judge.py │ │ ├── molecule_gpt.py │ │ ├── protein_mpnn.py │ │ ├── sentence_transformer.py │ │ ├── txt2kg.py │ │ └── vision_transformer.py │ ├── rag_loader.py │ └── utils/ │ ├── __init__.py │ ├── backend_utils.py │ ├── feature_store.py │ ├── graph_store.py │ └── vectorrag.py ├── loader/ │ ├── __init__.py │ ├── base.py │ ├── cache.py │ ├── cluster.py │ ├── data_list_loader.py │ ├── dataloader.py │ ├── dense_data_loader.py │ ├── dynamic_batch_sampler.py │ ├── graph_saint.py │ ├── hgt_loader.py │ ├── ibmb_loader.py │ ├── imbalanced_sampler.py │ ├── link_loader.py │ ├── link_neighbor_loader.py │ ├── mixin.py │ ├── neighbor_loader.py │ ├── neighbor_sampler.py │ ├── node_loader.py │ ├── prefetch.py │ ├── random_node_loader.py │ ├── shadow.py │ ├── temporal_dataloader.py │ ├── utils.py │ └── zip_loader.py ├── logging.py ├── metrics/ │ ├── __init__.py │ └── link_pred.py ├── nn/ │ ├── __init__.py │ ├── aggr/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── base.py │ │ ├── basic.py │ │ ├── deep_sets.py │ │ ├── equilibrium.py │ │ ├── fused.py │ │ ├── gmt.py │ │ ├── gru.py │ │ ├── lcm.py │ │ ├── lstm.py │ │ ├── mlp.py │ │ ├── multi.py │ │ ├── patch_transformer.py │ │ ├── quantile.py │ │ ├── scaler.py │ │ ├── set2set.py │ │ ├── set_transformer.py │ │ ├── sort.py │ │ ├── utils.py │ │ └── variance_preserving.py │ ├── attention/ │ │ ├── __init__.py │ │ ├── performer.py │ │ ├── polynormer.py │ │ ├── qformer.py │ │ └── sgformer.py │ ├── conv/ │ │ ├── __init__.py │ │ ├── agnn_conv.py │ │ ├── antisymmetric_conv.py │ │ ├── appnp.py │ │ ├── arma_conv.py │ │ ├── cg_conv.py │ │ ├── cheb_conv.py │ │ ├── cluster_gcn_conv.py │ │ ├── collect.jinja │ │ ├── cugraph/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── gat_conv.py │ │ │ ├── rgcn_conv.py │ │ │ └── sage_conv.py │ │ ├── dir_gnn_conv.py │ │ ├── dna_conv.py │ │ ├── edge_conv.py │ │ ├── edge_updater.jinja │ │ ├── eg_conv.py │ │ ├── fa_conv.py │ │ ├── feast_conv.py │ │ ├── film_conv.py │ │ ├── fused_gat_conv.py │ │ ├── gat_conv.py │ │ ├── gated_graph_conv.py │ │ ├── gatv2_conv.py │ │ ├── gcn2_conv.py │ │ ├── gcn_conv.py │ │ ├── gen_conv.py │ │ ├── general_conv.py │ │ ├── gin_conv.py │ │ ├── gmm_conv.py │ │ ├── gps_conv.py │ │ ├── graph_conv.py │ │ ├── gravnet_conv.py │ │ ├── han_conv.py │ │ ├── heat_conv.py │ │ ├── hetero_conv.py │ │ ├── hgt_conv.py │ │ ├── hypergraph_conv.py │ │ ├── le_conv.py │ │ ├── lg_conv.py │ │ ├── meshcnn_conv.py │ │ ├── message_passing.py │ │ ├── mf_conv.py │ │ ├── mixhop_conv.py │ │ ├── nn_conv.py │ │ ├── pan_conv.py │ │ ├── pdn_conv.py │ │ ├── pna_conv.py │ │ ├── point_conv.py │ │ ├── point_gnn_conv.py │ │ ├── point_transformer_conv.py │ │ ├── ppf_conv.py │ │ ├── propagate.jinja │ │ ├── res_gated_graph_conv.py │ │ ├── rgat_conv.py │ │ ├── rgcn_conv.py │ │ ├── sage_conv.py │ │ ├── sg_conv.py │ │ ├── signed_conv.py │ │ ├── simple_conv.py │ │ ├── spline_conv.py │ │ ├── ssg_conv.py │ │ ├── supergat_conv.py │ │ ├── tag_conv.py │ │ ├── transformer_conv.py │ │ ├── utils/ │ │ │ ├── __init__.py │ │ │ └── cheatsheet.py │ │ ├── wl_conv.py │ │ ├── wl_conv_continuous.py │ │ └── x_conv.py │ ├── data_parallel.py │ ├── dense/ │ │ ├── __init__.py │ │ ├── dense_gat_conv.py │ │ ├── dense_gcn_conv.py │ │ ├── dense_gin_conv.py │ │ ├── dense_graph_conv.py │ │ ├── dense_sage_conv.py │ │ ├── diff_pool.py │ │ ├── dmon_pool.py │ │ ├── linear.py │ │ └── mincut_pool.py │ ├── encoding.py │ ├── functional/ │ │ ├── __init__.py │ │ ├── bro.py │ │ └── gini.py │ ├── fx.py │ ├── glob.py │ ├── inits.py │ ├── kge/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── complex.py │ │ ├── distmult.py │ │ ├── loader.py │ │ ├── rotate.py │ │ └── transe.py │ ├── lr_scheduler.py │ ├── model_hub.py │ ├── models/ │ │ ├── __init__.py │ │ ├── attentive_fp.py │ │ ├── attract_repel.py │ │ ├── autoencoder.py │ │ ├── basic_gnn.py │ │ ├── captum.py │ │ ├── correct_and_smooth.py │ │ ├── deep_graph_infomax.py │ │ ├── deepgcn.py │ │ ├── dimenet.py │ │ ├── dimenet_utils.py │ │ ├── gnnff.py │ │ ├── gpse.py │ │ ├── graph_mixer.py │ │ ├── graph_unet.py │ │ ├── jumping_knowledge.py │ │ ├── label_prop.py │ │ ├── lightgcn.py │ │ ├── linkx.py │ │ ├── lpformer.py │ │ ├── mask_label.py │ │ ├── meta.py │ │ ├── metapath2vec.py │ │ ├── mlp.py │ │ ├── neural_fingerprint.py │ │ ├── node2vec.py │ │ ├── pmlp.py │ │ ├── polynormer.py │ │ ├── re_net.py │ │ ├── rect.py │ │ ├── rev_gnn.py │ │ ├── schnet.py │ │ ├── sgformer.py │ │ ├── signed_gcn.py │ │ ├── tgn.py │ │ └── visnet.py │ ├── module_dict.py │ ├── norm/ │ │ ├── __init__.py │ │ ├── batch_norm.py │ │ ├── diff_group_norm.py │ │ ├── graph_norm.py │ │ ├── graph_size_norm.py │ │ ├── instance_norm.py │ │ ├── layer_norm.py │ │ ├── mean_subtraction_norm.py │ │ ├── msg_norm.py │ │ └── pair_norm.py │ ├── parameter_dict.py │ ├── pool/ │ │ ├── __init__.py │ │ ├── approx_knn.py │ │ ├── asap.py │ │ ├── avg_pool.py │ │ ├── cluster_pool.py │ │ ├── connect/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── filter_edges.py │ │ ├── consecutive.py │ │ ├── decimation.py │ │ ├── edge_pool.py │ │ ├── glob.py │ │ ├── graclus.py │ │ ├── knn.py │ │ ├── max_pool.py │ │ ├── mem_pool.py │ │ ├── pan_pool.py │ │ ├── pool.py │ │ ├── sag_pool.py │ │ ├── select/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── topk.py │ │ ├── topk_pool.py │ │ └── voxel_grid.py │ ├── reshape.py │ ├── resolver.py │ ├── sequential.jinja │ ├── sequential.py │ ├── summary.py │ ├── to_fixed_size_transformer.py │ ├── to_hetero_module.py │ ├── to_hetero_transformer.py │ ├── to_hetero_with_bases_transformer.py │ └── unpool/ │ ├── __init__.py │ └── knn_interpolate.py ├── profile/ │ ├── __init__.py │ ├── benchmark.py │ ├── nvtx.py │ ├── profile.py │ ├── profiler.py │ └── utils.py ├── resolver.py ├── sampler/ │ ├── __init__.py │ ├── base.py │ ├── hgt_sampler.py │ ├── neighbor_sampler.py │ └── utils.py ├── seed.py ├── template.py ├── testing/ │ ├── __init__.py │ ├── asserts.py │ ├── data.py │ ├── decorators.py │ ├── distributed.py │ ├── feature_store.py │ └── graph_store.py ├── transforms/ │ ├── __init__.py │ ├── add_gpse.py │ ├── add_metapaths.py │ ├── add_positional_encoding.py │ ├── add_remaining_self_loops.py │ ├── add_self_loops.py │ ├── base_transform.py │ ├── cartesian.py │ ├── center.py │ ├── compose.py │ ├── constant.py │ ├── delaunay.py │ ├── distance.py │ ├── face_to_edge.py │ ├── feature_propagation.py │ ├── fixed_points.py │ ├── gcn_norm.py │ ├── gdc.py │ ├── generate_mesh_normals.py │ ├── grid_sampling.py │ ├── half_hop.py │ ├── knn_graph.py │ ├── laplacian_lambda_max.py │ ├── largest_connected_components.py │ ├── line_graph.py │ ├── linear_transformation.py │ ├── local_cartesian.py │ ├── local_degree_profile.py │ ├── mask.py │ ├── node_property_split.py │ ├── normalize_features.py │ ├── normalize_rotation.py │ ├── normalize_scale.py │ ├── one_hot_degree.py │ ├── pad.py │ ├── point_pair_features.py │ ├── polar.py │ ├── radius_graph.py │ ├── random_flip.py │ ├── random_jitter.py │ ├── random_link_split.py │ ├── random_node_split.py │ ├── random_rotate.py │ ├── random_scale.py │ ├── random_shear.py │ ├── remove_duplicated_edges.py │ ├── remove_isolated_nodes.py │ ├── remove_self_loops.py │ ├── remove_training_classes.py │ ├── rooted_subgraph.py │ ├── sample_points.py │ ├── sign.py │ ├── spherical.py │ ├── svd_feature_reduction.py │ ├── target_indegree.py │ ├── to_dense.py │ ├── to_device.py │ ├── to_sparse_tensor.py │ ├── to_superpixels.py │ ├── to_undirected.py │ ├── two_hop.py │ └── virtual_node.py ├── typing.py ├── utils/ │ ├── __init__.py │ ├── _assortativity.py │ ├── _coalesce.py │ ├── _degree.py │ ├── _grid.py │ ├── _homophily.py │ ├── _index_sort.py │ ├── _lexsort.py │ ├── _negative_sampling.py │ ├── _normalize_edge_index.py │ ├── _normalized_cut.py │ ├── _one_hot.py │ ├── _scatter.py │ ├── _segment.py │ ├── _select.py │ ├── _softmax.py │ ├── _sort_edge_index.py │ ├── _spmm.py │ ├── _subgraph.py │ ├── _to_dense_adj.py │ ├── _to_dense_batch.py │ ├── _train_test_split_edges.py │ ├── _tree_decomposition.py │ ├── _trim_to_layer.py │ ├── _unbatch.py │ ├── augmentation.py │ ├── convert.py │ ├── cross_entropy.py │ ├── dropout.py │ ├── embedding.py │ ├── functions.py │ ├── geodesic.py │ ├── hetero.py │ ├── influence.py │ ├── isolated.py │ ├── laplacian.py │ ├── loop.py │ ├── map.py │ ├── mask.py │ ├── mesh_laplacian.py │ ├── mixin.py │ ├── nested.py │ ├── noise_scheduler.py │ ├── num_nodes.py │ ├── ppr.py │ ├── random.py │ ├── repeat.py │ ├── smiles.py │ ├── sparse.py │ └── undirected.py ├── visualization/ │ ├── __init__.py │ ├── graph.py │ └── influence.py └── warnings.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODEOWNERS ================================================ # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners * @rusty1s @akihironitta *.py @rusty1s @wsad1 @akihironitta /.github/ @rusty1s @akihironitta /.github/CODEOWNERS @rusty1s /torch_geometric/data/ @rusty1s @mananshah99 @akihironitta /torch_geometric/loader/ @rusty1s @mananshah99 @akihironitta /torch_geometric/sampler/ @rusty1s @mananshah99 @akihironitta /docs/ @rusty1s @akihironitta /torch_geometric/nn/conv/cugraph @tingyu66 /examples/llm @puririshi98 /torch_geometric/llm @puririshi98 ================================================ FILE: .github/CONTRIBUTING.md ================================================ # Contributing to PyG If you are interested in contributing to PyG, your contributions will likely fall into one of the following two categories: 1. You want to implement a new feature: - In general, we accept any features as long as they fit the scope of this package. If you are unsure about this or need help on the design/implementation of your feature, post about it in an issue. 1. You want to fix a bug: - Feel free to send a Pull Request (PR) any time you encounter a bug. Please provide a clear and concise description of what the bug was. If you are unsure about if this is a bug at all or how to fix, post about it in an issue. Once you finish implementing a feature or bug-fix, please send a PR to https://github.com/pyg-team/pytorch_geometric. Your PR will be merged after one or more rounds of reviews by the [pyg-team](https://github.com/pyg-team). If your PR isn't merged anytime soon (*e.g.,* due to its large size, complexity or unavailability of reviewers), try moving your contribution to the [`torch_geometric.contrib`](https://pytorch-geometric.readthedocs.io/en/latest/modules/contrib.html) package. [`torch_geometric.contrib`](https://pytorch-geometric.readthedocs.io/en/latest/modules/contrib.html) has less rigourous review requirements and might lead to your PR getting merged faster. ## Developing PyG To develop PyG on your machine, here are some tips: 1. Ensure that you are running on one of the two latest PyTorch releases (*e.g.*, `2.8.0`): ```python import torch print(torch.__version__) ``` 1. *(Optional)* Follow the [installation instructions](https://github.com/pyg-team/pytorch_geometric#installation) to install `pyg-lib`, `torch-scatter`, `torch-sparse`, and `torch-cluster` (if you haven't already). Note that this step is optional and only necessary if you develop a feature that uses one of these libraries. ```bash pip install pyg-lib torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html ``` where `${TORCH}` should be replaced by your PyTorch version (*e.g.*, `2.8.0`), and `${CUDA}` should be replaced by your CUDA version (*e.g.*, `cpu`, `cu126`, `cu128`, or `cu129`). 1. Uninstall all existing PyG installations. It is advised to run this command repeatedly to confirm that installations across all locations are properly removed. ```bash pip uninstall torch-geometric pip uninstall torch-geometric # run this command twice ``` 1. Fork and clone the PyG repository: ```bash git clone https://github.com//pytorch_geometric cd pytorch_geometric ``` 1. If you already cloned PyG from source, update it: ```bash git pull ``` 1. Install PyG in editable mode: ```bash pip install -e ".[dev,full]" ``` This mode will symlink the Python files from the current local source tree into the Python install. Hence, if you modify a Python file, you do not need to re-install PyG again. 1. Ensure that you have a working PyG installation by running the entire test suite with ```bash pytest ``` In case an error occurs, please first check if all sub-packages ([`pyg-lib`](https://github.com/pyg-team/pyg-lib), [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter), [`torch-sparse`](https://github.com/rusty1s/pytorch_sparse) and [`torch-cluster`](https://github.com/rusty1s/pytorch_cluster)) are on its latest reported version. 1. Install pre-commit hooks: ```bash pre-commit install ``` ## Unit Testing The PyG testing suite is located under `test/`. Run the entire test suite with ```bash pytest ``` or test individual files via, *e.g.*, `pytest test/utils/test_convert.py`. ## Continuous Integration PyG uses [GitHub Actions](https://github.com/pyg-team/pytorch_geometric/actions) in combination with [CodeCov](https://codecov.io/github/pyg-team/pytorch_geometric?branch=master) for continuous integration. Everytime you send a Pull Request, your commit will be built and checked against the PyG guidelines: 1. Ensure that your code is formatted correctly by testing against the styleguide of [`flake8`](https://github.com/PyCQA/flake8). We use the [`Flake8-pyproject`](https://pypi.org/project/Flake8-pyproject/) plugin for configuration: ```bash flake8 . ``` If you do not want to format your code manually, we recommend to use [`yapf`](https://github.com/google/yapf). 1. Ensure that the entire test suite passes and that code coverage roughly stays the same. Please feel encouraged to provide a test with your submitted code. To test, either run ```bash pytest --cov ``` or ```bash FULL_TEST=1 pytest --cov ``` (which runs a set of additional but time-consuming tests) dependening on your needs. 1. Add your feature/bugfix to the [`CHANGELOG.md`](https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md?plain=1). If multiple PRs move towards integrating a single feature, it is advised to group them together into one bullet point. ## Building Documentation To build the documentation: 1. [Build and install](#developing-pyg) PyG from source. 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via ```bash pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git ``` 1. Generate the documentation via: ```bash cd docs make html ``` The documentation is now available to view by opening `docs/build/html/index.html`. ================================================ FILE: .github/ISSUE_TEMPLATE/bug-report.yml ================================================ name: "🐛 Bug Report" description: "Submit a report to help us reproduce and fix the bug" labels: bug body: - type: markdown attributes: value: > #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pyg-team/pytorch_geometric/issues). # - type: textarea attributes: label: 🐛 Describe the bug description: | Please provide a clear and concise description of the bug. If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as minimal as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports: ```python # All necessary imports at the beginning import torch from torch_geometric.utils import to_undirected # A minimal reproducing example trimmed down to the essential parts: edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) edge_index = to_undirected(edge_index, num_nodes=1) assert edge_index.size(1) == 6 # We expect that the number of edges is doubled. # NOTE: the bug is that num_nodes < edge_index.max() + 1 ``` Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. placeholder: | A clear and concise description of the bug. ```python # Sample code to reproduce the problem ``` ``` The error message you got, with the full traceback. ``` validations: required: true - type: textarea attributes: label: Versions description: | Please run the following and paste the output below. ```sh curl -OL https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py # For security purposes, please check the contents of collect_env.py before running it. python3 collect_env.py ``` validations: required: true ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: 🙏 Ask a Question url: https://github.com/pyg-team/pytorch_geometric/discussions/new about: Ask and answer PyG related questions - name: 💬 Slack url: https://join.slack.com/t/torchgeometricco/shared_invite/zt-p6br3yuo-BxRoe36OHHLF6jYU8xHtBA about: Chat with our community ================================================ FILE: .github/ISSUE_TEMPLATE/documentation.yml ================================================ name: "📚 Typos and Doc Fixes" description: "Tell us about how we can improve our documentation" labels: documentation body: - type: textarea attributes: label: 📚 Describe the documentation issue description: | A clear and concise description of the issue. validations: required: true - type: textarea attributes: label: Suggest a potential alternative/fix description: | Tell us how we could improve the documentation in this regard. ================================================ FILE: .github/ISSUE_TEMPLATE/feature-request.yml ================================================ name: "🚀 Feature Request" description: "Propose a new PyG feature" labels: feature body: - type: textarea attributes: label: 🚀 The feature, motivation and pitch description: > A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. validations: required: true - type: textarea attributes: label: Alternatives description: > A description of any alternative solutions or features you've considered, if any. - type: textarea attributes: label: Additional context description: > Add any other context or screenshots about the feature request. ================================================ FILE: .github/ISSUE_TEMPLATE/installation.yml ================================================ name: "😵 Installation" description: "Report an installation problem" labels: installation body: - type: markdown attributes: value: > #### Before submitting an installation problem, please make sure the issue hasn't been already reported by searching through [the existing and past issues](https://github.com/pyg-team/pytorch_geometric/issues). # - type: textarea attributes: label: 😵 Describe the installation problem description: | Please provide a clear and concise description of the installation problem. If you have installation log files, please prvoide them here as well. It may be relevant to wrap the log files in ```` ```triple quotes blocks``` ````. placeholder: | A clear and concise description of the installation problem. validations: required: true - type: textarea attributes: label: Environment description: | Please provide as much information as possible about your environment, such as your PyG (`print(torch_geometric.__version__)`) and PyTorch version (`print(torch.__version__)`), your OS (*e.g.*, Linux), and your Python version (*e.g.*, `3.13`): value: | * PyG version: * PyTorch version: * OS: * Python version: * CUDA/cuDNN version: * How you installed PyTorch and PyG (`conda`, `pip`, source): * Any other relevant information (*e.g.*, version of `torch-scatter`): ================================================ FILE: .github/ISSUE_TEMPLATE/refactor.yml ================================================ name: "🛠 Refactor" description: "Suggest a code refactor or deprecation" labels: refactor body: - type: textarea attributes: label: 🛠 Proposed Refactor description: | A clear and concise description of the refactor proposal. Please outline the motivation for the proposal. If this is related to another GitHub issue, please link here too. validations: required: true - type: textarea attributes: label: Suggest a potential alternative/fix description: | Tell us how we could improve the code in this regard. validations: required: true ================================================ FILE: .github/actions/setup/action.yml ================================================ name: Setup inputs: python-version: required: false default: '3.10' torch-version: required: false default: '2.10.0' cuda-version: required: false default: cpu full_install: required: false default: true runs: using: composite steps: - name: Install uv uses: astral-sh/setup-uv@v7 with: python-version: ${{ inputs.python-version }} activate-environment: true - name: Set up Python ${{ inputs.python-version }} run: | uv pip install --upgrade pip setuptools shell: bash - name: Install numpy run: | uv pip install "numpy<2" shell: bash - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }} if: ${{ inputs.torch-version != 'nightly' }} run: | uv pip install torch==${{ inputs.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }} shell: bash - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }} if: ${{ inputs.torch-version == 'nightly' }} run: | uv pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/${{ inputs.cuda-version }} shell: bash - name: Check installation run: | uv run --no-project python -c "import torch; print('PyTorch:', torch.__version__)" uv run --no-project python -c "import torch; print('CUDA available:', torch.cuda.is_available())" uv run --no-project python -c "import torch; print('CUDA:', torch.version.cuda)" shell: bash - name: Install pyg-lib if: ${{ inputs.torch-version != 'nightly' }} run: | uv pip install --no-index --upgrade pyg-lib -f https://data.pyg.org/whl/nightly/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html shell: bash - name: Install faiss-cpu if: ${{ inputs.full_install == 'true' && inputs.cuda-version == 'cpu' && runner.os == 'Linux' }} run: | uv pip install faiss-cpu==1.7.2 shell: bash - name: Install faiss-gpu if: ${{ inputs.full_install == 'true' && inputs.cuda-version != 'cpu' && runner.os == 'Linux' }} run: | uv pip install faiss-gpu==1.7.2 shell: bash - name: Install torchvision if: ${{ inputs.full_install == 'true' && inputs.torch-version != 'nightly' }} run: | if [ ${{ inputs.torch-version }} == '2.10.0' ]; then TORCHVISION_VERSION=0.25.0 elif [ ${{ inputs.torch-version }} == '2.9.0' ]; then TORCHVISION_VERSION=0.24.0 elif [ ${{ inputs.torch-version }} == '2.8.0' ]; then TORCHVISION_VERSION=0.23.0 elif [ ${{ inputs.torch-version }} == '2.7.0' ]; then TORCHVISION_VERSION=0.22.0 elif [ ${{ inputs.torch-version }} == '2.6.0' ]; then TORCHVISION_VERSION=0.21.0 fi uv pip install torchvision==${TORCHVISION_VERSION} --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }} shell: bash - name: Install extension packages if: ${{ inputs.full_install == 'true' && inputs.torch-version != 'nightly' }} run: | uv pip install scipy uv pip install --no-index --upgrade torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html shell: bash ================================================ FILE: .github/dependabot.yml ================================================ # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/dependabot-options-reference version: 2 updates: - package-ecosystem: "github-actions" directories: - "/" - "/.github/actions/setup" schedule: interval: "daily" time: "00:00" labels: - "ci" - "skip-changelog" pull-request-branch-name: separator: "-" open-pull-requests-limit: 10 ================================================ FILE: .github/labeler.yml ================================================ installation: - changed-files: - any-glob-to-any-file: ["pyproject.toml"] ci: - changed-files: - any-glob-to-any-file: [".github/**/*", "codecov.yaml", ".pre-commit-config.yaml"] documentation: - changed-files: - any-glob-to-any-file: ["docs/**/*", "readthedocs.yml", "README.MD"] example: - changed-files: - any-glob-to-any-file: "examples/**/*" data: - changed-files: - any-glob-to-any-file: "torch_geometric/data/**/*" dataset: - changed-files: - any-glob-to-any-file: ["torch_geometric/io/**/*", "torch_geometric/datasets/**/*"] sampler: - changed-files: - any-glob-to-any-file: "torch_geometric/sampler/**/*" loader: - changed-files: - any-glob-to-any-file: "torch_geometric/loader/**/*" nn: - changed-files: - any-glob-to-any-file: "torch_geometric/nn/**/*" explain: - changed-files: - any-glob-to-any-file: "torch_geometric/explain/**/*" transform: - changed-files: - any-glob-to-any-file: "torch_geometric/transforms/**/*" metrics: - changed-files: - any-glob-to-any-file: "torch_geometric/metrics/**/*" utils: - changed-files: - any-glob-to-any-file: "torch_geometric/utils/**/*" distributed: - changed-files: - any-glob-to-any-file: "torch_geometric/distributed/**/*" contrib: - changed-files: - any-glob-to-any-file: "torch_geometric/contrib/**/*" graphgym: - changed-files: - any-glob-to-any-file: ["graphgym/**/*", "torch_geometric/graphgym/**/*"] benchmark: - changed-files: - any-glob-to-any-file: ["benchmark/**/*", "torch_geometric/profile/**/*"] ================================================ FILE: .github/workflows/_testing.yml ================================================ name: Reusable Testing on: # yamllint disable-line rule:truthy workflow_call: inputs: test-matrix: type: string required: true defaults: run: shell: bash jobs: test: runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: ${{ fromJSON(inputs.test-matrix) }} steps: - name: Checkout repository uses: actions/checkout@v6 - name: Set up packages uses: ./.github/actions/setup with: python-version: ${{ matrix.python-version }} torch-version: ${{ matrix.torch-version }} cuda-version: ${{ matrix.cuda-version }} - name: Install graphviz if: ${{ runner.os == 'Linux' && matrix.full_test == '1' }} run: | sudo apt-get install graphviz - name: Install main package (Windows) if: ${{ runner.os == 'Windows' }} run: | uv pip install -e ".[test]" - name: Install main package if: ${{ runner.os != 'Windows' }} run: | uv pip install -e ".[full,test]" - name: Check installation run: | python -c "import torch; print('PyTorch:', torch.__version__)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" python -c "import torch; print('CUDA:', torch.version.cuda)" shell: bash - name: Run tests timeout-minutes: 20 run: | pytest -m "not rag" --cov --cov-report=xml --durations 10 env: FULL_TEST: ${{ matrix.full_test }} - name: Upload coverage if: ${{ runner.os == 'Linux' }} uses: codecov/codecov-action@v5 with: fail_ci_if_error: false ================================================ FILE: .github/workflows/auto-merge.yml ================================================ name: Auto-merge Bot PRs on: # yamllint disable-line rule:truthy pull_request_target: types: [opened, reopened] permissions: contents: write pull-requests: write jobs: auto-merge: runs-on: ubuntu-latest if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]' }} steps: - uses: actions/checkout@v6 - name: Label bot PRs run: gh pr edit --add-label "ci,skip-changelog" ${{ github.event.pull_request.html_url }} env: GITHUB_TOKEN: ${{ secrets.PAT }} - name: Auto-approve uses: hmarr/auto-approve-action@v4 with: github-token: ${{ secrets.PAT }} - name: Enable auto-merge run: gh pr merge --auto --squash ${{ github.event.pull_request.html_url }} env: GITHUB_TOKEN: ${{ secrets.PAT }} ================================================ FILE: .github/workflows/building_nightly.yml ================================================ name: Nightly Build on: # yamllint disable-line rule:truthy workflow_dispatch: schedule: - cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST jobs: build: if: github.repository == 'pyg-team/pytorch_geometric' runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: python-version: '3.10' - name: Set version run: echo "VERSION=$(sed -n "s/^__version__ = '\(.*\)'/\1/p" torch_geometric/__init__.py)" >> ${GITHUB_ENV} - name: Set time run: echo "TODAY=$(date +'%Y%m%d')" >> ${GITHUB_ENV} - name: Customize build version run: | sed -i "s/$VERSION/$VERSION.dev$TODAY/" torch_geometric/__init__.py sed -i '0,/name="torch-geometric"/s//name="pyg-nightly"/' pyproject.toml # Only change first occurence sed -i "s/version=\"$VERSION\"/version=\"$VERSION.dev$TODAY\"/" pyproject.toml - name: Build package run: | pip install build python -m build - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} ================================================ FILE: .github/workflows/changelog.yml ================================================ name: Changelog Enforcer on: # yamllint disable-line rule:truthy pull_request: types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled] jobs: changelog: runs-on: ubuntu-latest steps: - name: Enforce changelog entry uses: dangoslen/changelog-enforcer@v3 with: skipLabels: skip-changelog ================================================ FILE: .github/workflows/documentation.yml ================================================ name: Documentation on: # yamllint disable-line rule:truthy push: branches: - master paths: - 'torch_geometric/**' - 'docs/**' - 'pyproject.toml' - '.github/actions/setup/action.yml' - '.github/workflows/documentation.yml' pull_request: paths: - 'torch_geometric/**' - 'docs/**' - 'pyproject.toml' - '.github/actions/setup/action.yml' - '.github/workflows/documentation.yml' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} # yamllint disable-line # Only cancel intermediate builds if on a PR: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: make_html: runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 - name: Setup packages uses: ./.github/actions/setup with: full_install: false - name: Install main package run: | uv pip install -e . nbsphinx git+https://github.com/pyg-team/pyg_sphinx_theme.git - name: Build documentation working-directory: ./docs run: | sphinx-build -M html "source" "build" -W # Fail on warning. ================================================ FILE: .github/workflows/examples.yml ================================================ name: Examples on: # yamllint disable-line rule:truthy workflow_dispatch: schedule: - cron: "0 7 * * *" # Everyday at 7:00am UTC/11:00pm PST jobs: pytest: if: github.repository == 'pyg-team/pytorch_geometric' runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 - name: Setup packages uses: ./.github/actions/setup - name: Install main package run: | uv pip install ".[benchmark]" - name: Run GCN on Cora run: | python examples/gcn.py --wandb env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - name: Run GAT on Cora run: | python examples/gat.py --wandb env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - name: Run GIN on MUTAG run: | python examples/mutag_gin.py --wandb env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - name: Run GNNExplainer run: | python examples/explain/gnn_explainer.py ================================================ FILE: .github/workflows/labeler.yml ================================================ name: PR Labeler on: # yamllint disable-line rule:truthy pull_request: jobs: assign-labels: if: github.repository == 'pyg-team/pytorch_geometric' runs-on: ubuntu-latest permissions: contents: read pull-requests: write steps: - name: Add PR labels uses: actions/labeler@v6 continue-on-error: true with: repo-token: "${{ secrets.GITHUB_TOKEN }}" sync-labels: true assign-author: if: github.repository == 'pyg-team/pytorch_geometric' runs-on: ubuntu-latest steps: - name: Add PR author uses: samspills/assign-pr-to-author@v1.0 continue-on-error: true if: github.event_name == 'pull_request' with: repo-token: "${{ secrets.GITHUB_TOKEN }}" ================================================ FILE: .github/workflows/linting.yml ================================================ name: Linting on: # yamllint disable-line rule:truthy push: branches: - master paths: - 'torch_geometric/**' - 'test/**' - 'examples/**' - 'benchmark/**' - 'pyproject.toml' - '.github/actions/setup/action.yml' - '.github/workflows/linting.yml' pull_request: paths: - 'torch_geometric/**' - 'test/**' - 'examples/**' - 'benchmark/**' - 'pyproject.toml' - '.github/actions/setup/action.yml' - '.github/workflows/linting.yml' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} # yamllint disable-line # Only cancel intermediate builds if on a PR: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: mypy: runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 - name: Setup packages uses: ./.github/actions/setup - name: Install main package run: | uv pip install -e ".[full,test]" mypy types-requests - name: Check type hints run: | mypy -v --cache-dir=/dev/null ================================================ FILE: .github/workflows/testing.yml ================================================ name: Testing on: # yamllint disable-line rule:truthy push: branches: - master pull_request: schedule: - cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} # yamllint disable-line # Only cancel intermediate builds if on a PR: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: trigger: runs-on: ubuntu-latest outputs: triggered: ${{ steps.check.outputs.triggered }} steps: - uses: actions/checkout@v6 with: fetch-depth: 2 - id: check run: | if [ "${{ github.event_name }}" = "schedule" ]; then echo "triggered=true" >> "$GITHUB_OUTPUT" exit 0 fi if [ "${{ github.event_name }}" = "push" ]; then BASE=${{ github.event.before }} else BASE=${{ github.event.pull_request.base.sha }} fi PATHS=( 'torch_geometric/' 'test/' 'pyproject.toml' 'codecov.yml' '.github/actions/setup/action.yml' '.github/workflows/testing.yml' '.github/workflows/_testing.yml' ) CHANGED=$(git diff --name-only "$BASE" HEAD) for path in "${PATHS[@]}"; do if echo "$CHANGED" | grep -q "^${path}"; then echo "triggered=true" >> "$GITHUB_OUTPUT" exit 0 fi done echo "triggered=false" >> "$GITHUB_OUTPUT" pytest: needs: trigger if: ${{ needs.trigger.outputs.triggered == 'true' && github.repository == 'pyg-team/pytorch_geometric' && github.event_name != 'schedule' }} uses: ./.github/workflows/_testing.yml with: test-matrix: | { include: [ { "os": "ubuntu-22.04", "python-version": "3.10", "torch-version": "nightly", "cuda-version": "cpu" }, { "os": "ubuntu-22.04", "python-version": "3.10", "torch-version": "2.10.0", "cuda-version": "cpu" }, { "os": "ubuntu-22.04", "python-version": "3.10", "torch-version": "2.9.0", "cuda-version": "cpu" }, { "os": "ubuntu-22.04", "python-version": "3.10", "torch-version": "2.8.0", "cuda-version": "cpu" }, { "os": "macos-14", "python-version": "3.10", "torch-version": "2.10.0", "cuda-version": "cpu" }, { "os": "windows-2022", "python-version": "3.10", "torch-version": "2.10.0", "cuda-version": "cpu" } ]} pytest-full: # Only run this on nightly schedule for now. if: ${{ github.repository == 'pyg-team/pytorch_geometric' && github.event_name == 'schedule' }} uses: ./.github/workflows/_testing.yml with: test-matrix: | { "os": ["ubuntu-22.04", "macos-14", "windows-2022"], "python-version": ["3.10"], "torch-version": ["2.8.0", "2.9.0", "2.10.0", "nightly"], "cuda-version": ["cpu"], "full-test": "1" } # gpu: # if: github.repository == 'pyg-team/pytorch_geometric' # uses: ./.github/workflows/_testing.yml # with: # test-matrix: | # { # "os": [["self-hosted, nvidia"]], # "python-version": ["3.10"], # "torch-version": ["2.8.0"], # "cuda-version": ["cu128"] # } status: if: always() && github.event_name == 'pull_request' needs: - pytest - pytest-full runs-on: ubuntu-latest steps: - run: | if [ "${{ needs.pytest.result }}" = "failure" ] || \ [ "${{ needs.pytest-full.result }}" = "failure" ] || \ [ "${{ needs.pytest.result }}" = "cancelled" ] || \ [ "${{ needs.pytest-full.result }}" = "cancelled" ]; then exit 1 fi ================================================ FILE: .github/workflows/testing_rag.yml ================================================ name: Testing RAG on: # yamllint disable-line rule:truthy push: branches: - master paths: - 'torch_geometric/datasets/web_qsp_dataset.py' - 'torch_geometric/llm/**' - '.github/workflows/testing_rag.yml' pull_request: paths: - 'torch_geometric/datasets/web_qsp_dataset.py' - 'torch_geometric/llm/**' - '.github/workflows/testing_rag.yml' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} # yamllint disable-line # Only cancel intermediate builds if on a PR: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: rag_pytest: runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 - name: Setup packages uses: ./.github/actions/setup with: full_install: false - name: Install main package run: | uv pip install -e ".[test,rag]" - name: Run tests timeout-minutes: 15 run: | # ignore mysterious segfault (139) if tests pass since this does not repro locally bash -c 'pytest -m rag --cov --cov-report=xml -v; E=$?; [[ $E == 0 || $E == 139 ]] && exit 0 || exit $E' shell: bash env: TOKENIZERS_PARALLELISM: "false" OMP_NUM_THREADS: "1" MKL_NUM_THREADS: "1" ================================================ FILE: .gitignore ================================================ __pycache__/ .pytest_cache/ .DS_Store data/ build/ dist/ alpha/ runs/ wandb/ .cache/ .eggs/ lightning_logs/ outputs/ graphgym/datasets/ graphgym/results/ *.egg-info/ .ipynb_checkpoints .coverage .coverage.* coverage.xml .vscode .idea .venv *.out *.pt *.onnx examples/**/*.png examples/**/*.pdf benchmark/results/ .mypy_cache/ uv.lock !torch_geometric/data/ !test/data/ ================================================ FILE: .pre-commit-config.yaml ================================================ ci: # https://pre-commit.ci/#configuration autofix_prs: true autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' autoupdate_schedule: monthly repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - id: no-commit-to-branch name: No commits to master - id: end-of-file-fixer name: End-of-file fixer - name: mixed-line-ending id: mixed-line-ending args: [--fix, lf] - id: trailing-whitespace name: Remove trailing whitespaces - id: check-toml name: Check toml - id: check-yaml name: Check yaml - repo: https://github.com/adrienverge/yamllint.git rev: v1.38.0 hooks: - id: yamllint name: Lint yaml args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}'] - repo: https://github.com/asottile/pyupgrade rev: v3.21.2 hooks: - id: pyupgrade name: Upgrade Python syntax args: [--py38-plus] - repo: https://github.com/PyCQA/autoflake rev: v2.3.3 hooks: - id: autoflake name: Remove unused imports and variables args: [ --remove-all-unused-imports, --remove-unused-variables, --remove-duplicate-keys, --ignore-init-module-imports, --in-place, ] - repo: https://github.com/google/yapf rev: v0.43.0 hooks: - id: yapf name: Format code additional_dependencies: [toml] - repo: https://github.com/pycqa/isort rev: 8.0.1 hooks: - id: isort name: Sort imports - repo: https://github.com/PyCQA/flake8 rev: 7.3.0 hooks: - id: flake8 name: Check PEP8 additional_dependencies: [Flake8-pyproject] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.4 hooks: - id: ruff name: Ruff formatting args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/executablebooks/mdformat rev: 1.0.0 hooks: - id: mdformat name: Format Markdown additional_dependencies: - mdformat-gfm - mdformat-front-matters - mdformat-footnote - repo: https://github.com/sphinx-contrib/sphinx-lint rev: v1.0.2 hooks: - id: sphinx-lint name: Check Sphinx ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ## [Unreleased] - YYYY-MM-DD ### Added ### Changed - Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) ### Deprecated - Deprecated support for `torch-spline-conv` in favor of `pyg-lib>=0.6.0` ([#10622](https://github.com/pyg-team/pytorch_geometric/pull/10622)) ### Removed ### Fixed - Fixed `return_attention_weights: bool` being not respected in `GATConv` and `GATv2Conv` ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) - Fixed download links for politifact and gossipcop datasets of `UPFD` ([#10558](https://github.com/pyg-team/pytorch_geometric/pull/10558)) ### Security ## [2.7.0] - 2025-10-14 ### Added - Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918)) - Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436)) - Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432)) - Added `connected_components()` method to `Data` and `HeterData` ([#10388](https://github.com/pyg-team/pytorch_geometric/pull/10388)) - Added LPFormer Graph Transformer for Link Prediction ([#9956](https://github.com/pyg-team/pytorch_geometric/pull/9956)) - Added `BidirectionalSampler`, which samples both forwards and backwards on graph edges ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126)) - Enable Sampling both forwards and reverse edges on `NeighborSampler` ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126)) - Added ability to merge together `SamplerOutput` objects ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126)) - Added ability to get global row and col ids from `SamplerOutput` ([#10200](https://github.com/pyg-team/pytorch_geometric/pull/10200)) - Added PyTorch 2.8 support ([#10403](https://github.com/pyg-team/pytorch_geometric/pull/10403)) - Added `Polynormer` model and example ([#9908](https://github.com/pyg-team/pytorch_geometric/pull/9908)) - Added `ProteinMPNN` model and example ([#10289](https://github.com/pyg-team/pytorch_geometric/pull/10289)) - Added the `Teeth3DS` dataset, an extended benchmark for intraoral 3D scan analysis ([#9833](https://github.com/pyg-team/pytorch_geometric/pull/9833)) - Added `torch.device` to `PatchTransformerAggregation` [#10342](https://github.com/pyg-team/pytorch_geometric/pull/10342) - Added `torch.device` to normalization layers [#10341](https://github.com/pyg-team/pytorch_geometric/pull/10341) - Added `total_influence` for quantifying long-range dependency ([#10263](https://github.com/pyg-team/pytorch_geometric/pull/10263)) - Added `MedShapeNet` Dataset ([#9823](https://github.com/pyg-team/pytorch_geometric/pull/9823)) - Added RelBench example ([#10230](https://github.com/pyg-team/pytorch_geometric/pull/10230)) - Added `CityNetwork` dataset ([#10115](https://github.com/pyg-team/pytorch_geometric/pull/10115)) - Added `visualize_graph` to HeteroExplanation ([#10207](https://github.com/pyg-team/pytorch_geometric/pull/10207)) - Added PyTorch 2.6 support ([#10170](https://github.com/pyg-team/pytorch_geometric/pull/10170)) - Added support for heterogenous graphs in `AttentionExplainer` ([#10169](https://github.com/pyg-team/pytorch_geometric/pull/10169)) - Added support for heterogenous graphs in `PGExplainer` ([#10168](https://github.com/pyg-team/pytorch_geometric/pull/10168)) - Added support for heterogenous graphs in `GNNExplainer` ([#10158](https://github.com/pyg-team/pytorch_geometric/pull/10158)) - Added Graph Positional and Structural Encoder (GPSE) and example ([#9018](https://github.com/pyg-team/pytorch_geometric/pull/9018)) ([#10118](https://github.com/pyg-team/pytorch_geometric/pull/10118)) - Added attract-repel link prediction example ([#10107](https://github.com/pyg-team/pytorch_geometric/pull/10107)) - Added `ARLinkPredictor` for implementing Attract-Repel embeddings for link prediction ([#10105](https://github.com/pyg-team/pytorch_geometric/pull/10105)) - Improving documentation for [cuGraph](https://github.com/rapidsai/cugraph) ([#10083](https://github.com/pyg-team/pytorch_geometric/pull/10083)) - Added `HashTensor` ([#10072](https://github.com/pyg-team/pytorch_geometric/pull/10072)) - Added `SGFormer` model and example ([#9904](https://github.com/pyg-team/pytorch_geometric/pull/9904)) - Added `AveragePopularity` metric for link prediction ([#10022](https://github.com/pyg-team/pytorch_geometric/pull/10022)) - Added `Personalization` metric for link prediction ([#10015](https://github.com/pyg-team/pytorch_geometric/pull/10015)) - Added `HitRatio` metric for link prediction ([#10013](https://github.com/pyg-team/pytorch_geometric/pull/10013)) - Added Data Splitting Tutorial ([#8366](https://github.com/pyg-team/pytorch_geometric/pull/8366)) - Added `Diversity` metric for link prediction ([#10009](https://github.com/pyg-team/pytorch_geometric/pull/10009)) - Added `Coverage` metric for link prediction ([#10006](https://github.com/pyg-team/pytorch_geometric/pull/10006)) - Added Graph Transformer Tutorial ([#8144](https://github.com/pyg-team/pytorch_geometric/pull/8144)) - Consolidate Cugraph examples into `ogbn_train_cugraph.py` and `ogbn_train_cugraph_multigpu.py` for `ogbn-arxiv`, `ogbn-products` and `ogbn-papers100M` ([#9953](https://github.com/pyg-team/pytorch_geometric/pull/9953)) - Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975)) - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) - Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) - Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941)) - Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730)) - Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797)) - Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) - Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748)) - Added PyTorch 2.5 support ([#9779](https://github.com/pyg-team/pytorch_geometric/pull/9779), [#9779](https://github.com/pyg-team/pytorch_geometric/pull/9780)) - Support 3D tetrahedral mesh elements of shape `[4, num_faces]` in the `FaceToEdge` transformation ([#9776](https://github.com/pyg-team/pytorch_geometric/pull/9776)) - Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722)) - Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) - Add ComplexWebQuestions (CWQ) dataset ([#9950](https://github.com/pyg-team/pytorch_geometric/pull/9950)) ### Changed - Added `edge_attr` in `CuGraphGATConv` ([#10383](https://github.com/pyg-team/pytorch_geometric/pull/10383)) - Adapt `dgcnn_classification` example to work with `ModelNet` and `MedShapeNet` Datasets ([#9823](https://github.com/pyg-team/pytorch_geometric/pull/9823)) - Chained exceptions explicitly instead of implicitly ([#10242](https://github.com/pyg-team/pytorch_geometric/pull/10242)) - Updated cuGraph examples to use buffered sampling which keeps data in memory and is significantly faster than the deprecated buffered sampling ([#10079](https://github.com/pyg-team/pytorch_geometric/pull/10079)) - Updated Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606)) - Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807)) - Automatic num_params in LLM + update `GRetriever` default llm ([#9938](https://github.com/pyg-team/pytorch_geometric/pull/9938)) - Updated calls to NumPy's deprecated `np.in1d` to `np.isin` ([#10283](https://github.com/pyg-team/pytorch_geometric/pull/10283)) ### Deprecated - Deprecated `torch_geometric.distributed` ([#10411](https://github.com/pyg-team/pytorch_geometric/pull/10411)) ### Fixed - Fixed `ogbn_train_cugraph` example for distributed cuGraph ([#10439](https://github.com/pyg-team/pytorch_geometric/pull/10439)) - Added `safe_onnx_export` function with workarounds for `onnx_ir.serde.SerdeError` issues in ONNX export ([#10422](https://github.com/pyg-team/pytorch_geometric/pull/10422)) - Fixed importing PyTorch Lightning in `torch_geometric.graphgym` and `torch_geometric.data.lightning` when using `lightning` instead of `pytorch-lightning` ([#10404](https://github.com/pyg-team/pytorch_geometric/pull/10404), [#10417](https://github.com/pyg-team/pytorch_geometric/pull/10417))) - Fixed `detach()` warnings in example scripts involving tensor conversions ([#10357](https://github.com/pyg-team/pytorch_geometric/pull/10357)) - Fixed non-tuple indexing to resolve PyTorch deprecation warning ([#10389](https://github.com/pyg-team/pytorch_geometric/pull/10389)) - Fixed conversion to/from `cuGraph` graph objects by ensuring `cudf` column names are correctly specified ([#10343](https://github.com/pyg-team/pytorch_geometric/pull/10343)) - Fixed `_recursive_config()` for `torch.nn.ModuleList` and `torch.nn.ModuleDict` ([#10124](https://github.com/pyg-team/pytorch_geometric/pull/10124), [#10129](https://github.com/pyg-team/pytorch_geometric/pull/10129)) - Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756)) - Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766)) - Fixed `WebQSDataset.process` raising exceptions ([#9665](https://github.com/pyg-team/pytorch_geometric/pull/9665)) - Fixed `is_node_attr()` and `is_edge_attr()` errors when `cat_dim` is a tuple ([#9895](https://github.com/pyg-team/pytorch_geometric/issues/9895)) - Avoid GRetriever instantiation when num_gnn_layers == 0 ([#10156](https://github.com/pyg-team/pytorch_geometric/pull/10156)) ### Removed - Removed `proxies` and `resume_download` arguments from `PyGModelHubMixin` ([#10521](https://github.com/pyg-team/pytorch_geometric/pull/10521) - Dropped support for Python 3.9 ([#10461](https://github.com/pyg-team/pytorch_geometric/pull/10461)) - Dropped support for PyTorch 1.13 - 2.5 ([#00000](https://github.com/pyg-team/pytorch_geometric/pull/00000)) - Dropped support for PyTorch 1.12 ([#10248](https://github.com/pyg-team/pytorch_geometric/pull/10248)) - Dropped support for PyTorch 1.11 ([#10247](https://github.com/pyg-team/pytorch_geometric/pull/10247)) ## [2.6.0] - 2024-09-13 ### Added - Added the `WebQSPDataset` dataset ([#9481](https://github.com/pyg-team/pytorch_geometric/pull/9481)) - Added the `GRetriever` model and an example ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480), [#9167](https://github.com/pyg-team/pytorch_geometric/pull/9167)) - Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) - Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632)) - Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594)) - Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554)) - Added the `RemoveSelfLoops` transformation ([#9562](https://github.com/pyg-team/pytorch_geometric/pull/9562)) - Added ONNX export for `scatter` with min/max reductions ([#9587](https://github.com/pyg-team/pytorch_geometric/pull/9587)) - Added a `residual` option in `GATConv` and `GATv2Conv` ([#9515](https://github.com/pyg-team/pytorch_geometric/pull/9515)) - Added the `PatchTransformerAggregation` layer ([#9487](https://github.com/pyg-team/pytorch_geometric/pull/9487)) - Added the `nn.nlp.LLM` model ([#9462](https://github.com/pyg-team/pytorch_geometric/pull/9462)) - Added an example of training GNNs for a graph-level regression task ([#9070](https://github.com/pyg-team/pytorch_geometric/pull/9070)) - Added `utils.from_rdmol`/`utils.to_rdmol` functionality ([#9452](https://github.com/pyg-team/pytorch_geometric/pull/9452)) - Added the `OPFDataset` ([#9379](https://github.com/pyg-team/pytorch_geometric/pull/9379)) - Added the heterogeneous `HeteroJumpingKnowledge` module ([#9380](https://github.com/pyg-team/pytorch_geometric/pull/9380)) - Started work on GNN+LLM package ([#9350](https://github.com/pyg-team/pytorch_geometric/pull/9350)) - Added support for negative sampling in `LinkLoader` acccording to source and destination node weights ([#9316](https://github.com/pyg-team/pytorch_geometric/pull/9316)) - Added support for `EdgeIndex.unbind` ([#9298](https://github.com/pyg-team/pytorch_geometric/pull/9298)) - Integrate `torch_geometric.Index` into `torch_geometric.EdgeIndex` ([#9296](https://github.com/pyg-team/pytorch_geometric/pull/9296)) - Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291)) - Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297)) - Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240)) - Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131)) - Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090)) - Added support for cuGraph data loading and `GAT` in single node Papers100m examples ([#8173](https://github.com/pyg-team/pytorch_geometric/pull/8173)) - Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075)) - Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073)) - Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029)) - Added support for `EdgeIndex` in `spmm` ([#9026](https://github.com/pyg-team/pytorch_geometric/pull/9026)) - Added option to pre-allocate memory in GPU-based `ApproxKNN` ([#9046](https://github.com/pyg-team/pytorch_geometric/pull/9046)) - Added support for `EdgeIndex` in `MessagePassing` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007)) - Added support for `torch.compile` in combination with `EdgeIndex` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007)) - Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249)) - Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983)) - Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952)) - Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407)) ### Changed - Add args to Taobao multi-GPU example and move item-item compute to dataset ([#9550](https://github.com/pyg-team/pytorch_geometric/pull/9550)) - Use `torch.load(weights_only=True)` by default ([#9618](https://github.com/pyg-team/pytorch_geometric/pull/9618)) - Adapt `cugraph` examples to its new API ([#9541](https://github.com/pyg-team/pytorch_geometric/pull/9541)) - Allow optional but untyped tensors in `MessagePassing` ([#9494](https://github.com/pyg-team/pytorch_geometric/pull/9494)) - Added support for modifying `filename` of the stored partitioned file in `ClusterLoader` ([#9448](https://github.com/pyg-team/pytorch_geometric/pull/9448)) - Support other than two-dimensional inputs in `AttentionalAggregation` ([#9433](https://github.com/pyg-team/pytorch_geometric/pull/9433)) - Improved model performance of the `examples/ogbn_papers_100m.py` script ([#9386](https://github.com/pyg-team/pytorch_geometric/pull/9386), [#9445](https://github.com/pyg-team/pytorch_geometric/pull/9445)) - Added the `fmt` arg to `Dataset.get_summary` ([#9408](https://github.com/pyg-team/pytorch_geometric/pull/9408)) - Skipped zero atom molecules in `MoleculeNet` ([#9318](https://github.com/pyg-team/pytorch_geometric/pull/9318)) - Ensure proper parallelism in `OnDiskDataset` for multi-threaded `get` calls ([#9140](https://github.com/pyg-team/pytorch_geometric/pull/9140)) - Allow `None` outputs in `FeatureStore.get_tensor()` - `KeyError` should now be raised based on the implementation in `FeatureStore._get_tensor()` ([#9102](https://github.com/pyg-team/pytorch_geometric/pull/9102)) - Allow mini-batching of uncoalesced sparse matrices ([#9099](https://github.com/pyg-team/pytorch_geometric/pull/9099)) - Improvements to multi-node `ogbn-papers100m` default hyperparameters and adding evaluation on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823)) - Changed distributed sampler and loader tests to correctly report failures in subprocesses to `pytest` ([#8978](https://github.com/pyg-team/pytorch_geometric/pull/8978)) - Remove filtering of node/edge types in `trim_to_layer` functionality ([#9021](https://github.com/pyg-team/pytorch_geometric/pull/9021)) - Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009)) - Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001)) - Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937)) - Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918)) - Added XPU support to basic GNN examples ([#9421](https://github.com/pyg-team/pytorch_geometric/pull/9421), [#9439](https://github.com/pyg-team/pytorch_geometric/pull/9439)) ### Deprecated ### Fixed - Fixed `VirtualNode` transform for empty edge indices ([#9605](https://github.com/pyg-team/pytorch_geometric/pull/9605)) - Fixed an issue where import order in the multi-GPU `cugraph` example could cause an `rmm` error ([#9577](https://github.com/pyg-team/pytorch_geometric/pull/9577)) - Made the output of the single-GPU `cugraph` example more readable ([#9577](https://github.com/pyg-team/pytorch_geometric/pull/9577)) - Fixed `load_state_dict` behavior with lazy parameters in `HeteroDictLinear` ([#9493](https://github.com/pyg-team/pytorch_geometric/pull/9493)) - `Sequential` can now be properly pickled ([#9369](https://github.com/pyg-team/pytorch_geometric/pull/9369)) - Fixed `pickle.load` for jittable `MessagePassing` modules ([#9368](https://github.com/pyg-team/pytorch_geometric/pull/9368)) - Fixed batching of sparse tensors saved via `data.edge_index` ([#9317](https://github.com/pyg-team/pytorch_geometric/pull/9317)) - Fixed arbitrary keyword ordering in `MessagePassing.propgate` ([#9245](https://github.com/pyg-team/pytorch_geometric/pull/9245)) - Fixed node mapping bug in `RCDD` dataset ([#9234](https://github.com/pyg-team/pytorch_geometric/pull/9234)) - Fixed incorrect treatment of `edge_label` and `edge_label_index` in `ToSparseTensor` transform ([#9199](https://github.com/pyg-team/pytorch_geometric/pull/9199)) - Fixed `EgoData` processing in `SnapDataset` in case filenames are unsorted ([#9195](https://github.com/pyg-team/pytorch_geometric/pull/9195)) - Fixed empty graph and isolated node handling in `to_dgl` ([#9188](https://github.com/pyg-team/pytorch_geometric/pull/9188)) - Fixed bug in `to_scipy_sparse_matrix` when cuda is set as default torch device ([#9146](https://github.com/pyg-team/pytorch_geometric/pull/9146)) - Fixed `MetaPath2Vec` in case the last node is isolated ([#9145](https://github.com/pyg-team/pytorch_geometric/pull/9145)) - Ensure backward compatibility in `MessagePassing` via `torch.load` ([#9105](https://github.com/pyg-team/pytorch_geometric/pull/9105)) - Prevent model compilation on custom `propagate` functions ([#9079](https://github.com/pyg-team/pytorch_geometric/pull/9079)) - Ignore `self.propagate` appearances in comments when parsing `MessagePassing` implementation ([#9044](https://github.com/pyg-team/pytorch_geometric/pull/9044)) - Fixed `OSError` on read-only file systems within `MessagePassing` ([#9032](https://github.com/pyg-team/pytorch_geometric/pull/9032)) - Fixed metaclass conflict in `Dataset` ([#8999](https://github.com/pyg-team/pytorch_geometric/pull/8999)) - Fixed import errors on `MessagePassing` modules with nested inheritance ([#8973](https://github.com/pyg-team/pytorch_geometric/pull/8973)) - Fixed bug in multi XPU training ([#9456](https://github.com/pyg-team/pytorch_geometric/pull/9456)) - Fixed TorchScript compilation error for `MessagePassing._check_input` on older torch versions ([#9564](https://github.com/pyg-team/pytorch_geometric/pull/9564)) ### Removed ## [2.5.0] - 2024-02-16 ### Added - Added an example for recommender systems, including k-NN search and retrieval metrics ([#8546](https://github.com/pyg-team/pytorch_geometric/pull/8546)) - Added multi-GPU evaluation in distributed sampling example ([#8880](https://github.com/pyg-team/pytorch_geometric/pull/8880)) - Added end-to-end example for distributed CPU training ([#8713](https://github.com/pyg-team/pytorch_geometric/pull/8713)) - Added PyTorch 2.2 support ([#8857](https://github.com/pyg-team/pyg-lib/pull/8857)) - Added fallback code path for `segment` in case `torch-scatter` is not installed ([#8852](https://github.com/pyg-team/pytorch_geometric/pull/8852)) - Added support for custom node labels in `visualize_graph()` ([#8816](https://github.com/pyg-team/pytorch_geometric/pull/8816)) - Added support for graph partitioning for temporal data in `torch_geometric.distributed` ([#8718](https://github.com/pyg-team/pytorch_geometric/pull/8718), [#8815](https://github.com/pyg-team/pytorch_geometric/pull/8815), [#8874](https://github.com/pyg-team/pytorch_geometric/pull/8874)) - Added `TreeGraph` and `GridMotif` generators ([#8736](https://github.com/pyg-team/pytorch_geometric/pull/8736)) - Added two examples for edge-level temporal sampling on a heterogenous graph, with and without distributed training ([#8383](https://github.com/pyg-team/pytorch_geometric/pull/8383), [#8820](https://github.com/pyg-team/pytorch_geometric/pull/8820)) - Added the `num_graphs` option to the `StochasticBlockModelDataset` ([#8648](https://github.com/pyg-team/pytorch_geometric/pull/8648)) - Added noise scheduler utility for diffusion based graph generative models ([#8347](https://github.com/pyg-team/pytorch_geometric/pull/8347)) - Added the equivariant `ViSNet` model ([#8287](https://github.com/pyg-team/pytorch_geometric/pull/8287)) - Added temporal-related capabilities to `Data` ([#8454](https://github.com/pyg-team/pytorch_geometric/pull/8454)) - Added support for returning multi graphs in `to_networkx` ([#8575](https://github.com/pyg-team/pytorch_geometric/pull/8575)) - Added support for XPU device in `profileit` decorator ([#8532](https://github.com/pyg-team/pytorch_geometric/pull/8532)) - Added `KNNIndex` exclusion logic ([#8573](https://github.com/pyg-team/pytorch_geometric/pull/8573)) - Added warning when calling `dataset.num_classes` on regression problems ([#8550](https://github.com/pyg-team/pytorch_geometric/pull/8550)) - Added relabel node functionality to `dropout_node` ([#8524](https://github.com/pyg-team/pytorch_geometric/pull/8524)) - Added support for type checking via `mypy` ([#8254](https://github.com/pyg-team/pytorch_geometric/pull/8254)) - Added support for link-prediction retrieval metrics ([#8499](https://github.com/pyg-team/pytorch_geometric/pull/8499), [#8326](https://github.com/pyg-team/pytorch_geometric/pull/8326), [#8566](https://github.com/pyg-team/pytorch_geometric/pull/8566), [#8647](https://github.com/pyg-team/pytorch_geometric/pull/8647)) - Added METIS partitioning with CSC/CSR format selection in `ClusterData` ([#8438](https://github.com/pyg-team/pytorch_geometric/pull/8438)) - Added `is_torch_instance` to check against the original class of compiled models ([#8461](https://github.com/pyg-team/pytorch_geometric/pull/8461)) - Added dense computation for `AddRandomWalkPE` ([#8431](https://github.com/pyg-team/pytorch_geometric/pull/8431)) - Added a tutorial for point cloud processing ([#8015](https://github.com/pyg-team/pytorch_geometric/pull/8015)) - Added `fsspec` as file system backend ([#8379](https://github.com/pyg-team/pytorch_geometric/pull/8379), [#8426](https://github.com/pyg-team/pytorch_geometric/pull/8426), [#8434](https://github.com/pyg-team/pytorch_geometric/pull/8434), [#8474](https://github.com/pyg-team/pytorch_geometric/pull/8474)) - Added support for floating-point average degree numbers in `FakeDataset` and `FakeHeteroDataset` ([#8404](https://github.com/pyg-team/pytorch_geometric/pull/8404)) - Added support for device conversions of `InMemoryDataset` ([#8402](https://github.com/pyg-team/pytorch_geometric/pull/8402)) - Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372](https://github.com/pyg-team/pytorch_geometric/pull/8372), [#8428](https://github.com/pyg-team/pytorch_geometric/pull/8428)) - Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363)) - Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357), [#8436](https://github.com/pyg-team/pytorch_geometric/pull/8436)) - Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345)) - Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344)) - Added support for weighted `sparse_cross_entropy` ([#8340](https://github.com/pyg-team/pytorch_geometric/pull/8340)) - Added a multi GPU training benchmarks for CUDA and XPU devices ([#8288](https://github.com/pyg-team/pytorch_geometric/pull/8288), [#8382](https://github.com/pyg-team/pytorch_geometric/pull/8382), [#8386](https://github.com/pyg-team/pytorch_geometric/pull/8386)) - Support MRR computation in `KGEModel.test()` ([#8298](https://github.com/pyg-team/pytorch_geometric/pull/8298)) - Added an example for model parallelism (`examples/multi_gpu/model_parallel.py`) ([#8309](https://github.com/pyg-team/pytorch_geometric/pull/8309)) - Added a tutorial for multi-node multi-GPU training with pure PyTorch ([#8071](https://github.com/pyg-team/pytorch_geometric/pull/8071)) - Added a multinode-multigpu example on `ogbn-papers100M` ([#8070](https://github.com/pyg-team/pytorch_geometric/pull/8070)) - Added support for `to_hetero_with_bases` on static graphs ([#8247](https://github.com/pyg-team/pytorch_geometric/pull/8247)) - Added the `RCDD` dataset ([#8196](https://github.com/pyg-team/pytorch_geometric/pull/8196)) - Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032)) - Added the option to skip explanations of certain message passing layers via `conv.explain = False` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216)) ### Changed - Changed the default inference mode for `use_segment_matmul` based on benchmarking (from a heuristic-based version) ([#8615](https://github.com/pyg-team/pytorch_geometric/pull/8615)) - Return an empty tensor from `utils.group_argsort` if its input tensor is empty ([#8752](https://github.com/pyg-team/pytorch_geometric/pull/8752)) - GNN layers are now jittable by default ([#8745](https://github.com/pyg-team/pytorch_geometric/pull/8745)) - Sparse node features in `NELL` and `AttributedGraphDataset` are now represented as `torch.sparse_csr_tensor` instead of `torch_sparse.SparseTensor` ([#8679](https://github.com/pyg-team/pytorch_geometric/pull/8679)) - Accelerated mini-batching of `torch.sparse` tensors ([#8670](https://github.com/pyg-team/pytorch_geometric/pull/8670)) - Fixed RPC timeout due to worker closing in `DistLoader` with `atexit` not executed correctly in `worker_init_fn` ([#8605](https://github.com/pyg-team/pytorch_geometric/pull/8605)) - `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519)) - Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399)) - Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369)) - Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624), [#8722](https://github.com/pyg-team/pytorch_geometric/pull/8722)) - Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083)) - Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210)) - Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220)) ### Deprecated - Deprecated `MessagePassing.jittable` ([#8781](https://github.com/pyg-team/pytorch_geometric/pull/8781)) - Deprecated the usage of `torch_geometric.compile`; Use `torch.compile` instead ([#8780](https://github.com/pyg-team/pytorch_geometric/pull/8780)) - Deprecated the `typing` argument in `MessagePassing.jittable()` ([#8731](https://github.com/pyg-team/pytorch_geometric/pull/8731)) - Deprecated `torch_geometric.data.makedirs` in favor of `os.makedirs` ([#8421](https://github.com/pyg-team/pytorch_geometric/pull/8421)) - Deprecated `DataParallel` in favor of `DistributedDataParallel` ([#8250](https://github.com/pyg-team/pytorch_geometric/pull/8250)) ### Fixed - Fixed dummy value creation of boolean tensors in `to_homogeneous()` ([#8858](https://github.com/pyg-team/pytorch_geometric/pull/8858)) - Fixed Google Drive download issues ([#8804](https://github.com/pyg-team/pytorch_geometric/pull/8804)) - Fixed a bug in which `InMemoryDataset` did not reconstruct the correct data class when a `pre_transform` has modified it ([#8692](https://github.com/pyg-team/pytorch_geometric/pull/8692)) - Fixed a bug in which transforms were not applied for `OnDiskDataset` ([#8663](https://github.com/pyg-team/pytorch_geometric/pull/8663)) - Fixed mini-batch computation in `DMoNPooing` loss function ([#8285](https://github.com/pyg-team/pytorch_geometric/pull/8285)) - Fixed `NaN` handling in `SQLDatabase` ([#8479](https://github.com/pyg-team/pytorch_geometric/pull/8479)) - Fixed `CaptumExplainer` in case no `index` is passed ([#8440](https://github.com/pyg-team/pytorch_geometric/pull/8440)) - Fixed `edge_index` construction in the `UPFD` dataset ([#8413](https://github.com/pyg-team/pytorch_geometric/pull/8413)) - Fixed TorchScript support in `AttentionalAggregation` and `DeepSetsAggregation` ([#8406](https://github.com/pyg-team/pytorch_geometric/pull/8406)) - Fixed `GraphMaskExplainer` for GNNs with more than two layers ([#8401](https://github.com/pyg-team/pytorch_geometric/pull/8401)) - Breaking Change: Properly initialize modules in `GATConv` depending on whether the input is bipartite or non-bipartite ([#8397](https://github.com/pyg-team/pytorch_geometric/pull/8397)) - Fixed `input_id` computation in `NeighborLoader` in case a `mask` is given ([#8312](https://github.com/pyg-team/pytorch_geometric/pull/8312)) - Respect current device when deep-copying `Linear` layers ([#8311](https://github.com/pyg-team/pytorch_geometric/pull/8311)) - Fixed `Data.subgraph()`/`HeteroData.subgraph()` in case `edge_index` is not defined ([#8277](https://github.com/pyg-team/pytorch_geometric/pull/8277)) - Fixed empty edge handling in `MetaPath2Vec` ([#8248](https://github.com/pyg-team/pytorch_geometric/pull/8248)) - Fixed `AttentionExplainer` usage within `AttentiveFP` ([#8244](https://github.com/pyg-team/pytorch_geometric/pull/8244)) - Fixed `load_from_state_dict` in lazy `Linear` modules ([#8242](https://github.com/pyg-team/pytorch_geometric/pull/8242)) - Fixed pre-trained `DimeNet++` performance on `QM9` ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239)) - Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216)) - Fixed `to_networkx(to_undirected=True)` in case the input graph is not undirected ([#8204](https://github.com/pyg-team/pytorch_geometric/pull/8204)) - Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197), [#8225](https://github.com/pyg-team/pytorch_geometric/pull/8225)) - Fixed batching of `HeteroData` converted using `ToSparseTensor()` when `torch_sparse` is not installed ([#8356](https://github.com/pyg-team/pytorch_geometric/pull/8356)) ### Removed - Removed disabling of extension packages during `torch_geometric.compile` ([#8698](https://github.com/pyg-team/pytorch_geometric/pull/8698)) ## [2.4.0] - 2023-10-12 ### Added - Add the `ogc` method as example ([#8168](https://github.com/pyg-team/pytorch_geometric/pull/8168)) - Added a tutorial on `NeighborLoader` ([#7931](https://github.com/pyg-team/pytorch_geometric/pull/7931)) - Added the option to override usage of `segment_matmul`/`grouped_matmul` via the `torch_geometric.backend.use_segment_matmul` flag ([#8148](https://github.com/pyg-team/pytorch_geometric/pull/8148)) - Added support for PyTorch 2.1.0 ([#8134](https://github.com/pyg-team/pytorch_geometric/pull/8134)) - Added the `NeuroGraphDataset` benchmark collection ([#8122](https://github.com/pyg-team/pytorch_geometric/pull/8122)) - Added support for a node-level `mask` tensor in `dense_to_sparse` ([#8117](https://github.com/pyg-team/pytorch_geometric/pull/8117)) - Added the `to_on_disk_dataset()` method to convert `InMemoryDataset` instances to `OnDiskDataset` instances ([#8116](https://github.com/pyg-team/pytorch_geometric/pull/8116)) - Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118), [#8151](https://github.com/pyg-team/pytorch_geometric/pull/8151), [#8152](https://github.com/pyg-team/pytorch_geometric/pull/8152)) - Added the `DistLoader` base class ([#8079](https://github.com/pyg-team/pytorch_geometric/pull/8079)) - Added `HyperGraphData` to support hypergraphs ([#7611](https://github.com/pyg-team/pytorch_geometric/pull/7611)) - Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102)) - Added `module_headers` property to `nn.Sequential` models ([#8093](https://github.com/pyg-team/pytorch_geometric/pull/8093)) - Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088), [#8092](https://github.com/pyg-team/pytorch_geometric/pull/8092), [#8106](https://github.com/pyg-team/pytorch_geometric/pull/8106)) - Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938)) - Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894)) - Added `edge_attr` support to `ResGatedGraphConv` ([#8048](https://github.com/pyg-team/pytorch_geometric/pull/8048)) - Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054), [#8057](https://github.com/pyg-team/pytorch_geometric/pull/8057), [#8058](https://github.com/pyg-team/pytorch_geometric/pull/8058)) - Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038)) - Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025)) - Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033)) - Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230)) - Added the `NeuralFingerprint` model for learning fingerprints of molecules ([#7919](https://github.com/pyg-team/pytorch_geometric/pull/7919)) - Added `SparseTensor` support to `WLConvContinuous`, `GeneralConv`, `PDNConv` and `ARMAConv` ([#8013](https://github.com/pyg-team/pytorch_geometric/pull/8013)) - Added `LCMAggregation`, an implementation of Learnable Communitive Monoids, along with an example ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976), [#8020](https://github.com/pyg-team/pytorch_geometric/pull/8020), [#8023](https://github.com/pyg-team/pytorch_geometric/pull/8023), [#8026](https://github.com/pyg-team/pytorch_geometric/pull/8026), [#8075](https://github.com/pyg-team/pytorch_geometric/pull/8075)) - Added a warning for isolated/non-existing node types in `HeteroData.validate()` ([#7995](https://github.com/pyg-team/pytorch_geometric/pull/7995)) - Added `utils.cumsum` implementation ([#7994](https://github.com/pyg-team/pytorch_geometric/pull/7994)) - Added the `BrcaTcga` dataset ([#7905](https://github.com/pyg-team/pytorch_geometric/pull/7905)) - Added the `MyketDataset` ([#7959](https://github.com/pyg-team/pytorch_geometric/pull/7959)) - Added a multi-GPU `ogbn-papers100M` example ([#7921](https://github.com/pyg-team/pytorch_geometric/pull/7921)) - Added `group_argsort` implementation ([#7948](https://github.com/pyg-team/pytorch_geometric/pull/7948)) - Added `CachedLoader` implementation ([#7896](https://github.com/pyg-team/pytorch_geometric/pull/7896), [#7897](https://github.com/pyg-team/pytorch_geometric/pull/7897)) - Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925)) - Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917)) - Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918)) - Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915)) - Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895)) - Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827)) - Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864)) - Added TorchScript support inside `BasicGNN` models ([#7865](https://github.com/pyg-team/pytorch_geometric/pull/7865)) - Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851)) - Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402)) - Integrate `neg_sampling_ratio` into `TemporalDataLoader` ([#7644](https://github.com/pyg-team/pytorch_geometric/pull/7644)) - Added `faiss`-based `KNNINdex` classes for L2 or maximum inner product search ([#7842](https://github.com/pyg-team/pytorch_geometric/pull/7842)) - Added the `OSE_GVCS` dataset ([#7811](https://github.com/pyg-team/pytorch_geometric/pull/7811)) - Added `output_initializer` argument to `DimeNet` models ([#7774](https://github.com/pyg-team/pytorch_geometric/pull/7774), [#7780](https://github.com/pyg-team/pytorch_geometric/pull/7780)) - Added `lexsort` implementation ([#7775](https://github.com/pyg-team/pytorch_geometric/pull/7775)) - Added possibility to run inference benchmarks on XPU device ([#7705](https://github.com/pyg-team/pytorch_geometric/pull/7705)) - Added `HeteroData` support in `to_networkx` ([#7713](https://github.com/pyg-team/pytorch_geometric/pull/7713)) - Added `FlopsCount` support via `fvcore` ([#7693](https://github.com/pyg-team/pytorch_geometric/pull/7693)) - Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656)) - Added `Data.sort()` and `HeteroData.sort()` functionalities ([#7649](https://github.com/pyg-team/pytorch_geometric/pull/7649)) - Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647)) - Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700)) - Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603)) - Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594)) - Enabled different attention modes in `HypergraphConv` via the `attention_mode` argument ([#7601](https://github.com/pyg-team/pytorch_geometric/pull/7601)) - Added the `FilterEdges` graph coarsening operator ([#7361](https://github.com/pyg-team/pytorch_geometric/pull/7361)) - Added the `DirGNN` model for learning on directed graphs ([#7458](https://github.com/pyg-team/pytorch_geometric/pull/7458)) - Allow GPU tensors as input to `NodeLoader` and `LinkLoader` ([#7572](https://github.com/pyg-team/pytorch_geometric/pull/7572)) - Added an `embedding_device` option to allow for GPU inference in `BasicGNN` ([#7548](https://github.com/pyg-team/pytorch_geometric/pull/7548), [#7829](https://github.com/pyg-team/pytorch_geometric/pull/7829)) - Added `Performer` to `GPSConv` and remove `attn_dropout` argument from `GPSConv` ([#7465](https://github.com/pyg-team/pytorch_geometric/pull/7465)) - Enabled `LinkNeighborLoader` to return number of sampled nodes and edges per hop ([#7516](https://github.com/pyg-team/pytorch_geometric/pull/7516)) - Added the `HM` personalized fashion recommendation dataset ([#7515](https://github.com/pyg-team/pytorch_geometric/pull/7515)) - Added the `GraphMixer` model ([#7501](https://github.com/pyg-team/pytorch_geometric/pull/7501), [#7459](https://github.com/pyg-team/pytorch_geometric/pull/7459)) - Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246), [#7534](https://github.com/pyg-team/pytorch_geometric/pull/7534)) - Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479)) - Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493), [#7764](https://github.com/pyg-team/pytorch_geometric/pull/7764) [#7765](https://github.com/pyg-team/pytorch_geometric/pull/7765)) - Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483)) - Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425)) - Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846), [#7715](https://github.com/pyg-team/pytorch_geometric/pull/7715), [#7974](https://github.com/pyg-team/pytorch_geometric/pull/7974)) - Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442)) - Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421)) - Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441)) - Added a sparse `cross_entropy` implementation ([#7447](https://github.com/pyg-team/pytorch_geometric/pull/7447), [#7466](https://github.com/pyg-team/pytorch_geometric/pull/7466)) - Added the `MovieLens-100K` heterogeneous dataset ([#7398](https://github.com/pyg-team/pytorch_geometric/pull/7398)) - Added the `PMLP` model and an example ([#7370](https://github.com/pyg-team/pytorch_geometric/pull/7370), [#7543](https://github.com/pyg-team/pytorch_geometric/pull/7543)) - Added padding capabilities to `HeteroData.to_homogeneous()` in case feature dimensionalities do not match ([#7374](https://github.com/pyg-team/pytorch_geometric/pull/7374)) - Added an optional `batch_size` argument to `fps`, `knn`, `knn_graph`, `radius` and `radius_graph` ([#7368](https://github.com/pyg-team/pytorch_geometric/pull/7368)) - Added `PrefetchLoader` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376), [#7378](https://github.com/pyg-team/pytorch_geometric/pull/7378), [#7383](https://github.com/pyg-team/pytorch_geometric/pull/7383)) - Added an example for hierarchical sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244)) - Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298)) - Added an optional `add_pad_mask` argument to the `Pad` transform ([#7339](https://github.com/pyg-team/pytorch_geometric/pull/7339)) - Added `keep_inter_cluster_edges` option to `ClusterData` to support inter-subgraph edge connections when doing graph partitioning ([#7326](https://github.com/pyg-team/pytorch_geometric/pull/7326)) - Unify graph pooling framework ([#7308](https://github.com/pyg-team/pytorch_geometric/pull/7308), [#7625](https://github.com/pyg-team/pytorch_geometric/pull/7625)) - Added support for tuples as keys in `ModuleDict`/`ParameterDict` ([#7294](https://github.com/pyg-team/pytorch_geometric/pull/7294)) - Added `NodePropertySplit` transform for creating node-level splits using structural node properties ([#6894](https://github.com/pyg-team/pytorch_geometric/pull/6894)) - Added an option to preserve directed graphs in `CitationFull` datasets ([#7275](https://github.com/pyg-team/pytorch_geometric/pull/7275)) - Added support for `torch.sparse.Tensor` in `DataLoader` ([#7252](https://github.com/pyg-team/pytorch_geometric/pull/7252)) - Added `save` and `load` methods to `InMemoryDataset` ([#7250](https://github.com/pyg-team/pytorch_geometric/pull/7250), [#7413](https://github.com/pyg-team/pytorch_geometric/pull/7413)) - Added an example for heterogeneous GNN explanation via `CaptumExplainer` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096)) - Added `visualize_feature_importance` functionality to `HeteroExplanation` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096)) - Added a `AddRemainingSelfLoops` transform ([#7192](https://github.com/pyg-team/pytorch_geometric/pull/7192)) - Added `optimizer_resolver` ([#7209](https://github.com/pyg-team/pytorch_geometric/pull/7209)) - Added `type_ptr` argument to `HeteroLayerNorm` ([#7208](https://github.com/pyg-team/pytorch_geometric/pull/7208)) - Added an option to benchmark scripts to write PyTorch profiler results to CSV ([#7114](https://github.com/pyg-team/pytorch_geometric/pull/7114)) - Added subgraph type sampling option with bidirectional edge support ([#7199](https://github.com/pyg-team/pytorch_geometric/pull/7199), [#7200](https://github.com/pyg-team/pytorch_geometric/pull/7200)) - Added support for `"any"`-reductions in `scatter` ([#7198](https://github.com/pyg-team/pytorch_geometric/pull/7198)) - Added manual sampling interface to `NodeLoader` and `LinkLoader` ([#7197](https://github.com/pyg-team/pytorch_geometric/pull/7197)) - Extending `torch.sparse` support ([#7155](https://github.com/pyg-team/pytorch_geometric/pull/7155)) - Added edge weight support to `LightGCN` ([#7157](https://github.com/pyg-team/pytorch_geometric/pull/7157)) - Added `SparseTensor` support to `trim_to_layer` function ([#7089](https://github.com/pyg-team/pytorch_geometric/pull/7089)) - Added instructions for ROCm build wheels ([#7143](https://github.com/pyg-team/pytorch_geometric/pull/7143)) - Added a `ComposeFilters` class to compose `pre_filter` functions in `Dataset` ([#7097](https://github.com/pyg-team/pytorch_geometric/pull/7097)) - Added a time-step aware variant of the `EllipticBitcoinDataset` called `EllipticBitcoinTemporalDataset` ([#7011](https://github.com/pyg-team/pytorch_geometric/pull/7011)) - Added `to_dgl` and `from_dgl` conversion functions ([#7053](https://github.com/pyg-team/pytorch_geometric/pull/7053)) - Added support for `torch.jit.script` within `MessagePassing` layers without `torch_sparse` being installed ([#7061](https://github.com/pyg-team/pytorch_geometric/pull/7061), [#7062](https://github.com/pyg-team/pytorch_geometric/pull/7062)) - Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037)) - Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026)) - Added support for Apple silicon GPU acceleration in some main examples ([#7770](https://github.com/pyg-team/pytorch_geometric/pull/7770), [#7711](https://github.com/pyg-team/pytorch_geometric/pull/7711), [#7784](https://github.com/pyg-team/pytorch_geometric/pull/7784), [#7785](https://github.com/pyg-team/pytorch_geometric/pull/7785)) ### Changed - Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166)) - Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163)) - Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145)) - Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143)) - Fix `DynamicBatchSampler.__len__` to raise an error in case `num_steps` is undefined ([#8137](https://github.com/pyg-team/pytorch_geometric/pull/8137)) - Enabled pickling of `DimeNet` models ([#8019](https://github.com/pyg-team/pytorch_geometric/pull/8019)) - Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942)) - Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737)) - Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955)) - Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956) - Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953)) - Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941)) - Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923)) - Fixed broken links in `HGBDataset` ([#7907](https://github.com/pyg-team/pytorch_geometric/pull/7907)) - Fixed an issue where `SetTransformerAggregation` produced `NaN` values for isolates nodes ([#7902](https://github.com/pyg-team/pytorch_geometric/pull/7902)) - Fixed `model_summary` on modules with uninitialized parameters ([#7884](https://github.com/pyg-team/pytorch_geometric/pull/7884)) - Updated `QM9` data pre-processing to include the SMILES string ([#7867](https://github.com/pyg-team/pytorch_geometric/pull/7867)) - Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330)) - Fixed device issue in `PNAConv.get_degree_histogram` ([#7830](https://github.com/pyg-team/pytorch_geometric/pull/7830)) - Fixed the shape of `edge_label_time` when using temporal sampling on homogeneous graphs ([#7807](https://github.com/pyg-team/pytorch_geometric/pull/7807)) - Moved `torch_geometric.contrib.explain.GraphMaskExplainer` to `torch_geometric.explain.algorithm.GraphMaskExplainer` ([#7779](https://github.com/pyg-team/pytorch_geometric/pull/7779)) - Made `FieldStatus` enum picklable to avoid `PicklingError` in a multi-process setting ([#7808](https://github.com/pyg-team/pytorch_geometric/pull/7808)) - Fixed `edge_label_index` computation in `LinkNeighborLoader` for the homogeneous+`disjoint` mode ([#7791](https://github.com/pyg-team/pytorch_geometric/pull/7791)) - Fixed `CaptumExplainer` for `binary_classification` tasks ([#7787](https://github.com/pyg-team/pytorch_geometric/pull/7787)) - Warn user when using the `training` flag in `to_hetero` modules ([#7772](https://github.com/pyg-team/pytorch_geometric/pull/7772)) - Unchained exceptions raised when accessing non-existent data attributes for better readability ([#7734](https://github.com/pyg-team/pytorch_geometric/pull/7734)) - Raise error when collecting non-existing attributes in `HeteroData` ([#7714](https://github.com/pyg-team/pytorch_geometric/pull/7714)) - Renamed `dest` argument to `dst` in `utils.geodesic_distance` ([#7708](https://github.com/pyg-team/pytorch_geometric/pull/7708)) - Changed `add_random_edge` to only add true negative edges ([#7654](https://github.com/pyg-team/pytorch_geometric/pull/7654)) - Allowed the usage of `BasicGNN` models in `DeepGraphInfomax` ([#7648](https://github.com/pyg-team/pytorch_geometric/pull/7648)) - Breaking Change: Made `Data.keys` a method rather than a property ([#7629](https://github.com/pyg-team/pytorch_geometric/pull/7629)) - Added a `num_edges` parameter to the forward method of `HypergraphConv` ([#7560](https://github.com/pyg-team/pytorch_geometric/pull/7560)) - Fixed `get_mesh_laplacian` for `normalization="sym"` ([#7544](https://github.com/pyg-team/pytorch_geometric/pull/7544)) - Use `dim_size` to initialize output size of the `EquilibriumAggregation` layer ([#7530](https://github.com/pyg-team/pytorch_geometric/pull/7530)) - Added a `max_num_elements` parameter to the forward method of `GraphMultisetTransformer`, `GRUAggregation`, `LSTMAggregation` and `SetTransformerAggregation` ([#7529](https://github.com/pyg-team/pytorch_geometric/pull/7529)) - Fixed empty edge indices handling in `SparseTensor` ([#7519](https://github.com/pyg-team/pytorch_geometric/pull/7519)) - Move the `scaler` tensor in `GeneralConv` to the correct device ([#7484](https://github.com/pyg-team/pytorch_geometric/pull/7484)) - Fixed `HeteroLinear` bug when used via mixed precision ([#7473](https://github.com/pyg-team/pytorch_geometric/pull/7473)) - All transforms are now immutable, i.e., they perform a shallow-copy of the data and therefore do not longer modify data in-place ([#7429](https://github.com/pyg-team/pytorch_geometric/pull/7429)) - Set `output_size` in the `repeat_interleave` operation in `QuantileAggregation` ([#7426](https://github.com/pyg-team/pytorch_geometric/pull/7426)) - Fixed gradient computation of edge weights in `utils.spmm` ([#7428](https://github.com/pyg-team/pytorch_geometric/pull/7428)) - Re-factored `ClusterLoader` to integrate `pyg-lib` METIS routine ([#7416](https://github.com/pyg-team/pytorch_geometric/pull/7416)) - Fixed an index-out-of-range bug in `QuantileAggregation` when `dim_size` is passed ([#7407](https://github.com/pyg-team/pytorch_geometric/pull/7407)) - The `filter_per_worker` option will not get automatically inferred by default based on the device of the underlying data ([#7399](https://github.com/pyg-team/pytorch_geometric/pull/7399)) - Fixed a bug in `LightGCN.recommendation_loss()` to only use the embeddings of the nodes involved in the current mini-batch ([#7384](https://github.com/pyg-team/pytorch_geometric/pull/7384)) - Added an optional `max_num_elements` argument to `SortAggregation` ([#7367](https://github.com/pyg-team/pytorch_geometric/pull/7367)) - Added the option to pass `fill_value` as a `torch.tensor` to `utils.to_dense_batch` ([#7367](https://github.com/pyg-team/pytorch_geometric/pull/7367)) - Fixed a bug in which inputs where modified in-place in `to_hetero_with_bases` ([#7363](https://github.com/pyg-team/pytorch_geometric/pull/7363)) - Do not load `node_default` and `edge_default` attributes in `from_networkx` ([#7348](https://github.com/pyg-team/pytorch_geometric/pull/7348)) - Updated examples to use `NeighborLoader` instead of `NeighborSampler` ([#7152](https://github.com/pyg-team/pytorch_geometric/pull/7152)) - Fixed `HGTConv` utility function `_construct_src_node_feat` ([#7194](https://github.com/pyg-team/pytorch_geometric/pull/7194)) - Extend dataset summary to create stats for each node/edge type ([#7203](https://github.com/pyg-team/pytorch_geometric/pull/7203)) - Added an optional `batch_size` argument to `avg_pool_x` and `max_pool_x` ([#7216](https://github.com/pyg-team/pytorch_geometric/pull/7216)) - Fixed `subgraph` on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187)) - Allow missing node types in `HeteroDictLinear` ([#7185](https://github.com/pyg-team/pytorch_geometric/pull/7185)) - Optimized `from_networkx` memory footprint by reducing unnecessary copies ([#7119](https://github.com/pyg-team/pytorch_geometric/pull/7119)) - Added an optional `batch_size` argument to `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135)) - Improved code coverage ([#7093](https://github.com/pyg-team/pytorch_geometric/pull/7093), [#7195](https://github.com/pyg-team/pytorch_geometric/pull/7195)) - Fix `numpy` incompatiblity when reading files for `Planetoid` datasets ([#7141](https://github.com/pyg-team/pytorch_geometric/pull/7141)) - Added support for `Data.num_edges` for native `torch.sparse.Tensor` adjacency matrices ([#7104](https://github.com/pyg-team/pytorch_geometric/pull/7104)) - Fixed crash of heterogeneous data loaders if node or edge types are missing ([#7060](https://github.com/pyg-team/pytorch_geometric/pull/7060), [#7087](https://github.com/pyg-team/pytorch_geometric/pull/7087)) - Accelerated attention-based `MultiAggregation` ([#7077](https://github.com/pyg-team/pytorch_geometric/pull/7077)) - Edges in `HeterophilousGraphDataset` are now undirected by default ([#7065](https://github.com/pyg-team/pytorch_geometric/pull/7065)) - Fixed a bug in `FastHGTConv` that computed values via parameters used to compute the keys ([#7050](https://github.com/pyg-team/pytorch_geometric/pull/7050)) - Accelerated sparse tensor conversion routines ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042), [#7043](https://github.com/pyg-team/pytorch_geometric/pull/7043)) - Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041)) - Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239)) - Fixed training issues of the GraphGPS example ([#7377](https://github.com/pyg-team/pytorch_geometric/pull/7377)) - Allowed `CaptumExplainer` to be called multiple times in a row ([#7391](https://github.com/pyg-team/pytorch_geometric/pull/7391)) ### Removed - Dropped Python 3.7 support ([#7939](https://github.com/pyg-team/pytorch_geometric/pull/7939)) - Removed `layer_type` argument in `contrib.explain.GraphMaskExplainer` ([#7445](https://github.com/pyg-team/pytorch_geometric/pull/7445)) - Replaced `FastHGTConv` with `HGTConv` ([#7117](https://github.com/pyg-team/pytorch_geometric/pull/7117)) ## [2.3.0] - 2023-03-23 ### Added - Added a memory-efficient `utils.one_hot` implementation ([#7005](https://github.com/pyg-team/pytorch_geometric/pull/7005)) - Added `HeteroDictLinear` and an optimized `FastHGTConv` module ([#6178](https://github.com/pyg-team/pytorch_geometric/pull/6178), [#6998](https://github.com/pyg-team/pytorch_geometric/pull/6998)) - Added the `DenseGATConv` module ([#6928](https://github.com/pyg-team/pytorch_geometric/pull/6928)) - Added `trim_to_layer` utility function for more efficient `NeighborLoader` use-cases ([#6661](https://github.com/pyg-team/pytorch_geometric/pull/6661)) - Added the `DistMult` KGE model ([#6958](https://github.com/pyg-team/pytorch_geometric/pull/6958)) - Added `HeteroData.set_value_dict` functionality ([#6961](https://github.com/pyg-team/pytorch_geometric/pull/6961), [#6974](https://github.com/pyg-team/pytorch_geometric/pull/6974)) - Added PyTorch >= 2.0 support ([#6934](https://github.com/pyg-team/pytorch_geometric/pull/6934), [#7000](https://github.com/pyg-team/pytorch_geometric/pull/7000)) - Added PyTorch Lightning >= 2.0 support ([#6929](https://github.com/pyg-team/pytorch_geometric/pull/6929)) - Added the `ComplEx` KGE model ([#6898](https://github.com/pyg-team/pytorch_geometric/pull/6898)) - Added option to write benchmark results to csv ([#6888](https://github.com/pyg-team/pytorch_geometric/pull/6888)) - Added `HeteroLayerNorm` and `HeteroBatchNorm` layers ([#6838](https://github.com/pyg-team/pytorch_geometric/pull/6838)) - Added the `HeterophilousGraphDataset` suite ([#6846](https://github.com/pyg-team/pytorch_geometric/pull/6846)) - Added support for sparse tensor in full batch mode inference benchmark ([#6843](https://github.com/pyg-team/pytorch_geometric/pull/6843)) - Enabled `NeighborLoader` to return number of sampled nodes and edges per hop ([#6834](https://github.com/pyg-team/pytorch_geometric/pull/6834)) - Added `ZipLoader` to execute multiple `NodeLoader` or `LinkLoader` instances ([#6829](https://github.com/pyg-team/pytorch_geometric/issues/6829)) - Added common `utils.select` and `utils.narrow` functionality to support filtering of both tensors and lists ([#6162](https://github.com/pyg-team/pytorch_geometric/issues/6162)) - Support `normalization` customization in `get_mesh_laplacian` ([#6790](https://github.com/pyg-team/pytorch_geometric/issues/6790)) - Added the `TemporalEncoding` module ([#6785](https://github.com/pyg-team/pytorch_geometric/pull/6785)) - Added CPU-optimized `spmm_reduce` functionality via CSR format ([#6699](https://github.com/pyg-team/pytorch_geometric/pull/6699), [#6759](https://github.com/pyg-team/pytorch_geometric/pull/6759)) - Added support for the revised version of the `MD17` dataset ([#6734](https://github.com/pyg-team/pytorch_geometric/pull/6734)) - Added TorchScript support to the `RECT_L` model ([#6727](https://github.com/pyg-team/pytorch_geometric/pull/6727)) - Added TorchScript support to the `Node2Vec` model ([#6726](https://github.com/pyg-team/pytorch_geometric/pull/6726)) - Added `utils.to_edge_index` to convert sparse tensors to edge indices and edge attributes ([#6728](https://github.com/pyg-team/pytorch_geometric/issues/6728)) - Fixed expected data format in `PolBlogs` dataset ([#6714](https://github.com/pyg-team/pytorch_geometric/issues/6714)) - Added `SimpleConv` to perform non-trainable propagation ([#6718](https://github.com/pyg-team/pytorch_geometric/pull/6718)) - Added a `RemoveDuplicatedEdges` transform ([#6709](https://github.com/pyg-team/pytorch_geometric/pull/6709)) - Added TorchScript support to the `LINKX` model ([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712)) - Added `torch.jit` examples for `example/film.py` and `example/gcn.py`([#6602](https://github.com/pyg-team/pytorch_geometric/pull/6692)) - Added `Pad` transform ([#5940](https://github.com/pyg-team/pytorch_geometric/pull/5940), [#6697](https://github.com/pyg-team/pytorch_geometric/pull/6697), [#6731](https://github.com/pyg-team/pytorch_geometric/pull/6731), [#6758](https://github.com/pyg-team/pytorch_geometric/pull/6758)) - Added full batch mode to the inference benchmark ([#6631](https://github.com/pyg-team/pytorch_geometric/pull/6631)) - Added `cat` aggregation type to the `HeteroConv` class so that features can be concatenated during grouping ([#6634](https://github.com/pyg-team/pytorch_geometric/pull/6634)) - Added `torch.compile` support and benchmark study ([#6610](https://github.com/pyg-team/pytorch_geometric/pull/6610), [#6952](https://github.com/pyg-team/pytorch_geometric/pull/6952), [#6953](https://github.com/pyg-team/pytorch_geometric/pull/6953), [#6980](https://github.com/pyg-team/pytorch_geometric/pull/6980), [#6983](https://github.com/pyg-team/pytorch_geometric/pull/6983), [#6984](https://github.com/pyg-team/pytorch_geometric/pull/6984), [#6985](https://github.com/pyg-team/pytorch_geometric/pull/6985), [#6986](https://github.com/pyg-team/pytorch_geometric/pull/6986), [#6989](https://github.com/pyg-team/pytorch_geometric/pull/6989), [#7002](https://github.com/pyg-team/pytorch_geometric/pull/7002)) - Added the `AntiSymmetricConv` layer ([#6577](https://github.com/pyg-team/pytorch_geometric/pull/6577)) - Added a mixin for Huggingface model hub integration ([#5930](https://github.com/pyg-team/pytorch_geometric/pull/5930), [#6591](https://github.com/pyg-team/pytorch_geometric/pull/6591)) - Added support for accelerated GNN layers in `nn.conv.cugraph` via `cugraph-ops` ([#6278](https://github.com/pyg-team/pytorch_geometric/pull/6278), [#6388](https://github.com/pyg-team/pytorch_geometric/pull/6388), [#6412](https://github.com/pyg-team/pytorch_geometric/pull/6412)) - Added accelerated `index_sort` function from `pyg-lib` for faster sorting ([#6554](https://github.com/pyg-team/pytorch_geometric/pull/6554)) - Fix incorrect device in `EquilibriumAggregration` ([#6560](https://github.com/pyg-team/pytorch_geometric/pull/6560)) - Added bipartite graph support in `dense_to_sparse()` ([#6546](https://github.com/pyg-team/pytorch_geometric/pull/6546)) - Add CPU affinity support for more data loaders ([#6534](https://github.com/pyg-team/pytorch_geometric/pull/6534), [#6922](https://github.com/pyg-team/pytorch_geometric/pull/6922)) - Added the `BAMultiShapesDataset` ([#6541](https://github.com/pyg-team/pytorch_geometric/pull/6541)) - Added the interfaces of a graph pooling framework ([#6540](https://github.com/pyg-team/pytorch_geometric/pull/6540)) - Added automatic `n_id` and `e_id` attributes to mini-batches produced by `NodeLoader` and `LinkLoader` ([#6524](https://github.com/pyg-team/pytorch_geometric/pull/6524)) - Added `PGMExplainer` to `torch_geometric.contrib` ([#6149](https://github.com/pyg-team/pytorch_geometric/pull/6149), [#6588](https://github.com/pyg-team/pytorch_geometric/pull/6588), [#6589](https://github.com/pyg-team/pytorch_geometric/pull/6589)) - Added a `NumNeighbors` helper class for specifying the number of neighbors when sampling ([#6501](https://github.com/pyg-team/pytorch_geometric/pull/6501), [#6505](https://github.com/pyg-team/pytorch_geometric/pull/6505), [#6690](https://github.com/pyg-team/pytorch_geometric/pull/6690)) - Added caching to `is_node_attr()` and `is_edge_attr()` calls ([#6492](https://github.com/pyg-team/pytorch_geometric/pull/6492)) - Added `ToHeteroLinear` and `ToHeteroMessagePassing` modules to accelerate `to_hetero` functionality ([#5992](https://github.com/pyg-team/pytorch_geometric/pull/5992), [#6456](https://github.com/pyg-team/pytorch_geometric/pull/6456)) - Added `GraphMaskExplainer` ([#6284](https://github.com/pyg-team/pytorch_geometric/pull/6284)) - Added the `GRBCD` and `PRBCD` adversarial attack models ([#5972](https://github.com/pyg-team/pytorch_geometric/pull/5972)) - Added `dropout` option to `SetTransformer` and `GraphMultisetTransformer` ([#6484](https://github.com/pyg-team/pytorch_geometric/pull/6484)) - Added option to customize loader arguments for evaluation in `LightningNodeData` and `LightningLinkData` ([#6450](https://github.com/pyg-team/pytorch_geometric/pull/6450), [#6456](https://github.com/pyg-team/pytorch_geometric/pull/6456)) - Added option to customize `num_neighbors` in `NeighborSampler` after instantiation ([#6446](https://github.com/pyg-team/pytorch_geometric/pull/6446)) - Added the `Taobao` dataset and a corresponding example for it ([#6144](https://github.com/pyg-team/pytorch_geometric/pull/6144)) - Added `pyproject.toml` ([#6431](https://github.com/pyg-team/pytorch_geometric/pull/6431)) - Added the `torch_geometric.contrib` sub-package ([#6422](https://github.com/pyg-team/pytorch_geometric/pull/6422)) - Warn on using latest documentation ([#6418](https://github.com/pyg-team/pytorch_geometric/pull/6418)) - Added basic `pyright` type checker support ([#6415](https://github.com/pyg-team/pytorch_geometric/pull/6415)) - Added a new external resource for link prediction ([#6396](https://github.com/pyg-team/pytorch_geometric/pull/6396)) - Added `CaptumExplainer` ([#6383](https://github.com/pyg-team/pytorch_geometric/pull/6383), [#6387](https://github.com/pyg-team/pytorch_geometric/pull/6387), [#6433](https://github.com/pyg-team/pytorch_geometric/pull/6433), [#6487](https://github.com/pyg-team/pytorch_geometric/pull/6487), [#6966](https://github.com/pyg-team/pytorch_geometric/pull/6966)) - Added support for custom `HeteroData` mini-batch class in remote backends ([#6377](https://github.com/pyg-team/pytorch_geometric/pull/6377)) - Added the `GNNFF` model ([#5866](https://github.com/pyg-team/pytorch_geometric/pull/5866)) - Added `MLPAggregation`, `SetTransformerAggregation`, `GRUAggregation`, and `DeepSetsAggregation` as adaptive readout functions ([#6301](https://github.com/pyg-team/pytorch_geometric/pull/6301), [#6336](https://github.com/pyg-team/pytorch_geometric/pull/6336), [#6338](https://github.com/pyg-team/pytorch_geometric/pull/6338)) - Added `Dataset.to_datapipe` for converting PyG datasets into a torchdata `DataPipe`([#6141](https://github.com/pyg-team/pytorch_geometric/pull/6141)) - Added `to_nested_tensor` and `from_nested_tensor` functionality ([#6329](https://github.com/pyg-team/pytorch_geometric/pull/6329), [#6330](https://github.com/pyg-team/pytorch_geometric/pull/6330), [#6331](https://github.com/pyg-team/pytorch_geometric/pull/6331), [#6332](https://github.com/pyg-team/pytorch_geometric/pull/6332)) - Added the `GPSConv` Graph Transformer layer and example ([#6326](https://github.com/pyg-team/pytorch_geometric/pull/6326), [#6327](https://github.com/pyg-team/pytorch_geometric/pull/6327)) - Added `networkit` conversion utilities ([#6321](https://github.com/pyg-team/pytorch_geometric/pull/6321)) - Added global dataset attribute access via `dataset.{attr_name}` ([#6319](https://github.com/pyg-team/pytorch_geometric/pull/6319)) - Added the `TransE` KGE model and example ([#6314](https://github.com/pyg-team/pytorch_geometric/pull/6314)) - Added the Freebase `FB15k_237` dataset ([#3204](https://github.com/pyg-team/pytorch_geometric/pull/3204)) - Added `Data.update()` and `HeteroData.update()` functionality ([#6313](https://github.com/pyg-team/pytorch_geometric/pull/6313)) - Added `PGExplainer` ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6204)) - Added the `AirfRANS` dataset ([#6287](https://github.com/pyg-team/pytorch_geometric/pull/6287)) - Added `AttentionExplainer` ([#6279](https://github.com/pyg-team/pytorch_geometric/pull/6279)) - Added (un)faithfulness explainability metric ([#6090](https://github.com/pyg-team/pytorch_geometric/pull/6090)) - Added fidelity explainability metric ([#6116](https://github.com/pyg-team/pytorch_geometric/pull/6116), [#6510](https://github.com/pyg-team/pytorch_geometric/pull/6510)) - Added subgraph visualization of GNN explanations ([#6235](https://github.com/pyg-team/pytorch_geometric/pull/6235), [#6271](https://github.com/pyg-team/pytorch_geometric/pull/6271)) - Added weighted negative sampling option in `LinkNeighborLoader` ([#6264](https://github.com/pyg-team/pytorch_geometric/pull/6264)) - Added the `BA2MotifDataset` explainer dataset ([#6257](https://github.com/pyg-team/pytorch_geometric/pull/6257)) - Added `CycleMotif` motif generator to generate `n`-node cycle shaped motifs ([#6256](https://github.com/pyg-team/pytorch_geometric/pull/6256)) - Added the `InfectionDataset` to evaluate explanations ([#6222](https://github.com/pyg-team/pytorch_geometric/pull/6222)) - Added `characterization_score` and `fidelity_curve_auc` explainer metrics ([#6188](https://github.com/pyg-team/pytorch_geometric/pull/6188)) - Added `get_message_passing_embeddings` ([#6201](https://github.com/pyg-team/pytorch_geometric/pull/6201)) - Added the `PointGNNConv` layer ([#6194](https://github.com/pyg-team/pytorch_geometric/pull/6194)) - Added `GridGraph` graph generator to generate grid graphs ([#6220](https://github.com/pyg-team/pytorch_geometric/pull/6220) - Added explainability metrics for when ground truth is available ([#6137](https://github.com/pyg-team/pytorch_geometric/pull/6137)) - Added `visualize_feature_importance` to support node feature visualizations ([#6094](https://github.com/pyg-team/pytorch_geometric/pull/6094)) - Added heterogeneous graph support to `Explanation` framework ([#6091](https://github.com/pyg-team/pytorch_geometric/pull/6091), [#6218](https://github.com/pyg-team/pytorch_geometric/pull/6218)) - Added a `CustomMotif` motif generator ([#6179](https://github.com/pyg-team/pytorch_geometric/pull/6179)) - Added `ERGraph` graph generator to generate Ergos-Renyi (ER) graphs ([#6073](https://github.com/pyg-team/pytorch_geometric/pull/6073)) - Added `BAGraph` graph generator to generate Barabasi-Albert graphs - the usage of `datasets.BAShapes` is now deprecated ([#6072](https://github.com/pyg-team/pytorch_geometric/pull/6072) - Added explainability benchmark dataset framework ([#6104](https://github.com/pyg-team/pytorch_geometric/pull/6104)) - Added `seed_time` attribute to temporal `NodeLoader` outputs in case `input_time` is given ([#6196](https://github.com/pyg-team/pytorch_geometric/pull/6196)) - Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193)) - Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187)) - Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161)) - Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6930](https://github.com/pyg-team/pytorch_geometric/pull/6930), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6936](https://github.com/pyg-team/pytorch_geometric/pull/6936), [#6937](https://github.com/pyg-team/pytorch_geometric/pull/6937), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957)) - Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154)) - Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124)) - Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117)) ### Changed - Migrate to `pyproject.toml` for packaging ([#6880](https://github.com/pyg-team/pytorch_geometric/pull/6880)) - Drop internal usage of `__dunder__` names ([#6999](https://github.com/pyg-team/pytorch_geometric/issues/6999)) - Changed the interface of `sort_edge_index`, `coalesce` and `to_undirected` to only return single `edge_index` information in case the `edge_attr` argument is not specified ([#6875](https://github.com/pyg-team/pytorch_geometric/issues/6875), [#6879](https://github.com/pyg-team/pytorch_geometric/issues/6879), [#6893](https://github.com/pyg-team/pytorch_geometric/issues/6893)) - Fixed a bug in `to_hetero` when using an uninitialized submodule without implementing `reset_parameters` ([#6863](https://github.com/pyg-team/pytorch_geometric/issues/6790)) - Fixed a bug in `get_mesh_laplacian` ([#6790](https://github.com/pyg-team/pytorch_geometric/issues/6790)) - Fixed a bug in which masks were not properly masked in `GNNExplainer` on link prediction tasks ([#6787](https://github.com/pyg-team/pytorch_geometric/pull/6787)) - Allow the usage of `ChebConv` within `GNNExplainer` ([#6778](https://github.com/pyg-team/pytorch_geometric/pull/6778)) - Allow setting the `EdgeStorage.num_edges` property ([#6710](https://github.com/pyg-team/pytorch_geometric/pull/6710)) - Fixed a bug in `utils.bipartite_subgraph()` and updated docs of `HeteroData.subgraph()` ([#6654](https://github.com/pyg-team/pytorch_geometric/pull/6654)) - Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685)) - Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)) - Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609)) - Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763), [#6781](https://github.com/pyg-team/pytorch_geometric/pull/6781), [#6797](https://github.com/pyg-team/pytorch_geometric/pull/6797), [#6799](https://github.com/pyg-team/pytorch_geometric/pull/6799), [#6824](https://github.com/pyg-team/pytorch_geometric/pull/6824), [#6858](https://github.com/pyg-team/pytorch_geometric/pull/6858)) - Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522)) - Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517)) - Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512)) - Allow 1D input to `global_*_pool` functions ([#6504](https://github.com/pyg-team/pytorch_geometric/pull/6504)) - Add information about dynamic shapes in `RGCNConv` ([#6482](https://github.com/pyg-team/pytorch_geometric/pull/6482)) - Fixed the use of types removed in `numpy 1.24.0` ([#6495](https://github.com/pyg-team/pytorch_geometric/pull/6495)) - Fixed keyword parameters in `examples/mnist_voxel_grid.py` ([#6478](https://github.com/pyg-team/pytorch_geometric/pull/6478)) - Unified `LightningNodeData` and `LightningLinkData` code paths ([#6473](https://github.com/pyg-team/pytorch_geometric/pull/6473)) - Allow indices with any integer type in `RGCNConv` ([#6463](https://github.com/pyg-team/pytorch_geometric/pull/6463)) - Re-structured the documentation ([#6420](https://github.com/pyg-team/pytorch_geometric/pull/6420), [#6423](https://github.com/pyg-team/pytorch_geometric/pull/6423), [#6429](https://github.com/pyg-team/pytorch_geometric/pull/6429), [#6440](https://github.com/pyg-team/pytorch_geometric/pull/6440), [#6443](https://github.com/pyg-team/pytorch_geometric/pull/6443), [#6445](https://github.com/pyg-team/pytorch_geometric/pull/6445), [#6452](https://github.com/pyg-team/pytorch_geometric/pull/6452), [#6453](https://github.com/pyg-team/pytorch_geometric/pull/6453), [#6458](https://github.com/pyg-team/pytorch_geometric/pull/6458), [#6459](https://github.com/pyg-team/pytorch_geometric/pull/6459), [#6460](https://github.com/pyg-team/pytorch_geometric/pull/6460), [#6490](https://github.com/pyg-team/pytorch_geometric/pull/6490), [#6491](https://github.com/pyg-team/pytorch_geometric/pull/6491), [#6693](https://github.com/pyg-team/pytorch_geometric/pull/6693), [#6744](https://github.com/pyg-team/pytorch_geometric/pull/6744)) - Fix the default arguments of `DataParallel` class ([#6376](https://github.com/pyg-team/pytorch_geometric/pull/6376)) - Fix `ImbalancedSampler` on sliced `InMemoryDataset` ([#6374](https://github.com/pyg-team/pytorch_geometric/pull/6374)) - Breaking Change: Changed the interface and implementation of `GraphMultisetTransformer` ([#6343](https://github.com/pyg-team/pytorch_geometric/pull/6343)) - Fixed the approximate PPR variant in `transforms.GDC` to not crash on graphs with isolated nodes ([#6242](https://github.com/pyg-team/pytorch_geometric/pull/6242)) - Added a warning when accesing `InMemoryDataset.data` ([#6318](https://github.com/pyg-team/pytorch_geometric/pull/6318)) - Drop `SparseTensor` dependency in `GraphStore` ([#5517](https://github.com/pyg-team/pytorch_geometric/pull/5517)) - Replace `NeighborSampler` with `NeighborLoader` in the distributed sampling example ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6307)) - Fixed the filtering of node features in `transforms.RemoveIsolatedNodes` ([#6308](https://github.com/pyg-team/pytorch_geometric/pull/6308)) - Fixed a bug in `DimeNet` that causes a output dimension mismatch ([#6305](https://github.com/pyg-team/pytorch_geometric/pull/6305)) - Fixed `Data.to_heterogeneous()` with empty `edge_index` ([#6304](https://github.com/pyg-team/pytorch_geometric/pull/6304)) - Unify `Explanation.node_mask` and `Explanation.node_feat_mask` ([#6267](https://github.com/pyg-team/pytorch_geometric/pull/6267)) - Moved thresholding config of the `Explainer` to `Explanation` ([#6215](https://github.com/pyg-team/pytorch_geometric/pull/6215)) - Fixed a bug in the output order in `HeteroLinear` for un-sorted type vectors ([#6198](https://github.com/pyg-team/pytorch_geometric/pull/6198)) - Breaking Change: Move `ExplainerConfig` arguments to the `Explainer` class ([#6176](https://github.com/pyg-team/pytorch_geometric/pull/6176)) - Refactored `NeighborSampler` to be input-type agnostic ([#6173](https://github.com/pyg-team/pytorch_geometric/pull/6173)) - Infer correct CUDA device ID in `profileit` decorator ([#6164](https://github.com/pyg-team/pytorch_geometric/pull/6164)) - Correctly use edge weights in `GDC` example ([#6159](https://github.com/pyg-team/pytorch_geometric/pull/6159)) - Breaking Change: Moved PyTorch Lightning data modules to `torch_geometric.data.lightning` ([#6140](https://github.com/pyg-team/pytorch_geometric/pull/6140)) - Make `torch_sparse` an optional dependency ([#6132](https://github.com/pyg-team/pytorch_geometric/pull/6132), [#6134](https://github.com/pyg-team/pytorch_geometric/pull/6134), [#6138](https://github.com/pyg-team/pytorch_geometric/pull/6138), [#6139](https://github.com/pyg-team/pytorch_geometric/pull/6139), [#7387](https://github.com/pyg-team/pytorch_geometric/pull/7387)) - Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113), [#6155](https://github.com/pyg-team/pytorch_geometric/pull/6155), [#6805](https://github.com/pyg-team/pytorch_geometric/pull/6805)) - Optimized `topk` implementation for large enough graphs ([#6123](https://github.com/pyg-team/pytorch_geometric/pull/6123)) ### Removed - `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627), [#6628](https://github.com/pyg-team/pytorch_geometric/pull/6628), [#6629](https://github.com/pyg-team/pytorch_geometric/pull/6629), [#6630](https://github.com/pyg-team/pytorch_geometric/pull/6630)) - Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399), [#6400](https://github.com/pyg-team/pytorch_geometric/pull/6400), [#6615](https://github.com/pyg-team/pytorch_geometric/pull/6615), [#6617](https://github.com/pyg-team/pytorch_geometric/pull/6617)) - Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382)) - Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270)) - Removed `Aggregation.set_validate_args` option ([#6175](https://github.com/pyg-team/pytorch_geometric/pull/6175)) ## [2.2.0] - 2022-12-01 ### Added - Extended `GNNExplainer` to support edge level explanations ([#6056](https://github.com/pyg-team/pytorch_geometric/pull/6056), [#6083](https://github.com/pyg-team/pytorch_geometric/pull/6083)) - Added CPU affinitization for `NodeLoader` ([#6005](https://github.com/pyg-team/pytorch_geometric/pull/6005)) - Added triplet sampling in `LinkNeighborLoader` ([#6004](https://github.com/pyg-team/pytorch_geometric/pull/6004)) - Added `FusedAggregation` of simple scatter reductions ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036)) - Added a `to_smiles` function ([#6038](https://github.com/pyg-team/pytorch_geometric/pull/6038)) - Added option to make normalization coefficients trainable in `PNAConv` ([#6039](https://github.com/pyg-team/pytorch_geometric/pull/6039)) - Added `semi_grad` option in `VarAggregation` and `StdAggregation` ([#6042](https://github.com/pyg-team/pytorch_geometric/pull/6042)) - Allow for fused aggregations in `MultiAggregation` ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036), [#6040](https://github.com/pyg-team/pytorch_geometric/pull/6040)) - Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934)) - Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007)) - Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834)) - Added `LRGBDataset` to include 5 datasets from the [Long Range Graph Benchmark](https://openreview.net/pdf?id=in7XC5RcjEn) ([#5935](https://github.com/pyg-team/pytorch_geometric/pull/5935)) - Added a warning for invalid node and edge type names in `HeteroData` ([#5990](https://github.com/pyg-team/pytorch_geometric/pull/5990)) - Added PyTorch 1.13 support ([#5975](https://github.com/pyg-team/pytorch_geometric/pull/5975)) - Added `int32` support in `NeighborLoader` ([#5948](https://github.com/pyg-team/pytorch_geometric/pull/5948)) - Add `dgNN` support and `FusedGATConv` implementation ([#5140](https://github.com/pyg-team/pytorch_geometric/pull/5140)) - Added `lr_scheduler_solver` and customized `lr_scheduler` classes ([#5942](https://github.com/pyg-team/pytorch_geometric/pull/5942)) - Add `to_fixed_size` graph transformer ([#5939](https://github.com/pyg-team/pytorch_geometric/pull/5939)) - Add support for symbolic tracing of `SchNet` model ([#5938](https://github.com/pyg-team/pytorch_geometric/pull/5938)) - Add support for customizable interaction graph in `SchNet` model ([#5919](https://github.com/pyg-team/pytorch_geometric/pull/5919)) - Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6633](https://github.com/pyg-team/pytorch_geometric/pull/6633)) - Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903)) - Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886)) - Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888)) - Added TorchScript support for `AttentiveFP `([#5868](https://github.com/pyg-team/pytorch_geometric/pull/5868)) - Added `num_steps` argument to training and inference benchmarks ([#5898](https://github.com/pyg-team/pytorch_geometric/pull/5898)) - Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877), [#5997](https://github.com/pyg-team/pytorch_geometric/pull/5997)) - Enable VTune ITT in inference and training benchmarks ([#5830](https://github.com/pyg-team/pytorch_geometric/pull/5830), [#5878](https://github.com/pyg-team/pytorch_geometric/pull/5878)) - Add training benchmark ([#5774](https://github.com/pyg-team/pytorch_geometric/pull/5774)) - Added a "Link Prediction on MovieLens" Colab notebook ([#5823](https://github.com/pyg-team/pytorch_geometric/pull/5823)) - Added custom `sampler` support in `LightningDataModule` ([#5820](https://github.com/pyg-team/pytorch_geometric/pull/5820)) - Added a `return_semantic_attention_weights` argument `HANConv` ([#5787](https://github.com/pyg-team/pytorch_geometric/pull/5787)) - Added `disjoint` argument to `NeighborLoader` and `LinkNeighborLoader` ([#5775](https://github.com/pyg-team/pytorch_geometric/pull/5775)) - Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763)) - Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717)) - Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700)) - Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696)) - Added `assortativity` that computes degree assortativity coefficient ([#5587](https://github.com/pyg-team/pytorch_geometric/pull/5587)) - Added `SSGConv` layer ([#5599](https://github.com/pyg-team/pytorch_geometric/pull/5599)) - Added `shuffle_node`, `mask_feature` and `add_random_edge` augmentation methdos ([#5548](https://github.com/pyg-team/pytorch_geometric/pull/5548)) - Added `dropout_path` augmentation that drops edges from a graph based on random walks ([#5531](https://github.com/pyg-team/pytorch_geometric/pull/5531)) - Add support for filling labels with dummy values in `HeteroData.to_homogeneous()` ([#5540](https://github.com/pyg-team/pytorch_geometric/pull/5540)) - Added `temporal_strategy` option to `neighbor_sample` ([#5576](https://github.com/pyg-team/pyg-lib/pull/5576)) - Added `torch_geometric.sampler` package to docs ([#5563](https://github.com/pyg-team/pytorch_geometric/pull/5563)) - Added the `DGraphFin` dynamic graph dataset ([#5504](https://github.com/pyg-team/pytorch_geometric/pull/5504)) - Added `dropout_edge` augmentation that randomly drops edges from a graph - the usage of `dropout_adj` is now deprecated ([#5495](https://github.com/pyg-team/pytorch_geometric/pull/5495)) - Added `dropout_node` augmentation that randomly drops nodes from a graph ([#5481](https://github.com/pyg-team/pytorch_geometric/pull/5481)) - Added `AddRandomMetaPaths` that adds edges based on random walks along a metapath ([#5397](https://github.com/pyg-team/pytorch_geometric/pull/5397)) - Added `WLConvContinuous` for performing WL refinement with continuous attributes ([#5316](https://github.com/pyg-team/pytorch_geometric/pull/5316)) - Added `print_summary` method for the `torch_geometric.data.Dataset` interface ([#5438](https://github.com/pyg-team/pytorch_geometric/pull/5438)) - Added `sampler` support to `LightningDataModule` ([#5456](https://github.com/pyg-team/pytorch_geometric/pull/5456), [#5457](https://github.com/pyg-team/pytorch_geometric/pull/5457)) - Added official splits to `MalNetTiny` dataset ([#5078](https://github.com/pyg-team/pytorch_geometric/pull/5078)) - Added `IndexToMask` and `MaskToIndex` transforms ([#5375](https://github.com/pyg-team/pytorch_geometric/pull/5375), [#5455](https://github.com/pyg-team/pytorch_geometric/pull/5455)) - Added `FeaturePropagation` transform ([#5387](https://github.com/pyg-team/pytorch_geometric/pull/5387)) - Added `PositionalEncoding` ([#5381](https://github.com/pyg-team/pytorch_geometric/pull/5381)) - Consolidated sampler routines behind `torch_geometric.sampler`, enabling ease of extensibility in the future ([#5312](https://github.com/pyg-team/pytorch_geometric/pull/5312), [#5365](https://github.com/pyg-team/pytorch_geometric/pull/5365), [#5402](https://github.com/pyg-team/pytorch_geometric/pull/5402), [#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404)), [#5418](https://github.com/pyg-team/pytorch_geometric/pull/5418)) - Added `pyg-lib` neighbor sampling ([#5384](https://github.com/pyg-team/pytorch_geometric/pull/5384), [#5388](https://github.com/pyg-team/pytorch_geometric/pull/5388)) - Added `pyg_lib.segment_matmul` integration within `HeteroLinear` ([#5330](https://github.com/pyg-team/pytorch_geometric/pull/5330), [#5347](https://github.com/pyg-team/pytorch_geometric/pull/5347))) - Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293), [#5341](https://github.com/pyg-team/pytorch_geometric/pull/5341)) - Added `Aggregation.set_validate_args` option to skip validation of `dim_size` ([#5290](https://github.com/pyg-team/pytorch_geometric/pull/5290)) - Added `SparseTensor` support to inference and training benchmark suite ([#5242](https://github.com/pyg-team/pytorch_geometric/pull/5242), [#5258](https://github.com/pyg-team/pytorch_geometric/pull/5258), [#5881](https://github.com/pyg-team/pytorch_geometric/pull/5881)) - Added experimental mode in inference benchmarks ([#5254](https://github.com/pyg-team/pytorch_geometric/pull/5254)) - Added node classification example instrumented with [Weights and Biases (W&B) logging](https://wandb.com) and [W&B Sweeps](https://wandb.com/sweeps) ([#5192](https://github.com/pyg-team/pytorch_geometric/pull/5192)) - Added experimental mode for `utils.scatter` ([#5232](https://github.com/pyg-team/pytorch_geometric/pull/5232), [#5241](https://github.com/pyg-team/pytorch_geometric/pull/5241), [#5386](https://github.com/pyg-team/pytorch_geometric/pull/5386)) - Added missing test labels in `HGBDataset` ([#5233](https://github.com/pyg-team/pytorch_geometric/pull/5233)) - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) - Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804), [#6054](https://github.com/pyg-team/pytorch_geometric/pull/6054), [#6089](https://github.com/pyg-team/pytorch_geometric/pull/6089)) ### Changed - Moved and adapted `GNNExplainer` from `torch_geometric.nn` to `torch_geometric.explain.algorithm` ([#5967](https://github.com/pyg-team/pytorch_geometric/pull/5967), [#6065](https://github.com/pyg-team/pytorch_geometric/pull/6065)) - Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052)) - Support temperature value in `dense_mincut_pool` ([#5908](https://github.com/pyg-team/pytorch_geometric/pull/5908)) - Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819)) - Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815)) - Fixed `path` in `hetero_conv_dblp.py` example ([#5686](https://github.com/pyg-team/pytorch_geometric/pull/5686)) - Fix `auto_select_device` routine in GraphGym for PyTorch Lightning>=1.7 ([#5677](https://github.com/pyg-team/pytorch_geometric/pull/5677)) - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) - Improved type hint support ([#5842](https://github.com/pyg-team/pytorch_geometric/pull/5842), [#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5797](https://github.com/pyg-team/pytorch_geometric/pull/5797), [#5798](https://github.com/pyg-team/pytorch_geometric/pull/5798), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799), [#5800](https://github.com/pyg-team/pytorch_geometric/pull/5800), [#5806](https://github.com/pyg-team/pytorch_geometric/pull/5806), [#5810](https://github.com/pyg-team/pytorch_geometric/pull/5810), [#5811](https://github.com/pyg-team/pytorch_geometric/pull/5811), [#5828](https://github.com/pyg-team/pytorch_geometric/pull/5828), [#5847](https://github.com/pyg-team/pytorch_geometric/pull/5847), [#5851](https://github.com/pyg-team/pytorch_geometric/pull/5851), [#5852](https://github.com/pyg-team/pytorch_geometric/pull/5852)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) - Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514)) - Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494)) - Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490)) - Improved `utils.scatter` performance by explicitly choosing better implementation for `add` and `mean` reduction ([#5399](https://github.com/pyg-team/pytorch_geometric/pull/5399)) - Fix `to_dense_adj` with empty `edge_index` ([#5476](https://github.com/pyg-team/pytorch_geometric/pull/5476)) - The `AttentionalAggregation` module can now be applied to compute attentin on a per-feature level ([#5449](https://github.com/pyg-team/pytorch_geometric/pull/5449)) - Ensure equal lenghts of `num_neighbors` across edge types in `NeighborLoader` ([#5444](https://github.com/pyg-team/pytorch_geometric/pull/5444)) - Fixed a bug in `TUDataset` in which node features were wrongly constructed whenever `node_attributes` only hold a single feature (_e.g._, in `PROTEINS`) ([#5441](https://github.com/pyg-team/pytorch_geometric/pull/5441)) - Breaking change: removed `num_neighbors` as an attribute of loader ([#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404)) - `ASAPooling` is now jittable ([#5395](https://github.com/pyg-team/pytorch_geometric/pull/5395)) - Updated unsupervised `GraphSAGE` example to leverage `LinkNeighborLoader` ([#5317](https://github.com/pyg-team/pytorch_geometric/pull/5317)) - Replace in-place operations with out-of-place ones to align with `torch.scatter_reduce` API ([#5353](https://github.com/pyg-team/pytorch_geometric/pull/5353)) - Breaking bugfix: `PointTransformerConv` now correctly uses `sum` aggregation ([#5332](https://github.com/pyg-team/pytorch_geometric/pull/5332)) - Improve out-of-bounds error message in `MessagePassing` ([#5339](https://github.com/pyg-team/pytorch_geometric/pull/5339)) - Allow file names of a `Dataset` to be specified as either property and method ([#5338](https://github.com/pyg-team/pytorch_geometric/pull/5338)) - Fixed separating a list of `SparseTensor` within `InMemoryDataset` ([#5299](https://github.com/pyg-team/pytorch_geometric/pull/5299)) - Improved name resolving of normalization layers ([#5277](https://github.com/pyg-team/pytorch_geometric/pull/5277)) - Fail gracefully on `GLIBC` errors within `torch-spline-conv` ([#5276](https://github.com/pyg-team/pytorch_geometric/pull/5276)) - Fixed `Dataset.num_classes` in case a `transform` modifies `data.y` ([#5274](https://github.com/pyg-team/pytorch_geometric/pull/5274)) - Allow customization of the activation function within `PNAConv` ([#5262](https://github.com/pyg-team/pytorch_geometric/pull/5262)) - Do not fill `InMemoryDataset` cache on `dataset.num_features` ([#5264](https://github.com/pyg-team/pytorch_geometric/pull/5264)) - Changed tests relying on `dblp` datasets to instead use synthetic data ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250)) - Fixed a bug for the initialization of activation function examples in `custom_graphgym` ([#5243](https://github.com/pyg-team/pytorch_geometric/pull/5243)) - Allow any integer tensors when checking edge_index input to message passing ([5281](https://github.com/pyg-team/pytorch_geometric/pull/5281)) ### Removed - Removed `scatter_reduce` option from experimental mode ([#5399](https://github.com/pyg-team/pytorch_geometric/pull/5399)) ## [2.1.0] - 2022-08-17 ### Added - Added the test for `DeepGCNLayer` ([#5704](https://github.com/pyg-team/pytorch_geometric/pull/5704)) - Allow `.` in `ModuleDict` key names ([#5227](https://github.com/pyg-team/pytorch_geometric/pull/5227)) - Added `edge_label_time` argument to `LinkNeighborLoader` ([#5137](https://github.com/pyg-team/pytorch_geometric/pull/5137), [#5173](https://github.com/pyg-team/pytorch_geometric/pull/5173)) - Let `ImbalancedSampler` accept `torch.Tensor` as input ([#5138](https://github.com/pyg-team/pytorch_geometric/pull/5138)) - Added `flow` argument to `gcn_norm` to correctly normalize the adjacency matrix in `GCNConv` ([#5149](https://github.com/pyg-team/pytorch_geometric/pull/5149)) - `NeighborSampler` supports graphs without edges ([#5072](https://github.com/pyg-team/pytorch_geometric/pull/5072)) - Added the `MeanSubtractionNorm` layer ([#5068](https://github.com/pyg-team/pytorch_geometric/pull/5068)) - Added `pyg_lib.segment_matmul` integration within `RGCNConv` ([#5052](https://github.com/pyg-team/pytorch_geometric/pull/5052), [#5096](https://github.com/pyg-team/pytorch_geometric/pull/5096)) - Support `SparseTensor` as edge label in `LightGCN` (#[5046](https://github.com/pyg-team/pytorch_geometric/issues/5046)) - Added support for `BasicGNN` models within `to_hetero` ([#5091](https://github.com/pyg-team/pytorch_geometric/pull/5091)) - Added support for computing weighted metapaths in `AddMetapaths` ([#5049](https://github.com/pyg-team/pytorch_geometric/pull/5049)) - Added inference benchmark suite ([#4915](https://github.com/pyg-team/pytorch_geometric/pull/4915)) - Added a dynamically sized batch sampler for filling a mini-batch with a variable number of samples up to a maximum size ([#4972](https://github.com/pyg-team/pytorch_geometric/pull/4972)) - Added fine grained options for setting `bias` and `dropout` per layer in the `MLP` model ([#4981](https://github.com/pyg-team/pytorch_geometric/pull/4981)) - Added `EdgeCNN` model ([#4991](https://github.com/pyg-team/pytorch_geometric/pull/4991)) - Added scalable `inference` mode in `BasicGNN` with layer-wise neighbor loading ([#4977](https://github.com/pyg-team/pytorch_geometric/pull/4977)) - Added inference benchmarks ([#4892](https://github.com/pyg-team/pytorch_geometric/pull/4892), [#5107](https://github.com/pyg-team/pytorch_geometric/pull/5107)) - Added PyTorch 1.12 support ([#4975](https://github.com/pyg-team/pytorch_geometric/pull/4975)) - Added `unbatch_edge_index` functionality for splitting an `edge_index` tensor according to a `batch` vector ([#4903](https://github.com/pyg-team/pytorch_geometric/pull/4903)) - Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944)) - Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926), [#4951](https://github.com/pyg-team/pytorch_geometric/pull/4951), [#4958](https://github.com/pyg-team/pytorch_geometric/pull/4958), [#4959](https://github.com/pyg-team/pytorch_geometric/pull/4959)) - Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927)) - Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) - Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885)) - Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868)) - Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884)) - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files)) - Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088), [#5270](https://github.com/pyg-team/pytorch_geometric/pull/5270), [#5307](https://github.com/pyg-team/pytorch_geometric/pull/5307), [#5318](https://github.com/pyg-team/pytorch_geometric/pull/5318)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) - Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816)) - Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853)) - Added `FeatureStore` and `GraphStore` abstractions ([#4534](https://github.com/pyg-team/pytorch_geometric/pull/4534), [#4568](https://github.com/pyg-team/pytorch_geometric/pull/4568), [#5120](https://github.com/pyg-team/pytorch_geometric/pull/5120)) - Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827)) - Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825)) - Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805)) - Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750)) - Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756), [#4841](https://github.com/pyg-team/pytorch_geometric/pull/4841)) - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) - Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723)) - Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033), [#5085](https://github.com/pyg-team/pytorch_geometric/pull/5085), [#5097](https://github.com/pyg-team/pytorch_geometric/pull/5097), [#5099](https://github.com/pyg-team/pytorch_geometric/pull/5099), [#5104](https://github.com/pyg-team/pytorch_geometric/pull/5104), [#5113](https://github.com/pyg-team/pytorch_geometric/pull/5113), [#5130](https://github.com/pyg-team/pytorch_geometric/pull/5130), [#5098](https://github.com/pyg-team/pytorch_geometric/pull/5098), [#5191](https://github.com/pyg-team/pytorch_geometric/pull/5191)) - Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800)) - Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487)) - Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730)) - Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676)) - Added `unbatch` functionality ([#4628](https://github.com/pyg-team/pytorch_geometric/pull/4628)) - Confirm that `to_hetero()` works with custom functions, _e.g._, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653)) - Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652)) - Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647)) - Added `HeteroData.subgraph()`, `HeteroData.node_type_subgraph()` and `HeteroData.edge_type_subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635)) - Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626)) - Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644)) - Added PyTorch Lightning support in GraphGym ([#4511](https://github.com/pyg-team/pytorch_geometric/pull/4511), [#4516](https://github.com/pyg-team/pytorch_geometric/pull/4516) [#4531](https://github.com/pyg-team/pytorch_geometric/pull/4531), [#4689](https://github.com/pyg-team/pytorch_geometric/pull/4689), [#4843](https://github.com/pyg-team/pytorch_geometric/pull/4843)) - Added support for returning embeddings in `MLP` models ([#4625](https://github.com/pyg-team/pytorch_geometric/pull/4625)) - Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620), [#4702](https://github.com/pyg-team/pytorch_geometric/pull/4702)) - Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521)) - Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604)) - Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600)) - Added `nn.aggr.EquilibrumAggregation` implicit global layer ([#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522)) - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) - Added `HeteroData` support to the `RemoveIsolatedNodes` transform ([#4479](https://github.com/pyg-team/pytorch_geometric/pull/4479)) - Added `HeteroData.num_features` functionality ([#4504](https://github.com/pyg-team/pytorch_geometric/pull/4504)) - Added support for projecting features before propagation in `SAGEConv` ([#4437](https://github.com/pyg-team/pytorch_geometric/pull/4437)) - Added `Geom-GCN` splits to the `Planetoid` datasets ([#4442](https://github.com/pyg-team/pytorch_geometric/pull/4442)) - Added a `LinkNeighborLoader` for training scalable link predictions models [#4396](https://github.com/pyg-team/pytorch_geometric/pull/4396), [#4439](https://github.com/pyg-team/pytorch_geometric/pull/4439), [#4441](https://github.com/pyg-team/pytorch_geometric/pull/4441), [#4446](https://github.com/pyg-team/pytorch_geometric/pull/4446), [#4508](https://github.com/pyg-team/pytorch_geometric/pull/4508), [#4509](https://github.com/pyg-team/pytorch_geometric/pull/4509)) - Added an unsupervised `GraphSAGE` example on `PPI` ([#4416](https://github.com/pyg-team/pytorch_geometric/pull/4416)) - Added support for `LSTM` aggregation in `SAGEConv` ([#4379](https://github.com/pyg-team/pytorch_geometric/pull/4379)) - Added support for floating-point labels in `RandomLinkSplit` ([#4311](https://github.com/pyg-team/pytorch_geometric/pull/4311), [#4383](https://github.com/pyg-team/pytorch_geometric/pull/4383)) - Added support for `torch.data` `DataPipes` ([#4302](https://github.com/pyg-team/pytorch_geometric/pull/4302), [#4345](https://github.com/pyg-team/pytorch_geometric/pull/4345), [#4349](https://github.com/pyg-team/pytorch_geometric/pull/4349)) - Added support for the `cosine` argument in the `KNNGraph`/`RadiusGraph` transforms ([#4344](https://github.com/pyg-team/pytorch_geometric/pull/4344)) - Added support graph-level attributes in `networkx` conversion ([#4343](https://github.com/pyg-team/pytorch_geometric/pull/4343)) - Added support for renaming node types via `HeteroData.rename` ([#4329](https://github.com/pyg-team/pytorch_geometric/pull/4329)) - Added an example to load a trained PyG model in C++ ([#4307](https://github.com/pyg-team/pytorch_geometric/pull/4307)) - Added a `MessagePassing.explain_message` method to customize making explanations on messages ([#4278](https://github.com/pyg-team/pytorch_geometric/pull/4278), [#4448](https://github.com/pyg-team/pytorch_geometric/pull/4448))) - Added support for `GATv2Conv` in the `nn.models.GAT` model ([#4357](https://github.com/pyg-team/pytorch_geometric/pull/4357)) - Added `HeteroData.subgraph` functionality ([#4243](https://github.com/pyg-team/pytorch_geometric/pull/4243)) - Added the `MaskLabel` module and a corresponding masked label propagation example ([#4197](https://github.com/pyg-team/pytorch_geometric/pull/4197)) - Added temporal sampling support to `NeighborLoader` ([#4025](https://github.com/pyg-team/pytorch_geometric/pull/4025)) - Added an example for unsupervised heterogeneous graph learning based on "Deep Multiplex Graph Infomax" ([#3189](https://github.com/pyg-team/pytorch_geometric/pull/3189)) ### Changed - Changed docstring for `RandomLinkSplit` ([#5190](https://github.com/pyg-team/pytorch_geometric/issues/5190)) - Switched to PyTorch `scatter_reduce` implementation - experimental feature ([#5120](https://github.com/pyg-team/pytorch_geometric/pull/5120)) - Fixed `RGATConv` device mismatches for `f-scaled` mode ([#5187](https://github.com/pyg-team/pytorch_geometric/pull/5187)) - Allow for multi-dimensional `edge_labels` in `LinkNeighborLoader` ([#5186](https://github.com/pyg-team/pytorch_geometric/pull/5186)) - Fixed `GINEConv` bug with non-sequential input ([#5154](https://github.com/pyg-team/pytorch_geometric/pull/5154)) - Improved error message ([#5095](https://github.com/pyg-team/pytorch_geometric/pull/5095)) - Fixed `HGTLoader` bug which produced outputs with missing edge types ([#5067](https://github.com/pyg-team/pytorch_geometric/pull/5067)) - Fixed dynamic inheritance issue in data batching ([#5051](https://github.com/pyg-team/pytorch_geometric/pull/5051)) - Fixed `load_state_dict` in `Linear` with `strict=False` mode ([5094](https://github.com/pyg-team/pytorch_geometric/pull/5094)) - Fixed typo in `MaskLabel.ratio_mask` ([5093](https://github.com/pyg-team/pytorch_geometric/pull/5093)) - Fixed `data.num_node_features` computation for sparse matrices ([5089](https://github.com/pyg-team/pytorch_geometric/pull/5089)) - Fixed `torch.fx` bug with `torch.nn.aggr` package ([#5021](https://github.com/pyg-team/pytorch_geometric/pull/5021))) - Fixed `GenConv` test ([4993](https://github.com/pyg-team/pytorch_geometric/pull/4993)) - Fixed packaging tests for Python 3.10 ([4982](https://github.com/pyg-team/pytorch_geometric/pull/4982)) - Changed `act_dict` (part of `graphgym`) to create individual instances instead of reusing the same ones everywhere ([4978](https://github.com/pyg-team/pytorch_geometric/pull/4978)) - Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970)) - Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967)) - Fixed `BasicGNN` for `num_layers=1`, which now respects a desired number of `out_channels` ([#4943](https://github.com/pyg-team/pytorch_geometric/pull/4943)) - `len(batch)` will now return the number of graphs inside the batch, not the number of attributes ([#4931](https://github.com/pyg-team/pytorch_geometric/pull/4931)) - Fixed `data.subgraph` generation for 0-dim tensors ([#4932](https://github.com/pyg-team/pytorch_geometric/pull/4932)) - Removed unnecssary inclusion of self-loops when sampling negative edges ([#4880](https://github.com/pyg-team/pytorch_geometric/pull/4880)) - Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) - Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) - Correct docstring for SAGEConv ([#4852](https://github.com/pyg-team/pytorch_geometric/pull/4852)) - Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present - Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828)) - Do not allow accessing edge types in `HeteroData` with two node types when there exists multiple relations between these types ([#4782](https://github.com/pyg-team/pytorch_geometric/pull/4782)) - Allow `edge_type == rev_edge_type` argument in `RandomLinkSplit` ([#4757](https://github.com/pyg-team/pytorch_geometric/pull/4757), [#5221](https://github.com/pyg-team/pytorch_geometric/pull/5221)) - Fixed a numerical instability in the `GeneralConv` and `neighbor_sample` tests ([#4754](https://github.com/pyg-team/pytorch_geometric/pull/4754)) - Fixed a bug in `HANConv` in which destination node features rather than source node features were propagated ([#4753](https://github.com/pyg-team/pytorch_geometric/pull/4753)) - Fixed versions of `checkout` and `setup-python` in CI ([#4751](https://github.com/pyg-team/pytorch_geometric/pull/4751)) - Fixed `protobuf` version ([#4719](https://github.com/pyg-team/pytorch_geometric/pull/4719)) - Fixed the ranking protocol bug in the RGCN link prediction example ([#4688](https://github.com/pyg-team/pytorch_geometric/pull/4688)) - Math support in Markdown ([#4683](https://github.com/pyg-team/pytorch_geometric/pull/4683)) - Allow for `setter` properties in `Data` ([#4682](https://github.com/pyg-team/pytorch_geometric/pull/4682), [#4686](https://github.com/pyg-team/pytorch_geometric/pull/4686)) - Allow for optional `edge_weight` in `GCN2Conv` ([#4670](https://github.com/pyg-team/pytorch_geometric/pull/4670)) - Fixed the interplay between `TUDataset` and `pre_transform` that modify node features ([#4669](https://github.com/pyg-team/pytorch_geometric/pull/4669)) - Make use of the `pyg_sphinx_theme` documentation template ([#4664](https://github.com/pyg-team/pyg-lib/pull/4664), [#4667](https://github.com/pyg-team/pyg-lib/pull/4667)) - Refactored reading molecular positions from sdf file for qm9 datasets ([4654](https://github.com/pyg-team/pytorch_geometric/pull/4654)) - Fixed `MLP.jittable()` bug in case `return_emb=True` ([#4645](https://github.com/pyg-team/pytorch_geometric/pull/4645), [#4648](https://github.com/pyg-team/pytorch_geometric/pull/4648)) - The generated node features of `StochasticBlockModelDataset` are now ordered with respect to their labels ([#4617](https://github.com/pyg-team/pytorch_geometric/pull/4617)) - Fixed typos in the documentation ([#4616](https://github.com/pyg-team/pytorch_geometric/pull/4616), [#4824](https://github.com/pyg-team/pytorch_geometric/pull/4824), [#4895](https://github.com/pyg-team/pytorch_geometric/pull/4895), [#5161](https://github.com/pyg-team/pytorch_geometric/pull/5161)) - The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597)) - Fixed subclass behavior of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586)) - Fixed filtering of attributes for loaders in case `__cat_dim__ != 0` ([#4629](https://github.com/pyg-team/pytorch_geometric/pull/4629)) - Fixed `SparseTensor` support in `NeighborLoader` ([#4320](https://github.com/pyg-team/pytorch_geometric/pull/4320)) - Fixed average degree handling in `PNAConv` ([#4312](https://github.com/pyg-team/pytorch_geometric/pull/4312)) - Fixed a bug in `from_networkx` in case some attributes are PyTorch tensors ([#4486](https://github.com/pyg-team/pytorch_geometric/pull/4486)) - Added a missing clamp in `DimeNet` ([#4506](https://github.com/pyg-team/pytorch_geometric/pull/4506), [#4562](https://github.com/pyg-team/pytorch_geometric/pull/4562)) - Fixed the download link in `DBP15K` ([#4428](https://github.com/pyg-team/pytorch_geometric/pull/4428)) - Fixed an autograd bug in `DimeNet` when resetting parameters ([#4424](https://github.com/pyg-team/pytorch_geometric/pull/4424)) - Fixed bipartite message passing in case `flow="target_to_source"` ([#4418](https://github.com/pyg-team/pytorch_geometric/pull/4418)) - Fixed a bug in which `num_nodes` was not properly updated in the `FixedPoints` transform ([#4394](https://github.com/pyg-team/pytorch_geometric/pull/4394)) - PyTorch Lightning >= 1.6 support ([#4377](https://github.com/pyg-team/pytorch_geometric/pull/4377)) - Fixed a bug in which `GATConv` was not jittable ([#4347](https://github.com/pyg-team/pytorch_geometric/pull/4347)) - Fixed a bug in which the GraphGym config was not stored in each specific experiment directory ([#4338](https://github.com/pyg-team/pytorch_geometric/pull/4338)) - Fixed a bug in which `nn.models.GAT` did not produce `out_channels`-many output channels ([#4299](https://github.com/pyg-team/pytorch_geometric/pull/4299)) - Fixed mini-batching with empty lists as attributes ([#4293](https://github.com/pyg-team/pytorch_geometric/pull/4293)) - Fixed a bug in which `GCNConv` could not be combined with `to_hetero` on heterogeneous graphs with one node type ([#4279](https://github.com/pyg-team/pytorch_geometric/pull/4279)) - Added a scheduler to the Graph Sage OGBN Example [#9877](https://github.com/pyg-team/pytorch_geometric/pull/9877) ### Removed - Remove internal metrics in favor of `torchmetrics` ([#4287](https://github.com/pyg-team/pytorch_geometric/pull/4287)) ================================================ FILE: CITATION.cff ================================================ --- cff-version: 1.2.0 message: "Please cite our papers if you use this code in your own work." title: "Fast Graph Representation Learning with PyTorch Geometric" authors: - family-names: "Fey" given-names: "Matthias" date-released: 2019-05-06 license: MIT url: "https://github.com/pyg-team/pytorch_geometric" preferred-citation: type: article title: "PyG 2.0: Scalable Learning on Real World Graphs" authors: - family-names: "Fey" given-names: "Matthias" - family-names: "Sunil" given-names: "Jinu" - family-names: "Nitta" given-names: "Akihiro" - family-names: "Puri" given-names: "Rishi" - family-names: "Shah" given-names: "Manan" - family-names: "Stojanovi{\v{c}}" given-names: "Bla{\v{z}}" - family-names: "Bendias" given-names: "Ramona" - family-names: "Barghi" given-names: "Alexandria" - family-names: "Kocijan" given-names: "Vid" - family-names: "Zhang" given-names: "Zecheng" - family-names: "He" given-names: "Xinwei" - family-names: "Lenssen" given-names: "Jan Eric" - family-names: "Leskovec" given-names: "Jure" journal: "Temporal Graph Learning Workshop @ KDD" year: 2025 ================================================ FILE: LICENSE ================================================ Copyright (c) 2023 PyG Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================

______________________________________________________________________
[![PyPI Version][pypi-image]][pypi-url] [![PyPI Download][pypi-download-image]][pypi-download-url] [![Slack][slack-image]][slack-url] [![Contributing][contributing-image]][contributing-url] **[Documentation](https://pytorch-geometric.readthedocs.io)** | **[PyG 1.0 Paper](https://arxiv.org/abs/1903.02428)** | **[PyG 2.0 Paper](https://arxiv.org/abs/2507.16991)** | **[Colab Notebooks](https://pytorch-geometric.readthedocs.io/en/latest/get_started/colabs.html)** | **[External Resources](https://pytorch-geometric.readthedocs.io/en/latest/external/resources.html)** | **[OGB Examples](https://github.com/snap-stanford/ogb/tree/master/examples)**
**PyG** *(PyTorch Geometric)* is a library built upon [PyTorch](https://pytorch.org/) to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data. It consists of various methods for deep learning on graphs and other irregular structures, also known as *[geometric deep learning](http://geometricdeeplearning.com/)*, from a variety of published papers. In addition, it consists of easy-to-use mini-batch loaders for operating on many small and single giant graphs, [multi GPU-support](https://github.com/pyg-team/pytorch_geometric/tree/master/examples/multi_gpu), [`torch.compile`](https://pytorch-geometric.readthedocs.io/en/latest/advanced/compile.html) support, [`DataPipe`](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/datapipe.py) support, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds. **[Click here to join our Slack community!][slack-url]**

______________________________________________________________________ - [Library Highlights](#library-highlights) - [Quick Tour for New Users](#quick-tour-for-new-users) - [Architecture Overview](#architecture-overview) - [Implemented GNN Models](#implemented-gnn-models) - [Installation](#installation) ## Library Highlights Whether you are a machine learning researcher or first-time user of machine learning toolkits, here are some reasons to try out PyG for machine learning on graph-structured data. - **Easy-to-use and unified API**: All it takes is 10-20 lines of code to get started with training a GNN model (see the next section for a [quick tour](#quick-tour-for-new-users)). PyG is *PyTorch-on-the-rocks*: It utilizes a tensor-centric API and keeps design principles close to vanilla PyTorch. If you are already familiar with PyTorch, utilizing PyG is straightforward. - **Comprehensive and well-maintained GNN models**: Most of the state-of-the-art Graph Neural Network architectures have been implemented by library developers or authors of research papers and are ready to be applied. - **Great flexibility**: Existing PyG models can easily be extended for conducting your own research with GNNs. Making modifications to existing models or creating new architectures is simple, thanks to its easy-to-use message passing API, and a variety of operators and utility functions. - **Large-scale real-world GNN models**: We focus on the need of GNN applications in challenging real-world scenarios, and support learning on diverse types of graphs, including but not limited to: scalable GNNs for graphs with millions of nodes; dynamic GNNs for node predictions over time; heterogeneous GNNs with multiple node types and edge types. ## Quick Tour for New Users In this quick tour, we highlight the ease of creating and training a GNN model with only a few lines of code. ### Train your own GNN model In the first glimpse of PyG, we implement the training of a GNN for classifying papers in a citation graph. For this, we load the [Cora](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset, and create a simple 2-layer GCN model using the pre-defined [`GCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html): ```python import torch from torch import Tensor from torch_geometric.nn import GCNConv from torch_geometric.datasets import Planetoid dataset = Planetoid(root='.', name='Cora') class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: # x: Node feature matrix of shape [num_nodes, in_channels] # edge_index: Graph connectivity matrix of shape [2, num_edges] x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x model = GCN(dataset.num_features, 16, dataset.num_classes) ```
We can now optimize the model in a training loop, similar to the standard PyTorch training procedure. ```python import torch.nn.functional as F data = dataset[0] optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(200): pred = model(data.x, data.edge_index) loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask]) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() ```
More information about evaluating final model performance can be found in the corresponding [example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py). ### Create your own GNN layer In addition to the easy application of existing GNNs, PyG makes it simple to implement custom Graph Neural Networks (see [here](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html) for the accompanying tutorial). For example, this is all it takes to implement the [edge convolutional layer](https://arxiv.org/abs/1801.07829) from Wang *et al.*: $$x_i^{\\prime} ~ = ~ \\max\_{j \\in \\mathcal{N}(i)} ~ \\textrm{MLP}\_{\\theta} \\left( [ ~ x_i, ~ x_j - x_i ~ ] \\right)$$ ```python import torch from torch import Tensor from torch.nn import Sequential, Linear, ReLU from torch_geometric.nn import MessagePassing class EdgeConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr="max") # "Max" aggregation. self.mlp = Sequential( Linear(2 * in_channels, out_channels), ReLU(), Linear(out_channels, out_channels), ) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: # x: Node feature matrix of shape [num_nodes, in_channels] # edge_index: Graph connectivity matrix of shape [2, num_edges] return self.propagate(edge_index, x=x) # shape [num_nodes, out_channels] def message(self, x_j: Tensor, x_i: Tensor) -> Tensor: # x_j: Source node features of shape [num_edges, in_channels] # x_i: Target node features of shape [num_edges, in_channels] edge_features = torch.cat([x_i, x_j - x_i], dim=-1) return self.mlp(edge_features) # shape [num_edges, out_channels] ``` ## Architecture Overview PyG provides a multi-layer framework that enables users to build Graph Neural Network solutions on both low and high levels. It comprises of the following components: - The PyG **engine** utilizes the powerful PyTorch deep learning framework with full [`torch.compile`](https://pytorch-geometric.readthedocs.io/en/latest/advanced/compile.html) and [TorchScript](https://pytorch-geometric.readthedocs.io/en/latest/advanced/jit.html) support, as well as additions of efficient CPU/CUDA libraries for operating on sparse data, *e.g.*, [`pyg-lib`](https://github.com/pyg-team/pyg-lib). - The PyG **storage** handles data processing, transformation and loading pipelines. It is capable of handling and processing large-scale graph datasets, and provides effective solutions for heterogeneous graphs. It further provides a variety of sampling solutions, which enable training of GNNs on large-scale graphs. - The PyG **operators** bundle essential functionalities for implementing Graph Neural Networks. PyG supports important GNN building blocks that can be combined and applied to various parts of a GNN model, ensuring rich flexibility of GNN design. - Finally, PyG provides an abundant set of GNN **models**, and examples that showcase GNN models on standard graph benchmarks. Thanks to its flexibility, users can easily build and modify custom GNN models to fit their specific needs.

## Implemented GNN Models We list currently supported PyG models, layers and operators according to category: **GNN layers:** All Graph Neural Network layers are implemented via the **[`nn.MessagePassing`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html)** interface. A GNN layer specifies how to perform message passing, *i.e.* by designing different message, aggregation and update functions as defined [here](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html). These GNN layers can be stacked together to create Graph Neural Network models. - **[GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html)** from Kipf and Welling: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) (ICLR 2017) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py)\] - **[ChebConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ChebConv.html)** from Defferrard *et al.*: [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375) (NIPS 2016) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py#L36-L37)\] - **[GATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATConv.html)** from Veličković *et al.*: [Graph Attention Networks](https://arxiv.org/abs/1710.10903) (ICLR 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gat.py)\]
Expand to see all implemented GNN layers... - **[GCN2Conv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCN2Conv.html)** from Chen *et al.*: [Simple and Deep Graph Convolutional Networks](https://arxiv.org/abs/2007.02133) (ICML 2020) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_cora.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_ppi.py)\] - **[SplineConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SplineConv.html)** from Fey *et al.*: [SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels](https://arxiv.org/abs/1711.08920) (CVPR 2018) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cora.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/faust.py)\] - **[NNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.NNConv.html)** from Gilmer *et al.*: [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) (ICML 2017) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_nn_conv.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_nn_conv.py)\] - **[CGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.CGConv.html)** from Xie and Grossman: [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301) (Physical Review Letters 120, 2018) - **[ECConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ECConv.html)** from Simonovsky and Komodakis: [Edge-Conditioned Convolution on Graphs](https://arxiv.org/abs/1704.02901) (CVPR 2017) - **[EGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.EGConv.html)** from Tailor *et al.*: [Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions](https://arxiv.org/abs/2104.01481) (GNNSys 2021) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/egc.py)\] - **[GATv2Conv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATv2Conv.html)** from Brody *et al.*: [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491) (ICLR 2022) - **[TransformerConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.TransformerConv.html)** from Shi *et al.*: [Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification](https://arxiv.org/abs/2009.03509) (CoRR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/unimp_arxiv.py)\] - **[SAGEConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html)** from Hamilton *et al.*: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) (NIPS 2017) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_train.py), [**Example3**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_sage_unsup.py), [**Example4**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_sage_unsup_ppi.py)\] - **[GraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GraphConv.html)** from, *e.g.*, Morris *et al.*: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244) (AAAI 2019) - **[GatedGraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GatedGraphConv.html)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016) - **[ResGatedGraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ResGatedGraphConv.html)** from Bresson and Laurent: [Residual Gated Graph ConvNets](https://arxiv.org/abs/1711.07553) (CoRR 2017) - **[GINConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GINConv.html)** from Xu *et al.*: [How Powerful are Graph Neural Networks?](https://arxiv.org/abs/1810.00826) (ICLR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mutag_gin.py)\] - **[GINEConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GINEConv.html)** from Hu *et al.*: [Strategies for Pre-training Graph Neural Networks](https://arxiv.org/abs/1905.12265) (ICLR 2020) - **[ARMAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ARMAConv.html)** from Bianchi *et al.*: [Graph Neural Networks with Convolutional ARMA Filters](https://arxiv.org/abs/1901.01343) (CoRR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/arma.py)\] - **[SGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SGConv.html)** from Wu *et al.*: [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153) (CoRR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/sgc.py)\] - **[APPNP](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.APPNP.html)** from Klicpera *et al.*: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/abs/1810.05997) (ICLR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/appnp.py)\] - **[MFConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MFConv.html)** from Duvenaud *et al.*: [Convolutional Networks on Graphs for Learning Molecular Fingerprints](https://arxiv.org/abs/1509.09292) (NIPS 2015) - **[AGNNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.AGNNConv.html)** from Thekumparampil *et al.*: [Attention-based Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735) (CoRR 2017) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/agnn.py)\] - **[TAGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.TAGConv.html)** from Du *et al.*: [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/abs/1710.10370) (CoRR 2017) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tagcn.py)\] - **[PNAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PNAConv.html)** from Corso *et al.*: [Principal Neighbourhood Aggregation for Graph Nets](https://arxiv.org/abs/2004.05718) (CoRR 2020) \[**[Example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pna.py)**\] - **[FAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FAConv.html)** from Bo *et al.*: [Beyond Low-Frequency Information in Graph Convolutional Networks](https://arxiv.org/abs/2101.00797) (AAAI 2021) - **[PDNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.nn.conv.PDNConv.html)** from Rozemberczki *et al.*: [Pathfinder Discovery Networks for Neural Message Passing](https://arxiv.org/abs/2010.12878) (WWW 2021) - **[RGCNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.RGCNConv.html)** from Schlichtkrull *et al.*: [Modeling Relational Data with Graph Convolutional Networks](https://arxiv.org/abs/1703.06103) (ESWC 2018) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgcn.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgcn_link_pred.py)\] - **[RGATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.RGATConv.html)** from Busbridge *et al.*: [Relational Graph Attention Networks](https://arxiv.org/abs/1904.05811) (CoRR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgat.py)\] - **[FiLMConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FiLMConv.html)** from Brockschmidt: [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](https://arxiv.org/abs/1906.12192) (ICML 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/film.py)\] - **[SignedConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SignedConv.html)** from Derr *et al.*: [Signed Graph Convolutional Network](https://arxiv.org/abs/1808.06354) (ICDM 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/signed_gcn.py)\] - **[DNAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.DNAConv.html)** from Fey: [Just Jump: Dynamic Neighborhood Aggregation in Graph Neural Networks](https://arxiv.org/abs/1904.04849) (ICLR-W 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dna.py)\] - **[PANConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PANConv.html)** from Ma *et al.*: [Path Integral Based Convolution and Pooling for Graph Neural Networks](https://arxiv.org/abs/2006.16811) (NeurIPS 2020) - **[PointNetConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PointNetConv.html)** (including **[Iterative Farthest Point Sampling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.fps.html)**, dynamic graph generation based on **[nearest neighbor](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.knn_graph.html)** or **[maximum distance](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.radius_graph.html)**, and **[k-NN interpolation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.unpool.knn_interpolate.html)** for upsampling) from Qi *et al.*: [PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation](https://arxiv.org/abs/1612.00593) (CVPR 2017) and [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413) (NIPS 2017) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_classification.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_segmentation.py)\] - **[EdgeConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.EdgeConv.html)** from Wang *et al.*: [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829) (CoRR, 2018) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dgcnn_classification.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dgcnn_segmentation.py)\] - **[XConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.XConv.html)** from Li *et al.*: [PointCNN: Convolution On X-Transformed Points](https://arxiv.org/abs/1801.07791) (NeurIPS 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_cnn.py)\] - **[PPFConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PPFConv.html)** from Deng *et al.*: [PPFNet: Global Context Aware Local Features for Robust 3D Point Matching](https://arxiv.org/abs/1802.02669) (CVPR 2018) - **[GMMConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GMMConv.html)** from Monti *et al.*: [Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs](https://arxiv.org/abs/1611.08402) (CVPR 2017) - **[FeaStConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FeaStConv.html)** from Verma *et al.*: [FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis](https://arxiv.org/abs/1706.05206) (CVPR 2018) - **[PointTransformerConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PointTransformerConv.html)** from Zhao *et al.*: [Point Transformer](https://arxiv.org/abs/2012.09164) (2020) - **[HypergraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html)** from Bai *et al.*: [Hypergraph Convolution and Hypergraph Attention](https://arxiv.org/abs/1901.08150) (CoRR 2019) - **[GravNetConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GravNetConv.html)** from Qasim *et al.*: [Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks](https://arxiv.org/abs/1902.07987) (European Physics Journal C, 2019) - **[SuperGAT](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SuperGATConv.html)** from Kim and Oh: [How To Find Your Friendly Neighborhood: Graph Attention Design With Self-Supervision](https://openreview.net/forum?id=Wi5KUNlqWty) (ICLR 2021) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/super_gat.py)\] - **[HGTConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HGTConv.html)** from Hu *et al.*: [Heterogeneous Graph Transformer](https://arxiv.org/abs/2003.01332) (WWW 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hgt_dblp.py)\] - **[HEATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HEATonv.html)** from Mo *et al.*: [Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction](https://arxiv.org/abs/2106.07161) (CoRR 2021) - **[SSGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SSGConv.html)** from Zhu *et al.*: [Simple Spectral Graph Convolution](https://openreview.net/forum?id=CYO5T-YjWZV) (ICLR 2021) - **[FusedGATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FusedGATConv.html)** from Zhang *et al.*: [Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective](https://proceedings.mlsys.org/paper/2022/file/9a1158154dfa42caddbd0694a4e9bdc8-Paper.pdf) (MLSys 2022) - **[GPSConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html)** from Rampášek *et al.*: [Recipe for a General, Powerful, Scalable Graph Transformer](https://arxiv.org/abs/2205.12454) (NeurIPS 2022) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_gps.py)\]
**Pooling layers:** Graph pooling layers combine the vectorial representations of a set of nodes in a graph (or a subgraph) into a single vector representation that summarizes its properties of nodes. It is commonly applied to graph-level tasks, which require combining node features into a single graph representation. - **[Top-K Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.TopKPooling.html)** from Gao and Ji: [Graph U-Nets](https://arxiv.org/abs/1905.05178) (ICML 2019), Cangea *et al.*: [Towards Sparse Hierarchical Graph Classifiers](https://arxiv.org/abs/1811.01287) (NeurIPS-W 2018) and Knyazev *et al.*: [Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850) (ICLR-W 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_topk_pool.py)\] - **[DiffPool](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.dense_diff_pool.html)** from Ying *et al.*: [Hierarchical Graph Representation Learning with Differentiable Pooling](https://arxiv.org/abs/1806.08804) (NeurIPS 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_diff_pool.py)\]
Expand to see all implemented pooling layers... - **[Attentional Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.AttentionalAggregation.html)** from Li *et al.*: [Graph Matching Networks for Learning the Similarity of Graph Structured Objects](https://arxiv.org/abs/1904.12787) (ICML 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)\] - **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.Set2Set.html)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)\] - **[Sort Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.SortAggregation.html)** from Zhang *et al.*: [An End-to-End Deep Learning Architecture for Graph Classification](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) (AAAI 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)\] - **[MinCut Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.dense_mincut_pool.html)** from Bianchi *et al.*: [Spectral Clustering with Graph Neural Networks for Graph Pooling](https://arxiv.org/abs/1907.00481) (ICML 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py)\] - **[DMoN Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.DMoNPooling.html)** from Tsitsulin *et al.*: [Graph Clustering with Graph Neural Networks](https://arxiv.org/abs/2006.16904) (CoRR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_dmon_pool.py)\] - **[Graclus Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.graclus.html)** from Dhillon *et al.*: [Weighted Graph Cuts without Eigenvectors: A Multilevel Approach](http://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf) (PAMI 2007) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_graclus.py)\] - **[Voxel Grid Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.voxel_grid.html)** from, *e.g.*, Simonovsky and Komodakis: [Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) (CVPR 2017) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_voxel_grid.py)\] - **[SAG Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.SAGPooling.html)** from Lee *et al.*: [Self-Attention Graph Pooling](https://arxiv.org/abs/1904.08082) (ICML 2019) and Knyazev *et al.*: [Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850) (ICLR-W 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sag_pool.py)\] - **[Edge Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.EdgePooling.html)** from Diehl *et al.*: [Towards Graph Pooling by Edge Contraction](https://graphreason.github.io/papers/17.pdf) (ICML-W 2019) and Diehl: [Edge Contraction Pooling for Graph Neural Networks](https://arxiv.org/abs/1905.10990) (CoRR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/edge_pool.py)\] - **[ASAPooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.ASAPooling.html)** from Ranjan *et al.*: [ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations](https://arxiv.org/abs/1911.07979) (AAAI 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/asap.py)\] - **[PANPooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.PANPooling.html)** from Ma *et al.*: [Path Integral Based Convolution and Pooling for Graph Neural Networks](https://arxiv.org/abs/2006.16811) (NeurIPS 2020) - **[MemPooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.MemPooling.html)** from Khasahmadi *et al.*: [Memory-Based Graph Networks](https://arxiv.org/abs/2002.09518) (ICLR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mem_pool.py)\] - **[Graph Multiset Transformer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.GraphMultisetTransformer.html)** from Baek *et al.*: [Accurate Learning of Graph Representations with Graph Multiset Pooling](https://arxiv.org/abs/2102.11533) (ICLR 2021) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_gmt.py)\] - **[Equilibrium Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.EquilibriumAggregation.html)** from Bartunov *et al.*: [](https://arxiv.org/abs/2202.12795) (UAI 2022) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/equilibrium_median.py)\]
**GNN models:** Our supported GNN models incorporate multiple message passing layers, and users can directly use these pre-defined models to make predictions on graphs. Unlike simple stacking of GNN layers, these models could involve pre-processing, additional learnable parameters, skip connections, graph coarsening, etc. - **[SchNet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.SchNet.html)** from Schütt *et al.*: [SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions](https://arxiv.org/abs/1706.08566) (NIPS 2017) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_pretrained_schnet.py)\] - **[DimeNet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DimeNet.html)** and **[DimeNetPlusPlus](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DimeNetPlusPlus.html)** from Klicpera *et al.*: [Directional Message Passing for Molecular Graphs](https://arxiv.org/abs/2003.03123) (ICLR 2020) and [Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules](https://arxiv.org/abs/2011.14115) (NeurIPS-W 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_pretrained_dimenet.py)\] - **[Node2Vec](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.Node2Vec.html)** from Grover and Leskovec: [node2vec: Scalable Feature Learning for Networks](https://arxiv.org/abs/1607.00653) (KDD 2016) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/node2vec.py)\] - **[Deep Graph Infomax](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DeepGraphInfomax.html)** from Veličković *et al.*: [Deep Graph Infomax](https://arxiv.org/abs/1809.10341) (ICLR 2019) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/infomax_transductive.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/infomax_inductive.py)\] - **Deep Multiplex Graph Infomax** from Park *et al.*: [Unsupervised Attributed Multiplex Network Embedding](https://arxiv.org/abs/1911.06750) (AAAI 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/dmgi_unsup.py)\] - **[Masked Label Prediction](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MaskLabel.html)** from Shi *et al.*: [Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification](https://arxiv.org/abs/2009.03509) (CoRR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/unimp_arxiv.py)\] - **[PMLP](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.PMLP.html)** from Yang *et al.*: [Graph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPs](https://arxiv.org/abs/2212.09034) (ICLR 2023)
Expand to see all implemented GNN models... - **[Jumping Knowledge](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.JumpingKnowledge.html)** from Xu *et al.*: [Representation Learning on Graphs with Jumping Knowledge Networks](https://arxiv.org/abs/1806.03536) (ICML 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/gin.py#L54-L106)\] - A **[MetaLayer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaLayer.html)** for building any kind of graph network similar to the [TensorFlow Graph Nets library](https://github.com/deepmind/graph_nets) from Battaglia *et al.*: [Relational Inductive Biases, Deep Learning, and Graph Networks](https://arxiv.org/abs/1806.01261) (CoRR 2018) - **[MetaPath2Vec](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaPath2Vec.html)** from Dong *et al.*: [metapath2vec: Scalable Representation Learning for Heterogeneous Networks](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) (KDD 2017) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/metapath2vec.py)\] - All variants of **[Graph Autoencoders](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GAE.html)** and **[Variational Autoencoders](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.VGAE.html)** from: - [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308) from Kipf and Welling (NIPS-W 2016) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py)\] - [Adversarially Regularized Graph Autoencoder for Graph Embedding](https://arxiv.org/abs/1802.04407) from Pan *et al.* (IJCAI 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/argva_node_clustering.py)\] - [Simple and Effective Graph Autoencoders with One-Hop Linear Models](https://arxiv.org/abs/2001.07614) from Salha *et al.* (ECML 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py)\] - **[SEAL](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/seal_link_pred.py)** from Zhang and Chen: [Link Prediction Based on Graph Neural Networks](https://arxiv.org/pdf/1802.09691.pdf) (NeurIPS 2018) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/seal_link_pred.py)\] - **[RENet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.RENet.html)** from Jin *et al.*: [Recurrent Event Network for Reasoning over Temporal Knowledge Graphs](https://arxiv.org/abs/1904.05530) (ICLR-W 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/renet.py)\] - **[GraphUNet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphUNet.html)** from Gao and Ji: [Graph U-Nets](https://arxiv.org/abs/1905.05178) (ICML 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_unet.py)\] - **[AttentiveFP](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.AttentiveFP.html)** from Xiong *et al.*: [Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959) (J. Med. Chem. 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/attentive_fp.py)\] - **[DeepGCN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DeepGCNLayer.html)** and the **[GENConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GENConv.html)** from Li *et al.*: [DeepGCNs: Can GCNs Go as Deep as CNNs?](https://arxiv.org/abs/1904.03751) (ICCV 2019) and [DeeperGCN: All You Need to Train Deeper GCNs](https://arxiv.org/abs/2006.07739) (CoRR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_proteins_deepgcn.py)\] - **[RECT](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.RECT_L.html)** from Wang *et al.*: [Network Embedding with Completely-imbalanced Labels](https://ieeexplore.ieee.org/document/8979355) (TKDE 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rect.py)\] - **[GNNExplainer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.explain.algorithm.GNNExplainer.html)** from Ying *et al.*: [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894) (NeurIPS 2019) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer_ba_shapes.py), [**Example3**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py)\] - **Graph-less Neural Networks** from Zhang *et al.*: [Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation](https://arxiv.org/abs/2110.08727) (CoRR 2021) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/glnn.py)\] - **[LINKX](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.LINKX.html)** from Lim *et al.*: [Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods](https://arxiv.org/abs/2110.14446) (NeurIPS 2021) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/linkx.py)\] - **[RevGNN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GroupAddRev.html)** from Li *et al.*: [Training Graph Neural with 1000 Layers](https://arxiv.org/abs/2106.07476) (ICML 2021) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rev_gnn.py)\] - **[TransE](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.TransE.html)** from Bordes *et al.*: [Translating Embeddings for Modeling Multi-Relational Data](https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf) (NIPS 2013) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\] - **[ComplEx](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.ComplEx.html)** from Trouillon *et al.*: [Complex Embeddings for Simple Link Prediction](https://arxiv.org/abs/1606.06357) (ICML 2016) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\] - **[DistMult](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.DistMult.html)** from Yang *et al.*: [Embedding Entities and Relations for Learning and Inference in Knowledge Bases](https://arxiv.org/abs/1412.6575) (ICLR 2015) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\] - **[RotatE](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.RotatE.html)** from Sun *et al.*: [RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space](https://arxiv.org/abs/1902.10197) (ICLR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\]
**GNN operators and utilities:** PyG comes with a rich set of neural network operators that are commonly used in many GNN models. They follow an extensible design: It is easy to apply these operators and graph utilities to existing GNN layers and models to further enhance model performance. - **[DropEdge](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_edge)** from Rong *et al.*: [DropEdge: Towards Deep Graph Convolutional Networks on Node Classification](https://openreview.net/forum?id=Hkx1qkrKPr) (ICLR 2020) - **[DropNode](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_node)**, **[MaskFeature](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.mask_feature)** and **[AddRandomEdge](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.add_random_edge)** from You *et al.*: [Graph Contrastive Learning with Augmentations](https://arxiv.org/abs/2010.13902) (NeurIPS 2020) - **[DropPath](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_path)** from Li *et al.*: [MaskGAE: Masked Graph Modeling Meets Graph Autoencoders](https://arxiv.org/abs/2205.10053) (arXiv 2022) - **[ShuffleNode](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.shuffle_node)** from Veličković *et al.*: [Deep Graph Infomax](https://arxiv.org/abs/1809.10341) (ICLR 2019) - **[GraphNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.GraphNorm.html)** from Cai *et al.*: [GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training](https://proceedings.mlr.press/v139/cai21e.html) (ICML 2021) - **[GDC](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.GDC.html)** from Klicpera *et al.*: [Diffusion Improves Graph Learning](https://arxiv.org/abs/1911.05485) (NeurIPS 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py)\]
Expand to see all implemented GNN operators and utilities... - **[GraphSizeNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.GraphSizeNorm.html)** from Dwivedi *et al.*: [Benchmarking Graph Neural Networks](https://arxiv.org/abs/2003.00982) (CoRR 2020) - **[PairNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.PairNorm.html)** from Zhao and Akoglu: [PairNorm: Tackling Oversmoothing in GNNs](https://arxiv.org/abs/1909.12223) (ICLR 2020) - **[MeanSubtractionNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.MeanSubtractionNorm.html)** from Yang *et al.*: [Revisiting "Over-smoothing" in Deep GCNs](https://arxiv.org/abs/2003.13663) (CoRR 2020) - **[DiffGroupNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.DiffGroupNorm.html)** from Zhou *et al.*: [Towards Deeper Graph Neural Networks with Differentiable Group Normalization](https://arxiv.org/abs/2006.06972) (NeurIPS 2020) - **[Tree Decomposition](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.tree_decomposition)** from Jin *et al.*: [Junction Tree Variational Autoencoder for Molecular Graph Generation](https://arxiv.org/abs/1802.04364) (ICML 2018) - **[TGN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.TGNMemory.html)** from Rossi *et al.*: [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) (GRL+ 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py)\] - **[Weisfeiler Lehman Operator](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.WLConv.html)** from Weisfeiler and Lehman: [A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction](https://www.iti.zcu.cz/wl2018/pdf/wl_paper_translation.pdf) (Nauchno-Technicheskaya Informatsia 1968) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/wl_kernel.py)\] - **[Continuous Weisfeiler Lehman Operator](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.WLConvContinuous.html)** from Togninalli *et al.*: [Wasserstein Weisfeiler-Lehman Graph Kernels](https://arxiv.org/abs/1906.01277) (NeurIPS 2019) - **[Label Propagation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.LabelPropagation.html)** from Zhu and Ghahramani: [Learning from Labeled and Unlabeled Data with Label Propagation](http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf) (CMU-CALD 2002) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/label_prop.py)\] - **[Local Degree Profile](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.LocalDegreeProfile)** from Cai and Wang: [A Simple yet Effective Baseline for Non-attribute Graph Classification](https://arxiv.org/abs/1811.03508) (CoRR 2018) - **[CorrectAndSmooth](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.CorrectAndSmooth.html)** from Huang *et al.*: [Combining Label Propagation And Simple Models Out-performs Graph Neural Networks](https://arxiv.org/abs/2010.13993) (CoRR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/correct_and_smooth.py)\] - **[Gini](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.functional.gini.html)** and **[BRO](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.functional.bro.html)** regularization from Henderson *et al.*: [Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity](https://arxiv.org/abs/2105.04854) (ICML 2021) - **[RootedEgoNets](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.RootedEgoNets)** and **[RootedRWSubgraph](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.RootedRWSubgraph)** from Zhao *et al.*: [From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness](https://arxiv.org/abs/2110.03753) (ICLR 2022) - **[FeaturePropagation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.FeaturePropagation)** from Rossi *et al.*: [On the Unreasonable Effectiveness of Feature Propagation in Learning on Graphs with Missing Node Features](https://arxiv.org/abs/2111.12128) (CoRR 2021)
**Scalable GNNs:** PyG supports the implementation of Graph Neural Networks that can scale to large-scale graphs. Such application is challenging since the entire graph, its associated features and the GNN parameters cannot fit into GPU memory. Many state-of-the-art scalability approaches tackle this challenge by sampling neighborhoods for mini-batch training, graph clustering and partitioning, or by using simplified GNN models. These approaches have been implemented in PyG, and can benefit from the above GNN layers, operators and models. - **[NeighborLoader](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.NeighborLoader)** from Hamilton *et al.*: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) (NIPS 2017) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_train.py), [**Example3**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py)\] - **[ClusterGCN](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.ClusterLoader)** from Chiang *et al.*: [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](https://arxiv.org/abs/1905.07953) (KDD 2019) \[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cluster_gcn_reddit.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cluster_gcn_ppi.py)\] - **[GraphSAINT](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.GraphSAINTSampler)** from Zeng *et al.*: [GraphSAINT: Graph Sampling Based Inductive Learning Method](https://arxiv.org/abs/1907.04931) (ICLR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_saint.py)\]
Expand to see all implemented scalable GNNs... - **[ShaDow](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.ShaDowKHopSampler)** from Zeng *et al.*: [Decoupling the Depth and Scope of Graph Neural Networks](https://arxiv.org/abs/2201.07858) (NeurIPS 2021) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/shadow.py)\] - **[SIGN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.SIGN.html)** from Rossi *et al.*: [SIGN: Scalable Inception Graph Neural Networks](https://arxiv.org/abs/2004.11198) (CoRR 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/sign.py)\] - **[HGTLoader](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.loader.HGTLoader.html)** from Hu *et al.*: [Heterogeneous Graph Transformer](https://arxiv.org/abs/2003.01332) (WWW 2020) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py)\]
## Installation PyG is available for Python 3.10 to Python 3.14. From **PyG 2.3** onwards, you can install and use PyG **without any external library** required except for PyTorch. For this, simply run ``` pip install torch_geometric ``` ### Additional Libraries If you want to utilize the full set of features from PyG, there exists several additional libraries you may want to install: - **[`pyg-lib`](https://github.com/pyg-team/pyg-lib)**: Heterogeneous GNN operators, graph sampling routines, and [`SplineConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SplineConv.html) support - **[`torch-scatter`](https://github.com/rusty1s/pytorch_scatter)**: Accelerated and efficient sparse reductions - **[`torch-sparse`](https://github.com/rusty1s/pytorch_sparse)**: [`SparseTensor`](https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html) support - **[`torch-cluster`](https://github.com/rusty1s/pytorch_cluster)**: Graph clustering routines These packages come with their own CPU and GPU kernel implementations based on the [PyTorch C++/CUDA/hip(ROCm) extension interface](https://github.com/pytorch/extension-cpp). For a basic usage of PyG, these dependencies are **fully optional**. We recommend to start with a minimal installation, and install additional dependencies once you start to actually need them. For ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl). #### PyTorch 2.10 To install the binaries for PyTorch 2.10, simply run ``` pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.10.0+${CUDA}.html ``` where `${CUDA}` should be replaced by either `cpu`, `cu126`, `cu128`, or `cu130` depending on your PyTorch installation. | | `cpu` | `cu126` | `cu128` | `cu130` | | ----------- | ----- | ------- | ------- | ------- | | **Linux** | ✅ | ✅ | ✅ | ✅ | | **Windows** | ✅ | ✅ | ✅ | ✅ | | **macOS** | ✅ | | | | #### PyTorch 2.9 To install the binaries for PyTorch 2.9, simply run ``` pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.9.0+${CUDA}.html ``` where `${CUDA}` should be replaced by either `cpu`, `cu126`, `cu128`, or `cu130` depending on your PyTorch installation. | | `cpu` | `cu118` | `cu126` | `cu128` | | ----------- | ----- | ------- | ------- | ------- | | **Linux** | ✅ | ✅ | ✅ | ✅ | | **Windows** | ✅ | ✅ | ✅ | ✅ | | **macOS** | ✅ | | | | #### PyTorch 2.8 To install the binaries for PyTorch 2.8, simply run ``` pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.8.0+${CUDA}.html ``` where `${CUDA}` should be replaced by either `cpu`, `cu126`, `cu128`, or `cu129` depending on your PyTorch installation. | | `cpu` | `cu126` | `cu128` | `cu129` | | ----------- | ----- | ------- | ------- | ------- | | **Linux** | ✅ | ✅ | ✅ | ✅ | | **Windows** | ✅ | ✅ | ✅ | ✅ | | **macOS** | ✅ | | | | **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, PyTorch 2.3.0/2.3.1, PyTorch 2.4.0/2.4.1, PyTorch 2.5.0/2.5.1, PyTorch 2.6.0, and PyTorch 2.7.0/2.7.1 (following the same procedure). **For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source. You can look up the latest supported version number [here](https://data.pyg.org/whl). ### NVIDIA PyG Container NVIDIA provides a PyG docker container for effortlessly training and deploying GPU accelerated GNNs with PyG, see [here](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg). ### Nightly and Master In case you want to experiment with the latest PyG features which are not fully released yet, either install the **nightly version** of PyG via ``` pip install pyg-nightly ``` or install PyG **from master** via ``` pip install git+https://github.com/pyg-team/pytorch_geometric.git ``` ### ROCm Wheels The external [`pyg-rocm-build` repository](https://github.com/Looong01/pyg-rocm-build) provides wheels and detailed instructions on how to install PyG for ROCm. If you have any questions about it, please open an issue [here](https://github.com/Looong01/pyg-rocm-build/issues). ## Cite Please cite our [PyG 1.0](https://arxiv.org/abs/1903.02428) and [PyG 2.0](https://www.arxiv.org/abs/2507.16991) papers if you use this code in your own work: ``` @inproceedings{Fey/Lenssen/2019, title={Fast Graph Representation Learning with {PyTorch Geometric}}, author={Fey, Matthias and Lenssen, Jan E.}, booktitle={ICLR Workshop on Representation Learning on Graphs and Manifolds}, year={2019}, } @inproceedings{Fey/etal/2025, title={{PyG} 2.0: Scalable Learning on Real World Graphs}, author={Fey, Matthias and Sunil, Jinu and Nitta, Akihiro and Puri, Rishi and Shah, Manan, and Stojanovi{\v{c}, Bla{\v{z} and Bendias, Ramona and Alexandria, Barghi and Kocijan, Vid and Zhang, Zecheng and He, Xinwei and Lenssen, Jan E. and Leskovec, Jure}, booktitle={Temporal Graph Learning Workshop @ KDD}, year={2025}, } ``` Feel free to [email us](mailto:matthias.fey@tu-dortmund.de) if you wish your work to be listed in the [external resources](https://pytorch-geometric.readthedocs.io/en/latest/external/resources.html). If you notice anything unexpected, please open an [issue](https://github.com/pyg-team/pytorch_geometric/issues) and let us know. If you have any questions or are missing a specific feature, feel free [to discuss them with us](https://github.com/pyg-team/pytorch_geometric/discussions). We are motivated to constantly make PyG even better. [contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat&color=4B26A4 [contributing-url]: https://github.com/pyg-team/pytorch_geometric/blob/master/.github/CONTRIBUTING.md [pypi-download-image]: https://img.shields.io/pypi/dm/torch_geometric?color=4B26A4 [pypi-download-url]: https://pepy.tech/projects/torch_geometric [pypi-image]: https://img.shields.io/pypi/pyversions/torch-geometric?color=4B26A4 [pypi-url]: https://pypi.python.org/pypi/torch-geometric [slack-image]: https://img.shields.io/badge/slack-join-white.svg?logo=slack&color=4B26A4 [slack-url]: https://data.pyg.org/slack.html ================================================ FILE: benchmark/README.md ================================================ # PyG Benchmark Suite This benchmark suite provides evaluation scripts for **[semi-supervised node classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/citation)**, **[graph classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/kernel)**, and **[point cloud classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/points)** and **[runtimes](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/runtime)** in order to compare various methods in homogeneous evaluation scenarios. In particular, we take care to avoid to perform hyperparameter and model selection on the test set and instead use an additional validation set. ## Installation ``` $ pip install -e . ``` ================================================ FILE: benchmark/citation/README.md ================================================ # Semi-supervised Node Classification Evaluation scripts for various methods on the Cora, CiteSeer and PubMed citation networks. Each experiment is repeated 100 times on either a fixed train/val/test split or on multiple random splits: - **[GCN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/gcn.py)**: `python gcn.py` - **[GAT](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/gat.py)**: `python gat.py` - **[Cheby](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/cheb.py)**: `python cheb.py` - **[SGC](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/sgc.py)**: `python sgc.py` - **[ARMA](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/arma.py)**: `python arma.py` - **[APPNP](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/appnp.py)**: `python appnp.py` Run the whole test suite via ``` $ ./run.sh ``` ================================================ FILE: benchmark/citation/__init__.py ================================================ from .datasets import get_planetoid_dataset from .train_eval import random_planetoid_splits, run __all__ = [ 'get_planetoid_dataset', 'random_planetoid_splits', 'run', ] ================================================ FILE: benchmark/citation/appnp.py ================================================ import argparse import torch import torch.nn.functional as F from citation import get_planetoid_dataset, random_planetoid_splits, run from torch.nn import Linear from torch_geometric.nn import APPNP from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--random_splits', action='store_true') parser.add_argument('--runs', type=int, default=100) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--early_stopping', type=int, default=10) parser.add_argument('--hidden', type=int, default=64) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--no_normalize_features', action='store_true') parser.add_argument('--K', type=int, default=10) parser.add_argument('--alpha', type=float, default=0.1) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.lin1 = Linear(dataset.num_features, args.hidden) self.lin2 = Linear(args.hidden, dataset.num_classes) self.prop1 = APPNP(args.K, args.alpha) def reset_parameters(self): self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.dropout(x, p=args.dropout, training=self.training) x = F.relu(self.lin1(x)) x = F.dropout(x, p=args.dropout, training=self.training) x = self.lin2(x) x = self.prop1(x, edge_index) return F.log_softmax(x, dim=1) dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.inference, args.profile, args.bf16, args.compile, permute_masks) if args.profile: rename_profile_file('citation', APPNP.__name__, args.dataset, str(args.random_splits), 'inference' if args.inference else 'train') ================================================ FILE: benchmark/citation/arma.py ================================================ import argparse import torch import torch.nn.functional as F from citation import get_planetoid_dataset, random_planetoid_splits, run from torch_geometric.nn import ARMAConv from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--random_splits', action='store_true') parser.add_argument('--runs', type=int, default=100) parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--early_stopping', type=int, default=100) parser.add_argument('--hidden', type=int, default=16) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--no_normalize_features', action='store_true') parser.add_argument('--num_stacks', type=int, default=1) parser.add_argument('--num_layers', type=int, default=1) parser.add_argument('--shared_weights', action='store_true') parser.add_argument('--skip_dropout', type=float, default=0.75) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = ARMAConv(dataset.num_features, args.hidden, args.num_stacks, args.num_layers, args.shared_weights, dropout=args.skip_dropout) self.conv2 = ARMAConv(args.hidden, dataset.num_classes, args.num_stacks, args.num_layers, args.shared_weights, dropout=args.skip_dropout) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=args.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.inference, args.profile, args.bf16, args.compile, permute_masks) if args.profile: rename_profile_file('citation', ARMAConv.__name__, args.dataset, str(args.random_splits), 'inference' if args.inference else 'train') ================================================ FILE: benchmark/citation/cheb.py ================================================ import argparse import torch import torch.nn.functional as F from citation import get_planetoid_dataset, random_planetoid_splits, run from torch_geometric.nn import ChebConv from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--random_splits', action='store_true') parser.add_argument('--runs', type=int, default=100) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--early_stopping', type=int, default=10) parser.add_argument('--hidden', type=int, default=16) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--no_normalize_features', action='store_true') parser.add_argument('--num_hops', type=int, default=3) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = ChebConv(dataset.num_features, args.hidden, args.num_hops) self.conv2 = ChebConv(args.hidden, dataset.num_classes, args.num_hops) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=args.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.inference, args.profile, args.bf16, args.compile, permute_masks) if args.profile: rename_profile_file('citation', ChebConv.__name__, args.dataset, str(args.random_splits), 'inference' if args.inference else 'train') ================================================ FILE: benchmark/citation/datasets.py ================================================ import os.path as osp import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid def get_planetoid_dataset(name, normalize_features=False, transform=None): path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) dataset = Planetoid(path, name) if transform is not None and normalize_features: dataset.transform = T.Compose([T.NormalizeFeatures(), transform]) elif normalize_features: dataset.transform = T.NormalizeFeatures() elif transform is not None: dataset.transform = transform return dataset ================================================ FILE: benchmark/citation/gat.py ================================================ import argparse import torch import torch.nn.functional as F from citation import get_planetoid_dataset, random_planetoid_splits, run from torch_geometric.nn import GATConv from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--random_splits', action='store_true') parser.add_argument('--runs', type=int, default=100) parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--early_stopping', type=int, default=100) parser.add_argument('--hidden', type=int, default=8) parser.add_argument('--dropout', type=float, default=0.6) parser.add_argument('--no_normalize_features', action='store_true') parser.add_argument('--heads', type=int, default=8) parser.add_argument('--output_heads', type=int, default=1) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = GATConv(dataset.num_features, args.hidden, heads=args.heads, dropout=args.dropout) self.conv2 = GATConv(args.hidden * args.heads, dataset.num_classes, heads=args.output_heads, concat=False, dropout=args.dropout) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.dropout(x, p=args.dropout, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=args.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.inference, args.profile, args.bf16, args.compile, permute_masks) if args.profile: rename_profile_file('citation', GATConv.__name__, args.dataset, str(args.random_splits), 'inference' if args.inference else 'train') ================================================ FILE: benchmark/citation/gcn.py ================================================ import argparse import torch import torch.nn.functional as F from citation import get_planetoid_dataset, random_planetoid_splits, run from torch_geometric.nn import GCNConv from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--random_splits', action='store_true') parser.add_argument('--runs', type=int, default=100) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--early_stopping', type=int, default=10) parser.add_argument('--hidden', type=int, default=16) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--no_normalize_features', action='store_true') parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = GCNConv(dataset.num_features, args.hidden) self.conv2 = GCNConv(args.hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=args.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.inference, args.profile, args.bf16, args.compile, permute_masks) if args.profile: rename_profile_file('citation', GCNConv.__name__, args.dataset, str(args.random_splits), 'inference' if args.inference else 'train') ================================================ FILE: benchmark/citation/inference.sh ================================================ #!/bin/sh echo "Cora" echo "====" echo "GCN" python gcn.py --dataset=Cora --inference python gcn.py --dataset=Cora --random_splits --inference python gcn.py --dataset=Cora --inference --profile python gcn.py --dataset=Cora --random_splits --inference --profile echo "GAT" python gat.py --dataset=Cora --inference python gat.py --dataset=Cora --random_splits --inference python gat.py --dataset=Cora --inference --profile python gat.py --dataset=Cora --random_splits --inference --profile echo "Cheby" python cheb.py --dataset=Cora --num_hops=3 --inference python cheb.py --dataset=Cora --num_hops=3 --random_splits --inference python cheb.py --dataset=Cora --num_hops=3 --inference --profile python cheb.py --dataset=Cora --num_hops=3 --random_splits --inference --profile echo "SGC" python sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --inference python sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --random_splits --inference python sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --inference --profile python sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --random_splits --inference --profile echo "ARMA" python arma.py --dataset=Cora --num_stacks=2 --num_layers=1 --shared_weights=True --inference python arma.py --dataset=Cora --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference python arma.py --dataset=Cora --num_stacks=2 --num_layers=1 --shared_weights=True --inference --profile python arma.py --dataset=Cora --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference --profile echo "APPNP" python appnp.py --dataset=Cora --alpha=0.1 --inference python appnp.py --dataset=Cora --alpha=0.1 --random_splits --inference python appnp.py --dataset=Cora --alpha=0.1 --inference --profile python appnp.py --dataset=Cora --alpha=0.1 --random_splits --inference --profile echo "CiteSeer" echo "========" echo "GCN" python gcn.py --dataset=CiteSeer --inference python gcn.py --dataset=CiteSeer --random_splits --inference python gcn.py --dataset=CiteSeer --inference --profile python gcn.py --dataset=CiteSeer --random_splits --inference --profile echo "GAT" python gat.py --dataset=CiteSeer --inference python gat.py --dataset=CiteSeer --random_splits --inference python gat.py --dataset=CiteSeer --inference --profile python gat.py --dataset=CiteSeer --random_splits --inference --profile echo "Cheby" python cheb.py --dataset=CiteSeer --num_hops=2 --inference python cheb.py --dataset=CiteSeer --num_hops=3 --random_splits --inference python cheb.py --dataset=CiteSeer --num_hops=2 --inference --profile python cheb.py --dataset=CiteSeer --num_hops=3 --random_splits --inference --profile echo "SGC" python sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --inference python sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --random_splits --inference python sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --inference --profile python sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --random_splits --inference --profile echo "ARMA" python arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --inference python arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference python arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --inference --profile python arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference --profile echo "APPNP" python appnp.py --dataset=CiteSeer --alpha=0.1 --inference python appnp.py --dataset=CiteSeer --alpha=0.1 --random_splits --inference python appnp.py --dataset=CiteSeer --alpha=0.1 --inference --profile python appnp.py --dataset=CiteSeer --alpha=0.1 --random_splits --inference --profile echo "PubMed" echo "======" echo "GCN" python gcn.py --dataset=PubMed --inference python gcn.py --dataset=PubMed --random_splits --inference python gcn.py --dataset=PubMed --inference --profile python gcn.py --dataset=PubMed --random_splits --inference --profile echo "GAT" python gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --inference python gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --random_splits --inference python gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --inference --profile python gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --random_splits --inference --profile echo "Cheby" python cheb.py --dataset=PubMed --num_hops=2 --inference python cheb.py --dataset=PubMed --num_hops=2 --random_splits --inference python cheb.py --dataset=PubMed --num_hops=2 --inference --profile python cheb.py --dataset=PubMed --num_hops=2 --random_splits --inference --profile echo "SGC" python sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --inference python sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --random_splits --inference python sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --inference --profile python sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --random_splits --inference --profile echo "ARMA" python arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0 --inference python arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0.5 --random_splits --inference python arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0 --inference --profile python arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0.5 --random_splits --inference --profile echo "APPNP" python appnp.py --dataset=PubMed --alpha=0.1 --inference python appnp.py --dataset=PubMed --alpha=0.1 --random_splits --inference python appnp.py --dataset=PubMed --alpha=0.1 --inference --profile python appnp.py --dataset=PubMed --alpha=0.1 --random_splits --inference --profile ================================================ FILE: benchmark/citation/run.sh ================================================ #!/bin/sh echo "Cora" echo "====" echo "GCN" python gcn.py --dataset=Cora python gcn.py --dataset=Cora --random_splits echo "GAT" python gat.py --dataset=Cora python gat.py --dataset=Cora --random_splits echo "Cheby" python cheb.py --dataset=Cora --num_hops=3 python cheb.py --dataset=Cora --num_hops=3 --random_splits echo "SGC" python sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 python sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --random_splits echo "ARMA" python arma.py --dataset=Cora --num_stacks=2 --num_layers=1 --shared_weights python arma.py --dataset=Cora --num_stacks=3 --num_layers=1 --shared_weights --random_splits echo "APPNP" python appnp.py --dataset=Cora --alpha=0.1 python appnp.py --dataset=Cora --alpha=0.1 --random_splits echo "CiteSeer" echo "========" echo "GCN" python gcn.py --dataset=CiteSeer python gcn.py --dataset=CiteSeer --random_splits echo "GAT" python gat.py --dataset=CiteSeer python gat.py --dataset=CiteSeer --random_splits echo "Cheby" python cheb.py --dataset=CiteSeer --num_hops=2 python cheb.py --dataset=CiteSeer --num_hops=3 --random_splits echo "SGC" python sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 python sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --random_splits echo "ARMA" python arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights python arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights --random_splits echo "APPNP" python appnp.py --dataset=CiteSeer --alpha=0.1 python appnp.py --dataset=CiteSeer --alpha=0.1 --random_splits echo "PubMed" echo "======" echo "GCN" python gcn.py --dataset=PubMed python gcn.py --dataset=PubMed --random_splits echo "GAT" python gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 python gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --random_splits echo "Cheby" python cheb.py --dataset=PubMed --num_hops=2 python cheb.py --dataset=PubMed --num_hops=2 --random_splits echo "SGC" python sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 python sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --random_splits echo "ARMA" python arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0 python arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0.5 --random_splits echo "APPNP" python appnp.py --dataset=PubMed --alpha=0.1 python appnp.py --dataset=PubMed --alpha=0.1 --random_splits ================================================ FILE: benchmark/citation/sgc.py ================================================ import argparse import torch import torch.nn.functional as F from citation import get_planetoid_dataset, random_planetoid_splits, run from torch_geometric.nn import SGConv from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--random_splits', action='store_true') parser.add_argument('--runs', type=int, default=100) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--lr', type=float, default=0.1) parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--early_stopping', type=int, default=10) parser.add_argument('--no_normalize_features', action='store_true') parser.add_argument('--K', type=int, default=2) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = SGConv(dataset.num_features, dataset.num_classes, K=args.K, cached=True) def reset_parameters(self): self.conv1.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) return F.log_softmax(x, dim=1) dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.inference, args.profile, args.bf16, args.compile, permute_masks) if args.profile: rename_profile_file('citation', SGConv.__name__, args.dataset, str(args.random_splits), 'inference' if args.inference else 'train') ================================================ FILE: benchmark/citation/statistics.py ================================================ from citation import get_planetoid_dataset def print_dataset(dataset): data = dataset[0] print('Name', dataset) print('Nodes', data.num_nodes) print('Edges', data.num_edges // 2) print('Features', dataset.num_features) print('Classes', dataset.num_classes) print('Label rate', data.train_mask.sum().item() / data.num_nodes) print() for name in ['Cora', 'CiteSeer', 'PubMed']: print_dataset(get_planetoid_dataset(name)) ================================================ FILE: benchmark/citation/train_eval.py ================================================ import time import torch import torch.nn.functional as F from torch import tensor from torch.optim import Adam from torch_geometric.profile import timeit, torch_profile from torch_geometric.utils import index_to_mask if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') def random_planetoid_splits(data, num_classes): # Set new random planetoid splits: # * 20 * num_classes labels for training # * 500 labels for validation # * 1000 labels for testing indices = [] for i in range(num_classes): index = (data.y == i).nonzero().view(-1) index = index[torch.randperm(index.size(0))] indices.append(index) train_index = torch.cat([i[:20] for i in indices], dim=0) rest_index = torch.cat([i[20:] for i in indices], dim=0) rest_index = rest_index[torch.randperm(rest_index.size(0))] data.train_mask = index_to_mask(train_index, size=data.num_nodes) data.val_mask = index_to_mask(rest_index[:500], size=data.num_nodes) data.test_mask = index_to_mask(rest_index[500:1500], size=data.num_nodes) return data def run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping, profiling, use_compile, permute_masks=None, logger=None): val_losses, accs, durations = [], [], [] if use_compile: model = torch.compile(model) for run in range(runs): data = dataset[0] if permute_masks is not None: data = permute_masks(data, dataset.num_classes) data = data.to(device) model.to(device).reset_parameters() optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): try: torch.mps.synchronize() except ImportError: pass t_start = time.perf_counter() best_val_loss = float('inf') test_acc = 0 val_loss_history = [] for epoch in range(1, epochs + 1): if run == runs - 1 and epoch == epochs: with timeit(): train(model, optimizer, data) else: train(model, optimizer, data) eval_info = evaluate(model, data) eval_info['epoch'] = epoch if logger is not None: logger(eval_info) if eval_info['val_loss'] < best_val_loss: best_val_loss = eval_info['val_loss'] test_acc = eval_info['test_acc'] val_loss_history.append(eval_info['val_loss']) if early_stopping > 0 and epoch > epochs // 2: tmp = tensor(val_loss_history[-(early_stopping + 1):-1]) if eval_info['val_loss'] > tmp.mean().item(): break if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): try: torch.mps.synchronize() except ImportError: pass t_end = time.perf_counter() val_losses.append(best_val_loss) accs.append(test_acc) durations.append(t_end - t_start) loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations) print(f'Val Loss: {float(loss.mean()):.4f}, ' f'Test Accuracy: {float(acc.mean()):.3f} ± {float(acc.std()):.3f}, ' f'Duration: {float(duration.mean()):.3f}s') if profiling: with torch_profile(): train(model, optimizer, data) @torch.no_grad() def run_inference(dataset, model, epochs, profiling, bf16, use_compile, permute_masks=None, logger=None): data = dataset[0] if permute_masks is not None: data = permute_masks(data, dataset.num_classes) data = data.to(device) model.to(device).reset_parameters() if use_compile: model = torch.compile(model) if torch.cuda.is_available(): amp = torch.amp.autocast('cuda', enabled=False) else: amp = torch.cpu.amp.autocast(enabled=bf16) if bf16: data.x = data.x.to(torch.bfloat16) with amp: for epoch in range(1, epochs + 1): if epoch == epochs: with timeit(): inference(model, data) else: inference(model, data) if profiling: with torch_profile(): inference(model, data) def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping, inference, profiling, bf16, use_compile, permute_masks=None, logger=None): if not inference: run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping, profiling, use_compile, permute_masks, logger) else: run_inference(dataset, model, epochs, profiling, bf16, use_compile, permute_masks, logger) def train(model, optimizer, data): model.train() optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() @torch.no_grad() def evaluate(model, data): model.eval() out = model(data) outs = {} for key in ['train', 'val', 'test']: mask = data[f'{key}_mask'] loss = float(F.nll_loss(out[mask], data.y[mask])) pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() outs[f'{key}_loss'] = loss outs[f'{key}_acc'] = acc return outs @torch.no_grad() def inference(model, data): model.eval() model(data) ================================================ FILE: benchmark/inference/README.md ================================================ # Inference Benchmark ## Environment setup 1. Confirm that PyG is properly installed. 1. Install dataset package: ```bash pip install ogb ``` 1. Install `autoconf` required for `jemalloc` setup: ```bash sudo apt-get install autoconf ``` 1. Install `jemalloc` for performance benchmark: ```bash cd ${workspace} git clone https://github.com/jemalloc/jemalloc.git cd jemalloc git checkout 5.2.1 ./autogen.sh ./configure --prefix=${workspace}/jemalloc-bin make make install ``` ## Running benchmark 1. Set environment variables: ```bash source activate env_name export DNNL_PRIMITIVE_CACHE_CAPACITY=1024 export KMP_BLOCKTIME=1 export KMP_AFFINITY=granularity=fine,compact,1,0 jemalloc_lib=${workspace}/jemalloc-bin/lib/libjemalloc.so export LD_PRELOAD="$jemalloc_lib" export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000" ``` 1. Core binding, *e.g.*, single socket / single core / 4 cores per instance: ```bash OMP_NUM_THREADS=${CORES} numactl -C 0-${LAST_CORE} -m 0 CMD...... ``` 1. Execute benchmarks, *e.g.*: ```bash python -u inference_benchmark.py --datasets=Reddit --models=gcn --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64 python -u inference_benchmark.py --datasets=Reddit --models=gcn --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --use-sparse-tensor python -u inference_benchmark.py --datasets=ogbn-products --models=sage --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64 python -u inference_benchmark.py --datasets=ogbn-products --models=sage --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --use-sparse-tensor ``` ================================================ FILE: benchmark/inference/inference_benchmark.py ================================================ import argparse import warnings from collections import defaultdict from contextlib import nullcontext import torch from benchmark.utils import ( emit_itt, get_dataset_with_transformation, get_model, get_split_masks, save_benchmark_data, test, write_to_csv, ) from torch_geometric.io import fs from torch_geometric.loader import NeighborLoader from torch_geometric.nn import PNAConv from torch_geometric.profile import ( rename_profile_file, timeit, torch_profile, xpu_profile, ) supported_sets = { 'ogbn-mag': ['rgat', 'rgcn'], 'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'], 'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'], } @torch.no_grad() def full_batch_inference(model, data): model.eval() if hasattr(data, 'adj_t'): edge_index = data.adj_t else: edge_index = data.edge_index return model(data.x, edge_index) def run(args: argparse.ArgumentParser): csv_data = defaultdict(list) if args.write_csv == 'prof' and not args.profile: warnings.warn( "Cannot write profile data to CSV because profiling is " "disabled", stacklevel=2) if args.device == 'xpu': try: import intel_extension_for_pytorch as ipex except ImportError as e: raise RuntimeError( 'XPU device requires IPEX to be installed') from e if ((args.device == 'cuda' and not torch.cuda.is_available()) or (args.device == 'xpu' and not torch.xpu.is_available())): raise RuntimeError(f'{args.device.upper()} is not available') if args.device == 'cuda' and args.full_batch: raise RuntimeError('CUDA device is not suitable for full batch mode') device = torch.device(args.device) print('BENCHMARK STARTS') print(f'Running on {args.device.upper()}') for dataset_name in args.datasets: assert dataset_name in supported_sets.keys( ), f"Dataset {dataset_name} isn't supported." print(f'Dataset: {dataset_name}') load_time = timeit() if args.measure_load_time else nullcontext() with load_time: result = get_dataset_with_transformation(dataset_name, args.root, args.use_sparse_tensor, args.bf16) dataset, num_classes, transformation = result data = dataset.to(device) hetero = True if dataset_name == 'ogbn-mag' else False mask = ('paper', None) if dataset_name == 'ogbn-mag' else None _, _, test_mask = get_split_masks(data, dataset_name) degree = None if hetero and args.cached_loader: args.cached_loader = False print('Disabling CachedLoader, not supported in Hetero models') if args.num_layers != [1] and not hetero and args.num_steps != -1: raise ValueError("Layer-wise inference requires `steps=-1`") if args.device == 'cuda': amp = torch.amp.autocast('cuda', enabled=False) elif args.device == 'xpu': amp = torch.xpu.amp.autocast(enabled=False) else: amp = torch.cpu.amp.autocast(enabled=args.bf16) if args.device == 'xpu' and args.warmup < 1: print('XPU device requires warmup - setting warmup=1') args.warmup = 1 inputs_channels = data[ 'paper'].num_features if dataset_name == 'ogbn-mag' \ else dataset.num_features for model_name in args.models: if model_name not in supported_sets[dataset_name]: print(f'Configuration of {dataset_name} + {model_name} ' f'not supported. Skipping.') continue with_loader = not args.full_batch or (model_name == 'pna' and degree is None) print(f'Evaluation bench for {model_name}:') for batch_size in args.eval_batch_sizes: num_nodes = data[ 'paper'].num_nodes if hetero else data.num_nodes sampler = torch.utils.data.RandomSampler( range(num_nodes), num_samples=args.num_steps * batch_size ) if args.num_steps != -1 and with_loader else None kwargs = { 'batch_size': batch_size, 'shuffle': False, 'num_workers': args.num_workers, } if not hetero: subgraph_loader = NeighborLoader( data, num_neighbors=[-1], # layer-wise inference input_nodes=mask, sampler=sampler, **kwargs, ) if with_loader else None if args.evaluate and not args.full_batch: test_loader = NeighborLoader( data, num_neighbors=[-1], # layer-wise inference input_nodes=test_mask, sampler=None, **kwargs, ) for layers in args.num_layers: num_neighbors = [args.hetero_num_neighbors] * layers if hetero: # batch-wise inference subgraph_loader = NeighborLoader( data, num_neighbors=num_neighbors, input_nodes=mask, sampler=sampler, **kwargs, ) if with_loader else None if args.evaluate and not args.full_batch: test_loader = NeighborLoader( data, num_neighbors=num_neighbors, input_nodes=test_mask, sampler=None, **kwargs, ) for hidden_channels in args.num_hidden_channels: print('----------------------------------------------') print(f'Batch size={batch_size}, ' f'Layers amount={layers}, ' f'Num_neighbors={num_neighbors}, ' f'Hidden features size={hidden_channels}, ' f'Sparse tensor={args.use_sparse_tensor}') params = { 'inputs_channels': inputs_channels, 'hidden_channels': hidden_channels, 'output_channels': num_classes, 'num_heads': args.num_heads, 'num_layers': layers, } if model_name == 'pna': if degree is None: degree = PNAConv.get_degree_histogram( subgraph_loader) print(f'Calculated degree for {dataset_name}.') params['degree'] = degree model = get_model( model_name, params, metadata=data.metadata() if hetero else None) model = model.to(device) # TODO: Migrate to ModelHubMixin. if args.ckpt_path: state_dict = fs.torch_load(args.ckpt_path) model.load_state_dict(state_dict) model.eval() if args.device == 'xpu': model = ipex.optimize(model) # Define context manager parameters: if args.cpu_affinity and with_loader: cpu_affinity = subgraph_loader.enable_cpu_affinity( args.loader_cores) else: cpu_affinity = nullcontext() if args.profile and args.device == 'xpu': profile = xpu_profile(args.export_chrome_trace) elif args.profile: profile = torch_profile(args.export_chrome_trace, csv_data, args.write_csv) else: profile = nullcontext() itt = emit_itt( ) if args.vtune_profile else nullcontext() if args.full_batch and args.use_sparse_tensor: data = transformation(data) with cpu_affinity, amp, timeit() as time: inference_kwargs = dict(cache=args.cached_loader) if args.reuse_device_for_embeddings and not hetero: inference_kwargs['embedding_device'] = device for _ in range(args.warmup): if args.full_batch: full_batch_inference(model, data) else: model.inference( subgraph_loader, device, progress_bar=True, **inference_kwargs, ) if args.warmup > 0: time.reset() with itt, profile: if args.full_batch: y = full_batch_inference(model, data) if args.evaluate: mask = data.test_mask pred = y[mask].argmax(1) test_acc = pred.eq(data.y[mask]).sum( ).item() / mask.sum().item() print(f'Full Batch Test Accuracy: \ {test_acc:.4f}') else: y = model.inference( subgraph_loader, device, progress_bar=True, **inference_kwargs, ) if args.evaluate: test_acc = test( model, test_loader, device, hetero, progress_bar=True, ) print(f'Mini Batch Test Accuracy: \ {test_acc:.4f}') if args.profile and args.export_chrome_trace: rename_profile_file(model_name, dataset_name, str(batch_size), str(layers), str(hidden_channels), str(num_neighbors)) total_time = time.duration if args.num_steps != -1: total_num_samples = args.num_steps * batch_size else: total_num_samples = num_nodes throughput = total_num_samples / total_time latency = total_time / total_num_samples * 1000 print(f'Throughput: {throughput:.3f} samples/s') print(f'Latency: {latency:.3f} ms') num_records = 1 if args.write_csv == 'prof': # For profiling with PyTorch, we save the top-5 # most time consuming operations. Therefore, the # same data should be entered for each of them. num_records = 5 for _ in range(num_records): save_benchmark_data( csv_data, batch_size, layers, num_neighbors, hidden_channels, total_time, model_name, dataset_name, args.use_sparse_tensor, ) if args.write_csv: write_to_csv(csv_data, args.write_csv) if __name__ == '__main__': argparser = argparse.ArgumentParser('GNN inference benchmark') add = argparser.add_argument add('--device', choices=['cpu', 'cuda', 'xpu'], default='cpu', help='Device to run benchmark on') add('--reuse-device-for-embeddings', action='store_true', help='Use the same device for embeddings as specified in "--device"') add('--datasets', nargs='+', default=['ogbn-mag', 'ogbn-products', 'Reddit'], type=str) add('--use-sparse-tensor', action='store_true', help='use torch_sparse.SparseTensor as graph storage format') add('--models', nargs='+', default=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn'], type=str) add('--root', default='../../data', type=str, help='relative path to look for the datasets') add('--eval-batch-sizes', nargs='+', default=[512, 1024, 2048, 4096, 8192], type=int) add('--num-layers', nargs='+', default=[2, 3], type=int) add('--num-hidden-channels', nargs='+', default=[64, 128, 256], type=int) add('--num-heads', default=2, type=int, help='number of hidden attention heads, applies only for gat and rgat') add('--hetero-num-neighbors', default=10, type=int, help='number of neighbors to sample per layer for hetero workloads') add('--num-workers', default=0, type=int) add('--num-steps', default=-1, type=int, help='number of steps, -1 means iterating through all the data') add('--warmup', default=1, type=int) add('--profile', action='store_true') add('--vtune-profile', action='store_true') add('--bf16', action='store_true') add('--cpu-affinity', action='store_true', help='Use DataLoader affinitzation.') add('--loader-cores', nargs='+', default=[], type=int, help="List of CPU core IDs to use for DataLoader workers") add('--measure-load-time', action='store_true') add('--full-batch', action='store_true', help='Use full batch mode') add('--evaluate', action='store_true') add('--ckpt_path', type=str, help='Checkpoint path for loading a model') add('--write-csv', choices=[None, 'bench', 'prof'], default=None, help='Write benchmark or PyTorch profile data to CSV') add('--export-chrome-trace', default=True, type=bool, help='Export chrome trace file. Works only with PyTorch profiler') add('--cached-loader', action='store_true', help='Use CachedLoader') run(argparser.parse_args()) ================================================ FILE: benchmark/kernel/README.md ================================================ # Graph Classification Evaluation script for various methods on [common benchmark datasets](https://chrsmrrs.github.io/datasets/) via 10-fold cross validation, where a training fold is randomly sampled to serve as a validation set. Hyperparameter selection is performed for the number of hidden units and the number of layers with respect to the validation set: - **[GCN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/gcn.py)** - **[GraphSAGE](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/graph_sage.py)** - **[GIN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/gin.py)** - **[Graclus](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/graclus.py)** - **[Top-K Pooling](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/top_k.py)** - **[SAG Pooling](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sag_pool.py)** - **[DiffPool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/diff_pool.py)** - **[EdgePool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/edge_pool.py)** - **[GlobalAttention](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)** - **[Set2Set](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)** - **[SortPool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)** - **[ASAPool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/asap.py)** Run (or modify) the whole test suite via ``` $ python main.py ``` For more comprehensive time-measurement and memory usage information, you may use ``` $ python main_performance.py ``` ================================================ FILE: benchmark/kernel/__init__.py ================================================ from .datasets import get_dataset from .train_eval import cross_validation_with_val_set __all__ = [ 'get_dataset', 'cross_validation_with_val_set', ] ================================================ FILE: benchmark/kernel/asap.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import ( ASAPooling, GraphConv, JumpingKnowledge, global_mean_pool, ) class ASAP(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.8, dropout=0): super().__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.convs.extend([ GraphConv(hidden, hidden, aggr='mean') for i in range(num_layers - 1) ]) self.pools.extend([ ASAPooling(hidden, ratio, dropout=dropout) for i in range((num_layers) // 2) ]) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch edge_weight = None x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight) x = F.relu(x) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, edge_weight, batch, _ = pool( x=x, edge_index=edge_index, edge_weight=edge_weight, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/datasets.py ================================================ import os.path as osp import torch import torch_geometric.transforms as T from torch_geometric.datasets import TUDataset from torch_geometric.utils import degree class NormalizedDegree: def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, data): deg = degree(data.edge_index[0], dtype=torch.float) deg = (deg - self.mean) / self.std data.x = deg.view(-1, 1) return data def get_dataset(name, sparse=True, cleaned=False): path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) dataset = TUDataset(path, name, cleaned=cleaned) dataset.data.edge_attr = None if dataset.data.x is None: max_degree = 0 degs = [] for data in dataset: degs += [degree(data.edge_index[0], dtype=torch.long)] max_degree = max(max_degree, degs[-1].max().item()) if max_degree < 1000: dataset.transform = T.OneHotDegree(max_degree) else: deg = torch.cat(degs, dim=0).to(torch.float) mean, std = deg.mean().item(), deg.std().item() dataset.transform = NormalizedDegree(mean, std) if not sparse: num_nodes = max_num_nodes = 0 for data in dataset: num_nodes += data.num_nodes max_num_nodes = max(data.num_nodes, max_num_nodes) # Filter out a few really large graphs in order to apply DiffPool. if name == 'REDDIT-BINARY': num_nodes = min(int(num_nodes / len(dataset) * 1.5), max_num_nodes) else: num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes) indices = [] for i, data in enumerate(dataset): if data.num_nodes <= num_nodes: indices.append(i) dataset = dataset.copy(torch.tensor(indices)) if dataset.transform is None: dataset.transform = T.ToDense(num_nodes) else: dataset.transform = T.Compose( [dataset.transform, T.ToDense(num_nodes)]) return dataset ================================================ FILE: benchmark/kernel/diff_pool.py ================================================ from math import ceil import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import DenseSAGEConv, JumpingKnowledge, dense_diff_pool class Block(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'): super().__init__() self.conv1 = DenseSAGEConv(in_channels, hidden_channels) self.conv2 = DenseSAGEConv(hidden_channels, out_channels) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin = Linear(hidden_channels + out_channels, out_channels) else: self.lin = Linear(out_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.lin.reset_parameters() def forward(self, x, adj, mask=None): x1 = F.relu(self.conv1(x, adj, mask)) x2 = F.relu(self.conv2(x1, adj, mask)) return self.lin(self.jump([x1, x2])) class DiffPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.25): super().__init__() num_nodes = ceil(ratio * dataset[0].num_nodes) self.embed_block1 = Block(dataset.num_features, hidden, hidden) self.pool_block1 = Block(dataset.num_features, hidden, num_nodes) self.embed_blocks = torch.nn.ModuleList() self.pool_blocks = torch.nn.ModuleList() for _ in range((num_layers // 2) - 1): num_nodes = ceil(ratio * num_nodes) self.embed_blocks.append(Block(hidden, hidden, hidden)) self.pool_blocks.append(Block(hidden, hidden, num_nodes)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.embed_block1.reset_parameters() self.pool_block1.reset_parameters() for embed_block, pool_block in zip(self.embed_blocks, self.pool_blocks): embed_block.reset_parameters() pool_block.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, adj, mask = data.x, data.adj, data.mask s = self.pool_block1(x, adj, mask) x = F.relu(self.embed_block1(x, adj, mask)) xs = [x.mean(dim=1)] x, adj, _, _ = dense_diff_pool(x, adj, s, mask) for i, (embed_block, pool_block) in enumerate( zip(self.embed_blocks, self.pool_blocks)): s = pool_block(x, adj) x = F.relu(embed_block(x, adj)) xs.append(x.mean(dim=1)) if i < len(self.embed_blocks) - 1: x, adj, _, _ = dense_diff_pool(x, adj, s) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/edge_pool.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import ( EdgePooling, GraphConv, JumpingKnowledge, global_mean_pool, ) class EdgePool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.convs.extend([ GraphConv(hidden, hidden, aggr='mean') for i in range(num_layers - 1) ]) self.pools.extend( [EdgePooling(hidden) for i in range((num_layers) // 2)]) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, batch, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/gcn.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import GCNConv, JumpingKnowledge, global_mean_pool class GCN(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = GCNConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ class GCNWithJK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, mode='cat'): super().__init__() self.conv1 = GCNConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = Linear(num_layers * hidden, hidden) else: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [x] for conv in self.convs: x = F.relu(conv(x, edge_index)) xs += [x] x = self.jump(xs) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/gin.py ================================================ import torch import torch.nn.functional as F from torch.nn import BatchNorm1d as BN from torch.nn import Linear, ReLU, Sequential from torch_geometric.nn import GINConv, JumpingKnowledge, global_mean_pool class GIN0(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = GINConv( Sequential( Linear(dataset.num_features, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=False) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append( GINConv( Sequential( Linear(hidden, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=False)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.conv1(x, edge_index) for conv in self.convs: x = conv(x, edge_index) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ class GIN0WithJK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, mode='cat'): super().__init__() self.conv1 = GINConv( Sequential( Linear(dataset.num_features, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=False) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append( GINConv( Sequential( Linear(hidden, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=False)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = Linear(num_layers * hidden, hidden) else: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.conv1(x, edge_index) xs = [x] for conv in self.convs: x = conv(x, edge_index) xs += [x] x = self.jump(xs) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ class GIN(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = GINConv( Sequential( Linear(dataset.num_features, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=True) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append( GINConv( Sequential( Linear(hidden, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=True)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.conv1(x, edge_index) for conv in self.convs: x = conv(x, edge_index) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ class GINWithJK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, mode='cat'): super().__init__() self.conv1 = GINConv( Sequential( Linear(dataset.num_features, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=True) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append( GINConv( Sequential( Linear(hidden, hidden), ReLU(), BN(hidden), Linear(hidden, hidden), ReLU(), BN(hidden), ), train_eps=True)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = Linear(num_layers * hidden, hidden) else: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.conv1(x, edge_index) xs = [x] for conv in self.convs: x = conv(x, edge_index) xs += [x] x = self.jump(xs) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/global_attention.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import AttentionalAggregation, SAGEConv class GlobalAttentionNet(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.att = AttentionalAggregation(Linear(hidden, 1)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.att.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = self.att(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/graclus.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.data import Batch from torch_geometric.nn import ( GraphConv, JumpingKnowledge, global_mean_pool, graclus, max_pool, ) class Graclus(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden, aggr='mean')) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: cluster = graclus(edge_index, num_nodes=x.size(0)) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool(cluster, data) x, edge_index, batch = data.x, data.edge_index, data.batch x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/graph_sage.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import JumpingKnowledge, SAGEConv, global_add_pool class GraphSAGE(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_add_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ class GraphSAGEWithJK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, mode='cat'): super().__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin1 = Linear(num_layers * hidden, hidden) else: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [x] for conv in self.convs: x = F.relu(conv(x, edge_index)) xs += [x] x = self.jump(xs) x = global_add_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/main.py ================================================ import argparse from itertools import product from asap import ASAP from datasets import get_dataset from diff_pool import DiffPool from edge_pool import EdgePool from gcn import GCN, GCNWithJK from gin import GIN, GIN0, GIN0WithJK, GINWithJK from global_attention import GlobalAttentionNet from graclus import Graclus from graph_sage import GraphSAGE, GraphSAGEWithJK from sag_pool import SAGPool from set2set import Set2SetNet from sort_pool import SortPool from top_k import TopK from train_eval import cross_validation_with_val_set parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--lr_decay_factor', type=float, default=0.5) parser.add_argument('--lr_decay_step_size', type=int, default=50) args = parser.parse_args() layers = [1, 2, 3, 4, 5] hiddens = [16, 32, 64, 128] datasets = ['MUTAG', 'PROTEINS', 'IMDB-BINARY', 'REDDIT-BINARY'] # , 'COLLAB'] nets = [ GCNWithJK, GraphSAGEWithJK, GIN0WithJK, GINWithJK, Graclus, TopK, SAGPool, DiffPool, EdgePool, GCN, GraphSAGE, GIN0, GIN, GlobalAttentionNet, Set2SetNet, SortPool, ASAP, ] def logger(info): fold, epoch = info['fold'] + 1, info['epoch'] val_loss, test_acc = info['val_loss'], info['test_acc'] print(f'{fold:02d}/{epoch:03d}: Val Loss: {val_loss:.4f}, ' f'Test Accuracy: {test_acc:.3f}') results = [] for dataset_name, Net in product(datasets, nets): best_result = (float('inf'), 0, 0) # (loss, acc, std) print(f'--\n{dataset_name} - {Net.__name__}') for num_layers, hidden in product(layers, hiddens): dataset = get_dataset(dataset_name, sparse=Net != DiffPool) model = Net(dataset, num_layers, hidden) loss, acc, std = cross_validation_with_val_set( dataset, model, folds=10, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, lr_decay_factor=args.lr_decay_factor, lr_decay_step_size=args.lr_decay_step_size, weight_decay=0, logger=None, ) if loss < best_result[0]: best_result = (loss, acc, std) desc = f'{best_result[1]:.3f} ± {best_result[2]:.3f}' print(f'Best result - {desc}') results += [f'{dataset_name} - {model}: {desc}'] results = '\n'.join(results) print(f'--\n{results}') ================================================ FILE: benchmark/kernel/main_performance.py ================================================ import argparse from itertools import product import torch from datasets import get_dataset from gcn import GCN from gin import GIN from graph_sage import GraphSAGE from train_eval import eval_acc, inference_run, train from torch_geometric import seed_everything from torch_geometric.loader import DataLoader from torch_geometric.profile import rename_profile_file, timeit, torch_profile seed_everything(0) parser = argparse.ArgumentParser() parser.add_argument( '--datasets', type=str, nargs='+', default=['MUTAG', 'PROTEINS', 'IMDB-BINARY', 'REDDIT-BINARY']) parser.add_argument('--models', type=str, nargs='+', default=['GCN', 'GraphSAGE', 'GIN']) parser.add_argument('--layers', type=int, nargs='+', default=[1, 2, 3]) parser.add_argument('--hiddens', type=int, nargs='+', default=[16, 32]) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--warmup_profile', type=int, default=1, help='Skip the first few runs') parser.add_argument('--goal_accuracy', type=int, default=1, help='The goal test accuracy') parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') if torch.cuda.is_available(): amp = torch.amp.autocast('cuda', enabled=False) else: amp = torch.cpu.amp.autocast(enabled=args.bf16) MODELS = { 'GCN': GCN, 'GraphSAGE': GraphSAGE, 'GIN': GIN, } def prepare_dataloader(dataset_name): dataset = get_dataset(dataset_name, sparse=True) num_train = int(len(dataset) * 0.8) num_val = int(len(dataset) * 0.1) train_dataset = dataset[:num_train] val_dataset = dataset[num_train:num_train + num_val] test_dataset = dataset[num_train + num_val:] train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) return dataset, train_loader, val_loader, test_loader def run_train(): for dataset_name, model_name in product(args.datasets, args.models): dataset, train_loader, val_loader, test_loader = prepare_dataloader( dataset_name) Model = MODELS[model_name] for num_layers, hidden in product(args.layers, args.hiddens): print('--') print(f'{dataset_name} - {model_name}- {num_layers} - {hidden}') model = Model(dataset, num_layers, hidden).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.compile: model = torch.compile(model) loss_list = [] acc_list = [] for epoch in range(1, args.epochs + 1): if epoch == args.epochs: with timeit(): loss = train(model, optimizer, train_loader) else: loss = train(model, optimizer, train_loader) with timeit(log=False) as t: val_acc = eval_acc(model, val_loader) val_time = t.duration with timeit(log=False) as t: test_acc = eval_acc(model, test_loader) test_time = t.duration if epoch >= args.warmup_profile: loss_list.append(loss) acc_list.append([val_acc, val_time, test_acc, test_time]) if args.profile: with torch_profile(): train(model, optimizer, train_loader) rename_profile_file(model_name, dataset_name, str(num_layers), str(hidden), 'train') @torch.no_grad() def run_inference(): for dataset_name, model_name in product(args.datasets, args.models): dataset, _, _, test_loader = prepare_dataloader(dataset_name) Model = MODELS[model_name] for num_layers, hidden in product(args.layers, args.hiddens): print('--') print(f'{dataset_name} - {model_name}- {num_layers} - {hidden}') model = Model(dataset, num_layers, hidden).to(device) if args.compile: model = torch.compile(model) with amp: for epoch in range(1, args.epochs + 1): if epoch == args.epochs: with timeit(): inference_run(model, test_loader, args.bf16) else: inference_run(model, test_loader, args.bf16) if args.profile: with torch_profile(): inference_run(model, test_loader, args.bf16) rename_profile_file(model_name, dataset_name, str(num_layers), str(hidden), 'inference') if not args.inference: run_train() else: run_inference() ================================================ FILE: benchmark/kernel/sag_pool.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import ( GraphConv, JumpingKnowledge, SAGPooling, global_mean_pool, ) class SAGPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.8): super().__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.convs.extend([ GraphConv(hidden, hidden, aggr='mean') for i in range(num_layers - 1) ]) self.pools.extend( [SAGPooling(hidden, ratio) for i in range((num_layers) // 2)]) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, _, batch, _, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/set2set.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import SAGEConv, Set2Set class Set2SetNet(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.set2set = Set2Set(hidden, processing_steps=4) self.lin1 = Linear(2 * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.set2set.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = self.set2set(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/sort_pool.py ================================================ import torch import torch.nn.functional as F from torch.nn import Conv1d, Linear from torch_geometric.nn import SAGEConv, SortAggregation class SortPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.pool = SortAggregation(k=30) self.conv1d = Conv1d(hidden, 32, 5) self.lin1 = Linear(32 * (30 - 5 + 1), hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.conv1d.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = self.pool(x, batch) x = x.view(len(x), self.k, -1).permute(0, 2, 1) x = F.relu(self.conv1d(x)) x = x.view(len(x), -1) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/statistics.py ================================================ from kernel.datasets import get_dataset def print_dataset(dataset): num_nodes = num_edges = 0 for data in dataset: num_nodes += data.num_nodes num_edges += data.num_edges print('Name', dataset) print('Graphs', len(dataset)) print('Nodes', num_nodes / len(dataset)) print('Edges', (num_edges // 2) / len(dataset)) print('Features', dataset.num_features) print('Classes', dataset.num_classes) print() for name in ['MUTAG', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY']: print_dataset(get_dataset(name)) ================================================ FILE: benchmark/kernel/top_k.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.nn import ( GraphConv, JumpingKnowledge, TopKPooling, global_mean_pool, ) class TopK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.8): super().__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.convs.extend([ GraphConv(hidden, hidden, aggr='mean') for i in range(num_layers - 1) ]) self.pools.extend( [TopKPooling(hidden, ratio) for i in range((num_layers) // 2)]) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, _, batch, _, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__ ================================================ FILE: benchmark/kernel/train_eval.py ================================================ import time import torch import torch.nn.functional as F from sklearn.model_selection import StratifiedKFold from torch import tensor from torch.optim import Adam from torch_geometric.loader import DataLoader from torch_geometric.loader import DenseDataLoader as DenseLoader if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') def cross_validation_with_val_set(dataset, model, folds, epochs, batch_size, lr, lr_decay_factor, lr_decay_step_size, weight_decay, logger=None): val_losses, accs, durations = [], [], [] for fold, (train_idx, test_idx, val_idx) in enumerate(zip(*k_fold(dataset, folds))): train_dataset = dataset[train_idx] test_dataset = dataset[test_idx] val_dataset = dataset[val_idx] if 'adj' in train_dataset[0]: train_loader = DenseLoader(train_dataset, batch_size, shuffle=True) val_loader = DenseLoader(val_dataset, batch_size, shuffle=False) test_loader = DenseLoader(test_dataset, batch_size, shuffle=False) else: train_loader = DataLoader(train_dataset, batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size, shuffle=False) model.to(device).reset_parameters() optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): try: torch.mps.synchronize() except ImportError: pass t_start = time.perf_counter() for epoch in range(1, epochs + 1): train_loss = train(model, optimizer, train_loader) val_losses.append(eval_loss(model, val_loader)) accs.append(eval_acc(model, test_loader)) eval_info = { 'fold': fold, 'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_losses[-1], 'test_acc': accs[-1], } if logger is not None: logger(eval_info) if epoch % lr_decay_step_size == 0: for param_group in optimizer.param_groups: param_group['lr'] = lr_decay_factor * param_group['lr'] if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() t_end = time.perf_counter() durations.append(t_end - t_start) loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations) loss, acc = loss.view(folds, epochs), acc.view(folds, epochs) loss, argmin = loss.min(dim=1) acc = acc[torch.arange(folds, dtype=torch.long), argmin] loss_mean = loss.mean().item() acc_mean = acc.mean().item() acc_std = acc.std().item() duration_mean = duration.mean().item() print(f'Val Loss: {loss_mean:.4f}, Test Accuracy: {acc_mean:.3f} ' f'± {acc_std:.3f}, Duration: {duration_mean:.3f}') return loss_mean, acc_mean, acc_std def k_fold(dataset, folds): skf = StratifiedKFold(folds, shuffle=True, random_state=12345) test_indices, train_indices = [], [] for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y): test_indices.append(torch.from_numpy(idx).to(torch.long)) val_indices = [test_indices[i - 1] for i in range(folds)] for i in range(folds): train_mask = torch.ones(len(dataset), dtype=torch.bool) train_mask[test_indices[i]] = 0 train_mask[val_indices[i]] = 0 train_indices.append(train_mask.nonzero(as_tuple=False).view(-1)) return train_indices, test_indices, val_indices def num_graphs(data): if hasattr(data, 'num_graphs'): return data.num_graphs else: return data.x.size(0) def train(model, optimizer, loader): model.train() total_loss = 0 for data in loader: optimizer.zero_grad() data = data.to(device) out = model(data) loss = F.nll_loss(out, data.y.view(-1)) loss.backward() total_loss += loss.item() * num_graphs(data) optimizer.step() return total_loss / len(loader.dataset) def eval_acc(model, loader): model.eval() correct = 0 for data in loader: data = data.to(device) with torch.no_grad(): pred = model(data).max(1)[1] correct += pred.eq(data.y.view(-1)).sum().item() return correct / len(loader.dataset) def eval_loss(model, loader): model.eval() loss = 0 for data in loader: data = data.to(device) with torch.no_grad(): out = model(data) loss += F.nll_loss(out, data.y.view(-1), reduction='sum').item() return loss / len(loader.dataset) @torch.no_grad() def inference_run(model, loader, bf16): model.eval() for data in loader: data = data.to(device) if bf16: data.x = data.x.to(torch.bfloat16) model(data) ================================================ FILE: benchmark/loader/neighbor_loader.py ================================================ import argparse import ast import os.path as osp from contextlib import nullcontext from timeit import default_timer import tqdm from ogb.nodeproppred import PygNodePropPredDataset import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import NeighborLoader from torch_geometric.profile import torch_profile def run(args: argparse.ArgumentParser): for dataset_name in args.datasets: print(f"Dataset: {dataset_name}") root = osp.join(args.root, dataset_name) transform = T.ToSparseTensor( remove_edge_index=False) if args.use_sparse_tensor else None if dataset_name == 'mag': transform = (T.ToUndirected(merge=True) if transform is None else T.Compose([T.ToUndirected(merge=True), transform])) dataset = OGB_MAG(root=root, transform=transform) train_idx = ('paper', dataset[0]['paper'].train_mask) eval_idx = ('paper', None) neighbor_sizes = (args.hetero_neighbor_sizes if args.hetero_neighbor_sizes else None) else: dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root) split_idx = dataset.get_idx_split() train_idx = split_idx['train'] eval_idx = None neighbor_sizes = (args.homo_neighbor_sizes if args.homo_neighbor_sizes else None) data = dataset[0].to(args.device) average_times = [] profile = torch_profile() if args.profile else nullcontext() # run dataloader iteration if neighbor_sizes is not None: for num_neighbors in neighbor_sizes: print(f'Training sampling with {num_neighbors} neighbors') for batch_size in args.batch_sizes: train_loader = NeighborLoader( data, num_neighbors=num_neighbors, input_nodes=train_idx, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, subgraph_type=args.subgraph_type, ) cpu_affinity = train_loader.enable_cpu_affinity( args.loader_cores ) if args.cpu_affinity else nullcontext() runtimes = [] num_iterations = 0 with profile, cpu_affinity: for _ in range(args.runs): start = default_timer() for _ in tqdm.tqdm(train_loader): num_iterations += 1 stop = default_timer() runtimes.append(round(stop - start, 3)) average_time = round(sum(runtimes) / args.runs, 3) print(f'batch size={batch_size}, ' f'iterations={num_iterations}, ' f'runtimes={runtimes}, ' f'average runtime={average_time}') average_times.append(average_time) eval_batch_sizes = (args.eval_batch_sizes if args.eval_batch_sizes else None) if eval_batch_sizes is not None: print('Evaluation sampling with all neighbors') for batch_size in eval_batch_sizes: subgraph_loader = NeighborLoader( data, num_neighbors=[-1], input_nodes=eval_idx, batch_size=batch_size, shuffle=False, num_workers=args.num_workers, ) cpu_affinity = subgraph_loader.enable_cpu_affinity( args.loader_cores) if args.cpu_affinity else nullcontext() runtimes = [] num_iterations = 0 with profile, cpu_affinity: for _ in range(args.runs): start = default_timer() for _ in tqdm.tqdm(subgraph_loader): num_iterations += 1 stop = default_timer() runtimes.append(round(stop - start, 3)) average_time = round(sum(runtimes) / args.runs, 3) print(f'batch size={batch_size}, ' f'iterations={num_iterations}, ' f'runtimes={runtimes}, ' f'average runtime={average_time}') average_times.append(average_time) print(f"Total time averages: {average_times}") if __name__ == '__main__': parser = argparse.ArgumentParser('NeighborLoader Sampling Benchmarking') add = parser.add_argument add('--device', default='cpu') add('--datasets', nargs="+", default=['arxiv', 'products', 'mag']) add('--root', default='../../data') add('--batch-sizes', default=[8192, 4096, 2048, 1024, 512], type=ast.literal_eval) add('--eval-batch-sizes', default=[16384, 8192, 4096, 2048, 1024, 512], type=ast.literal_eval) add('--homo-neighbor_sizes', default=[[10, 5], [15, 10, 5], [20, 15, 10]], type=ast.literal_eval) add('--hetero-neighbor_sizes', default=[[5], [10], [10, 5]], type=ast.literal_eval) add('--use-sparse-tensor', action='store_true', help='use torch_sparse.SparseTensor as graph storage format') add('--num-workers', type=int, default=0, help="Number of DataLoader workers to use.") add('--runs', type=int, default=3, help="Number of iterations for each test setting.") add('--profile', default=False, action='store_true', help="Run torch.profiler.") add('--cpu-affinity', default=False, action='store_true', help="Use DataLoader affinitzation.") add('--loader-cores', nargs='+', default=[], type=int, help="List of CPU core IDs to use for DataLoader workers.") add('--subgraph-type', type=str, default='directional', help="The type of the returned subgraph (directional, bidirectional)") run(parser.parse_args()) ================================================ FILE: benchmark/multi_gpu/training/README.md ================================================ # Training Benchmark ## Running benchmark on CUDA GPU Run benchmark, e.g. assuming you have `n` NVIDIA GPUs: ``` python training_benchmark_cuda.py --dataset ogbn-products --model edge_cnn --num-epochs 3 --n_gpus ``` ## Running benchmark on Intel GPU ### Environment setup ### Prerequisites - Intel Data Center GPU Max Series. You could try it through [Intel DevCloud](https://www.intel.com/content/www/us/en/developer/tools/devcloud/services.html). - Verify the Intel GPU Driver is installed, refer to the [guide](https://dgpu-docs.intel.com/driver/installation.html). ### docker setup If you want to run your scripts inside a docker image, you could refer to the [dockerfile](https://github.com/pyg-team/pytorch_geometric/blob/master/docker/Dockerfile.xpu) and the corresponding [guide](https://github.com/pyg-team/pytorch_geometric/blob/master/docker). ### bare-metal setup If you prefer to run your scripts directly on the bare-metal server. We recommend the installation guidance provided by [Intel® Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu&version=v2.1.30%2bxpu&os=linux%2fwsl2&package=pip). The following are some key steps: - Install [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html), indluding [Intel® oneAPI DPC++ Compiler](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compiler.html), [Intel® oneAPI Math Kernel Library (oneMKL)](https://www.intel.com/content/www/us/en/docs/oneapi/programming-guide/2024-1/intel-oneapi-math-kernel-library-onemkl.html), [Intel® oneAPI Collective Communications Library (oneCCL)](https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html), and [Intel® oneCCL Bindings for PyTorch](https://github.com/intel/torch-ccl). ```bash # Install oneCCL package on Ubuntu sudo apt install -y intel-oneapi-dpcpp-cpp-2024.1=2024.1.0-963 intel-oneapi-mkl-devel=2024.1.0-691 intel-oneapi-ccl-devel=2021.12.0-309 # Install oneccl_bindings_for_pytorch pip install oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # Runtime Dynamic Linking source /opt/intel/oneapi/setvars.sh ``` - Install [Intel® Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch) and the corresponding version of PyTorch ```bash pip install torch==2.1.0.post2 intel-extension-for-pytorch==2.1.30+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ ``` ### Running benchmark This [guide](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/features/DDP.html) is helpful for you to launch DDP training on intel GPU. To Run benchmark, e.g. assuming you have `n` XPUs: ``` mpirun -np python training_benchmark_xpu.py --dataset ogbn-products --model edge_cnn --num-epochs 3 ``` ================================================ FILE: benchmark/multi_gpu/training/common.py ================================================ import argparse import ast from time import perf_counter from typing import Any, Callable, Tuple, Union import torch import torch.distributed as dist import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from benchmark.utils import get_model, get_split_masks, test from torch_geometric.data import Data, HeteroData from torch_geometric.loader import NeighborLoader from torch_geometric.nn import PNAConv supported_sets = { 'ogbn-mag': ['rgat', 'rgcn'], 'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'], 'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'], } device_conditions = { 'xpu': (lambda: torch.xpu.is_available()), 'cuda': (lambda: torch.cuda.is_available()), } def train_homo(model: Any, loader: NeighborLoader, optimizer: torch.optim.Adam, device: torch.device) -> torch.Tensor: for batch in loader: optimizer.zero_grad() batch = batch.to(device) out = model(batch.x, batch.edge_index) batch_size = batch.batch_size out = out[:batch_size] target = batch.y[:batch_size] loss = F.cross_entropy(out, target) loss.backward() optimizer.step() return loss def train_hetero(model: Any, loader: NeighborLoader, optimizer: torch.optim.Adam, device: torch.device) -> torch.Tensor: for batch in loader: optimizer.zero_grad() batch = batch.to(device) out = model(batch.x_dict, batch.edge_index_dict) batch_size = batch['paper'].batch_size out = out['paper'][:batch_size] target = batch['paper'].y[:batch_size] loss = F.cross_entropy(out, target) loss.backward() optimizer.step() return loss def maybe_synchronize(device: str): if device == 'xpu' and torch.xpu.is_available(): torch.xpu.synchronize() if device == 'cuda' and torch.cuda.is_available(): torch.cuda.synchronize() def create_mask_per_rank( global_mask: Union[torch.Tensor, Tuple[str, torch.Tensor]], rank: int, world_size: int, hetero: bool = False) -> Union[torch.Tensor, Tuple[str, torch.Tensor]]: mask = global_mask[-1] if hetero else global_mask nonzero = mask.nonzero().reshape(-1) rank_indices = nonzero.split(nonzero.size(0) // world_size, dim=0)[rank].clone() mask_per_rank = torch.full_like(mask, False) mask_per_rank[rank_indices] = True if hetero: return tuple((global_mask[0], mask_per_rank)) else: return mask_per_rank def run(rank: int, world_size: int, args: argparse.ArgumentParser, num_classes: int, data: Union[Data, HeteroData], custom_optimizer: Callable[[Any, Any], Tuple[Any, Any]] = None): if not device_conditions[args.device](): raise RuntimeError(f'{args.device.upper()} is not available') device = torch.device(f'{args.device}:{rank}') if rank == 0: print('BENCHMARK STARTS') print(f'Running on {args.device.upper()}') print(f'Dataset: {args.dataset}') hetero = True if args.dataset == 'ogbn-mag' else False mask, val_mask, test_mask = get_split_masks(data, args.dataset) mask = create_mask_per_rank(mask, rank, world_size, hetero) degree = None inputs_channels = data[ 'paper'].num_features if args.dataset == 'ogbn-mag' \ else data.num_features if args.model not in supported_sets[args.dataset]: err_msg = (f'Configuration of {args.dataset} + {args.model}' 'not supported') raise RuntimeError(err_msg) if rank == 0: print(f'Training bench for {args.model}:') num_nodes = int(mask[-1].sum()) if hetero else int(mask.sum()) num_neighbors = args.num_neighbors if type(num_neighbors) is list: if len(num_neighbors) == 1: num_neighbors = num_neighbors * args.num_layers elif type(num_neighbors) is int: num_neighbors = [num_neighbors] * args.num_layers if len(num_neighbors) != args.num_layers: err_msg = (f'num_neighbors={num_neighbors} length != num of' 'layers={args.num_layers}') kwargs = { 'num_neighbors': num_neighbors, 'batch_size': args.batch_size, 'num_workers': args.num_workers, } subgraph_loader = NeighborLoader( data, input_nodes=mask, sampler=None, **kwargs, ) if rank == 0 and args.evaluate: val_loader = NeighborLoader( data, input_nodes=val_mask, sampler=None, **kwargs, ) test_loader = NeighborLoader( data, input_nodes=test_mask, sampler=None, **kwargs, ) if rank == 0: print('----------------------------------------------') print( f'Batch size={args.batch_size}, ' f'Layers amount={args.num_layers}, ' f'Num_neighbors={num_neighbors}, ' f'Hidden features size={args.num_hidden_channels}', flush=True) params = { 'inputs_channels': inputs_channels, 'hidden_channels': args.num_hidden_channels, 'output_channels': num_classes, 'num_heads': args.num_heads, 'num_layers': args.num_layers, } if args.model == 'pna' and degree is None: degree = PNAConv.get_degree_histogram(subgraph_loader) print(f'Rank: {rank}, calculated degree for {args.dataset}.', flush=True) params['degree'] = degree dist.barrier() torch.manual_seed(12345) model = get_model(args.model, params, metadata=data.metadata() if hetero else None) model = model.to(device) if hetero: model.eval() x_keys = data.metadata()[0] edge_index_keys = data.metadata()[1] fake_x_dict = { k: torch.rand((32, inputs_channels), device=device) for k in x_keys } fake_edge_index_dict = { k: torch.randint(0, 32, (2, 8), device=device) for k in edge_index_keys } model.forward(fake_x_dict, fake_edge_index_dict) model = DDP(model, device_ids=[device], find_unused_parameters=hetero) model.train() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) if custom_optimizer: model, optimizer = custom_optimizer(model, optimizer) train = train_hetero if hetero else train_homo maybe_synchronize(args.device) dist.barrier() if rank == 0: beg = perf_counter() for epoch in range(args.num_epochs): loss = train( model, subgraph_loader, optimizer, device, ) dist.barrier() if rank == 0: print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}', flush=True) if rank == 0 and args.evaluate: # In evaluate, throughput and # latency are not accurate. val_acc = test(model, val_loader, device, hetero, progress_bar=False) print(f'Val Accuracy: {val_acc:.4f}') dist.barrier() maybe_synchronize(args.device) dist.barrier() if rank == 0: end = perf_counter() duration = end - beg if rank == 0 and args.evaluate: test_acc = test(model, test_loader, device, hetero, progress_bar=False) print(f'Test Accuracy: {test_acc:.4f}') dist.barrier() if rank == 0: num_nodes_total = num_nodes * world_size duration_per_epoch = duration / args.num_epochs throughput = num_nodes_total / duration_per_epoch latency = duration_per_epoch / num_nodes_total * 1000 print(f'Time: {duration_per_epoch:.4f}s') print(f'Throughput: {throughput:.3f} samples/s') print(f'Latency: {latency:.3f} ms', flush=True) dist.destroy_process_group() def get_predefined_args() -> argparse.ArgumentParser: argparser = argparse.ArgumentParser( 'GNN distributed (DDP) training benchmark') add = argparser.add_argument add('--dataset', choices=['ogbn-mag', 'ogbn-products', 'Reddit'], default='Reddit', type=str) add('--model', choices=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn', 'sage'], default='sage', type=str) add('--root', default='../../data', type=str, help='relative path to look for the datasets') add('--batch-size', default=4096, type=int) add('--num-layers', default=3, type=int) add('--num-hidden-channels', default=128, type=int) add('--num-heads', default=2, type=int, help='number of hidden attention heads, applies only for gat and rgat') add('--num-neighbors', default=[10], type=ast.literal_eval, help='number of neighbors to sample per layer') add('--num-workers', default=0, type=int) add('--num-epochs', default=1, type=int) add('--evaluate', action='store_true') return argparser ================================================ FILE: benchmark/multi_gpu/training/training_benchmark_cuda.py ================================================ import argparse import os from typing import Union import torch import torch.distributed as dist import torch.multiprocessing as mp from benchmark.multi_gpu.training.common import ( get_predefined_args, run, supported_sets, ) from benchmark.utils import get_dataset from torch_geometric.data import Data, HeteroData def run_cuda(rank: int, world_size: int, args: argparse.ArgumentParser, num_classes: int, data: Union[Data, HeteroData]): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) run(rank, world_size, args, num_classes, data) if __name__ == '__main__': argparser = get_predefined_args() argparser.add_argument('--n-gpus', default=1, type=int) args = argparser.parse_args() args.device = 'cuda' assert args.dataset in supported_sets.keys(), \ f"Dataset {args.dataset} isn't supported." data, num_classes = get_dataset(args.dataset, args.root) max_world_size = torch.cuda.device_count() chosen_world_size = args.n_gpus if chosen_world_size <= max_world_size: world_size = chosen_world_size else: print(f'User selected {chosen_world_size} GPUs ' f'but only {max_world_size} GPUs are available') world_size = max_world_size print(f'Let\'s use {world_size} GPUs!') mp.spawn( run_cuda, args=(world_size, args, num_classes, data), nprocs=world_size, join=True, ) ================================================ FILE: benchmark/multi_gpu/training/training_benchmark_xpu.py ================================================ import os from typing import Any, Tuple import intel_extension_for_pytorch as ipex import oneccl_bindings_for_pytorch # noqa import torch.distributed as dist from benchmark.multi_gpu.training.common import ( get_predefined_args, run, supported_sets, ) from benchmark.utils import get_dataset def get_dist_params() -> Tuple[int, int, str]: master_addr = "127.0.0.1" master_port = "29500" os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port mpi_rank = int(os.environ.get("PMI_RANK", -1)) mpi_world_size = int(os.environ.get("PMI_SIZE", -1)) rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0) world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get( "WORLD_SIZE", 1)) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) init_method = f"tcp://{master_addr}:{master_port}" return rank, world_size, init_method def custom_optimizer(model: Any, optimizer: Any) -> Tuple[Any, Any]: return ipex.optimize(model, optimizer=optimizer) if __name__ == '__main__': rank, world_size, init_method = get_dist_params() dist.init_process_group(backend="ccl", init_method=init_method, world_size=world_size, rank=rank) argparser = get_predefined_args() args = argparser.parse_args() args.device = 'xpu' assert args.dataset in supported_sets.keys(), \ f"Dataset {args.dataset} isn't supported." # if the dataset is not present, it will be downloaded # only by process with rank=0, # other process will use the dataset cache from rank=0, # and will not re-download and process it if rank == 0: data, num_classes = get_dataset(args.dataset, args.root) dist.barrier() if rank != 0: data, num_classes = get_dataset(args.dataset, args.root) run(rank, world_size, args, num_classes, data, custom_optimizer) ================================================ FILE: benchmark/points/README.md ================================================ # Point Cloud classification Evaluation scripts for various methods on the ModelNet10 dataset: - **[MPNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/mpnn.py)**: `python mpnn.py` - **[PointNet++](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_net.py)**: `python point_net.py` - **[EdgeCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/edge_cnn.py)**: `python edge_cnn.py` - **[SplineCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/spline_cnn.py)**: `python spline_cnn.py` - **[PointCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_cnn.py)**: `python point_cnn.py` ================================================ FILE: benchmark/points/__init__.py ================================================ from .datasets import get_dataset from .train_eval import run __all__ = [ 'get_dataset', 'run', ] ================================================ FILE: benchmark/points/datasets.py ================================================ import os.path as osp import torch_geometric.transforms as T from torch_geometric.datasets import ModelNet def get_dataset(num_points): name = 'ModelNet10' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) pre_transform = T.NormalizeScale() transform = T.SamplePoints(num_points) train_dataset = ModelNet(path, name='10', train=True, transform=transform, pre_transform=pre_transform) test_dataset = ModelNet(path, name='10', train=False, transform=transform, pre_transform=pre_transform) return train_dataset, test_dataset ================================================ FILE: benchmark/points/edge_cnn.py ================================================ import argparse import torch import torch.nn.functional as F from points.datasets import get_dataset from points.train_eval import run from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.nn import DynamicEdgeConv, global_max_pool from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr_decay_factor', type=float, default=0.5) parser.add_argument('--lr_decay_step_size', type=int, default=50) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, num_classes): super().__init__() nn = Seq(Lin(6, 64), ReLU(), Lin(64, 64), ReLU(), Lin(64, 64), ReLU()) self.conv1 = DynamicEdgeConv(nn, k=20, aggr='max') nn = Seq(Lin(128, 128), ReLU(), Lin(128, 128), ReLU(), Lin(128, 256), ReLU()) self.conv2 = DynamicEdgeConv(nn, k=20, aggr='max') self.lin0 = Lin(256, 512) self.lin1 = Lin(512, 256) self.lin2 = Lin(256, 256) self.lin3 = Lin(256, num_classes) def forward(self, pos, batch): x = self.conv1(pos, batch) x = self.conv2(x, batch) x = F.relu(self.lin0(x)) x = global_max_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1) train_dataset, test_dataset = get_dataset(num_points=1024) model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, args.inference, args.profile, args.bf16, args.compile) if args.profile: rename_profile_file('points', DynamicEdgeConv.__name__) ================================================ FILE: benchmark/points/mpnn.py ================================================ import argparse import torch import torch.nn.functional as F from points.datasets import get_dataset from points.train_eval import run from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.nn import NNConv, fps, global_mean_pool, radius_graph from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr_decay_factor', type=float, default=0.5) parser.add_argument('--lr_decay_step_size', type=int, default=50) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, num_classes): super().__init__() nn = Seq(Lin(3, 25), ReLU(), Lin(25, 1 * 64)) self.conv1 = NNConv(1, 64, nn, aggr='mean') nn = Seq(Lin(3, 25), ReLU(), Lin(25, 64 * 64)) self.conv2 = NNConv(64, 64, nn, aggr='mean') nn = Seq(Lin(3, 25), ReLU(), Lin(25, 64 * 128)) self.conv3 = NNConv(64, 128, nn, aggr='mean') self.lin1 = torch.nn.Linear(128, 256) self.lin2 = torch.nn.Linear(256, 256) self.lin3 = torch.nn.Linear(256, num_classes) def forward(self, pos, batch): x = pos.new_ones((pos.size(0), 1)) radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = pos[edge_index[1]] - pos[edge_index[0]] x = F.relu(self.conv1(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = pos[edge_index[1]] - pos[edge_index[0]] x = F.relu(self.conv2(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = pos[edge_index[1]] - pos[edge_index[0]] x = F.relu(self.conv3(x, edge_index, pseudo)) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1) train_dataset, test_dataset = get_dataset(num_points=1024) model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, args.inference, args.profile, args.bf16, args.compile) if args.profile: rename_profile_file('points', NNConv.__name__) ================================================ FILE: benchmark/points/point_cnn.py ================================================ import argparse import torch import torch.nn.functional as F from points.datasets import get_dataset from points.train_eval import run from torch.nn import Linear as Lin from torch_geometric.nn import XConv, fps, global_mean_pool from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr_decay_factor', type=float, default=0.5) parser.add_argument('--lr_decay_step_size', type=int, default=50) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, num_classes): super().__init__() self.conv1 = XConv(0, 48, dim=3, kernel_size=8, hidden_channels=32) self.conv2 = XConv(48, 96, dim=3, kernel_size=12, hidden_channels=64, dilation=2) self.conv3 = XConv(96, 192, dim=3, kernel_size=16, hidden_channels=128, dilation=2) self.conv4 = XConv(192, 384, dim=3, kernel_size=16, hidden_channels=256, dilation=2) self.lin1 = Lin(384, 256) self.lin2 = Lin(256, 128) self.lin3 = Lin(128, num_classes) def forward(self, pos, batch): x = F.relu(self.conv1(None, pos, batch)) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.conv2(x, pos, batch)) idx = fps(pos, batch, ratio=0.334) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.conv3(x, pos, batch)) x = F.relu(self.conv4(x, pos, batch)) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1) train_dataset, test_dataset = get_dataset(num_points=1024) model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, args.inference, args.profile, args.bf16, args.compile) if args.profile: rename_profile_file('points', XConv.__name__) ================================================ FILE: benchmark/points/point_net.py ================================================ import argparse import torch import torch.nn.functional as F from points.datasets import get_dataset from points.train_eval import run from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.nn import PointNetConv, fps, global_max_pool, radius_graph from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr_decay_factor', type=float, default=0.5) parser.add_argument('--lr_decay_step_size', type=int, default=50) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, num_classes): super().__init__() nn = Seq(Lin(3, 64), ReLU(), Lin(64, 64)) self.conv1 = PointNetConv(local_nn=nn) nn = Seq(Lin(67, 128), ReLU(), Lin(128, 128)) self.conv2 = PointNetConv(local_nn=nn) nn = Seq(Lin(131, 256), ReLU(), Lin(256, 256)) self.conv3 = PointNetConv(local_nn=nn) self.lin1 = Lin(256, 256) self.lin2 = Lin(256, 256) self.lin3 = Lin(256, num_classes) def forward(self, pos, batch): radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv1(None, pos, edge_index)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv2(x, pos, edge_index)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv3(x, pos, edge_index)) x = global_max_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1) train_dataset, test_dataset = get_dataset(num_points=1024) model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, args.inference, args.profile, args.bf16, args.compile) if args.profile: rename_profile_file('points', PointNetConv.__name__) ================================================ FILE: benchmark/points/spline_cnn.py ================================================ import argparse import torch import torch.nn.functional as F from points.datasets import get_dataset from points.train_eval import run from torch.nn import Linear as Lin from torch_geometric.nn import SplineConv, fps, global_mean_pool, radius_graph from torch_geometric.profile import rename_profile_file parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr_decay_factor', type=float, default=0.5) parser.add_argument('--lr_decay_step_size', type=int, default=50) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, num_classes): super().__init__() self.conv1 = SplineConv(1, 64, dim=3, kernel_size=5) self.conv2 = SplineConv(64, 64, dim=3, kernel_size=5) self.conv3 = SplineConv(64, 128, dim=3, kernel_size=5) self.lin1 = Lin(128, 256) self.lin2 = Lin(256, 256) self.lin3 = Lin(256, num_classes) def forward(self, pos, batch): x = pos.new_ones((pos.size(0), 1)) radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv1(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv2(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv3(x, edge_index, pseudo)) x = global_mean_pool(x, batch) x = F.elu(self.lin1(x)) x = F.elu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1) train_dataset, test_dataset = get_dataset(num_points=1024) model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, args.inference, args.profile, args.bf16, args.compile) if args.profile: rename_profile_file('points', SplineConv.__name__) ================================================ FILE: benchmark/points/statistics.py ================================================ from points.datasets import get_dataset from torch_geometric.transforms import RadiusGraph def print_dataset(train_dataset, test_dataset): num_nodes = num_edges = 0 for data in train_dataset: data = RadiusGraph(0.2)(data) num_nodes += data.num_nodes num_edges += data.num_edges for data in test_dataset: data = RadiusGraph(0.2)(data) num_nodes += data.num_nodes num_edges += data.num_edges num_graphs = len(train_dataset) + len(test_dataset) print('Graphs', num_graphs) print('Nodes', num_nodes / num_graphs) print('Edges', (num_edges // 2) / num_graphs) print('Label rate', len(train_dataset) / num_graphs) print() print_dataset(*get_dataset(num_points=1024)) ================================================ FILE: benchmark/points/train_eval.py ================================================ import time import torch import torch.nn.functional as F from torch.optim import Adam from torch_geometric.loader import DataLoader from torch_geometric.profile import timeit, torch_profile if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') def run_train(train_dataset, test_dataset, model, epochs, batch_size, use_compile, lr, lr_decay_factor, lr_decay_step_size, weight_decay): model = model.to(device) if use_compile: model = torch.compile(model) optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) train_loader = DataLoader(train_dataset, batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size, shuffle=False) for epoch in range(1, epochs + 1): if torch.cuda.is_available(): torch.cuda.synchronize() elif (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()): torch.mps.synchronize() t_start = time.perf_counter() train(model, optimizer, train_loader, device) test_acc = test(model, test_loader, device) if torch.cuda.is_available(): torch.cuda.synchronize() elif (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()): torch.mps.synchronize() t_end = time.perf_counter() print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}, ' f'Duration: {t_end - t_start:.2f}') if epoch % lr_decay_step_size == 0: for param_group in optimizer.param_groups: param_group['lr'] = lr_decay_factor * param_group['lr'] @torch.no_grad() def run_inference(test_dataset, model, epochs, batch_size, profiling, bf16, use_compile): model = model.to(device) if use_compile: model = torch.compile(model) test_loader = DataLoader(test_dataset, batch_size, shuffle=False) if torch.cuda.is_available(): amp = torch.amp.autocast('cuda', enabled=False) else: amp = torch.cpu.amp.autocast(enabled=bf16) with amp: for epoch in range(1, epochs + 1): print("Epoch: ", epoch) if epoch == epochs: with timeit(): inference(model, test_loader, device, bf16) else: inference(model, test_loader, device, bf16) if profiling: with torch_profile(): inference(model, test_loader, device, bf16) def run(train_dataset, test_dataset, model, epochs, batch_size, lr, lr_decay_factor, lr_decay_step_size, weight_decay, inference, profiling, bf16, use_compile): if not inference: run_train(train_dataset, test_dataset, model, epochs, batch_size, use_compile, lr, lr_decay_factor, lr_decay_step_size, weight_decay) else: run_inference(test_dataset, model, epochs, batch_size, profiling, bf16, use_compile) def train(model, optimizer, train_loader, device): model.train() for data in train_loader: optimizer.zero_grad() data = data.to(device) out = model(data.pos, data.batch) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() @torch.no_grad() def test(model, test_loader, device): model.eval() correct = 0 for data in test_loader: data = data.to(device) pred = model(data.pos, data.batch).max(1)[1] correct += pred.eq(data.y).sum().item() test_acc = correct / len(test_loader.dataset) return test_acc @torch.no_grad() def inference(model, test_loader, device, bf16): model.eval() for data in test_loader: data = data.to(device) if bf16: data.pos = data.pos.to(torch.bfloat16) model = model.to(torch.bfloat16) model(data.pos, data.batch) ================================================ FILE: benchmark/runtime/README.md ================================================ # Runtimes Run the test suite for PyG via ``` python main.py ``` Install `dgl` and run the test suite for DGL via ``` cd dgl python main.py ``` ================================================ FILE: benchmark/runtime/__init__.py ================================================ from .train import train_runtime __all__ = [ 'train_runtime', ] ================================================ FILE: benchmark/runtime/dgl/gat.py ================================================ import dgl.function as fn import torch import torch.nn.functional as F from dgl.nn.pytorch import EdgeSoftmax from torch.nn import Parameter from torch_geometric.nn.inits import glorot, zeros class GATConv(torch.nn.Module): def __init__(self, g, in_channels, out_channels, heads=1, negative_slope=0.2, dropout=0): super().__init__() self.g = g self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.negative_slope = negative_slope self.dropout = dropout self.weight = Parameter(torch.empty(in_channels, heads * out_channels)) self.att = Parameter(torch.empty(1, heads, 2 * out_channels)) self.bias = Parameter(torch.empty(heads * out_channels)) self.reset_parameters() def reset_parameters(self): glorot(self.weight) glorot(self.att) zeros(self.bias) def gat_msg(self, edge): alpha = torch.cat([edge.src['x'], edge.dst['x']], dim=-1) alpha = (alpha * self.att).sum(dim=-1) alpha = F.leaky_relu(alpha, self.negative_slope) return {'m': edge.src['x'], 'a': alpha} def gat_reduce(self, node): alpha = torch.softmax(node.mailbox['a'], dim=1) alpha = F.dropout(alpha, p=self.dropout, training=self.training) x = (node.mailbox['m'] * alpha.unsqueeze(-1)).sum(dim=1) return {'x': x} def forward(self, x): x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels) self.g.ndata['x'] = x self.g.update_all(self.gat_msg, self.gat_reduce) x = self.g.ndata.pop('x') x = x.view(-1, self.heads * self.out_channels) x = x + self.bias return x class GAT(torch.nn.Module): def __init__(self, g, in_channels, out_channels): super().__init__() self.g = g self.conv1 = GATConv(g, in_channels, 8, 8, 0.6, 0.2) self.conv2 = GATConv(g, 64, out_channels, 1, 0.6, 0.2) def forward(self, x): x = F.dropout(x, p=0.6, training=self.training) x = F.elu(self.conv1(x)) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x) return F.log_softmax(x, dim=1) class GATSPMVConv(torch.nn.Module): def __init__(self, g, in_channels, out_channels, heads=1, negative_slope=0.2, dropout=0): super().__init__() self.g = g self.out_channels = out_channels self.heads = heads self.negative_slope = negative_slope self.dropout = dropout self.weight = Parameter(torch.empty(in_channels, heads * out_channels)) self.att_l = Parameter(torch.empty(heads, out_channels, 1)) self.att_r = Parameter(torch.empty(heads, out_channels, 1)) self.bias = Parameter(torch.empty(heads * out_channels)) self.softmax = EdgeSoftmax() self.reset_parameters() def reset_parameters(self): glorot(self.weight) glorot(self.att_l) glorot(self.att_r) zeros(self.bias) def forward(self, x): x = torch.matmul(x, self.weight) x = x.reshape((x.size(0), self.heads, -1)) # NxHxD' head_x = x.transpose(0, 1) # HxNxD' a1 = torch.bmm(head_x, self.att_l).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_x, self.att_r).transpose(0, 1) # NxHx1 self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2}) self.g.apply_edges(self.edge_attention) self.edge_softmax() self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x')) x = self.g.ndata['x'] / self.g.ndata['z'] # NxHxD' return x.view(-1, self.heads * self.out_channels) def edge_attention(self, edge): a = F.leaky_relu(edge.src['a1'] + edge.dst['a2'], self.negative_slope) return {'a': a} def edge_softmax(self): alpha, normalizer = self.softmax(self.g.edata['a'], self.g) self.g.ndata['z'] = normalizer if self.training and self.dropout > 0: alpha = F.dropout(alpha, p=self.dropout, training=True) self.g.edata['a'] = alpha class GATSPMV(torch.nn.Module): def __init__(self, g, in_channels, out_channels): super().__init__() self.g = g self.conv1 = GATSPMVConv(g, in_channels, 8, 8, 0.6, 0.2) self.conv2 = GATSPMVConv(g, 64, out_channels, 1, 0.6, 0.2) def forward(self, x): x = F.dropout(x, p=0.6, training=self.training) x = F.elu(self.conv1(x)) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x) return F.log_softmax(x, dim=1) ================================================ FILE: benchmark/runtime/dgl/gcn.py ================================================ import dgl.function as fn import torch import torch.nn.functional as F from torch.nn import Parameter from torch_geometric.nn.inits import glorot, zeros class GCNConv(torch.nn.Module): def __init__(self, g, in_channels, out_channels): super().__init__() self.g = g self.weight = Parameter(torch.empty(in_channels, out_channels)) self.bias = Parameter(torch.empty(out_channels)) self.reset_parameters() def reset_parameters(self): glorot(self.weight) zeros(self.bias) def gcn_msg(self, edge): return {'m': edge.src['x'] * edge.src['norm']} def gcn_reduce(self, node): return {'x': node.mailbox['m'].sum(dim=1) * node.data['norm']} def forward(self, x): self.g.ndata['x'] = torch.matmul(x, self.weight) self.g.update_all(self.gcn_msg, self.gcn_reduce) x = self.g.ndata.pop('x') x = x + self.bias return x class GCN(torch.nn.Module): def __init__(self, g, in_channels, out_channels): super().__init__() self.conv1 = GCNConv(g, in_channels, 16) self.conv2 = GCNConv(g, 16, out_channels) def forward(self, x): x = F.relu(self.conv1(x)) x = F.dropout(x, training=self.training) x = self.conv2(x) return F.log_softmax(x, dim=1) class GCNSPMVConv(torch.nn.Module): def __init__(self, g, in_channels, out_channels): super().__init__() self.g = g self.weight = Parameter(torch.empty(in_channels, out_channels)) self.bias = Parameter(torch.empty(out_channels)) self.reset_parameters() def reset_parameters(self): glorot(self.weight) zeros(self.bias) def forward(self, x): x = torch.matmul(x, self.weight) self.g.ndata['x'] = x * self.g.ndata['norm'] self.g.update_all(fn.copy_src(src='x', out='m'), fn.sum(msg='m', out='x')) x = self.g.ndata.pop('x') * self.g.ndata['norm'] x = x + self.bias return x class GCNSPMV(torch.nn.Module): def __init__(self, g, in_channels, out_channels): super().__init__() self.conv1 = GCNSPMVConv(g, in_channels, 16) self.conv2 = GCNSPMVConv(g, 16, out_channels) def forward(self, x): x = F.relu(self.conv1(x)) x = F.dropout(x, training=self.training) x = self.conv2(x) return F.log_softmax(x, dim=1) ================================================ FILE: benchmark/runtime/dgl/hidden.py ================================================ import os import sys import warnings warnings.filterwarnings('ignore') class HiddenPrint: def __enter__(self): self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.close() sys.stdout = self._original_stdout ================================================ FILE: benchmark/runtime/dgl/main.py ================================================ from itertools import product import dgl import torch from dgl import DGLGraph from dgl.contrib.data import load_data from dgl.data import citation_graph from runtime.dgl.gat import GAT, GATSPMV from runtime.dgl.gcn import GCN, GCNSPMV from runtime.dgl.hidden import HiddenPrint from runtime.dgl.rgcn import RGCN, RGCNSPMV from runtime.dgl.train import train_runtime if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') with HiddenPrint(): Cora = citation_graph.load_cora() CiteSeer = citation_graph.load_citeseer() PubMed = citation_graph.load_pubmed() MUTAG = load_data('mutag') # fair comparison # One training run before we start tracking duration to warm up GPU. g = DGLGraph(Cora.graph) g.set_n_initializer(dgl.init.zero_initializer) g.add_edges(g.nodes(), g.nodes()) norm = torch.pow(g.in_degrees().float(), -0.5) norm[torch.isinf(norm)] = 0 g.ndata['norm'] = norm.unsqueeze(1).to(device) model = GCNSPMV(g, Cora.features.shape[1], Cora.num_labels).to(device) train_runtime(model, Cora, epochs=200, device=device) for d, Net in product([Cora, CiteSeer, PubMed], [GCN, GCNSPMV, GAT, GATSPMV]): g = DGLGraph(d.graph) g.set_n_initializer(dgl.init.zero_initializer) g.add_edges(g.nodes(), g.nodes()) norm = torch.pow(g.in_degrees().float(), -0.5) norm[torch.isinf(norm)] = 0 g.ndata['norm'] = norm.unsqueeze(1).to(device) model = Net(g, d.features.shape[1], d.num_labels).to(device) t = train_runtime(model, d, epochs=200, device=device) print(f'{d.name} - {Net.__name__}: {t:.2f}s') for d, Net in product([MUTAG], [RGCN, RGCNSPMV]): g = DGLGraph() g.add_nodes(d.num_nodes) g.add_edges(d.edge_src, d.edge_dst) edge_type = torch.from_numpy(d.edge_type).to(device) edge_norm = torch.from_numpy(d.edge_norm).to(device) g.edata.update({'type': edge_type, 'norm': edge_norm}) g.ndata['id'] = torch.arange(d.num_nodes, dtype=torch.long, device=device) model = Net(g, d.num_nodes, d.num_classes, d.num_rels) t = train_runtime(model, d, epochs=200, device=device) print(f'{d.name} - {Net.__name__}: {t:.2f}s') ================================================ FILE: benchmark/runtime/dgl/rgcn.py ================================================ import dgl.function as fn import torch import torch.nn.functional as F from torch.nn import Parameter as Param from torch_geometric.nn.inits import uniform class RGCNConv(torch.nn.Module): def __init__(self, g, in_channels, out_channels, num_relations, num_bases): super().__init__() self.g = g self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.basis = Param(torch.empty(num_bases, in_channels, out_channels)) self.att = Param(torch.empty(num_relations, num_bases)) self.root = Param(torch.empty(in_channels, out_channels)) self.bias = Param(torch.empty(out_channels)) self.reset_parameters() def reset_parameters(self): size = self.num_bases * self.in_channels uniform(size, self.basis) uniform(size, self.att) uniform(size, self.root) uniform(size, self.bias) def rgcn_reduce(self, node): return {'x': node.mailbox['m'].sum(dim=1)} def forward(self, x): self.w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) self.w = self.w.view(self.num_relations, self.in_channels, self.out_channels) if x is None: def msg_func(edge): w = self.w.view(-1, self.out_channels) index = edge.data['type'] * self.in_channels + edge.src['id'] m = w.index_select(0, index) * edge.data['norm'].unsqueeze(1) return {'m': m} else: self.g.ndata['x'] = x def msg_func(edge): w = self.w.index_select(0, edge.data['type']) m = torch.bmm(edge.src['x'].unsqueeze(1), w).squeeze() m = m * edge.data['norm'].unsqueeze(1) return {'m': m} self.g.update_all(msg_func, self.rgcn_reduce) out = self.g.ndata.pop('x') if x is None: out = out + self.root else: out = out + torch.matmul(x, self.root) out = out + self.bias return out class RGCN(torch.nn.Module): def __init__(self, g, in_channels, out_channels, num_relations): super().__init__() self.conv1 = RGCNConv(g, in_channels, 16, num_relations, num_bases=30) self.conv2 = RGCNConv(g, 16, out_channels, num_relations, num_bases=30) def forward(self, x): x = F.relu(self.conv1(None)) x = self.conv2(x) return F.log_softmax(x, dim=1) class RGCNSPMVConv(torch.nn.Module): def __init__(self, g, in_channels, out_channels, num_relations, num_bases): super().__init__() self.g = g self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.basis = Param(torch.empty(num_bases, in_channels, out_channels)) self.att = Param(torch.empty(num_relations, num_bases)) self.root = Param(torch.empty(in_channels, out_channels)) self.bias = Param(torch.empty(out_channels)) self.reset_parameters() def reset_parameters(self): size = self.num_bases * self.in_channels uniform(size, self.basis) uniform(size, self.att) uniform(size, self.root) uniform(size, self.bias) def forward(self, x): self.w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) self.w = self.w.view(self.num_relations, self.in_channels, self.out_channels) if x is None: def msg_func(edge): w = self.w.view(-1, self.out_channels) index = edge.data['type'] * self.in_channels + edge.src['id'] m = w.index_select(0, index) * edge.data['norm'].unsqueeze(1) return {'m': m} else: self.g.ndata['x'] = x def msg_func(edge): w = self.w.index_select(0, edge.data['type']) m = torch.bmm(edge.src['x'].unsqueeze(1), w).squeeze() m = m * edge.data['norm'].unsqueeze(1) return {'m': m} self.g.update_all(msg_func, fn.sum(msg='m', out='x')) out = self.g.ndata.pop('x') if x is None: out = out + self.root else: out = out + torch.matmul(x, self.root) out = out + self.bias return out class RGCNSPMV(torch.nn.Module): def __init__(self, g, in_channels, out_channels, num_relations): super().__init__() self.conv1 = RGCNSPMVConv(g, in_channels, 16, num_relations, num_bases=30) self.conv2 = RGCNSPMVConv(g, 16, out_channels, num_relations, num_bases=30) def forward(self, x): x = F.relu(self.conv1(None)) x = self.conv2(x) return F.log_softmax(x, dim=1) ================================================ FILE: benchmark/runtime/dgl/train.py ================================================ import time import torch import torch.nn.functional as F def train_runtime(model, data, epochs, device): if hasattr(data, 'features'): x = torch.tensor(data.features, dtype=torch.float, device=device) else: x = None mask = data.train_mask if hasattr(data, 'train_mask') else data.train_idx y = torch.tensor(data.labels, dtype=torch.long, device=device)[mask] model = model.to(device) model.train() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() t_start = time.perf_counter() for _ in range(epochs): optimizer.zero_grad() out = model(x) loss = F.nll_loss(out[mask], y.view(-1)) loss.backward() optimizer.step() if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() t_end = time.perf_counter() return t_end - t_start ================================================ FILE: benchmark/runtime/gat.py ================================================ import torch import torch.nn.functional as F from torch_geometric.nn import GATConv class GAT(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6) self.conv2 = GATConv(8 * 8, out_channels, dropout=0.6) def forward(self, data): x, edge_index = data.x, data.edge_index x = F.dropout(x, p=0.6, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) ================================================ FILE: benchmark/runtime/gcn.py ================================================ import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv class GCN(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, 16, cached=True) self.conv2 = GCNConv(16, out_channels, cached=True) def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) ================================================ FILE: benchmark/runtime/main.py ================================================ import os.path as osp from itertools import product import torch from runtime.gat import GAT from runtime.gcn import GCN from runtime.rgcn import RGCN from runtime.train import train_runtime from torch_geometric.datasets import Entities, Planetoid if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data') Cora = Planetoid(osp.join(root, 'Cora'), 'Cora') CiteSeer = Planetoid(osp.join(root, 'CiteSeer'), 'CiteSeer') PubMed = Planetoid(osp.join(root, 'PubMed'), 'PubMed') MUTAG = Entities(osp.join(root, 'EntitiesMUTAG'), 'MUTAG') # One training run before we start tracking duration to warm up GPU. model = GCN(Cora.num_features, Cora.num_classes) train_runtime(model, Cora[0], epochs=200, device=device) for d, Net in product([Cora, CiteSeer, PubMed], [GCN, GAT]): model = Net(d.num_features, d.num_classes) t = train_runtime(model, d[0], epochs=200, device=device) print(f'{str(d)[:-2]} - {Net.__name__}: {t:.2f}s') for d, Net in product([MUTAG], [RGCN]): model = Net(d[0].num_nodes, d.num_classes, d.num_relations) t = train_runtime(model, d[0], epochs=200, device=device) print(f'{str(d)[:-2]} - {Net.__name__}: {t:.2f}s') ================================================ FILE: benchmark/runtime/rgcn.py ================================================ import torch import torch.nn.functional as F from torch_geometric.nn import FastRGCNConv class RGCN(torch.nn.Module): def __init__(self, in_channels, out_channels, num_relations): super().__init__() self.conv1 = FastRGCNConv(in_channels, 16, num_relations, num_bases=30) self.conv2 = FastRGCNConv(16, out_channels, num_relations, num_bases=30) def forward(self, data): edge_index, edge_type = data.edge_index, data.edge_type x = F.relu(self.conv1(None, edge_index, edge_type)) x = self.conv2(x, edge_index, edge_type) return F.log_softmax(x, dim=1) ================================================ FILE: benchmark/runtime/train.py ================================================ import time import torch import torch.nn.functional as F def train_runtime(model, data, epochs, device): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) model = model.to(device) data = data.to(device) model.train() mask = data.train_mask if 'train_mask' in data else data.train_idx y = data.y[mask] if 'train_mask' in data else data.train_y if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() t_start = time.perf_counter() for _ in range(epochs): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[mask], y) loss.backward() optimizer.step() if torch.cuda.is_available(): torch.cuda.synchronize() elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() t_end = time.perf_counter() return t_end - t_start ================================================ FILE: benchmark/setup.py ================================================ from setuptools import find_packages, setup setup( name='torch_geometric_benchmark', version='0.1.0', description='PyG Benchmark Suite', author='Matthias Fey', author_email='matthias.fey@tu-dortmund.de', url='https://github.com/pyg-team/pytorch_geometric_benchmark', install_requires=['scikit-learn'], packages=find_packages(), ) ================================================ FILE: benchmark/training/README.md ================================================ # Training Benchmark ## Environment setup 1. Confirm that PyG is properly installed. 1. Install dataset package: ```bash pip install ogb ``` 1. Install `jemalloc` for performance benchmark: ```bash cd ${workspace} git clone https://github.com/jemalloc/jemalloc.git cd jemalloc git checkout 5.2.1 ./autogen.sh ./configure --prefix=${workspace}/jemalloc-bin make make install ``` ## Running benchmark 1. Set environment variables: ```bash source activate env_name export DNNL_PRIMITIVE_CACHE_CAPACITY=1024 export KMP_BLOCKTIME=1 export KMP_AFFINITY=granularity=fine,compact,1,0 jemalloc_lib=${workspace}/jemalloc-bin/lib/libjemalloc.so export LD_PRELOAD="$jemalloc_lib" export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000" ``` 1. Core binding, *e.g.*, single socket / single core / 4 cores per instance: ```bash OMP_NUM_THREADS=${CORES} numactl -C 0-${LAST_CORE} -m 0 CMD...... ``` 1. Execute benchmarks, *e.g.*: ```bash python training_benchmark.py --models=gcn --datasets=Reddit --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50 python training_benchmark.py --models=gcn --datasets=Reddit --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50 --use-sparse-tensor python training_benchmark.py --models=sage --datasets=ogbn-products --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50 python training_benchmark.py --models=sage --datasets=ogbn-products --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50 --use-sparse-tensor ``` ================================================ FILE: benchmark/training/training_benchmark.py ================================================ import argparse import ast import warnings from collections import defaultdict from contextlib import nullcontext import torch import torch.nn.functional as F from tqdm import tqdm from benchmark.utils import ( emit_itt, get_dataset, get_model, get_split_masks, save_benchmark_data, test, write_to_csv, ) from torch_geometric import compile from torch_geometric.loader import NeighborLoader from torch_geometric.nn import PNAConv from torch_geometric.profile import ( rename_profile_file, timeit, torch_profile, xpu_profile, ) supported_sets = { 'ogbn-mag': ['rgat', 'rgcn'], 'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'], 'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'], } device_conditions = { 'cpu': (lambda: True), 'cuda': (lambda: torch.cuda.is_available()), 'mps': (lambda: (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available())), 'xpu': (lambda: torch.xpu.is_available()), } def train_homo(model, loader, optimizer, device, progress_bar=True, desc="", trim=False): if progress_bar: loader = tqdm(loader, desc=desc) for batch in loader: optimizer.zero_grad() batch = batch.to(device) if 'adj_t' in batch: edge_index = batch.adj_t else: edge_index = batch.edge_index if not trim: out = model(batch.x, edge_index) else: out = model( batch.x, edge_index, num_sampled_nodes_per_hop=batch.num_sampled_nodes, num_sampled_edges_per_hop=batch.num_sampled_edges, ) batch_size = batch.batch_size out = out[:batch_size] target = batch.y[:batch_size] loss = F.cross_entropy(out, target) loss.backward() optimizer.step() def train_hetero(model, loader, optimizer, device, progress_bar=True, desc="", trim=False): if trim: warnings.warn("Trimming not yet implemented for heterogeneous graphs", stacklevel=2) if progress_bar: loader = tqdm(loader, desc=desc) for batch in loader: optimizer.zero_grad() batch = batch.to(device) if 'adj_t' in batch: edge_index_dict = batch.adj_t_dict else: edge_index_dict = batch.edge_index_dict out = model(batch.x_dict, edge_index_dict) batch_size = batch['paper'].batch_size out = out['paper'][:batch_size] target = batch['paper'].y[:batch_size] loss = F.cross_entropy(out, target) loss.backward() optimizer.step() def run(args: argparse.ArgumentParser): csv_data = defaultdict(list) if args.write_csv == 'prof' and not args.profile: warnings.warn( "Cannot write profile data to CSV because profiling is " "disabled", stacklevel=2) if args.device == 'xpu': try: import intel_extension_for_pytorch as ipex except ImportError as e: raise RuntimeError( 'XPU device requires IPEX to be installed') from e if not device_conditions[args.device](): raise RuntimeError(f'{args.device.upper()} is not available') device = torch.device(args.device) # If we use a custom number of steps, then we need to use RandomSampler, # which already does shuffle. shuffle = False if args.num_steps != -1 else True print('BENCHMARK STARTS') print(f'Running on {args.device.upper()}') for dataset_name in args.datasets: assert dataset_name in supported_sets.keys( ), f"Dataset {dataset_name} isn't supported." print(f'Dataset: {dataset_name}') load_time = timeit() if args.measure_load_time else nullcontext() with load_time: data, num_classes = get_dataset(dataset_name, args.root, args.use_sparse_tensor, args.bf16) hetero = True if dataset_name == 'ogbn-mag' else False mask, val_mask, test_mask = get_split_masks(data, dataset_name) degree = None if args.device == 'cpu': amp = torch.cpu.amp.autocast(enabled=args.bf16) elif args.device == 'cuda': amp = torch.amp.autocast('cuda', enabled=False) elif args.device == 'xpu': amp = torch.xpu.amp.autocast(enabled=False) else: amp = nullcontext() if args.device == 'xpu' and args.warmup < 1: print('XPU device requires warmup - setting warmup=1') args.warmup = 1 inputs_channels = data[ 'paper'].num_features if dataset_name == 'ogbn-mag' \ else data.num_features for model_name in args.models: if model_name not in supported_sets[dataset_name]: print(f'Configuration of {dataset_name} + {model_name} ' f'not supported. Skipping.') continue print(f'Training bench for {model_name}:') for batch_size in args.batch_sizes: num_nodes = int(mask[-1].sum()) if hetero else int(mask.sum()) sampler = torch.utils.data.RandomSampler( range(num_nodes), num_samples=args.num_steps * batch_size) if args.num_steps != -1 else None for layers in args.num_layers: num_neighbors = args.num_neighbors if type(num_neighbors) is list: if len(num_neighbors) == 1: num_neighbors = num_neighbors * layers elif type(num_neighbors) is int: num_neighbors = [num_neighbors] * layers assert len( num_neighbors) == layers, \ f'''num_neighbors={num_neighbors} length != num of layers={layers}''' kwargs = { 'num_neighbors': num_neighbors, 'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': args.num_workers, } subgraph_loader = NeighborLoader( data, input_nodes=mask, sampler=sampler, **kwargs, ) if args.evaluate: val_loader = NeighborLoader( data, input_nodes=val_mask, sampler=None, **kwargs, ) test_loader = NeighborLoader( data, input_nodes=test_mask, sampler=None, **kwargs, ) for hidden_channels in args.num_hidden_channels: print('----------------------------------------------') print(f'Batch size={batch_size}, ' f'Layers amount={layers}, ' f'Num_neighbors={num_neighbors}, ' f'Hidden features size={hidden_channels}, ' f'Sparse tensor={args.use_sparse_tensor}') params = { 'inputs_channels': inputs_channels, 'hidden_channels': hidden_channels, 'output_channels': num_classes, 'num_heads': args.num_heads, 'num_layers': layers, } if model_name == 'pna': if degree is None: degree = PNAConv.get_degree_histogram( subgraph_loader) print(f'Calculated degree for {dataset_name}.') params['degree'] = degree model = get_model( model_name, params, metadata=data.metadata() if hetero else None) model = model.to(device) model.train() if args.compile: model = compile(model, dynamic=True) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) if args.device == 'xpu': model, optimizer = ipex.optimize( model, optimizer=optimizer) progress_bar = False if args.no_progress_bar else True train = train_hetero if hetero else train_homo # Define context manager parameters: cpu_affinity = subgraph_loader.enable_cpu_affinity( args.loader_cores ) if args.cpu_affinity else nullcontext() with amp, cpu_affinity: for _ in range(args.warmup): train( model, subgraph_loader, optimizer, device, progress_bar=progress_bar, desc="Warmup", trim=args.trim, ) with timeit(avg_time_divisor=args.num_epochs) as t: # becomes a no-op if vtune_profile == False with emit_itt(args.vtune_profile): for epoch in range(args.num_epochs): train( model, subgraph_loader, optimizer, device, progress_bar=progress_bar, desc=f"Epoch={epoch}", trim=args.trim, ) if args.evaluate: # In evaluate, throughput and # latency are not accurate. val_acc = test( model, val_loader, device, hetero, progress_bar=progress_bar) print( f'Val Accuracy: {val_acc:.4f}') if args.evaluate: test_acc = test(model, test_loader, device, hetero, progress_bar=progress_bar) print(f'Test Accuracy: {test_acc:.4f}') if args.profile: if args.device == 'xpu': profile = xpu_profile( args.export_chrome_trace) else: profile = torch_profile( args.export_chrome_trace, csv_data, args.write_csv) with profile: train(model, subgraph_loader, optimizer, device, progress_bar=progress_bar, desc="Profile training") if args.export_chrome_trace: rename_profile_file( model_name, dataset_name, str(batch_size), str(layers), str(hidden_channels), str(num_neighbors)) total_time = t.duration if args.num_steps != -1: total_num_samples = args.num_steps * batch_size else: total_num_samples = num_nodes throughput = total_num_samples / total_time latency = total_time / total_num_samples * 1000 print(f'Throughput: {throughput:.3f} samples/s') print(f'Latency: {latency:.3f} ms') num_records = 1 if args.write_csv == 'prof': # For profiling with PyTorch, we save the top-5 # most time consuming operations. Therefore, the # same data should be entered for each of them. num_records = 5 for _ in range(num_records): save_benchmark_data( csv_data, batch_size, layers, num_neighbors, hidden_channels, total_time, model_name, dataset_name, args.use_sparse_tensor, ) if args.write_csv: write_to_csv(csv_data, args.write_csv, training=True) if __name__ == '__main__': argparser = argparse.ArgumentParser('GNN training benchmark') add = argparser.add_argument add('--device', choices=['cpu', 'cuda', 'mps', 'xpu'], default='cpu', help='Device to run benchmark on') add('--datasets', nargs='+', default=['ogbn-mag', 'ogbn-products', 'Reddit'], type=str) add('--use-sparse-tensor', action='store_true', help='use torch_sparse.SparseTensor as graph storage format') add('--models', nargs='+', default=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn'], type=str) add('--root', default='../../data', type=str, help='relative path to look for the datasets') add('--batch-sizes', nargs='+', default=[512, 1024, 2048, 4096, 8192], type=int) add('--num-layers', nargs='+', default=[2, 3], type=int) add('--num-hidden-channels', nargs='+', default=[64, 128, 256], type=int) add('--num-heads', default=2, type=int, help='number of hidden attention heads, applies only for gat and rgat') add('--num-neighbors', default=[10], type=ast.literal_eval, help='number of neighbors to sample per layer') add('--num-workers', default=2, type=int) add('--warmup', default=1, type=int) add('--profile', action='store_true') add('--vtune-profile', action='store_true') add('--bf16', action='store_true') add('--no-progress-bar', action='store_true', default=False, help='turn off using progress bar') add('--num-epochs', default=1, type=int) add('--num-steps', default=-1, type=int, help='number of steps, -1 means iterating through all the data') add('--cpu-affinity', action='store_true', help="Use DataLoader affinitzation.") add('--loader-cores', nargs='+', default=[], type=int, help="List of CPU core IDs to use for DataLoader workers.") add('--measure-load-time', action='store_true') add('--evaluate', action='store_true') add('--write-csv', choices=[None, 'bench', 'prof'], default=None, help='Write benchmark or PyTorch profile data to CSV') add('--export-chrome-trace', default=True, type=bool, help='Export chrome trace file. Works only with PyTorch profiler') add('--trim', action='store_true', help="Use `trim_to_layer` optimization") add('--compile', action='store_true') args = argparser.parse_args() run(args) ================================================ FILE: benchmark/utils/__init__.py ================================================ from .utils import emit_itt from .utils import get_dataset, get_dataset_with_transformation from .utils import get_model from .utils import get_split_masks from .utils import save_benchmark_data, write_to_csv from .utils import test __all__ = [ 'emit_itt', 'get_dataset', 'get_dataset_with_transformation', 'get_model', 'get_split_masks', 'save_benchmark_data', 'write_to_csv', 'test', ] ================================================ FILE: benchmark/utils/hetero_gat.py ================================================ import torch from tqdm import tqdm from torch_geometric.nn import GAT, to_hetero class HeteroGAT(torch.nn.Module): def __init__(self, metadata, hidden_channels, num_layers, output_channels, num_heads): super().__init__() self.model = to_hetero( GAT((-1, -1), hidden_channels, num_layers, output_channels, add_self_loops=False, heads=num_heads), metadata) def forward(self, x_dict, edge_index_dict): return self.model(x_dict, edge_index_dict) @torch.no_grad() def inference(self, loader, device, progress_bar=False, **kwargs): self.model.eval() if progress_bar: loader = tqdm(loader, desc="Inference") for batch in loader: batch = batch.to(device) if 'adj_t' in batch: self.model(batch.x_dict, batch.adj_t_dict) else: self.model(batch.x_dict, batch.edge_index_dict) @torch.no_grad() def test(self, x, loader, device, progress_bar=False): self.model.eval() total_examples = total_correct = 0 if progress_bar: loader = tqdm(loader, desc="Evaluate") for batch in loader: batch = batch.to(device) if 'adj_t' in batch: out = self.model(batch.x_dict, batch.adj_t_dict) else: out = self.model(batch.x_dict, batch.edge_index_dict) batch_size = batch['paper'].batch_size out = out['paper'][:batch_size] pred = out.argmax(dim=-1) total_examples += batch_size total_correct += int((pred == batch['paper'].y[:batch_size]).sum()) return total_correct / total_examples ================================================ FILE: benchmark/utils/hetero_sage.py ================================================ import torch from tqdm import tqdm from torch_geometric.nn import GraphSAGE, to_hetero class HeteroGraphSAGE(torch.nn.Module): def __init__(self, metadata, hidden_channels, num_layers, output_channels): super().__init__() self.model = to_hetero( GraphSAGE((-1, -1), hidden_channels, num_layers, output_channels), metadata) def forward(self, x_dict, edge_index_dict): return self.model(x_dict, edge_index_dict) @torch.no_grad() def inference(self, loader, device, progress_bar=False, **kwargs): self.model.eval() if progress_bar: loader = tqdm(loader, desc="Inference") for batch in loader: batch = batch.to(device) if 'adj_t' in batch: self.model(batch.x_dict, batch.adj_t_dict) else: self.model(batch.x_dict, batch.edge_index_dict) @torch.no_grad() def test(self, loader, device, progress_bar=False): self.model.eval() total_examples = total_correct = 0 if progress_bar: loader = tqdm(loader, desc="Evaluate") for batch in loader: batch = batch.to(device) if 'adj_t' in batch: out = self.model(batch.x_dict, batch.adj_t_dict) else: out = self.model(batch.x_dict, batch.edge_index_dict) batch_size = batch['paper'].batch_size out = out['paper'][:batch_size] pred = out.argmax(dim=-1) total_examples += batch_size total_correct += int((pred == batch['paper'].y[:batch_size]).sum()) return total_correct / total_examples ================================================ FILE: benchmark/utils/utils.py ================================================ import os import os.path as osp from datetime import datetime import torch from ogb.nodeproppred import PygNodePropPredDataset from tqdm import tqdm import torch_geometric.transforms as T from torch_geometric.data import HeteroData from torch_geometric.datasets import OGB_MAG, Reddit from torch_geometric.nn import GAT, GCN, PNA, EdgeCNN, GraphSAGE from torch_geometric.utils import index_to_mask from .hetero_gat import HeteroGAT from .hetero_sage import HeteroGraphSAGE try: from torch.autograd.profiler import emit_itt except ImportError: from contextlib import contextmanager @contextmanager def emit_itt(*args, **kwargs): yield models_dict = { 'edge_cnn': EdgeCNN, 'gat': GAT, 'gcn': GCN, 'pna': PNA, 'sage': GraphSAGE, 'rgat': HeteroGAT, 'rgcn': HeteroGraphSAGE, } def get_dataset_with_transformation(name, root, use_sparse_tensor=False, bf16=False): path = osp.join(osp.dirname(osp.realpath(__file__)), root, name) transform = T.ToSparseTensor( remove_edge_index=False) if use_sparse_tensor else None if name == 'ogbn-mag': if transform is None: transform = T.ToUndirected(merge=True) else: transform = T.Compose([T.ToUndirected(merge=True), transform]) dataset = OGB_MAG(root=path, preprocess='metapath2vec', transform=transform) elif name == 'ogbn-products': if transform is None: transform = T.RemoveDuplicatedEdges() else: transform = T.Compose([T.RemoveDuplicatedEdges(), transform]) dataset = PygNodePropPredDataset('ogbn-products', root=path, transform=transform) elif name == 'Reddit': dataset = Reddit(root=path, transform=transform) data = dataset[0] if name == 'ogbn-products': split_idx = dataset.get_idx_split() data.train_mask = index_to_mask(split_idx['train'], size=data.num_nodes) data.val_mask = index_to_mask(split_idx['valid'], size=data.num_nodes) data.test_mask = index_to_mask(split_idx['test'], size=data.num_nodes) data.y = data.y.squeeze() if bf16: if isinstance(data, HeteroData): for node_type in data.node_types: data[node_type].x = data[node_type].x.to(torch.bfloat16) else: data.x = data.x.to(torch.bfloat16) return data, dataset.num_classes, transform def get_dataset(name, root, use_sparse_tensor=False, bf16=False): data, num_classes, _ = get_dataset_with_transformation( name, root, use_sparse_tensor, bf16) return data, num_classes def get_model(name, params, metadata=None): Model = models_dict.get(name, None) assert Model is not None, f'Model {name} not supported!' if name == 'rgat': return Model(metadata, params['hidden_channels'], params['num_layers'], params['output_channels'], params['num_heads']) if name == 'rgcn': return Model(metadata, params['hidden_channels'], params['num_layers'], params['output_channels']) if name == 'gat': return Model(params['inputs_channels'], params['hidden_channels'], params['num_layers'], params['output_channels'], heads=params['num_heads']) if name == 'pna': return Model(params['inputs_channels'], params['hidden_channels'], params['num_layers'], params['output_channels'], aggregators=['mean', 'min', 'max', 'std'], scalers=['identity', 'amplification', 'attenuation'], deg=params['degree']) return Model(params['inputs_channels'], params['hidden_channels'], params['num_layers'], params['output_channels']) def get_split_masks(data, dataset_name): if dataset_name == 'ogbn-mag': train_mask = ('paper', data['paper'].train_mask) test_mask = ('paper', data['paper'].test_mask) val_mask = ('paper', data['paper'].val_mask) else: train_mask = data.train_mask val_mask = data.val_mask test_mask = data.test_mask return train_mask, val_mask, test_mask def save_benchmark_data(csv_data, batch_size, layers, num_neighbors, hidden_channels, total_time, model_name, dataset_name, use_sparse_tensor): config = f'Batch size={batch_size}, ' \ f'#Layers={layers}, ' \ f'#Neighbors={num_neighbors}, ' \ f'#Hidden features={hidden_channels}' csv_data['DATE'].append(datetime.now().date()) csv_data['TIME (s)'].append(round(total_time, 2)) csv_data['MODEL'].append(model_name) csv_data['DATASET'].append(dataset_name) csv_data['CONFIG'].append(config) csv_data['SPARSE'].append(use_sparse_tensor) def write_to_csv(csv_data, write_csv='bench', training=False): import pandas as pd results_path = osp.join(osp.dirname(osp.realpath(__file__)), '../results/') os.makedirs(results_path, exist_ok=True) name = 'training' if training else 'inference' if write_csv == 'bench': csv_file_name = f'TOTAL_{name}_benchmark.csv' else: csv_file_name = f'TOTAL_prof_{name}_benchmark.csv' csv_path = osp.join(results_path, csv_file_name) index_label = 'TEST_ID' if write_csv == 'bench' else 'ID' with_header = not osp.exists(csv_path) df = pd.DataFrame(csv_data) df.to_csv(csv_path, mode='a', index_label=index_label, header=with_header) @torch.no_grad() def test(model, loader, device, hetero, progress_bar=True, desc="Evaluation") -> None: if progress_bar: loader = tqdm(loader, desc=desc) total_examples = total_correct = 0 if hetero: for batch in loader: batch = batch.to(device) if 'adj_t' in batch: edge_index_dict = batch.adj_t_dict else: edge_index_dict = batch.edge_index_dict out = model(batch.x_dict, edge_index_dict) batch_size = batch['paper'].batch_size out = out['paper'][:batch_size] pred = out.argmax(dim=-1) total_examples += batch_size total_correct += int((pred == batch['paper'].y[:batch_size]).sum()) else: for batch in loader: batch = batch.to(device) if 'adj_t' in batch: edge_index = batch.adj_t else: edge_index = batch.edge_index out = model(batch.x, edge_index) batch_size = batch.batch_size out = out[:batch_size] pred = out.argmax(dim=-1) total_examples += batch_size total_correct += int((pred == batch.y[:batch_size]).sum()) return total_correct / total_examples ================================================ FILE: codecov.yml ================================================ # See: https://docs.codecov.io/docs/codecov-yaml coverage: range: 80..100 round: down precision: 2 status: project: default: target: 80% threshold: 1% patch: default: target: 80% threshold: 1% ================================================ FILE: docker/Dockerfile ================================================ FROM nvcr.io/nvidia/cuda-dl-base:24.09-cuda12.6-devel-ubuntu22.04 # Based on NGC PyG 24.09 image: # https://docs.nvidia.com/deeplearning/frameworks/pyg-release-notes/rel-24-09.html#rel-24-09 # install pip RUN apt-get update && apt-get install -y python3-pip # install PyTorch - latest stable version RUN pip install torch torchvision torchaudio # install graphviz - latest stable version RUN apt-get install -y graphviz graphviz-dev RUN pip install pygraphviz # install python packages with NGC PyG 24.09 image versions RUN pip install torch_geometric==2.6.0 RUN pip install triton==3.0.0 numba==0.59.0 requests==2.32.3 opencv-python==4.7.0.72 scipy==1.14.0 jupyterlab==4.2.5 # install cugraph RUN pip install cugraph-cu12 cugraph-pyg-cu12 --extra-index-url=https://pypi.nvidia.com ================================================ FILE: docker/Dockerfile.xpu ================================================ ARG BASE_IMAGE="intel/intel-extension-for-pytorch" ARG BASE_TAG="2.1.30-xpu" FROM ${BASE_IMAGE}:${BASE_TAG} # meta information LABEL org.opencontainers.image.version = "2.3.1" LABEL org.opencontainers.image.authors = "PyG authors" LABEL org.opencontainers.image.source = "https://github.com/pyg-team/pytorch_geometric" LABEL org.opencontainers.image.licenses = "MIT" LABEL org.opencontainers.image.base.name=${BASE_IMAGE}:${BASE_TAG} # Create a working directory RUN mkdir /app WORKDIR /app # Add the XPU-related package repository for the LTS releases RUN . /etc/os-release && \ wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \ sudo gpg --yes --dearmor --output /usr/share/keyrings/intel-graphics.gpg && \ echo "deb [arch=amd64 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu ${VERSION_CODENAME}/lts/2350 unified" | \ sudo tee /etc/apt/sources.list.d/intel-gpu-${VERSION_CODENAME}.list # Install oneCCL RUN sudo apt update && apt install -y intel-oneapi-ccl-devel=2021.12.0-309 python3-dev cmake vim RUN echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc # Install PyG RUN pip install ninja wheel ogb && pip install git+https://github.com/pyg-team/pyg-lib.git && \ pip install torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.5.0+cpu.html && \ pip install torch_geometric ================================================ FILE: docker/README.md ================================================ # Docker on NVIDIA GPU The recommended way to use Docker for NVIDIA hardware is described [here](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg). You can also run PyG with CUDA 12.1 inside a docker image. This method is deprecated and we highly recommend the above mentioned official NVIDIA docker containers instead. The creation of [our dockerfile](https://github.com/pyg-team/pytorch_geometric/blob/master/docker/Dockerfile) refers to the dockerfiles provided by [NVIDIA](https://gitlab.com/nvidia/cuda/tree/ubuntu18.04) and [PyTorch](https://github.com/anibali/docker-pytorch). 1. Download the dockerfile to your host server. 1. `$ docker build -t "custom image name"` 1. `$ docker run --rm -it --init --runtime=nvidia --ipc=host --network=host --volume=$PWD:/app -e NVIDIA_VISIBLE_DEVICES=0 "custom image name" /bin/bash` If you encounter any problems, please feel free to create a GitHub issue. # Docker on Intel GPU You can also run PyG with Intel GPU inside a docker image. The creation of [our dockerfile](https://github.com/pyg-team/pytorch_geometric/blob/master/docker/Dockerfile.xpu) refers to the dockerfiles provided by [Intel](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-main/docker/Dockerfile.prebuilt) and the installation guidance provided by [Intel® Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu&version=v2.1.30%2bxpu&os=linux%2fwsl2&package=pip). 1. Download the dockerfile to your host server. 1. `$ docker build -f docker/Dockerfile.xpu -t "custom image name"` 1. `$ docker run --rm -it --ipc=host -v /dev/dri:/dev/dri --volume=$PWD:/app "custom image name" /bin/bash` # Singularity You can run PyG inside a singularity image. An example singularity file can be found in this folder. You might have to modify the script; depending on your needs, modify the following: - **cuda version 10.1**: If you need another version, change `From: nvidia/cuda:10.1-cudnn7-devel-ubuntu16.04` to the corresponding tag from . Same if you want to use anything but Ubuntu 16.04. Your host has to have at least this cuda version! - **python version 3.7.2**: If you need another version, change `pyenv install 3.7.2` and the following lines to the corresponding version. - **pytorch version 1.3.0**: If you need another version, change `pip install torch==1.3.0`. - **pytorch_geometric versions**: This uses specific versions for each of the `pytorch_geometric` requirements (scatter: 1.4.0, sparse: 0.4.3, cluster: 1.4.5, geometric 1.3.2). To change these, change the corresponding `git checkout` lines near the bottom. - **cuda compute capability 5.0 and 6.1**: If you run it on multiply systems (likely, with singularity), ensure that all compute capabilities are listed here. If you have the cuda samples installed, check it with (for example) `/usr/local/cuda-10.1/extras/demo_suite/deviceQuery | grep 'CUDA Capability'`; if not, check [here](https://en.wikipedia.org/wiki/CUDA#GPUs_supported). Note: If your harddisk runs full after multiple builds, this is known and apparently working as intended; delete the `/tmp/sbuild-XXXXXXXXX` files. ## Building and Using the Container To build the container, run `sudo singularity build geometric.sif singularity` then wait. Once finished, you can run the GAT example in the folder you built the image in by calling ``` wget https://raw.githubusercontent.com/pyg-team/pytorch_geometric/master/examples/gat.py ``` (to download the sample), then ``` singularity exec geometric.sif python3 gat.py ``` to run on the CPU, or ``` singularity exec --nv geometric.sif python3 gat.py ``` to run on the GPU. ================================================ FILE: docker/singularity ================================================ Bootstrap: docker From: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 %post CURDIR=$(pwd) # Set timezone to Etc/UTC for tzdata. See issue #4365 for more details. TZ=Etc/UTC && \ ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && \ echo $TZ > /etc/timezone apt-get update -y apt-get install -y tmux nano git wget apt-get install -y --no-install-recommends \ build-essential \ gfortran \ libssl-dev \ zlib1g-dev \ libbz2-dev \ libreadline-dev \ libsqlite3-dev \ wget \ curl \ llvm \ libncurses5-dev \ xz-utils \ tk-dev \ libxml2-dev \ libxmlsec1-dev \ libffi-dev \ liblzma-dev \ liblapack-dev \ libopenblas-dev \ libhdf5-dev export PYENV_ROOT=/opt/pyenv export PATH="/opt/pyenv/bin:$PATH" curl -L https://github.com/pyenv/pyenv-installer/raw/master/bin/pyenv-installer | bash pyenv install 3.7.2 echo 'export PATH=/opt/pyenv/versions/3.7.2/bin/:$PATH' >> $SINGULARITY_ENVIRONMENT export PATH=/opt/pyenv/versions/3.7.2/bin/:$PATH pip install torch==1.3.0 mkdir -p $SINGULARITY_ROOTFS/tmp/sing_build_cuda cd $SINGULARITY_ROOTFS/tmp/sing_build_cuda export TORCH_CUDA_ARCH_LIST="5.0 6.1" git clone https://github.com/rusty1s/pytorch_scatter.git && \ cd ./pytorch_scatter && \ git checkout 1.4.0 && \ python3 -m pip install . && \ cd .. git clone https://github.com/rusty1s/pytorch_sparse.git && \ cd ./pytorch_sparse && \ git checkout 0.4.3 && \ python3 -m pip install . && \ cd .. git clone https://github.com/rusty1s/pytorch_cluster.git && \ cd ./pytorch_cluster && \ git checkout 1.4.5 && \ python3 -m pip install . && \ cd .. git clone https://github.com/pyg-team/pytorch_geometric.git && \ cd ./pytorch_geometric && \ git checkout 1.3.2 && \ python3 -m pip install . && \ cd .. cd $CURDIR rm -rf $SINGULARITY_ROOTFS/tmp/sing_build_cuda ================================================ FILE: docs/Makefile ================================================ SPHINXBUILD = sphinx-build SPHINXPROJ = pytorch_geometric SOURCEDIR = source BUILDDIR = build .PHONY: help Makefile %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(0) ================================================ FILE: docs/README.md ================================================ # Building Documentation To build the documentation: 1. [Build and install](https://github.com/pyg-team/pytorch_geometric/blob/master/.github/CONTRIBUTING.md#developing-pytorch-geometric) PyG from source. 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via ``` pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git ``` 1. Generate the documentation file via: ``` cd docs make html ``` The documentation is now available to view by opening `docs/build/html/index.html`. ================================================ FILE: docs/requirements.txt ================================================ https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl numpy>=1.19.5 git+https://github.com/pyg-team/pyg_sphinx_theme.git ================================================ FILE: docs/source/.gitignore ================================================ generated/ ================================================ FILE: docs/source/_figures/.gitignore ================================================ *.aux *.log *.pdf ================================================ FILE: docs/source/_figures/build.sh ================================================ #!/bin/sh for filename in *.tex; do basename=$(basename $filename .tex) pdflatex "$basename.tex" pdf2svg "$basename.pdf" "$basename.svg" done ================================================ FILE: docs/source/_figures/graph.tex ================================================ \documentclass{standalone} \usepackage{tikz} \begin{document} \begin{tikzpicture} \node[draw,circle,label= left:{$x_1=-1$}] (0) at (0, 0) {0}; \node[draw,circle,label=above:{$x_1=0$}] (1) at (1, 1) {1}; \node[draw,circle,label=right:{$x_1=1$}] (2) at (2, 0) {2}; \path[draw] (0) -- (1); \path[draw] (1) -- (2); \end{tikzpicture} \end{document} ================================================ FILE: docs/source/_figures/hg_example.tex ================================================ \documentclass{standalone} \usepackage{tikz} \begin{document} \begin{tikzpicture} \node[draw,rectangle, align=center] (0) at (0, 0) {\textbf{Author}\\ $1,134,649$ nodes}; \node[draw,rectangle, align=center] (1) at (4, 2) {\textbf{Paper}\\ $736,389$ nodes}; \node[draw,rectangle, align=center] (2) at (8, 0) {\textbf{Institution}\\ $8,740$ nodes}; \node[draw,rectangle, align=center] (3) at (4, 4) {\textbf{Field of Study}\\ $59,965$ nodes}; \path[->,>=stealth] (0) edge [above left] node[align=center] {\textbf{writes}\\$7,145,660$ edges} (1.south); \path[->,>=stealth] (0) edge [below] node[align=center] {\textbf{affiliated with}\\$1,043,998$ edges} (2); \path[->,>=stealth,every loop/.style={looseness=3}] (1) edge [out=350, in=10, loop, right] node[align=center] {\textbf{cites}\\$5,416,271$ edges} (1); \path[->,>=stealth] (1) edge [left] node[align=center] {\textbf{has topic}\\$7,505,078$ edges} (3); \end{tikzpicture} \end{document} ================================================ FILE: docs/source/_figures/to_hetero.tex ================================================ \documentclass{standalone} \usepackage{tikz} \definecolor{green}{RGB}{159,213,179} \definecolor{blue}{RGB}{10,153,201} \begin{document} \begin{tikzpicture} \tikzset{rect/.style={draw,rectangle,inner sep=0pt,minimum width=2cm,minimum height=0.6cm,rounded corners=2pt}} \tikzset{arrow/.style={draw,->,>=stealth}} \def\offset{1.2} \node[inner sep=0pt] at (0,0.7) {\strut\textbf{Homogeneous Model}}; \node[rect] (x) at (0,0) {\strut\texttt{x}}; \node[rect,fill=green!20!white] (conv1) at (0,-1*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu1) at (0,-2*\offset) {\texttt{ReLU}}; \node[rect,fill=green!20!white] (conv2) at (0,-3*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu2) at (0,-4*\offset) {\texttt{ReLU}}; \node[rect] (out) at (0,-5*\offset) {\strut\texttt{out}}; \draw[arrow] (x) -- (conv1); \draw[arrow] (conv1) -- (relu1); \draw[arrow] (relu1) -- (conv2); \draw[arrow] (conv2) -- (relu2); \draw[arrow] (relu2) -- (out); \node[inner sep=0pt] at (6,0.7) {\strut\textbf{Heterogeneous Model}}; \node[rect] (xpaper) at (3.5,-0) {\strut\texttt{x\_paper}}; \node[rect,fill=green!20!white] (conv1paper) at (3.5,-1*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu1paper) at (3.5,-2*\offset) {\texttt{ReLU}}; \node[rect,fill=green!20!white] (conv2paper) at (3.5,-3*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu2paper) at (3.5,-4*\offset) {\texttt{ReLU}}; \node[rect] (outpaper) at (3.5,-5*\offset) {\strut\texttt{out\_paper}}; \node[rect,fill=green!20!white] (conv1middle) at (6,-1*\offset) {\texttt{SAGEConv}}; \node[rect,fill=green!20!white] (conv2middle) at (6,-3*\offset) {\texttt{SAGEConv}}; \node[rect] (xauthor) at (8.5,-0) {\strut\texttt{x\_author}}; \node[rect,fill=green!20!white] (conv1author) at (8.5,-1*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu1author) at (8.5,-2*\offset) {\texttt{ReLU}}; \node[rect,fill=green!20!white] (conv2author) at (8.5,-3*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu2author) at (8.5,-4*\offset) {\texttt{ReLU}}; \node[rect] (outauthor) at (8.5,-5*\offset) {\strut\texttt{out\_author}}; \draw[arrow] (xpaper) -- (conv1paper); \draw[arrow,out=270+45,in=90+45] (xpaper.south) to (conv1middle.north); \draw[arrow] (xauthor) -- (conv1author); \draw[arrow] (conv1paper) -- (relu1paper); \draw[arrow,out=270+45,in=90+45] (conv1middle.south) to (relu1author.north); \draw[arrow,out=270-60,in=0] (conv1author.south) to (relu1paper.east); \draw[arrow] (relu1paper) -- (conv2paper); \draw[arrow,out=270+45,in=90+45] (relu1paper.south) to (conv2middle.north); \draw[arrow] (relu1author) -- (conv2author); \draw[arrow] (conv2paper) -- (relu2paper); \draw[arrow,out=270+45,in=90+45] (conv2middle.south) to (relu2author.north); \draw[arrow,out=270-60,in=0] (conv2author.south) to (relu2paper.east); \draw[arrow] (relu2paper) -- (outpaper); \draw[arrow] (relu2author) -- (outauthor); \end{tikzpicture} \end{document} ================================================ FILE: docs/source/_figures/to_hetero_with_bases.tex ================================================ \documentclass{standalone} \usepackage{tikz} \definecolor{green}{RGB}{159,213,179} \definecolor{blue}{RGB}{10,153,201} \begin{document} \begin{tikzpicture} \tikzset{rect/.style={draw,rectangle,inner sep=0pt,minimum width=2cm,minimum height=0.6cm,rounded corners=2pt}} \tikzset{arrow/.style={draw,->,>=stealth}} \def\offset{1.2} \node[inner sep=0pt] at (0,0.7) {\strut\textbf{Homogeneous Model}}; \node[rect] (x) at (0,0) {\strut\texttt{x}}; \node[rect,fill=green!20!white] (conv1) at (0,-1*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu1) at (0,-2*\offset) {\texttt{ReLU}}; \node[rect,fill=green!20!white] (conv2) at (0,-3*\offset) {\texttt{SAGEConv}}; \node[rect,fill=blue!20!white] (relu2) at (0,-4*\offset) {\texttt{ReLU}}; \node[rect] (out) at (0,-5*\offset) {\texttt{out}}; \draw[arrow] (x) -- (conv1); \draw[arrow] (conv1) -- (relu1); \draw[arrow] (relu1) -- (conv2); \draw[arrow] (conv2) -- (relu2); \draw[arrow] (relu2) -- (out); \node[inner sep=0pt] at (6,2*\offset+0.7) {\strut\textbf{Heterogeneous Model}}; \node[rect] (xpaper) at (3.5,2*\offset) {\strut\texttt{x\_paper}}; \node[rect] (xauthor) at (8.5,2*\offset) {\strut\texttt{x\_author}}; \node[rect] (linpaper) at (3.5,1*\offset) {\strut\texttt{Linear}}; \node[rect] (linauthor) at (8.5,1*\offset) {\strut\texttt{Linear}}; \node[rect] (x) at (6,0*\offset) {\strut\texttt{x}}; \node[rect,fill=green!20!white] (conv11) at (3.5,-1*\offset) {\texttt{SAGEConv}}; \node[rect,fill=green!20!white] (conv12) at (6,-1*\offset) {\texttt{SAGEConv}}; \node[rect,fill=green!20!white] (conv13) at (8.5,-1*\offset) {\texttt{SAGEConv}}; \node[inner sep=1pt] (aggr1) at (6,-1.5*\offset) {\footnotesize$+$}; \node[rect,fill=blue!20!white] (relu1) at (6,-2*\offset) {\texttt{ReLU}}; \node[rect,fill=green!20!white] (conv21) at (3.5,-3*\offset) {\texttt{SAGEConv}}; \node[rect,fill=green!20!white] (conv22) at (6,-3*\offset) {\texttt{SAGEConv}}; \node[rect,fill=green!20!white] (conv23) at (8.5,-3*\offset) {\texttt{SAGEConv}}; \node[inner sep=1pt] (aggr2) at (6,-3.5*\offset) {\footnotesize$+$}; \node[rect,fill=blue!20!white] (relu2) at (6,-4*\offset) {\texttt{ReLU}}; \node[rect] (outpaper) at (3.5,-5*\offset) {\strut\texttt{out\_paper}}; \node[rect] (outauthor) at (8.5,-5*\offset) {\strut\texttt{out\_author}}; \draw[arrow] (xpaper) -- (linpaper); \draw[arrow] (xauthor) -- (linauthor); \draw[arrow] (linpaper) -- (x); \draw[arrow] (linauthor) -- (x); \draw[arrow] (x) -- node[fill=white,inner sep=1pt] {\footnotesize $\mathbf{a}_{\mathcal{R}, 1}$} (conv11) ; \draw[arrow] (x) -- node[fill=white,inner sep=1pt] {\footnotesize $\mathbf{a}_{\mathcal{R}, 2}$} (conv12) ; \draw[arrow] (x) -- node[fill=white,inner sep=1pt] {\footnotesize $\mathbf{a}_{\mathcal{R}, 3}$} (conv13) ; \draw[arrow] (conv11) -- (aggr1); \draw[arrow] (conv12) -- (aggr1); \draw[arrow] (conv13) -- (aggr1); \draw[arrow] (aggr1) -- (relu1); \draw[arrow] (relu1) -- node[fill=white,inner sep=0pt] {\footnotesize $\mathbf{a}_{\mathcal{R}, 1}$} (conv21) ; \draw[arrow] (relu1) -- node[fill=white,inner sep=0pt] {\footnotesize $\mathbf{a}_{\mathcal{R}, 2}$} (conv22) ; \draw[arrow] (relu1) -- node[fill=white,inner sep=0pt] {\footnotesize $\mathbf{a}_{\mathcal{R}, 3}$} (conv23) ; \draw[arrow] (conv21) -- (aggr2); \draw[arrow] (conv22) -- (aggr2); \draw[arrow] (conv23) -- (aggr2); \draw[arrow] (aggr2) -- (relu2); \draw[arrow] (relu2) -- (outpaper); \draw[arrow] (relu2) -- (outauthor); \end{tikzpicture} \end{document} ================================================ FILE: docs/source/_static/js/version_alert.js ================================================ function warnOnLatestVersion() { if (!window.READTHEDOCS_DATA || window.READTHEDOCS_DATA.version !== "latest") { return; // not on ReadTheDocs and not latest. } var note = document.createElement('div'); note.setAttribute('class', 'admonition note'); note.innerHTML = "

Note

" + "

" + "This documentation is for an unreleased development version. " + "Click here to access the documentation of the current stable release." + "

"; var parent = document.querySelector('#pyg-documentation'); if (parent) parent.insertBefore(note, parent.querySelector('h1')); } document.addEventListener('DOMContentLoaded', warnOnLatestVersion); ================================================ FILE: docs/source/_templates/autosummary/class.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} :show-inheritance: :members: ================================================ FILE: docs/source/_templates/autosummary/inherited_class.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} :show-inheritance: :members: :inherited-members: :special-members: __cat_dim__, __inc__ ================================================ FILE: docs/source/_templates/autosummary/metrics.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} :show-inheritance: :members: update, compute, reset ================================================ FILE: docs/source/_templates/autosummary/nn.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} {% if objname != "MessagePassing" %} .. autoclass:: {{ objname }} :show-inheritance: :members: :exclude-members: forward, reset_parameters, message, message_and_aggregate, edge_update, aggregate, update .. automethod:: forward .. automethod:: reset_parameters {% else %} .. autoclass:: {{ objname }} :show-inheritance: :members: {% endif %} ================================================ FILE: docs/source/_templates/autosummary/only_class.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} :show-inheritance: ================================================ FILE: docs/source/advanced/batching.rst ================================================ Advanced Mini-Batching ====================== The creation of mini-batching is crucial for letting the training of a deep learning model scale to huge amounts of data. Instead of processing examples one-by-one, a mini-batch groups a set of examples into a unified representation where it can efficiently be processed in parallel. In the image or language domain, this procedure is typically achieved by rescaling or padding each example into a set to equally-sized shapes, and examples are then grouped in an additional dimension. The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the :obj:`batch_size`. Since graphs are one of the most general data structures that can hold *any* number of nodes or edges, the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. In :pyg:`PyG`, we opt for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension, *i.e.* .. math:: \mathbf{A} = \begin{bmatrix} \mathbf{A}_1 & & \\ & \ddots & \\ & & \mathbf{A}_n \end{bmatrix}, \qquad \mathbf{X} = \begin{bmatrix} \mathbf{X}_1 \\ \vdots \\ \mathbf{X}_n \end{bmatrix}, \qquad \mathbf{Y} = \begin{bmatrix} \mathbf{Y}_1 \\ \vdots \\ \mathbf{Y}_n \end{bmatrix}. This procedure has some crucial advantages over other batching procedures: 1. GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs. 2. There is no computational or memory overhead. For example, this batching procedure works completely without any padding of node or edge features. Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges. :pyg:`PyG` automatically takes care of batching multiple graphs into a single giant graph with the help of the :class:`torch_geometric.loader.DataLoader` class. Internally, :class:`~torch_geometric.loader.DataLoader` is just a regular :pytorch:`PyTorch` :class:`torch.utils.data.DataLoader` that overwrites its :func:`collate` functionality, *i.e.*, the definition of how a list of examples should be grouped together. Therefore, all arguments that can be passed to a :pytorch:`PyTorch` :class:`~torch.utils.data.DataLoader` can also be passed to a :pyg:`PyG` :class:`~torch_geometric.loader.DataLoader`, *e.g.*, the number of workers :obj:`num_workers`. In its most general form, the :pyg:`PyG` :class:`~torch_geometric.loader.DataLoader` will automatically increment the :obj:`edge_index` tensor by the cumulated number of nodes of all graphs that got collated before the currently processed graph, and will concatenate :obj:`edge_index` tensors (that are of shape :obj:`[2, num_edges]`) in the second dimension. The same is true for :obj:`face` tensors, *i.e.*, face indices in meshes. All other tensors will just get concatenated in the first dimension without any further increasement of their values. However, there are a few special use-cases (as outlined below) where the user actively wants to modify this behavior to its own needs. :pyg:`PyG` allows modification to the underlying batching procedure by overwriting the :meth:`torch_geometric.data.Data.__inc__` and :meth:`torch_geometric.data.Data.__cat_dim__` functionalities. Without any modifications, these are defined as follows in the :class:`~torch_geometric.data.Data` class: .. code-block:: python def __inc__(self, key, value, *args, **kwargs): if 'index' in key: return self.num_nodes else: return 0 def __cat_dim__(self, key, value, *args, **kwargs): if 'index' in key: return 1 else: return 0 We can see that :meth:`~torch_geometric.data.Data.__inc__` defines the incremental count between two consecutive graph attributes. By default, :pyg:`PyG` increments attributes by the number of nodes whenever their attribute names contain the substring :obj:`index` (for historical reasons), which comes in handy for attributes such as :obj:`edge_index` or :obj:`node_index`. However, note that this may lead to unexpected behavior for attributes whose names contain the substring :obj:`index` but should not be incremented. To make sure, it is best practice to always double-check the output of batching. Furthermore, :meth:`~torch_geometric.data.Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together. Both functions are called for each attribute stored in the :class:`~torch_geometric.data.Data` class, and get passed their specific :obj:`key` and value :obj:`item` as arguments. In what follows, we present a few use-cases where the modification of :meth:`~torch_geometric.data.Data.__inc__` and :meth:`~torch_geometric.data.Data.__cat_dim__` might be absolutely necessary. Pairs of Graphs --------------- In case you want to store multiple graphs in a single :class:`~torch_geometric.data.Data` object, *e.g.*, for applications such as graph matching, you need to ensure correct batching behavior across all those graphs. For example, consider storing two graphs, a source graph :math:`\mathcal{G}_s` and a target graph :math:`\mathcal{G}_t` in a :class:`~torch_geometric.data.Data`, *e.g.*: .. code-block:: python from torch_geometric.data import Data class PairData(Data): pass data = PairData(x_s=x_s, edge_index_s=edge_index_s, # Source graph. x_t=x_t, edge_index_t=edge_index_t) # Target graph. In this case, :obj:`edge_index_s` should be increased by the number of nodes in the source graph :math:`\mathcal{G}_s`, *e.g.*, :obj:`x_s.size(0)`, and :obj:`edge_index_t` should be increased by the number of nodes in the target graph :math:`\mathcal{G}_t`, *e.g.*, :obj:`x_t.size(0)`: .. code-block:: python class PairData(Data): def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index_s': return self.x_s.size(0) if key == 'edge_index_t': return self.x_t.size(0) return super().__inc__(key, value, *args, **kwargs) We can test our :class:`PairData` batching behavior by setting up a simple test script: .. code-block:: python from torch_geometric.loader import DataLoader x_s = torch.randn(5, 16) # 5 nodes. edge_index_s = torch.tensor([ [0, 0, 0, 0], [1, 2, 3, 4], ]) x_t = torch.randn(4, 16) # 4 nodes. edge_index_t = torch.tensor([ [0, 0, 0], [1, 2, 3], ]) data = PairData(x_s=x_s, edge_index_s=edge_index_s, x_t=x_t, edge_index_t=edge_index_t) data_list = [data, data] loader = DataLoader(data_list, batch_size=2) batch = next(iter(loader)) print(batch) >>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8], x_t=[8, 16], edge_index_t=[2, 6]) print(batch.edge_index_s) >>> tensor([[0, 0, 0, 0, 5, 5, 5, 5], [1, 2, 3, 4, 6, 7, 8, 9]]) print(batch.edge_index_t) >>> tensor([[0, 0, 0, 4, 4, 4], [1, 2, 3, 5, 6, 7]]) Everything looks good so far! :obj:`edge_index_s` and :obj:`edge_index_t` get correctly batched together, even when using a different numbers of nodes for :math:`\mathcal{G}_s` and :math:`\mathcal{G}_t`. However, the :obj:`batch` attribute (that maps each node to its respective graph) is missing since :pyg:`PyG` fails to identify the actual graph in the :class:`PairData` object. That is where the :obj:`follow_batch` argument of the :class:`~torch_geometric.loader.DataLoader` comes into play. Here, we can specify for which attributes we want to maintain the batch information: .. code-block:: python loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t']) batch = next(iter(loader)) print(batch) >>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8], x_s_batch=[10], x_t=[8, 16], edge_index_t=[2, 6], x_t_batch=[8]) print(batch.x_s_batch) >>> tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) print(batch.x_t_batch) >>> tensor([0, 0, 0, 0, 1, 1, 1, 1]) As one can see, :obj:`follow_batch=['x_s', 'x_t']` now successfully creates assignment vectors :obj:`x_s_batch` and :obj:`x_t_batch` for the node features :obj:`x_s` and :obj:`x_t`, respectively. That information can now be used to perform reduce operations, *e.g.*, global pooling, on multiple graphs in a single :class:`Batch` object. Bipartite Graphs ---------------- The adjacency matrix of a bipartite graph defines the relationship between nodes of two different node types. In general, the number of nodes for each node type do not need to match, resulting in a non-quadratic adjacency matrix of shape :math:`\mathbf{A} \in \{ 0, 1 \}^{N \times M}` with :math:`N \neq M` potentially. In a mini-batching procedure of bipartite graphs, the source nodes of edges in :obj:`edge_index` should get increased differently than the target nodes of edges in :obj:`edge_index`. To achieve this, consider a bipartite graph between two node types with corresponding node features :obj:`x_s` and :obj:`x_t`, respectively: .. code-block:: python from torch_geometric.data import Data class BipartiteData(Data): pass data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index) For a correct mini-batching procedure in bipartite graphs, we need to tell :pyg:`PyG` that it should increment source and target nodes of edges in :obj:`edge_index` independently: .. code-block:: python class BipartiteData(Data): def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index': return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]]) return super().__inc__(key, value, *args, **kwargs) Here, :obj:`edge_index[0]` (the source nodes of edges) get incremented by :obj:`x_s.size(0)` while :obj:`edge_index[1]` (the target nodes of edges) get incremented by :obj:`x_t.size(0)`. We can again test our implementation by running a simple test script: .. code-block:: python from torch_geometric.loader import DataLoader x_s = torch.randn(2, 16) # 2 nodes. x_t = torch.randn(3, 16) # 3 nodes. edge_index = torch.tensor([ [0, 0, 1, 1], [0, 1, 1, 2], ]) data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index) data_list = [data, data] loader = DataLoader(data_list, batch_size=2) batch = next(iter(loader)) print(batch) >>> BipartiteDataBatch(x_s=[4, 16], x_t=[6, 16], edge_index=[2, 8]) print(batch.edge_index) >>> tensor([[0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 3, 4, 4, 5]]) Again, this is exactly the behavior we aimed for! Batching Along New Dimensions ----------------------------- Sometimes, attributes of :obj:`data` objects should be batched by gaining a new batch dimension (as in classical mini-batching), *e.g.*, for graph-level properties or targets. Specifically, a list of attributes of shape :obj:`[num_features]` should be returned as :obj:`[num_examples, num_features]` rather than :obj:`[num_examples * num_features]`. :pyg:`PyG` achieves this by returning a concatenation dimension of :obj:`None` in :meth:`~torch_geometric.data.Data.__cat_dim__`: .. code-block:: python from torch_geometric.data import Data from torch_geometric.loader import DataLoader class MyData(Data): def __cat_dim__(self, key, value, *args, **kwargs): if key == 'foo': return None return super().__cat_dim__(key, value, *args, **kwargs) edge_index = torch.tensor([ [0, 1, 1, 2], [1, 0, 2, 1], ]) foo = torch.randn(16) data = MyData(num_nodes=3, edge_index=edge_index, foo=foo) data_list = [data, data] loader = DataLoader(data_list, batch_size=2) batch = next(iter(loader)) print(batch) >>> MyDataBatch(num_nodes=6, edge_index=[2, 8], foo=[2, 16]) As desired, :obj:`batch.foo` is now described by two dimensions: The batch dimension and the feature dimension. ================================================ FILE: docs/source/advanced/compile.rst ================================================ Compiled Graph Neural Networks ============================== :meth:`torch.compile` is the latest method to speed up your :pytorch:`PyTorch` code in :obj:`torch >= 2.0.0`! :meth:`torch.compile` makes PyTorch code run faster by JIT-compiling it into optimized kernels, all while required minimal code changes. Under the hood, :meth:`torch.compile` captures :pytorch:`PyTorch` programs via :obj:`TorchDynamo`, canonicalizes over 2,000 :pytorch:`PyTorch` operators via :obj:`PrimTorch`, and finally generates fast code out of it across multiple accelerators and backends via the deep learning compiler :obj:`TorchInductor`. .. note:: See `here `__ for a general tutorial on how to leverage :meth:`torch.compile`, and `here `__ for a description of its interface. In this tutorial, we show how to optimize your custom :pyg:`PyG` model via :meth:`torch.compile`. .. note:: From :pyg:`PyG` 2.5 (and onwards), :meth:`torch.compile` is now fully compatible with all :pyg:`PyG` GNN layers. If you are on an earlier version of :pyg:`PyG`, consider using :meth:`torch_geometric.compile` instead. Basic Usage ----------- Once you have a :pyg:`PyG` model defined, simply wrap it with :meth:`torch.compile` to obtain its optimized version: .. code-block:: python import torch from torch_geometric.nn import GraphSAGE model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels) model = model.to(device) model = torch.compile(model) and execute it as usual: .. code-block:: python from torch_geometric.datasets import Planetoid dataset = Planetoid(root, name="Cora") data = dataset[0].to(device) out = model(data.x, data.edge_index) Maximizing Performance ---------------------- The :meth:`torch.compile` method provides two important arguments to be aware of: * Most of the mini-batches observed in :pyg:`PyG` are dynamic by nature, meaning that their shape varies across different mini-batches. For these scenarios, we can enforce dynamic shape tracing in :pytorch:`PyTorch` via the :obj:`dynamic=True` argument: .. code-block:: python torch.compile(model, dynamic=True) With this, :pytorch:`PyTorch` will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches. Note that when :obj:`dynamic` is set to :obj:`False`, :pytorch:`PyTorch` will *never* generate dynamic kernels, and thus only work when graph sizes are guaranteed to never change (*e.g.*, in full-batch training on small graphs). By default, :obj:`dynamic` is set to :obj:`None` in :pytorch:`PyTorch` :obj:`>= 2.1.0`, and :pytorch:`PyTorch` will automatically detect if dynamism has occurred. Note that support for dynamic shape tracing requires :pytorch:`PyTorch` :obj:`>= 2.1.0` to be installed. * In order to maximize speedup, graph breaks in the compiled model should be limited. We can force compilation to raise an error upon the first graph break encountered by using the :obj:`fullgraph=True` argument: .. code-block:: python torch.compile(model, fullgraph=True) It is generally a good practice to confirm that your written model does not contain any graph breaks. Importantly, there exist a few operations in :pyg:`PyG` that will currently lead to graph breaks (but workarounds exist), *e.g.*: 1. :meth:`~torch_geometric.nn.pool.global_mean_pool` (and other pooling operators) perform device synchronization in case the batch size :obj:`size` is not passed, leading to a graph break. 2. :meth:`~torch_geometric.utils.remove_self_loops` and :meth:`~torch_geometric.utils.add_remaining_self_loops` mask the given :obj:`edge_index`, leading to a device synchronization to compute its final output shape. As such, we recommend augmenting your graph *before* inputting it into your GNN, *e.g.*, via the :class:`~torch_geometric.transforms.AddSelfLoops` or :class:`~torch_geometric.transforms.GCNNorm` transformations, and setting :obj:`add_self_loops=False`/:obj:`normalize=False` when initializing layers such as :class:`~torch_geometric.nn.conv.GCNConv`. Example Scripts --------------- We have incorporated multiple examples in :obj:`examples/compile` that further show the practical usage of :meth:`torch.compile`: #. `Node Classification `__ via :class:`~torch_geometric.nn.models.GCN` (:obj:`dynamic=False`) #. `Graph Classification `__ via :class:`~torch_geometric.nn.models.GIN` (:obj:`dynamic=True`) If you notice that :meth:`torch.compile` fails for a certain :pyg:`PyG` model, do not hesitate to reach out either on :github:`null` `GitHub `_ or :slack:`null` `Slack `_. We are very eager to improve :meth:`torch.compile` support across the whole :pyg:`PyG` code base. Benchmark --------- :meth:`torch.compile` works **fantastically well** for many :pyg:`PyG` models. **Overall, we observe runtime improvements of up to 300%.** Specifically, we benchmark :class:`~torch_geometric.nn.models.GCN`, :class:`~torch_geometric.nn.models.GraphSAGE` and :class:`~torch_geometric.nn.models.GIN` and compare runtimes obtained from traditional eager mode and :meth:`torch.compile`. We use a synthetic graph with 10,000 nodes and 200,000 edges, and a hidden feature dimensionality of 64. We report runtimes over 500 optimization steps: .. list-table:: :widths: 15 15 15 15 15 15 :header-rows: 1 * - Model - Mode - Forward - Backward - Total - Speedup * - :class:`~torch_geometric.nn.models.GCN` - Eager - 2.6396s - 2.1697s - 4.8093s - * - :class:`~torch_geometric.nn.models.GCN` - **Compiled** - **1.1082s** - **0.5896s** - **1.6978s** - **2.83x** * - :class:`~torch_geometric.nn.models.GraphSAGE` - Eager - 1.6023s - 1.6428s - 3.2451s - * - :class:`~torch_geometric.nn.models.GraphSAGE` - **Compiled** - **0.7033s** - **0.7465s** - **1.4498s** - **2.24x** * - :class:`~torch_geometric.nn.models.GIN` - Eager - 1.6701s - 1.6990s - 3.3690s - * - :class:`~torch_geometric.nn.models.GIN` - **Compiled** - **0.7320s** - **0.7407s** - **1.4727s** - **2.29x** To reproduce these results, run .. code-block:: console python test/nn/models/test_basic_gnn.py from the root folder of your checked out :pyg:`PyG` repository from :github:`GitHub`. ================================================ FILE: docs/source/advanced/cpu_affinity.rst ================================================ CPU Affinity for PyG Workloads ============================== The performance of :pyg:`PyG` workloads using CPU can be significantly improved by setting a proper affinity mask. Processor affinity, or core binding, is a modification of the native OS queue scheduling algorithm that enables an application to assign a specific set of cores to processes or threads launched during its execution on the CPU. In consequence, it increases the overall effective hardware utilisation by minimizing core stalls and memory bounds. It also secures CPU resources to critical processes or threads, even if the system is under heavy load. CPU affinity targets the two main performance-critical regions: * **Execution bind:** Indicates a core where process/thread will run. * **Memory bind:** Indicates a preferred memory area where memory pages will be bound (local areas in NUMA machine). The following article discusses readily available tools and environment settings that one can use to maximize the performance of Intel CPUs with :pyg:`PyG`. .. note:: Overall, CPU affinity can be a useful tool for improving the performance and predictability of certain types of applications, but one configuration does not necessarily fit all cases: it is important to carefully consider whether CPU affinity is appropriate for your use case, and to test and measure the impact of any changes you make. Using CPU affinity ------------------ Each :pyg:`PyG` workload can be parallelized using the :pytorch:`PyTorch` iterator class :class:`MultiProcessingDataLoaderIter`, which is automatically enabled in case :obj:`num_workers > 0` is passed to a :class:`torch.utils.data.DataLoader`. Under the hood, it creates :obj:`num_workers` many sub-processes that will run in parallel to the main process. Setting a CPU affinity mask for the data loading processes places :class:`~torch.utils.data.DataLoader` worker threads on specific CPU cores. In effect, it allows for more efficient data batch preparation by allocating pre-fetched batches in local memory. Every time a process or thread moves from one core to another, registers and caches need to be flushed and reloaded. This can become very costly if it happens often, and threads may also no longer be close to their data, or be able to share data in a cache. Since :pyg:`PyG` (2.3 and beyond), :class:`~torch_geometric.loader.NodeLoader` and :class:`~torch_geometric.loader.LinkLoader` classes officially support a native solution for CPU affinity using the :class:`torch_geometric.loader.AffinityMixin` context manager. CPU affinity can be enabled via the :meth:`~torch_geometric.loader.AffinityMixin.enable_cpu_affinity` method for :obj:`num_workers > 0` use-cases, and will guarantee that a separate core is assigned to each worker at initialization. A user-defined list of core IDs may be assigned using the :attr:`loader_cores` argument. Otherwise, cores will be assigned automatically, starting at core ID 0. As of now, only a single core can be assigned to a worker, hence multi-threading is disabled in workers' processes by default. The recommended number of workers to start with lies between :obj:`[2, 4]`, and the optimum may vary based on workload characteristics: .. code-block:: python loader = NeigborLoader( data, num_workers=3, ..., ) with loader.enable_cpu_affinity(loader_cores=[0, 1, 2]): for batch in loader: pass It is generally advisable to use :obj:`filter_per_worker=True` for any multi-process CPU workloads (:obj:`True` by default). The workers then prepare each mini-batch: first by sampling the node indices using pre-defined a sampler, and secondly filtering node and edge features according to sampled nodes and edges. The filtering function selects node feature vectors from the complete input :class:`~torch_geometric.data.Data` tensor loaded into DRAM. When :attr:`filter_per_worker` is set to :attr:`True`, each worker's subprocess performs the filtering within it's CPU resource. Hence, main process resources are relieved and can be secured only for GNN computation. Binding processes to physical cores ----------------------------------- Following general performance tuning principles, it is advisable to use only physical cores for deep learning workloads. For example, while two logical threads run :obj:`GEMM` at the same time, they will be sharing the same core resources causing front end bound, such that the overhead from this front end bound is greater than the gain from running both logical threads at the same time. This is because OpenMP threads will contend for the same :obj:`GEMM` execution units, see `here `__. The binding can be done in many ways, however the most common tools are: * :obj:`numactl` (only on Linux): .. code-block:: console --physcpubind=, -C or --cpunodebind=, -N * `Intel OMP `__ :obj:`libiomp`: .. code-block:: console export KMP_AFFINITY=granularity=fine,proclist=[0-],explicit * GNU :obj:`libgomp`: .. code-block:: console export GOMP_CPU_AFFINITY="0-" Isolating the :class:`~torch.utils.data.DataLoader` process ----------------------------------------------------------- For best performance, it is required combine main process affinity using the tools listed above, with the multi-process :class:`~torch.utils.data.DataLoader` affinity settings. In each parallelized :pyg:`PyG` workload execution, the main process performs message passing updates over GNN layers, while the :class:`~torch.utils.data.DataLoader` workers sub-processes take care of fetching and pre-processing data to be passed to a GNN model. It is advisable to isolate the CPU resources made available to these two processes to achieve the best results. To do this, CPUs assigned to each affinity mask should be mutually exclusive. For example, if four :class:`~torch.utils.data.DataLoader` workers are assigned to CPUs :obj:`[0, 1, 2, 3]`, the main process should use the rest of available cores, *i.e.* by calling: .. code-block:: console numactl -C 4-(N-1) --localalloc python … where :obj:`N` is the total number of physical cores, with the last CPU having core ID :obj:`N-1`. Adding :obj:`--localalloc` improves local memory allocation and keeps the cache closer to active cores. Dual socket CPU separation ~~~~~~~~~~~~~~~~~~~~~~~~~~ With dual-socket CPUs, it might be beneficial to further isolate the processes between the sockets. This leads to decreased frequency of remote memory calls for the main process. The goal is to `utilize high-speed cache on local memory and reduces memory bound caused by migrating cached data between NUMA nodes `__. This can be achieved by using :class:`~torch.utils.data.DataLoader` affinity, and launching main process on the cores of the second socket, *i.e.* with: .. code-block:: console numactl -C M-(N-1) -m 1 python … where :obj:`M` is the :obj:`cpuid` of the first core of the second CPU socket. Adding a complementary memory-allocation flag :obj:`-m 1` prioritizes cache allocation on the same NUMA node, where the main process is running (alternatively for less strict memory allocation use :obj:`--preferred 1`). This makes the data readily available on the same socket where the computation takes place. Using this setting is very workload-specific and may require some fine-tuning, as one needs to manage a trade-off between using more OMP threads vs. limiting the number of remote memory calls. Improving memory bounds ----------------------- Following the CPU performance optimization guidelines for :pytorch:`PyTorch`, it is also advised for :pyg:`PyG` to use :obj:`jemalloc` or :obj:`TCMalloc`. These generally can reach better memory usage than the default :pytorch:`PyTorch` `memory allocator `__ :obj:`PTMalloc`. A `non-default memory allocator `__ can be specified using :obj:`LD_PRELOAD` prior to script execution. Quick start guidelines ---------------------- The general guidelines for achieving the best performance with CPU affinity can be summarized in the following steps: #. Test if your dataset benefits from using parallel data loaders. For some datasets, it might be more beneficial to use a plain serial data loader, especially when the dimensions of the input :class:`~torch_geometric.data.Data` are relatively small. #. Enable multi-process data loaders by setting :attr:`num_workers > 0`. A good estimate for :obj:`num_workers` lies in the range :obj:`[2, 4]`. However, for more complex datasets you might want to experiment with larger number of workers. Use the :meth:`~torch_geometric.loader.AffinityMixin.enable_cpu_affinity` feature to affinitize :class:`~torch.utils.data.DataLoader` cores. #. Bind execution to physical cores. Alternatively, hyperthreading can be disabled completely at a system-level. #. Separate the cores used for main process from the data loader workers' cores by using :obj:`numactl`, :obj:`KMP_AFFINITY` of the :obj:`libiomp5` library, or :obj:`GOMP_CPU_AFFINITY` of the :obj:`libgomp` library. #. Find the optimum number of OMP threads for your workload. A good starting point is :obj:`N - num_workers`. Generally, well-parallelized models will benefit from many OMP threads. However, if your model computation flow has interlaced parallel and serial regions, the performance will decrease due to resource allocation needed for spawning and maintaining threads between parallel regions. #. When using a dual-socket CPU, you might want to experiment with assigning data loading to one socket and main process to another socket with memory allocation (:obj:`numactl -m`) on the same socket where the main process is executed. This leads to best cache-allocation and often overweighs the benefit of using more OMP threads. #. An additional boost in performance can be obtained by using non-default memory allocator, such as :obj:`jemalloc` or :obj:`TCMalloc`. #. Finding an optimal setup for the CPU affinity mask is a problem of managing the proportion of CPU time spent in each iteration for loading and preparing the data vs. time spent during GNN execution. Different results may be obtained by changing model hyperparameters, such as the batch size, number of sampled neighbors, and the number of layers. As a general rule, workloads which require sampling a complex graph may benefit more from reserving some CPU resources just for the data preparation step. Example results --------------- The figure below presents the outcome of applying CPU affinity mask to :obj:`benchmark/training/training_benchmark.py`. Measurements were taken for a variable number of workers, while other hyperparameters for each benchmark were constant: :obj:`--warmup 0 --use-sparse-tensor --num-layers 3 --num-hidden-channels 128 --batch-sizes 2048`. Three different affinity configurations are presented: * **Baseline** - only :obj:`OMP_NUM_THREADS` changes: .. code-block:: console OMP_NUM_THREADS=(N-num_workers) python training_benchmark.py --num-workers … * **Aff** - data loader process on first socket, main process on first and second socket, 98-110 threads: .. code-block:: console LD_PRELOAD=(path)/libjemalloc.so (path)/libiomp5.so MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto OMP_NUM_THREADS=(N-num_workers) KMP_AFFINITY=granularity=fine,compact,1,0 KMP_BLOCKTIME=0 numactl -C --localalloc python training_benchmark.py --cpu-affinity --num-workers … * **Aff+SocketSep** - data loader process on first socket, main process on second socket, 60 threads: .. code-block:: console LD_PRELOAD=(path)/libjemalloc.so (path)/libiomp5.so MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto OMP_NUM_THREADS=(N-M) KMP_AFFINITY=granularity=fine,compact,1,0 KMP_BLOCKTIME=0 numactl -C -m 1 python training_benchmark.py --cpu-affinity --num-workers ... Training times for each model/dataset combination were obtained by taking a mean of results at a variable number of dataloader workers: :obj:`[0, 2, 4, 8, 16]` for the baseline and :obj:`[2, 4, 8, 16]` workers for each affinity configuration. Then, the affinity means were normalized with respect to the mean baseline measurement. This value is denoted on the :math:`y`-axis. The labels above each result indicate the end-to-end performance gain from using the discussed configuration. Over all model/dataset samples, the average training time is decreased by **1.53x** for plain affinity and **1.85x** for the affinity with socket separation. .. figure:: ../_figures/training_affinity.png :width: 100% Pre-production dual-socket Intel(R) Xeon(R) Platinum 8481C @ 2.0Ghz (2 x 56) cores CPU. ================================================ FILE: docs/source/advanced/graphgym.rst ================================================ Managing Experiments with GraphGym ================================== GraphGym is a platform for **designing and evaluating Graph Neural Networks (GNNs)**, as originally proposed in the `"Design Space for Graph Neural Networks" `__ paper. We now officially support GraphGym as part of of :pyg:`PyG`. .. warning:: GraphGym API may change in the future as we are continuously working on better and deeper integration with :pyg:`PyG`. Highlights ---------- #. **Highly modularized pipeline for GNN:** - **Data:** Data loading and data splitting - **Model:** Modularized GNN implementations - **Tasks:** Node-level, edge-level and graph-level tasks - **Evaluation:** Accuracy, ROC AUC, ... #. **Reproducible experiment configuration:** - Each experiment is *fully described by a configuration file* #. **Scalable experiment management:** - Easily launch *thousands of GNN experiments in parallel* - *Auto-generate* experiment analyses and figures across random seeds and experiments #. **Flexible user customization:** - Easily *register your own modules*, such as data loaders, GNN layers, loss functions, etc Why GraphGym? ------------- **TL;DR:** GraphGym is great for GNN beginners, domain experts and GNN researchers. **Scenario 1:** You are a beginner to graph representation learning and want to understand how GNNs work: You probably have read many exciting papers on GNNs, and try to write your own GNN implementation. Even if using raw :pyg:`PyG`, you still have to code up the essential pipeline on your own. GraphGym is a perfect place for your to start learning about *standardized GNN implementation and evaluation*. .. figure:: ../_figures/graphgym_design_space.png :align: center :width: 450px **Figure 1:** Modularized GNN implementation. **Scenario 2:** You want to apply GNNs to your exciting application: You probably know that there are hundreds of possible GNN models, and selecting the best model is notoriously hard. Even worse, the `GraphGym paper `__ shows that the best GNN designs for different tasks differ drastically. GraphGym provides a *simple interface to try out thousands of GNNs in parallel* and understand the best designs for your specific task. GraphGym also recommends a "go-to" GNN design space, after investigating 10 million GNN model-task combinations. .. figure:: ../_figures/graphgym_results.png :align: center :width: 100% **Figure 2:** A guideline for desirable GNN design choices. **Scenario 3:** You are a GNN researcher, who wants to innovate new GNN models or propose new GNN tasks: Say you have proposed a new GNN layer :class:`ExampleConv`. GraphGym can help you convincingly argue that :class:`ExampleConv` is better than, *e.g.*, :class:`~torch_geometric.nn.conv.GCNConv`: When randomly sampling from 10 million possible model-task combinations, how often will :class:`ExampleConv` will outperform :class:`~torch_geometric.nn.conv.GCNConv` when everything else is fixed (including computational costs)? Moreover, GraphGym can help you easily do hyper-parameter search, and *visualize* what design choices are better. In sum, GraphGym can greatly facilitate your GNN research. .. figure:: ../_figures/graphgym_evaluation.png :align: center :width: 100% **Figure 3:** Evaluation of a given GNN design dimension, *e.g.*, :obj:`BatchNorm`. Basic Usage ----------- .. note:: For using GraphGym, :pyg:`PyG` requires additional dependencies. You can install those by running :obj:`pip install torch-geometric[graphgym]`. To use GraphGym, you need to clone :pyg:`PyG` from :github:`GitHub`, then change to the :obj:`graphgym/` directory. .. code-block:: bash git clone https://github.com/pyg-team/pytorch_geometric.git cd pytorch_geometric/graphgym #. **Run a single experiment:** Run an experiment using GraphGym via :obj:`run_single.sh`. Configurations are specified in :obj:`configs/pyg/example_node.yaml`. The default experiment is about node classification on the :class:`~torch_geometric.datasets.Planetoid` datasets (using a random 80/20 train/validation split). .. code-block:: bash bash run_single.sh # run a single experiment #. **Run a batch of experiments:** Run a batch of experiments using GraphGym via :obj:`run_batch.sh`. Configurations are specified in :obj:`configs/pyg/example_node.yaml` (controls the basic architecture) and :obj:`grids/example.txt` (controls how to do grid search). The experiment examines 96 models in the recommended GNN design space, on 2 graph classification datasets. Each experiment is repeated 3 times, and we set up that 8 jobs can be concurrently run. Depending on your infrastructure, finishing all the experiments may take a long time; you can quit the experiment via :obj:`Ctrl-C` (GraphGym will properly kill all the processes). .. code-block:: bash bash run_batch.sh # run a batch of experiments #. **Run GraphGym with CPU backend:** GraphGym supports CPU backend as well -- you only need to add the line :obj:`accelerator: cpu` to the :obj:`*.yaml` file. In-Depth Usage -------------- To use GraphGym, you need to clone :pyg:`PyG` from :github:`GitHub`, then change to the :obj:`graphgym/` directory. .. code-block:: bash git clone https://github.com/pyg-team/pytorch_geometric.git cd pytorch_geometric/graphgym #. **Run a single experiment:** A full example is specified in :obj:`run_single.sh`. #. **Specify a configuration file:** In GraphGym, an experiment is fully specified by a :obj:`*.yaml` file. Unspecified configurations in the :obj:`*.yaml` file will be populated by the default values in :meth:`torch_geometric.graphgym.set_cfg`. For example, in :obj:`configs/pyg/example_node.yaml`, there are configurations for the dataset, training procedure, model, etc. Concrete description for each configuration is described in :meth:`~torch_geometric.graphgym.set_cfg`. #. **Launch an experiment:** For example, in :obj:`run_single.sh`: .. code-block:: bash python main.py --cfg configs/pyg/example_node.yaml --repeat 3 You can specify the number of different random seeds to repeat via :obj:`--repeat`. #. **Understand the results:** Experimental results will be automatically saved in :obj:`results/${CONFIG_NAME}/`. In the example above, this amounts to :obj:`results/pyg/example_node/`. Results for different random seeds will be saved in different subdirectories, *e.g.*, :obj:`results/pyg/example_node/2`. The aggregated results over all the random seeds are *automatically* generated into :obj:`results/example/agg`, including the mean and standard deviation :obj:`_std` for each metric. Train/validation/test results are further saved into subdirectories, such as :obj:`results/example/agg/val`. Here, :obj:`stats.json` stores the results after each epoch aggregated across random seeds, and :obj:`best.json` stores the results of *the epoch with the highest validation accuracy*. #. **Run a batch of experiments:** A full example is specified in :obj:`run_batch.sh`. #. **Specify a base file:** GraphGym supports running a batch of experiments. To start, a user needs to select a base architecture via :obj:`--config`. The batch of experiments will be created by perturbing certain configurations of the base architecture. #. **(Optionally) specify a base file for computational budget:** Additionally, GraphGym allows a user to select a base architecture to *control the computational budget* for the grid search via :obj:`--config_budget`. The computational budget is currently measured by the number of trainable parameters, and the control is achieved by auto-adjusting the hidden dimensionality of the underlying GNN. If no :obj:`--config_budget` is provided, GraphGym will not control the computational budget. #. **Specify a grid file:** A grid file describes how to perturb the base file, in order to generate the batch of experiments. For example, the base file could specify an experiment of a 3-layer GCN for node classification on the :class:`~torch_geometric.datasets.Planetoid` datasets. Then, the grid file specifies how to perturb the experiment along different dimensions, such as the number of layers, the model architecture, the dataset, the level of task, etc. #. **Generate configuration files for the batch of experiments** based on the information specified above: For example, in :obj:`run_batch.sh`: .. code-block:: bash python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ --config_budget configs/${DIR}/${CONFIG}.yaml \ --grid grids/${DIR}/${GRID}.txt \ --out_dir configs #. **Launch the batch of experiments:** For example, in :obj:`run_batch.sh`: .. code-block:: bash bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP Each experiment will be repeated for :obj:`$REPEAT` times. We implemented a queue system to sequentially launch all the jobs, with :obj:`$MAX_JOBS` concurrent jobs running at the same time. In practice, our system works great when handling thousands of jobs. #. **Understand the results:** Experimental results will be automatically saved in directory :obj:`results/${CONFIG_NAME}_grid_${GRID_NAME}/`. In the example above, this amounts to :obj:`results/pyg/example_grid_example/`. After running each experiment, GraphGym additionally automatically averages across different models, saved in :obj:`results/pyg/example_grid_example/agg`. There, :obj:`val.csv` represents the validation accuracy for each model configuration at the *final* epoch, :obj:`val_best.csv` represents the results at the epoch with the highest average validation accuracy, and :obj:`val_best_epoch.csv` represents the results at the epoch with the highest validation accuracy averaged over different random seeds. When a test set split is provided, :obj:`test.csv` represents the test accuracy for each model configuration at the *final* epoch, :obj:`test_best.csv` represents the test set results at the epoch with the highest average validation accuracy, and :obj:`test_best_epoch.csv` represents the test set results at the epoch with the highest validation accuracy averaged over different random seeds. Customizing GraphGym -------------------- A highlight of GraphGym is that it allows you to easily register customized modules. For each project, you can have a unique GraphGym copy with different customized modules. For example, the `"Design Space for Graph Neural Networks" `__ and `"Identity-aware Graph Neural Networks" `__ papers represent two successful projects using customized GraphGym, and you may find more details about them `here `__. Eventually, every GraphGym-powered project will be unique. There are two ways for customizing GraphGym: #. Use the :obj:`graphgym/custom_graphgym` directory outside the :pyg:`PyG` package: You can register your customized modules here without touching :pyg:`PyG`. This use case will be great for your own customized project. #. Use the :obj:`torch_geometric/graphgym/contrib` directory inside the :pyg:`PyG` package: If you have come up with a nice customized module, you can directly copy your files into :obj:`torch_geometric/graphgym/contrib`, and **create a pull request** to :pyg:`PyG`. This way, your idea can ship with :pyg:`PyG` installations, and will have a much higher visibility and impact. Concretely, the supported customized modules includes - Activations: :obj:`custom_graphgym/act/` - Customized configurations: :obj:`custom_graphgym/config/` - Feature augmentations: :obj:`custom_graphgym/feature_augment/` - Feature encoders: :obj:`custom_graphgym/feature_encoder/` - GNN heads: :obj:`custom_graphgym/head/` - GNN layers: :obj:`custom_graphgym/layer/` - Data loaders: :obj:`custom_graphgym/loader/` - Loss functions: :obj:`custom_graphgym/loss/` - GNN network architectures: :obj:`custom_graphgym/network/` - Optimizers: :obj:`custom_graphgym/optimizer/` - GNN global pooling layers (for graph classification only): :obj:`custom_graphgym/pooling/` - GNN stages: :obj:`custom_graphgym/stage/` - GNN training pipelines: :obj:`custom_graphgym/train/` - Data transformations: :obj:`custom_graphgym/transform/` Within each directory, at least one example is provided that shows how to register customized modules via :meth:`torch_geometric.graphgym.register`. Note that new customized modules may result in new configurations. In these cases, new configuration fields can be registered via :obj:`custom_graphgym/config/`. ================================================ FILE: docs/source/advanced/hgam.rst ================================================ Hierarchical Neighborhood Sampling ================================== One of the design principles of :pyg:`PyG` is that models and data loading routines should be exchangeable to allow for flexible GNN and data loading experimentation. As such, models can usually be written in a data loading agnostic fashion, independent of whether one applies full-batch or mini-batch training strategies via, *e.g.*, :class:`~torch_geometric.loader.DataLoader`, :class:`~torch_geometric.loader.NeighborLoader` or :class:`~torch_geometric.loader.ClusterLoader`. However, in some scenarios, this flexibility comes at the cost of performance, as the model cannot exploit special characteristics of the underlying data loading routine. One such limitation is that a GNN trained with the :class:`~torch_geometric.loader.NeighborLoader` routine iteratively builds representations for *all* nodes at *all* depths of the network, although nodes sampled in later hops do not contribute to the node representations of seed nodes in later GNN layers anymore, thus performing useless computation. *Hierarchical Neighborhood Sampling* or *Hierarchical Graph Adjacency Matrix (HGAM)* is a technique available in :pyg:`PyG` to eliminate this overhead and speeds up training and inference in mini-batch GNNs. Its main idea is to progressively trim the adjacency matrix of the returned subgraph before inputting it to each GNN layer. It works seamlessly across several models, basically reducing the amount of compute necessary to generate the representations for the seed node of the given mini-batch. Crucially, HGAM recognizes that the computation of the final node representations is only necessary for the seed nodes (which are the real target of the batch computation). Thus, HGAM allows for every layer of the GNN to compute only the representations of the nodes that are necessary for that layer, leading to a reduction of the computation and a speed up of the training process that grows with the depth of the GNN being considered. In practice, this is achieved by **trimming the adjacency matrix** and the various **features matrices** as the computation proceeds throughout the GNN layers. This is in line with the fact that in order to compute the representation for the seed/target nodes (from which the mini-batch was build via sampling methods), the depth of the relevant neighborhood shrinks as we proceed through the layers of the GNN. The trimming applied by HGAM is possible as the nodes of the subgraph built via sampling are ordered according to a *Breadth First Search (BFS)* strategy, meaning that the rows and columns of the adjacency matrix refer to a node ordering that starts with the seed nodes (in any order) followed by the 1-hop neighbors of the first seed node, followed by the 1-hop sampled neighbors of the second seed node and so on. The BFS ordering of nodes in a mini-batch allows for incremental trimming (reduction) of the adjacency matrix of the subgraph. This progressive trimming is done in a computational convenient manner thanks to the BFS ordering that causes the nodes more distant from the seed nodes to be appear farther away in the list of ordered nodes. To support this trimming and implement it effectively, the :class:`~torch_geometric.loader.NeighborLoader` implementation in :pyg:`PyG` and in :pyg:`pyg-lib` additionally return the number of nodes and edges sampled in hop. This information allows for fast manipulation of the adjacency matrix, which in turns lead to great computation reduction. The :class:`~torch_geometric.loader.NeighborLoader` prepares this metadata via the dedicated attributes :obj:`num_sampled_nodes` and :obj:`num_sampled_edges`. It can be accessed from the :class:`~torch_geometric.data.Batch` object returned for both homogeneous and heterogeneous graphs. To sum up, HGAM is special data structure that enables efficient message passing computation in :class:`~torch_geometric.loader.NeighborLoader` scenarios. HGAM is implemented in :pyg:`PyG` and can be utilized via the special :meth:`~torch_geometric.utils.trim_to_layer` functionality. HGAM is currently an option that :pyg:`PyG` users are free to switch on, or leave it off *(current default)*. Usage ----- Here, we show examples of how to use the HGAM functionality in combination with :class:`~torch_geometric.loader.NeighborLoader`: * **Homogeneous data example:** .. code-block:: python from torch_geometric.datasets import Planetoid from torch_geometric.loader import NeighborLoader data = Planetoid(path, name='Cora')[0] loader = NeighborLoader( data, num_neighbors=[10] * 3, batch_size=128, ) batch = next(iter(loader)) print(batch) >>> Data(x=[1883, 1433], edge_index=[2, 5441], y=[1883], train_mask=[1883], val_mask=[1883], test_mask=[1883], batch_size=128, num_sampled_nodes=[4], num_sampled_edges=[3]) print(batch.num_sampled_nodes) >>> [128, 425, 702, 628] # Number of sampled nodes per hop/layer. print(batch.num_sampled_edges) >>> [520, 2036, 2885] # Number of sampled edges per hop/layer. * **Heterogeneous data example:** .. code-block:: python from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import NeighborLoader data = OGB_MAG(path)[0] loader = NeighborLoader( data, num_neighbors=[10] * 3, batch_size=128, input_nodes='paper', ) batch = next(iter(loader)) print(batch) >>> HeteroData( paper={ x=[2275, 128], num_sampled_nodes=[3], batch_size=128, }, author={ num_nodes=2541, num_sampled_nodes=[3], }, institution={ num_nodes=0, num_sampled_nodes=[3], }, field_of_study={ num_nodes=0, num_sampled_nodes=[3], }, (author, affiliated_with, institution)={ edge_index=[2, 0], num_sampled_edges=[2], }, (author, writes, paper)={ edge_index=[2, 3255], num_sampled_edges=[2], }, (paper, cites, paper)={ edge_index=[2, 2691], num_sampled_edges=[2], }, (paper, has_topic, field_of_study)={ edge_index=[2, 0], num_sampled_edges=[2], } ) print(batch['paper'].num_sampled_nodes) >>> [128, 508, 1598] # Number of sampled paper nodes per hop/layer. print(batch['author', 'writes', 'paper'].num_sampled_edges) >>>> [629, 2621] # Number of sampled author<>paper edges per hop/layer. The attributes :obj:`num_sampled_nodes` and :obj:`num_sampled_edges` can be used by the :meth:`~torch_geometric.utils.trim_to_layer` function inside the GNN: .. code-block:: python from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader from torch_geometric.nn import SAGEConv from torch_geometric.utils import trim_to_layer dataset = Reddit(path) loader = NeighborLoader(data, num_neighbors=[10, 5, 5], ...) class GNN(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, num_layers: int): super().__init__() self.convs = ModuleList([SAGEConv(in_channels, 64)]) for _ in range(num_layers - 1): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.lin = Linear(hidden_channels, out_channels) def forward( self, x: Tensor, edge_index: Tensor, num_sampled_nodes_per_hop: List[int], num_sampled_edges_per_hop: List[int], ) -> Tensor: for i, conv in enumerate(self.convs): # Trim edge and node information to the current layer `i`. x, edge_index, _ = trim_to_layer( i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop, x, edge_index) x = conv(x, edge_index).relu() return self.lin(x) Examples -------- We provide full examples of HGAM in the :pyg:`PyG` :obj:`examples/` folder: * :obj:`examples/hierarchical_sampling.py`: An `example `__ to show-case the basic usage of HGAM. * :obj:`examples/hetero/hierarchical_sage.py`: An `example `__ of HGAM on heterogeneous graphs. ================================================ FILE: docs/source/advanced/jit.rst ================================================ TorchScript Support =================== TorchScript is a way to create serializable and optimizable models from :pytorch:`PyTorch` code. Any TorchScript program can be saved from a :python:`Python` process and loaded in a process where there is no :python:`Python` dependency. If you are unfamilar with TorchScript, we recommend to read the official "`Introduction to TorchScript `_" tutorial first. Converting GNN Models --------------------- .. note:: From :pyg:`PyG` 2.5 (and onwards), GNN layers are now fully compatible with :meth:`torch.jit.script` without any modification needed. If you are on an earlier version of :pyg:`PyG`, consider to convert your GNN layers into "jittable" instances first by calling :meth:`~torch_geometric.nn.conv.MessagePassing.jittable`. Converting your :pyg:`PyG` model to a TorchScript program is straightforward and requires only a few code changes. Let's consider the following model: .. code-block:: python import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv class GNN(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, 64) self.conv2 = GCNConv(64, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = F.relu(x) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) model = GNN(dataset.num_features, dataset.num_classes) The instantiated model can now be directly passed into :meth:`torch.jit.script`: .. code-block:: python model = torch.jit.script(model) That is all you need to know on how to convert your :pyg:`PyG` models to TorchScript programs. You can have a further look at our JIT examples that show-case how to obtain TorchScript programs for `node `_ and `graph classification `_ models. Creating Jittable GNN Operators -------------------------------- All :pyg:`PyG` :class:`~torch_geometric.nn.conv.MessagePassing` operators are tested to be convertible to a TorchScript program. However, if you want your own GNN module to be compatible with :meth:`torch.jit.script`, you need to account for the following two things: 1. As one would expect, your :meth:`forward` code may need to be adjusted so that it passes the TorchScript compiler requirements, *e.g.*, by adding type notations. 2. You need to tell the :class:`~torch_geometric.nn.conv.MessagePassing` module the types that you pass to its :meth:`~torch_geometric.nn.conv.MessagePassing.propagate` function. This can be achieved in two different ways: 1. Declaring the type of propagation arguments in a dictionary called :obj:`propagate_type`: .. code-block:: python from typing import Optional from torch import Tensor from torch_geometric.nn import MessagePassing class MyConv(MessagePassing): propagate_type = {'x': Tensor, 'edge_weight': Optional[Tensor] } def forward( self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, ) -> Tensor: return self.propagate(edge_index, x=x, edge_weight=edge_weight) 2. Declaring the type of propagation arguments as a comment inside your module: .. code-block:: python from typing import Optional from torch import Tensor from torch_geometric.nn import MessagePassing class MyConv(MessagePassing): def forward( self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, ) -> Tensor: # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) return self.propagate(edge_index, x=x, edge_weight=edge_weight) If none of these options are given, the :class:`~torch_geometric.nn.conv.MessagePassing` module will infer the arguments of :meth:`~torch_geometric.nn.conv.MessagePassing.propagate` to be of type :class:`torch.Tensor` (mimicking the default type that TorchScript is inferring for non-annotated arguments). ================================================ FILE: docs/source/advanced/remote.rst ================================================ Scaling Up GNNs via Remote Backends =================================== :pyg:`PyG` (2.2 and beyond) includes numerous primitives to easily integrate with simple paradigms for scalable graph machine learning, enabling users to train GNNs on graphs far larger than the size of their machine's available memory. It does so by introducing simple, easy-to-use, and extensible abstractions of a :class:`torch_geometric.data.FeatureStore` and a :class:`torch_geometric.data.GraphStore` that plug directly into existing familiar :pyg:`PyG` interfaces. Defining a :class:`~torch_geometric.data.FeatureStore` allows users to leverage node (and soon, edge) features stored remotely, and defining a :class:`~torch_geometric.data.GraphStore` allows users to leverage graph structure information stored remotely. Together, they allow for powerful GNN scalability with low developer friction. .. warning:: The remote backend APIs discussed here may change in the future as we continuously work to improve their ease-of-use and generalizability. .. note:: Currently, the :class:`~torch_geometric.data.FeatureStore` and :class:`~torch_geometric.data.GraphStore` only support *heterogeneous graphs*, and do not support edge features. Homogeneous graph and edge feature support is coming soon. Background ---------- An instantiated Graph Neural Network consists of two types of data: - **Node and/or edge feature information:** Dense vectors corresponding to attributes of the nodes and edges in a graph - **Graph structure information:** The nodes in the graph and the edges that connect them An immediate observation of GNNs is that scaling to data larger than the available memory of a chosen accelerator requires training on sampled subgraphs (which form mini-batches), instead of the full graph at once (full-batch training). While this method adds stochasticity to the learning process, it reduces the memory requirements of the accelerator to those of the sampled subgraphs. .. figure:: ../_figures/remote_1.png :align: center :width: 100% **Figure 1:** The classical mini-batch GNN training paradigm. However, while mini-batch training reduces the memory requirements of the chosen accelerator, it is not a silver bullet for all graph learning scalability problems. In particular, since one must sample subgraphs to pass to the accelerator at each iteration of the learning process, the graph and features are traditionally required to be stored in the CPU DRAM of a user's machine. At large scale, this requirement can become quite burdensome: - Acquiring instances with enough CPU DRAM to store a graph and features is challenging - Training with data parallelism requires replicating the graph and features in each compute node - Graphs and features can easily be much larger than the memory of a single machine Scalability to very large graphs and features beyond the memory requirements of a single machine thus requires moving these data structures out-of-core and only processing sampled subgraphs on a node that performs computation. In order to achieve this goal, :pyg:`PyG` relies on two primary abstractions to store feature information and graph structure: Features are stored in a key-value :class:`~torch_geometric.data.FeatureStore`, which must support efficient random access. Graph information is stored in a :class:`~torch_geometric.data.GraphStore`, which must support efficient sampling for the samplers defined to operate on the :class:`~torch_geometric.data.GraphStore` instance. .. figure:: ../_figures/remote_2.png :align: center :width: 100% **Figure 2:** Graph data storage layout between remote storage and a training instance. In :pyg:`PyG` (2.2 and beyond), the separation of graph data into its features and structure information, the storage of this information in locations potentially remote to the actual training node, and the interactions between these components, are all completely abstracted from the end user. As long as the :class:`~torch_geometric.data.FeatureStore` and :class:`~torch_geometric.data.GraphStore` are defined appropriately (keeping in mind the aforementioned performance requirements), :pyg:`PyG` handles the rest! Feature Store ------------- A :class:`torch_geometric.data.FeatureStore` holds features for the nodes and edges of a graph. Feature storage is often the primary storage bottleneck in graph learning applications, as storing a graph's layout information (*i.e.* the :obj:`edge_index`) is relatively cheap (~32 bytes per edge). :pyg:`PyG` provides a common interface for various :class:`~torch_geometric.data.FeatureStore` implementations to interface with its core learning API. The implementation details of a :class:`~torch_geometric.data.FeatureStore` are abstracted from :pyg:`PyG` through a CRUD-like interface. In particular, implementors of the :class:`~torch_geometric.data.FeatureStore` abstraction are expected to primarily override :meth:`~torch_geometric.data.FeatureStore.put_tensor`, :meth:`~torch_geometric.data.FeatureStore.get_tensor`, and :meth:`~torch_geometric.data.FeatureStore.remove_tensor` functionalities. Doing so both enables :pyg:`PyG` to leverage the features stored in the implementation and allows a user to employ a pythonic interface to inspect and modify the :class:`~torch_geometric.data.FeatureStore` elements: .. code-block:: python feature_store = CustomFeatureStore() paper_features = ... # [num_papers, num_paper_features] author_features = ... # [num_authors, num_author_features] # Add features: feature_store['paper', 'x', None] = paper_features feature_store['author', 'x', None] = author_features # Access features: assert torch.equal(feature_store['paper', 'x'], paper_features) assert torch.equal(feature_store['paper'].x, paper_features) assert torch.equal(feature_store['author', 'x', 0:20], author_features[0:20]) Common implementations of the :class:`~torch_geometric.data.FeatureStore` abstractions are key-value stores, *e.g.*, backends such as :obj:`memcached`, :obj:`LevelDB`, :obj:`RocksDB` are all viable performant options. Graph Store and Sampler ----------------------- A :class:`torch_geometric.data.GraphStore` holds the edge indices that define relationships between nodes in a graph. The goal of the :class:`~torch_geometric.data.GraphStore` is to store graph information in a manner that allows for efficient sampling from root nodes, according to a sampling algorithm of the developer's choice. Similar to the :class:`~torch_geometric.data.FeatureStore`, :pyg:`PyG` provides a common interface for various :class:`~torch_geometric.data.GraphStore` implementations to interface with its core learning API. However, unlike the :class:`~torch_geometric.data.FeatureStore`, the :class:`~torch_geometric.data.GraphStore` does not need to provide efficient random access for all its elements; rather, it needs to define a representation that provides efficient subgraph sampling. An example usage of the interface is shown below: .. code-block:: python graph_store = CustomGraphStore() edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) # Put edges: graph_store['edge', 'coo'] = coo # Access edges: row, col = graph_store['edge', 'coo'] assert torch.equal(row, edge_index[0]) assert torch.equal(col, edge_index[1]) Common implementations of the :class:`~torch_geometric.data.GraphStore` are graph databases, *e.g.*, :obj:`Neo4j`, :obj:`TigerGraph`, :obj:`ArangoDB`, :obj:`Kùzu` are all viable performant options. We provide an example of using :pyg:`PyG` in combination with the :obj:`Kùzu` database `here `__. A graph sampler is tightly coupled to the given :class:`~torch_geometric.data.GraphStore`, and operates on the :class:`~torch_geometric.data.GraphStore` to produce sampled subgraphs from input nodes. Different sampling algorithms are implemented behind the :class:`torch_geometric.sampler.BaseSampler` interface. By default, :pyg:`PyG's` default in-memory sampler pulls all edge indices from the :class:`~torch_geometric.data.GraphStore` into the training node memory, converts them to compressed sparse column (CSC) format, and leverages pre-built in-memory sampling routines. However, custom sampler implementations may choose to call specialized :class:`~torch_geometric.data.GraphStore` methods by implementing the :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` and/or :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges` of the :class:`~torch_geometric.sampler.BaseSampler` class for efficiency reasons (*e.g.*, for performing sampling directly on the remote :class:`~torch_geometric.data.GraphStore`): .. code-block:: python # `CustomGraphSampler` knows how to sample on `CustomGraphStore`: node_sampler = CustomGraphSampler( graph_store=graph_store, num_neighbors=[10, 20], ... ) Data Loader ----------- :pyg:`PyG` does not define a domain-specific language for sampling that must be implemented by the :class:`~torch_geometric.data.GraphStore`; rather, the sampler and the :class:`~torch_geometric.data.GraphStore` are tightly coupled together through a data loader. :pyg:`PyG` provides two data loaders out-of-the-box: a :class:`torch_geometric.loader.NodeLoader` that samples subgraphs from input nodes for use in node classification tasks, and a :class:`torch_geometric.loader.LinkLoader` that samples subgraphs from either side of an edge for use in link prediction tasks. These data loaders require a :class:`~torch_geometric.data.FeatureStore`, a :class:`~torch_geometric.data.GraphStore`, and a graph sampler as input, and internally call the sampler's :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` or :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges` method to perform subgraph sampling: .. code-block:: python # Instead of passing PyG data objects, we now pass a tuple # of the `FeatureStore` and `GraphStore as input data: loader = NodeLoader( data=(feature_store, graph_store), node_sampler=node_sampler, batch_size=20, input_nodes='paper', ) for batch in loader: pass Putting it All Together ----------------------- At a high level, the components listed above all work together to provide support for scaling up GNNs within :pyg:`PyG`. - The **data loader** (precisely, each worker) leverages a :class:`~torch_geometric.sampler.BaseSampler` to make a sampling request to the :class:`~torch_geometric.data.GraphStore`. - Upon receipt of a response, the data loader subsequently queries the :class:`~torch_geometric.data.FeatureStore` for features associated with the nodes and edges of the sampled subgraphs. - The data loader subsequently constructs a final mini-batch from graph structure and feature information to send to the accelerator for forward/backward passes. - Repeat until convergence. All of the outlined classes speak through common interfaces, making them extensible, generalizable, and easy to integrate with the :pyg:`PyG` you use today: .. figure:: ../_figures/remote_3.png :align: center :width: 80% **Figure 3:** The common interfaces (and data flow) uniting the :class:~torch_geometric.data.`FeatureStore`, :class:`~torch_geometric.data.GraphStore`, graph sampler, and data loader. To get started with scalability, we recommend inspecting the interfaces listed above and defining your own :class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`, and :class:`~torch_geometric.sampler.BaseSampler` implementations behind them. Once a :class:`~torch_geometric.data.FeatureStore`, a :class:`~torch_geometric.data.GraphStore`, and a :class:`~torch_geometric.sampler.BaseSampler` are correctly implemented, simply pass them as parameters to a :class:`~torch_geometric.loader.NodeLoader` or a :class:`~torch_geometric.loader.LinkLoader`, and the rest of :pyg:`PyG` will work seamlessly and similar to any pure in-memory application. Since this feature is still undergoing heavy development, please feel free to reach out to the :pyg:`PyG` core team either on :github:`null` `GitHub `_ or :slack:`null` `Slack `_ if you have any questions, comments or concerns. ================================================ FILE: docs/source/advanced/sparse_tensor.rst ================================================ Memory-Efficient Aggregations ============================= The :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface of :pyg:`PyG` relies on a gather-scatter scheme to aggregate messages from neighboring nodes. For example, consider the message passing layer .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \textrm{MLP}(\mathbf{x}_j - \mathbf{x}_i), that can be implemented as: .. code-block:: python from torch_geometric.nn import MessagePassing x = ... # Node features of shape [num_nodes, num_features] edge_index = ... # Edge indices of shape [2, num_edges] class MyConv(MessagePassing): def __init__(self): super().__init__(aggr="add") def forward(self, x, edge_index): return self.propagate(edge_index, x=x) def message(self, x_i, x_j): return MLP(x_j - x_i) Under the hood, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementation produces a code that looks as follows: .. code-block:: python from torch_geometric.utils import scatter x = ... # Node features of shape [num_nodes, num_features] edge_index = ... # Edge indices of shape [2, num_edges] x_j = x[edge_index[0]] # Source node features [num_edges, num_features] x_i = x[edge_index[1]] # Target node features [num_edges, num_features] msg = MLP(x_j - x_i) # Compute message for each edge # Aggregate messages based on target node indices out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum') While the gather-scatter formulation generalizes to a lot of useful GNN implementations, it has the disadvantage of explicitly materalizing :obj:`x_j` and :obj:`x_i`, resulting in a high memory footprint on large and dense graphs. Luckily, not all GNNs need to be implemented by explicitly materalizing :obj:`x_j` and/or :obj:`x_i`. In some cases, GNNs can also be implemented as a simple-sparse matrix multiplication. As a general rule of thumb, this holds true for GNNs that do not make use of the central node features :obj:`x_i` or multi-dimensional edge features when computing messages. For example, the :class:`~torch_geometric.nn.conv.GINConv` layer .. math:: \mathbf{x}^{\prime}_i = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right), is equivalent to computing .. math:: \mathbf{X}^{\prime} = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{X} + \mathbf{A}\mathbf{X} \right), where :math:`\mathbf{A}` denotes a sparse adjacency matrix of shape :obj:`[num_nodes, num_nodes]`. This formulation allows to leverage dedicated and fast sparse-matrix multiplication implementations. In :pyg:`null` **PyG >= 1.6.0**, we officially introduce better support for sparse-matrix multiplication GNNs, resulting in a **lower memory footprint** and a **faster execution time**. As a result, we introduce the :class:`SparseTensor` class (from the :obj:`torch_sparse` package), which implements fast forward and backward passes for sparse-matrix multiplication based on the `"Design Principles for Sparse Matrix Multiplication on the GPU" `_ paper. Using the :class:`SparseTensor` class is straightforward and similar to the way :obj:`scipy` treats sparse matrices: .. code-block:: python from torch_sparse import SparseTensor adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=..., sparse_sizes=(num_nodes, num_nodes)) # value is optional and can be None # Obtain different representations (COO, CSR, CSC): row, col, value = adj.coo() rowptr, col, value = adj.csr() colptr, row, value = adj.csc() adj = adj[:100, :100] # Slicing, indexing and masking support adj = adj.set_diag() # Add diagonal entries adj_t = adj.t() # Transpose out = adj.matmul(x) # Sparse-dense matrix multiplication adj = adj.matmul(adj) # Sparse-sparse matrix multiplication # Creating SparseTensor instances: adj = SparseTensor.from_dense(mat) adj = SparseTensor.eye(100, 100) adj = SparseTensor.from_scipy(mat) Our :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface can handle both :obj:`torch.Tensor` and :class:`SparseTensor` as input for propagating messages. However, when holding a directed graph in :class:`SparseTensor`, you need to make sure to input the **transposed sparse matrix** to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`: .. code-block:: python conv = GCNConv(16, 32) out1 = conv(x, edge_index) out2 = conv(x, adj.t()) assert torch.allclose(out1, out2) conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32))) out1 = conv(x, edge_index) out2 = conv(x, adj.t()) assert torch.allclose(out1, out2) To leverage sparse-matrix multiplications, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface introduces the :func:`~torch_geometric.nn.conv.message_passing.message_and_aggregate` function (which fuses the :func:`~torch_geometric.nn.conv.message_passing.message` and :func:`~torch_geometric.nn.conv.message_passing.aggregate` functions into a single computation step), which gets called whenever it is implemented and receives a :class:`SparseTensor` as input for :obj:`edge_index`. With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemented as follows: .. code-block:: python import torch_sparse class GINConv(MessagePassing): def __init__(self): super().__init__(aggr="add") def forward(self, x, edge_index): out = self.propagate(edge_index, x=x) return MLP((1 + eps) x + out) def message(self, x_j): return x_j def message_and_aggregate(self, adj_t, x): return torch_sparse.matmul(adj_t, x, reduce=self.aggr) Playing around with the new :class:`SparseTensor` format is straightforward since all of our GNNs work with it out-of-the-box. To convert the :obj:`edge_index` format to the newly introduced :class:`SparseTensor` format, you can make use of the :class:`torch_geometric.transforms.ToSparseTensor` transform: .. code-block:: python import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor()) data = dataset[0] >>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...) class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16, cached=True) self.conv2 = GCNConv(16, dataset.num_classes, cached=True) def forward(self, x, adj_t): x = self.conv1(x, adj_t) x = F.relu(x) x = self.conv2(x, adj_t) return F.log_softmax(x, dim=1) model = GNN() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(data): model.train() optimizer.zero_grad() out = model(data.x, data.adj_t) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() return float(loss) for epoch in range(1, 201): loss = train(data) All code remains the same as before, except for the :obj:`data` transform via :obj:`T.ToSparseTensor()`. As an additional advantage, :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementations that utilize the :class:`SparseTensor` class are deterministic on the GPU since aggregations no longer rely on atomic operations. Notably, the GNN layer execution slightly changes in case GNNs incorporate single or multi-dimensional edge information :obj:`edge_weight` or :obj:`edge_attr` into their message passing formulation, respectively. In particular, it is now expected that these attributes are directly added as values to the :class:`SparseTensor` object. Instead of calling the GNN as .. code-block:: python conv = GMMConv(16, 32, dim=3) out = conv(x, edge_index, edge_attr) we now execute our GNN operator as .. code-block:: python conv = GMMConv(16, 32, dim=3) adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr) out = conv(x, adj.t()) .. note:: Since this feature is still experimental, some operations, *e.g.*, graph pooling methods, may still require you to input the :obj:`edge_index` format. You can convert :obj:`adj_t` back to :obj:`(edge_index, edge_attr)` via: .. code-block:: python row, col, edge_attr = adj_t.t().coo() edge_index = torch.stack([row, col], dim=0) Please let us know what you think of :class:`SparseTensor`, how we can improve it, and whenever you encounter any unexpected behavior. ================================================ FILE: docs/source/cheatsheet/data_cheatsheet.rst ================================================ Dataset Cheatsheet ================== .. note:: This dataset statistics table is a **work in progress**. Please consider helping us filling its content by providing statistics for individual datasets. See `here `__ and `here `__ for examples on how to do so. Homogeneous Datasets -------------------- .. list-table:: :widths: 50 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes/#tasks {% for cls in torch_geometric.datasets.homo_datasets %} * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %} - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }} {% for child in torch_geometric.datasets.utils.get_children(cls) %} * - └─ {{ child }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }} {% endfor %} {% endfor %} Heterogeneous Datasets ---------------------- .. list-table:: :widths: 50 30 10 10 :header-rows: 1 * - Name - #nodes/#edges - #features - #classes/#tasks {% for cls in torch_geometric.datasets.hetero_datasets %} * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %} - - - {% for child in torch_geometric.datasets.utils.get_children(cls) %} * - └─ **{{torch_geometric.datasets.utils.get_type(child)}} Type**: {{ child }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes/#edges', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }} {% endfor %} {% endfor %} Synthetic Datasets ------------------ .. list-table:: :widths: 50 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes/#tasks {% for cls in torch_geometric.datasets.synthetic_datasets %} * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %} - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }} {% for child in torch_geometric.datasets.utils.get_children(cls) %} * - └─ {{ child }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }} {% endfor %} {% endfor %} ================================================ FILE: docs/source/cheatsheet/gnn_cheatsheet.rst ================================================ GNN Cheatsheet ============== * :class:`~torch_sparse.SparseTensor`: If checked (✓), supports message passing based on :class:`torch_sparse.SparseTensor`, *e.g.*, :obj:`GCNConv(...).forward(x, adj_t)`. See `here <../advanced/sparse_tensor.html>`__ for the accompanying tutorial. * :obj:`edge_weight`: If checked (✓), supports message passing with one-dimensional edge weight information, *e.g.*, :obj:`GraphConv(...).forward(x, edge_index, edge_weight)`. * :obj:`edge_attr`: If checked (✓), supports message passing with multi-dimensional edge feature information, *e.g.*, :obj:`GINEConv(...).forward(x, edge_index, edge_attr)`. * **bipartite**: If checked (✓), supports message passing in bipartite graphs with potentially different feature dimensionalities for source and destination nodes, *e.g.*, :obj:`SAGEConv(in_channels=(16, 32), out_channels=64)`. * **static**: If checked (✓), supports message passing in static graphs, *e.g.*, :obj:`GCNConv(...).forward(x, edge_index)` with :obj:`x` having shape :obj:`[batch_size, num_nodes, in_channels]`. * **lazy**: If checked (✓), supports lazy initialization of message passing layers, *e.g.*, :obj:`SAGEConv(in_channels=-1, out_channels=64)`. Graph Neural Network Operators ------------------------------ .. list-table:: :widths: 40 10 10 10 10 10 10 :header-rows: 1 * - Name - :class:`~torch_sparse.SparseTensor` - :obj:`edge_weight` - :obj:`edge_attr` - bipartite - static - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if not torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) and not torch_geometric.nn.conv.utils.processes_hypergraphs(cls) and not torch_geometric.nn.conv.utils.processes_point_clouds(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %} - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} Heterogeneous Graph Neural Network Operators -------------------------------------------- .. list-table:: :widths: 40 10 10 10 10 10 10 :header-rows: 1 * - Name - :class:`~torch_sparse.SparseTensor` - :obj:`edge_weight` - :obj:`edge_attr` - bipartite - static - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %} - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} Hypergraph Neural Network Operators ----------------------------------- .. list-table:: :widths: 40 10 10 10 10 10 10 :header-rows: 1 * - Name - :class:`~torch_sparse.SparseTensor` - :obj:`edge_weight` - :obj:`edge_attr` - bipartite - static - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if torch_geometric.nn.conv.utils.processes_hypergraphs(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %} - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} Point Cloud Neural Network Operators ------------------------------------ .. list-table:: :widths: 80 10 10 :header-rows: 1 * - Name - bipartite - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if torch_geometric.nn.conv.utils.processes_point_clouds(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %} - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} ================================================ FILE: docs/source/conf.py ================================================ import datetime import os.path as osp import sys import pyg_sphinx_theme import torch_geometric author = 'PyG Team' project = 'pytorch_geometric' version = torch_geometric.__version__ copyright = f'{datetime.datetime.now().year}, {author}' sys.path.append(osp.join(osp.dirname(pyg_sphinx_theme.__file__), 'extension')) extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', 'sphinx_copybutton', 'nbsphinx', 'pyg', ] html_theme = 'pyg_sphinx_theme' html_logo = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/' 'master/pyg_sphinx_theme/static/img/pyg_logo.png') html_favicon = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/' 'master/pyg_sphinx_theme/static/img/favicon.png') html_static_path = ['_static'] templates_path = ['_templates'] add_module_names = False autodoc_member_order = 'bysource' suppress_warnings = ['autodoc.import_object'] intersphinx_mapping = { 'python': ('https://docs.python.org/', None), # 'numpy': ('http://docs.scipy.org/doc/numpy', None), 'pandas': ('https://pandas.pydata.org/docs', None), 'torch': ('https://pytorch.org/docs/main', None), } typehints_use_rtype = False typehints_defaults = 'comma' nbsphinx_thumbnails = { 'tutorial/create_gnn': '_static/thumbnails/create_gnn.png', 'tutorial/heterogeneous': '_static/thumbnails/heterogeneous.png', 'tutorial/create_dataset': '_static/thumbnails/create_dataset.png', 'tutorial/load_csv': '_static/thumbnails/load_csv.png', 'tutorial/dataset_splitting': '_static/thumbnails/dataset_splitting.png', 'tutorial/neighbor_loader': '_static/thumbnails/neighbor_loader.png', 'tutorial/point_cloud': '_static/thumbnails/point_cloud.png', 'tutorial/explain': '_static/thumbnails/explain.png', 'tutorial/shallow_node_embeddings': '_static/thumbnails/shallow_node_embeddings.png', 'tutorial/distributed_pyg': '_static/thumbnails/distributed_pyg.png', 'tutorial/multi_gpu_vanilla': '_static/thumbnails/multi_gpu_vanilla.png', 'tutorial/multi_node_multi_gpu_vanilla': '_static/thumbnails/multi_gpu_vanilla.png', 'tutorial/graph_transformer': '_static/thumbnails/graph_transformer.png', } def rst_jinja_render(app, _, source): if hasattr(app.builder, 'templates'): rst_context = {'torch_geometric': torch_geometric} source[0] = app.builder.templates.render_string(source[0], rst_context) def setup(app): r"""Setup sphinx application.""" app.connect('source-read', rst_jinja_render) app.add_js_file('js/version_alert.js') # Do not drop type hints in signatures: del app.events.listeners['autodoc-process-signature'] ================================================ FILE: docs/source/external/resources.rst ================================================ External Resources ================== * Fey *et al.*: **PyG 2.0: Scalable Learning on Real World Graphs** [`Paper `__] * Matthias Fey and Jan E. Lenssen: **Fast Graph Representation Learning with** :pyg:`null` **PyTorch Geometric** [`Paper `_, `Slides (3.3MB) `__, `Poster (2.3MB) `__, `Notebook `__] * :stanford:`Stanford CS224W: Machine Learning with Graphs`: **Graph Machine Learning lectures** [:youtube:`null` `Youtube `__] * :stanford:`Stanford University`: **A collection of graph machine learning tutorial blog posts**, fully realized with :pyg:`null` **PyG** [`Website `__] * Soumith Chintala: **Automatic Differentiation,** :pytorch:`null` **PyTorch and Graph Neural Networks** [`Talk (starting from 26:15) `__] * Stanford University: **Graph Neural Networks using** :pyg:`null` **PyTorch Geometric** [:youtube:`null` `YouTube (starting from 33:33) `__] * Antonio Longa, Gabriele Santin and Giovanni Pellegrini: :pyg:`null` **PyTorch Geometric Tutorial** [`Website `__, :github:`null` `GitHub `__] * DAIR.AI | elvis: **Introduction to GNNs with** :pyg:`null` **PyTorch Geometric** [`Website `__, :colab:`null` `Colab `__] * Nicolas Chaulet *et al.*: **PyTorch Points 3D** - A framework for running common deep learning models for point cloud analysis tasks that heavily relies on Pytorch Geometric [:github:`null` `GitHub `__, `Documentation `__] * Weihua Hu *et al.*: :ogb:`null` **Open Graph Benchmark** - A collection of large-scale benchmark datasets, data loaders, and evaluators for graph machine learning, including :pyg:`PyG` support and examples [`Website `__, :github:`null` `GitHub `__] * **DeepSNAP** - A :pytorch:`PyTorch` library that bridges between graph libraries such as NetworkX and :pyg:`PyG` [:github:`null` `GitHub `__, `Documentation `__] * **Quiver** - A distributed graph learning library for :pyg:`PyG` [:github:`null` `GitHub `__] * Benedek Rozemberczki: **PyTorch Geometric Temporal** - A temporal GNN library built upon :pyg:`PyG` [:github:`null` `GitHub `__, `Documentation `__] * Yixuan He: **PyTorch Geometric Signed Directed** - A signed and directed GNN library built upon :pyg:`PyG` [:github:`null` `GitHub `__, `Documentation `__] * Steeve Huang: **Hands-on Graph Neural Networks with** :pytorch:`null` **PyTorch &** :pyg:`null` **PyTorch Geometric** [`Tutorial `__, `Code `__] * Francesco Landolfi: :pyg:`null` **PyTorch Geometric Tutorial** [`PDF (0.4MB) `__] * Sachin Sharma: **How to Deploy (almost) any** :pyg:`null` **PyTorch Geometric Model on Nvidia's Triton Inference Server with an Application to Amazon Product Recommendation and ArangoDB** [`Blog `__] * Amitoz Azad: **torch_pdegraph** - Solving PDEs on Graphs with :pyg:`PyG` [`Devpost `__, :github:`null` `GitHub `__] * Amitoz Azad: **Primal-Dual Algorithm for Total Variation Processing on Graphs** [`Jupyter `__] * Manan Goel: **Recommending Amazon Products using Graph Neural Networks in** :pyg:`null` **PyTorch Geometric** [:wandb:`null` `W&B Report `__] * Kùzu: **Remote Backend for** :pyg:`null` **PyTorch Geometric** [:colab:`null` `Colab `__] * Aniket Saxena: **Graph Neural Networks-based Explanation App using** :pyg:`null` **PyTorch Geometric** [`Website `__, :github:`null` `GitHub `__] * Mashaan Alshammari: **Graph Attention in** :pyg:`null` **PyTorch Geometric** [:youtube:`null` `Youtube `__, :github:`null` `GitHub `__] * Mashaan Alshammari: **Graph Convolutional Networks (GCNs) in** :pytorch:`null` **PyTorch** [:youtube:`null` `Youtube `__, :github:`null` `GitHub `__] * Mashaan Alshammari: **GCN and SGC in** :pytorch:`null` **PyTorch** [:youtube:`null` `Youtube `__, :github:`null` `GitHub `__], * Mashaan Alshammari: **GCN Variants SGC and ASGC in** :pytorch:`null` **PyTorch** [:youtube:`null` `Youtube `__, :github:`null` `GitHub `__] ================================================ FILE: docs/source/get_started/colabs.rst ================================================ Colab Notebooks and Video Tutorials =================================== Official Examples ----------------- We have prepared a list of :colab:`Colab` notebooks that practically introduces you to the world of **Graph Neural Networks** with :pyg:`PyG`: 1. `Introduction: Hands-on Graph Neural Networks `__ 2. `Node Classification with Graph Neural Networks `__ 3. `Graph Classification with Graph Neural Networks `__ 4. `Scaling Graph Neural Networks `__ 5. `Point Cloud Classification with Graph Neural Networks `__ 6. `Explaining GNN Model Predictions using `__ :captum:`null` `Captum `__ 7. `Customizing Aggregations within Message Passing `__ 8. `Node Classification Instrumented with `__ :wandb:`null` `Weights&Biases `__ 9. `Graph Classification Instrumented with `__ :wandb:`null` `Weights&Biases `__ 10. `Link Prediction on MovieLens `__ 11. `Link Regression on MovieLens `__ 12. `Pooling in Graph Neural Networks with `__ :tgp:`null` `tgp `__ All :colab:`Colab` notebooks are released under the MIT license. Stanford CS224W Tutorials ------------------------- .. image:: https://data.pyg.org/img/cs224w_tutorials.png :align: center :width: 941px :target: https://medium.com/stanford-cs224w .. raw:: html
The :stanford:`null` `Stanford CS224W `__ course has collected a set of `graph machine learning tutorial blog posts `__, fully realized with :pyg:`PyG`. Students worked on projects spanning all kinds of tasks, model architectures and applications. All tutorials also link to a :colab:`Colab` with the code in the tutorial for you to follow along with as you read it! PyTorch Geometric Tutorial Project ---------------------------------- The :pyg:`null` `PyTorch Geometric Tutorial `__ project provides **video tutorials and** :colab:`null` **Colab notebooks** for a variety of different methods in :pyg:`PyG`: 1. Introduction [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 2. :pytorch:`PyTorch` basics [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 3. Graph Attention Networks (GATs) [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 4. Spectral Graph Convolutional Layers [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 5. Aggregation Functions in GNNs [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 6. (Variational) Graph Autoencoders (GAE and VGAE) [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 7. Adversarially Regularized Graph Autoencoders (ARGA and ARGVA) [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 8. Graph Generation [:youtube:`null` `YouTube `__] 9. Recurrent Graph Neural Networks [:youtube:`null` `YouTube `__, :colab:`null` `Colab (Part 1) `__, :colab:`null` `Colab (Part 2) `__] 10. DeepWalk and Node2Vec [:youtube:`null` `YouTube (Theory) `__, :youtube:`null` `YouTube (Practice) `__, :colab:`null` `Colab `__] 11. Edge analysis [:youtube:`null` `YouTube `__, :colab:`null` `Colab (Link Prediction) `__, :colab:`null` `Colab (Label Prediction) `__] 12. Data handling in :pyg:`PyG` (Part 1) [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 13. Data handling in :pyg:`PyG` (Part 2) [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 14. MetaPath2vec [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] 15. Graph pooling (DiffPool) [:youtube:`null` `YouTube `__, :colab:`null` `Colab `__] ================================================ FILE: docs/source/get_started/introduction.rst ================================================ Introduction by Example ======================= We shortly introduce the fundamental concepts of :pyg:`PyG` through self-contained examples. For an introduction to Graph Machine Learning, we refer the interested reader to the :stanford:`null` `Stanford CS224W: Machine Learning with Graphs `__ lectures. For an interactive introduction to :pyg:`PyG`, we recommend our carefully curated :colab:`null` `Google Colab `__ notebooks. At its core, :pyg:`PyG` provides the following main features: .. contents:: :local: Data Handling of Graphs ----------------------- A graph is used to model pairwise relations (edges) between objects (nodes). A single graph in :pyg:`PyG` is described by an instance of :class:`torch_geometric.data.Data`, which holds the following attributes by default: - :obj:`data.x`: Node feature matrix with shape :obj:`[num_nodes, num_node_features]` - :obj:`data.edge_index`: Graph connectivity in `COO format `_ with shape :obj:`[2, num_edges]` and type :obj:`torch.long` - :obj:`data.edge_attr`: Edge feature matrix with shape :obj:`[num_edges, num_edge_features]` - :obj:`data.y`: Target to train against (may have arbitrary shape), *e.g.*, node-level targets of shape :obj:`[num_nodes, *]` or graph-level targets of shape :obj:`[1, *]` - :obj:`data.pos`: Node position matrix with shape :obj:`[num_nodes, num_dimensions]` None of these attributes are required. In fact, the :class:`~torch_geometric.data.Data` object is not even restricted to these attributes. We can, *e.g.*, extend it by :obj:`data.face` to save the connectivity of triangles from a 3D mesh in a tensor with shape :obj:`[3, num_faces]` and type :obj:`torch.long`. .. Note:: :pytorch:`PyTorch` and :obj:`torchvision` define an example as a tuple of an image and a target. We omit this notation in :pyg:`PyG` to allow for various data structures in a clean and understandable way. We show a simple example of an unweighted and undirected graph with three nodes and four edges. Each node contains exactly one feature: .. code-block:: python import torch from torch_geometric.data import Data edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index) >>> Data(edge_index=[2, 4], x=[3, 1]) .. image:: ../_figures/graph.svg :align: center :width: 300px | Note that :obj:`edge_index`, *i.e.* the tensor defining the source and target nodes of all edges, is **not** a list of index tuples. If you want to write your indices this way, you should transpose and call :obj:`contiguous` on it before passing them to the data constructor: .. code-block:: python import torch from torch_geometric.data import Data edge_index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long) x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index.t().contiguous()) >>> Data(edge_index=[2, 4], x=[3, 1]) Although the graph has only two edges, we need to define four index tuples to account for both directions of a edge. .. Note:: You can print out your data object anytime and receive a short information about its attributes and their shapes. Note that it is necessary that the elements in :obj:`edge_index` only hold indices in the range :obj:`{ 0, ..., num_nodes - 1}`. This is needed as we want our final data representation to be as compact as possible, *e.g.*, we want to index the source and destination node features of the first edge :obj:`(0, 1)` via :obj:`x[0]` and :obj:`x[1]`, respectively. You can always check that your final :class:`~torch_geometric.data.Data` objects fulfill these requirements by running :meth:`~torch_geometric.data.Data.validate`: .. code-block:: python data.validate(raise_on_error=True) Besides holding a number of node-level, edge-level or graph-level attributes, :class:`~torch_geometric.data.Data` provides a number of useful utility functions, *e.g.*: .. code-block:: python print(data.keys()) >>> ['x', 'edge_index'] print(data['x']) >>> tensor([[-1.0], [0.0], [1.0]]) for key, item in data: print(f'{key} found in data') >>> x found in data >>> edge_index found in data 'edge_attr' in data >>> False data.num_nodes >>> 3 data.num_edges >>> 4 data.num_node_features >>> 1 data.has_isolated_nodes() >>> False data.has_self_loops() >>> False data.is_directed() >>> False # Transfer data object to GPU. device = torch.device('cuda') data = data.to(device) You can find a complete list of all methods at :class:`torch_geometric.data.Data`. Common Benchmark Datasets ------------------------- :pyg:`PyG` contains a large number of common benchmark datasets, *e.g.*, all Planetoid datasets (Cora, Citeseer, Pubmed), all graph classification datasets from `TUDatasets `_ and their `cleaned versions `_, the QM7 and QM9 dataset, and a handful of 3D mesh/point cloud datasets like FAUST, ModelNet10/40 and ShapeNet. Initializing a dataset is straightforward. An initialization of a dataset will automatically download its raw files and process them to the previously described :class:`~torch_geometric.data.Data` format. *E.g.*, to load the ENZYMES dataset (consisting of 600 graphs within 6 classes), type: .. code-block:: python from torch_geometric.datasets import TUDataset dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') >>> ENZYMES(600) len(dataset) >>> 600 dataset.num_classes >>> 6 dataset.num_node_features >>> 3 We now have access to all 600 graphs in the dataset: .. code-block:: python data = dataset[0] >>> Data(edge_index=[2, 168], x=[37, 3], y=[1]) data.is_undirected() >>> True We can see that the first graph in the dataset contains 37 nodes, each one having 3 features. There are 168/2 = 84 undirected edges and the graph is assigned to exactly one class. In addition, the data object is holding exactly one graph-level target. We can even use slices, long or bool tensors to split the dataset. *E.g.*, to create a 90/10 train/test split, type: .. code-block:: python train_dataset = dataset[:540] >>> ENZYMES(540) test_dataset = dataset[540:] >>> ENZYMES(60) If you are unsure whether the dataset is already shuffled before you split, you can randomly permute it by running: .. code-block:: python dataset = dataset.shuffle() >>> ENZYMES(600) This is equivalent of doing: .. code-block:: python perm = torch.randperm(len(dataset)) dataset = dataset[perm] >> ENZYMES(600) Let's try another one! Let's download Cora, the standard benchmark dataset for semi-supervised graph node classification: .. code-block:: python from torch_geometric.datasets import Planetoid dataset = Planetoid(root='/tmp/Cora', name='Cora') >>> Cora() len(dataset) >>> 1 dataset.num_classes >>> 7 dataset.num_node_features >>> 1433 Here, the dataset contains only a single, undirected citation graph: .. code-block:: python data = dataset[0] >>> Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708]) data.is_undirected() >>> True data.train_mask.sum().item() >>> 140 data.val_mask.sum().item() >>> 500 data.test_mask.sum().item() >>> 1000 This time, the :class:`~torch_geometric.data.Data` objects holds a label for each node, and additional node-level attributes: :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask`, where - :obj:`train_mask` denotes against which nodes to train (140 nodes), - :obj:`val_mask` denotes which nodes to use for validation, *e.g.*, to perform early stopping (500 nodes), - :obj:`test_mask` denotes against which nodes to test (1000 nodes). Mini-batches ------------ Neural networks are usually trained in a batch-wise fashion. :pyg:`PyG` achieves parallelization over a mini-batch by creating sparse block diagonal adjacency matrices (defined by :obj:`edge_index`) and concatenating feature and target matrices in the node dimension. This composition allows differing number of nodes and edges over examples in one batch: .. math:: \mathbf{A} = \begin{bmatrix} \mathbf{A}_1 & & \\ & \ddots & \\ & & \mathbf{A}_n \end{bmatrix}, \qquad \mathbf{X} = \begin{bmatrix} \mathbf{X}_1 \\ \vdots \\ \mathbf{X}_n \end{bmatrix}, \qquad \mathbf{Y} = \begin{bmatrix} \mathbf{Y}_1 \\ \vdots \\ \mathbf{Y}_n \end{bmatrix} :pyg:`PyG` contains its own :class:`torch_geometric.loader.DataLoader`, which already takes care of this concatenation process. Let's learn about it in an example: .. code-block:: python from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: batch >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32]) batch.num_graphs >>> 32 :class:`torch_geometric.data.Batch` inherits from :class:`torch_geometric.data.Data` and contains an additional attribute called :obj:`batch`. :obj:`batch` is a column vector which maps each node to its respective graph in the batch: .. math:: \mathrm{batch} = {\begin{bmatrix} 0 & \cdots & 0 & 1 & \cdots & n - 2 & n -1 & \cdots & n - 1 \end{bmatrix}}^{\top} You can use it to, *e.g.*, average node features in the node dimension for each graph individually: .. code-block:: python from torch_geometric.utils import scatter from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True) for data in loader: data >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32]) data.num_graphs >>> 32 x = scatter(data.x, data.batch, dim=0, reduce='mean') x.size() >>> torch.Size([32, 21]) You can learn more about the internal batching procedure of :pyg:`PyG`, *e.g.*, how to modify its behavior, `here <../advanced/batching.html>`__. For documentation of scatter operations, we refer the interested reader to the :obj:`torch_scatter` `documentation `_. Data Transforms --------------- Transforms are a common way in :obj:`torchvision` to transform images and perform augmentation. :pyg:`PyG` comes with its own transforms, which expect a :class:`~torch_geometric.data.Data` object as input and return a new transformed :class:`~torch_geometric.data.Data` object. Transforms can be chained together using :class:`torch_geometric.transforms.Compose` and are applied before saving a processed dataset on disk (:obj:`pre_transform`) or before accessing a graph in a dataset (:obj:`transform`). Let's look at an example, where we apply transforms on the ShapeNet dataset (containing 17,000 3D shape point clouds and per point labels from 16 shape categories). .. code-block:: python from torch_geometric.datasets import ShapeNet dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane']) dataset[0] >>> Data(pos=[2518, 3], y=[2518]) We can convert the point cloud dataset into a graph dataset by generating nearest neighbor graphs from the point clouds via transforms: .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'], pre_transform=T.KNNGraph(k=6)) dataset[0] >>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518]) .. note:: We use the :obj:`pre_transform` to convert the data before saving it to disk (leading to faster loading times). Note that the next time the dataset is initialized it will already contain graph edges, even if you do not pass any transform. If the :obj:`pre_transform` does not match with the one from the already processed dataset, you will be given a warning. In addition, we can use the :obj:`transform` argument to randomly augment a :class:`~torch_geometric.data.Data` object, *e.g.*, translating each node position by a small number: .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'], pre_transform=T.KNNGraph(k=6), transform=T.RandomJitter(0.01)) dataset[0] >>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518]) You can find a complete list of all implemented transforms at :mod:`torch_geometric.transforms`. Learning Methods on Graphs -------------------------- After learning about data handling, datasets, loader and transforms in :pyg:`PyG`, it's time to implement our first graph neural network! We will use a simple GCN layer and replicate the experiments on the Cora citation dataset. For a high-level explanation on GCN, have a look at its `blog post `_. We first need to load the Cora dataset: .. code-block:: python from torch_geometric.datasets import Planetoid dataset = Planetoid(root='/tmp/Cora', name='Cora') >>> Cora() Note that we do not need to use transforms or a dataloader. Now let's implement a two-layer GCN: .. code-block:: python import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_node_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) The constructor defines two :class:`~torch_geometric.nn.conv.GCNConv` layers which get called in the forward pass of our network. Note that the non-linearity is not integrated in the :obj:`conv` calls and hence needs to be applied afterwards (something which is consistent across all operators in :pyg:`PyG`). Here, we chose to use ReLU as our intermediate non-linearity and finally output a softmax distribution over the number of classes. Let's train this model on the training nodes for 200 epochs: .. code-block:: python device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GCN().to(device) data = dataset[0].to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) model.train() for epoch in range(200): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() Finally, we can evaluate our model on the test nodes: .. code-block:: python model.eval() pred = model(data).argmax(dim=1) correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() acc = int(correct) / int(data.test_mask.sum()) print(f'Accuracy: {acc:.4f}') >>> Accuracy: 0.8150 This is all it takes to implement your first graph neural network. The easiest way to learn more about Graph Neural Networks is to study the examples in the :obj:`examples/` directory and to browse :mod:`torch_geometric.nn`. Happy hacking! Exercises --------- 1. What does :obj:`edge_index.t().contiguous()` do? 2. Load the :obj:`"IMDB-BINARY"` dataset from the :class:`~torch_geometric.datasets.TUDataset` benchmark suite and randomly split it into 80%/10%/10% training, validation and test graphs. 3. What does each number of the following output mean? .. code-block:: python print(batch) >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32]) ================================================ FILE: docs/source/index.rst ================================================ :github_url: https://github.com/pyg-team/pytorch_geometric PyG Documentation ================= :pyg:`null` **PyG** *(PyTorch Geometric)* is a library built upon :pytorch:`null` `PyTorch `_ to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data. It consists of various methods for deep learning on graphs and other irregular structures, also known as `geometric deep learning `_, from a variety of published papers. In addition, it consists of easy-to-use mini-batch loaders for operating on many small and single giant graphs, `multi GPU-support `_, `torch.compile `_ support, `DataPipe `_ support, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds. .. slack_button:: .. toctree:: :maxdepth: 1 :caption: Install PyG install/installation .. toctree:: :maxdepth: 1 :caption: Get Started get_started/introduction get_started/colabs .. toctree:: :maxdepth: 1 :caption: Tutorials tutorial/gnn_design tutorial/dataset tutorial/application tutorial/distributed .. toctree:: :maxdepth: 1 :caption: Advanced Concepts advanced/batching advanced/sparse_tensor advanced/hgam advanced/compile advanced/jit advanced/remote advanced/graphgym advanced/cpu_affinity .. toctree:: :maxdepth: 1 :caption: Package Reference modules/root modules/nn modules/data modules/loader modules/sampler modules/datasets modules/llm modules/transforms modules/utils modules/explain modules/metrics modules/distributed modules/contrib modules/graphgym modules/profile .. toctree:: :maxdepth: 1 :caption: Cheatsheets cheatsheet/gnn_cheatsheet cheatsheet/data_cheatsheet .. toctree:: :maxdepth: 1 :caption: External Resources external/resources ================================================ FILE: docs/source/install/installation.rst ================================================ Installation ============ :pyg:`PyG` is available for :python:`Python 3.10` to :python:`Python 3.14`. .. note:: We do not recommend installation as a root user on your system :python:`Python`. Please setup a virtual environment, *e.g.*, via `venv `_, :conda:`null` `Anaconda/Miniconda `_, or create a `Docker image `_. Quick Start ----------- .. raw:: html :file: quick-start.html Installation via PyPI --------------------- From :pyg:`null` **PyG 2.3** onwards, you can install and use :pyg:`PyG` **without any external library** required except for :pytorch:`PyTorch`. For this, simply run: .. code-block:: none pip install torch_geometric Additional Libraries ~~~~~~~~~~~~~~~~~~~~ If you want to utilize the full set of features from :pyg:`PyG`, there exists several additional libraries you may want to install: * `pyg-lib `__: Heterogeneous GNN operators, graph sampling routines, and :class:`~torch_geometric.nn.conv.SplineConv` support * `torch-scatter `__: Accelerated and efficient sparse reductions * `torch-sparse `__: :class:`SparseTensor` support, see `here `__ * `torch-cluster `__: Graph clustering routines .. note:: ``torch-spline-conv`` is no longer required as a separate package. Its functionality has been migrated to ``pyg-lib>=0.6.0``. These packages come with their own CPU and GPU kernel implementations based on the :pytorch:`null` `PyTorch C++/CUDA/hip(ROCm) extension interface `_. For a basic usage of :pyg:`PyG`, these dependencies are **fully optional**. We recommend to start with a minimal installation, and install additional dependencies once you start to actually need them. Installation from Wheels ~~~~~~~~~~~~~~~~~~~~~~~~ For ease of installation of these extensions, we provide :obj:`pip` wheels for these packages for all major OS, :pytorch:`PyTorch` and CUDA combinations, see `here `__: #. Ensure that at least :pytorch:`PyTorch` 1.13.0 is installed: .. code-block:: none python -c "import torch; print(torch.__version__)" >>> 2.10.0 #. Find the CUDA version :pytorch:`PyTorch` was installed with: .. code-block:: none python -c "import torch; print(torch.version.cuda)" >>> 12.8 #. Install the relevant packages: .. code-block:: none pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html where :obj:`${TORCH}` and :obj:`${CUDA}` should be replaced by the specific :pytorch:`PyTorch` and CUDA versions, respectively: * :pytorch:`PyTorch` 2.10.*: :obj:`${TORCH}=2.10.0` and :obj:`${CUDA}=cpu|cu126|cu128|cu130` * :pytorch:`PyTorch` 2.9.*: :obj:`${TORCH}=2.9.0` and :obj:`${CUDA}=cpu|cu126|cu128|cu130` * :pytorch:`PyTorch` 2.8.*: :obj:`${TORCH}=2.8.0` and :obj:`${CUDA}=cpu|cu126|cu128|cu129` * :pytorch:`PyTorch` 2.7.*: :obj:`${TORCH}=2.7.0` and :obj:`${CUDA}=cpu|cu118|cu126|cu128` * :pytorch:`PyTorch` 2.6.*: :obj:`${TORCH}=2.6.0` and :obj:`${CUDA}=cpu|cu118|cu124|cu126` * :pytorch:`PyTorch` 2.5.*: :obj:`${TORCH}=2.5.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124` * :pytorch:`PyTorch` 2.4.*: :obj:`${TORCH}=2.4.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124` * :pytorch:`PyTorch` 2.3.*: :obj:`${TORCH}=2.3.0` and :obj:`${CUDA}=cpu|cu118|cu121` * :pytorch:`PyTorch` 2.2.*: :obj:`${TORCH}=2.2.0` and :obj:`${CUDA}=cpu|cu118|cu121` * :pytorch:`PyTorch` 2.1.*: :obj:`${TORCH}=2.1.0` and :obj:`${CUDA}=cpu|cu118|cu121` * :pytorch:`PyTorch` 2.0.*: :obj:`${TORCH}=2.0.0` and :obj:`${CUDA}=cpu|cu117|cu118` * :pytorch:`PyTorch` 1.13.*: :obj:`${TORCH}=1.13.0` and :obj:`${CUDA}=cpu|cu116|cu117` For example, for :pytorch:`PyTorch` 2.10.* and CUDA 13.0, type: .. code-block:: none pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.10.0+cu130.html For example, for :pytorch:`PyTorch` 2.9.* and CUDA 12.8, type: .. code-block:: none pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.9.0+cu128.html **Note:** Binaries of older versions are also provided for :pytorch:`PyTorch` 1.4.0, 1.5.0, 1.6.0, 1.7.0/1.7.1, 1.8.0/1.8.1, 1.9.0, 1.10.0/1.10.1/1.10.2, 1.11.0, 1.12.0/1.12.1, 1.13.0/1.13.1, 2.0.0/2.0.1, 2.1.0/2.1.1/2.1.2, 2.2.0/2.2.1/2.2.2, 2.3.0/2.3.1, 2.4.0/2.4.1, 2.5.0/2.5.1, 2.6.0, and 2.7.0/2.7.1 (following the same procedure). **For older versions, you need to explicitly specify the latest supported version number** or install via :obj:`pip install --no-index` in order to prevent a manual installation from source. You can look up the latest supported version number `here `__. **ROCm:** The external `pyg-rocm-build repository `__ provides wheels and detailed instructions on how to install :pyg:`PyG` for ROCm. If you have any questions about it, please open an issue `here `__. Installation from Source ~~~~~~~~~~~~~~~~~~~~~~~~ In case a specific version is not supported by `our wheels `_, you can alternatively install them from source: #. Ensure that your CUDA is setup correctly (optional): #. Check if :pytorch:`PyTorch` is installed with CUDA support: .. code-block:: none python -c "import torch; print(torch.cuda.is_available())" >>> True #. Add CUDA to :obj:`$PATH` and :obj:`$CPATH` (note that your actual CUDA path may vary from :obj:`/usr/local/cuda`): .. code-block:: none export PATH=/usr/local/cuda/bin:$PATH echo $PATH >>> /usr/local/cuda/bin:... export CPATH=/usr/local/cuda/include:$CPATH echo $CPATH >>> /usr/local/cuda/include:... #. Add CUDA to :obj:`$LD_LIBRARY_PATH` on Linux and to :obj:`$DYLD_LIBRARY_PATH` on macOS (note that your actual CUDA path may vary from :obj:`/usr/local/cuda`): .. code-block:: none export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH echo $LD_LIBRARY_PATH >>> /usr/local/cuda/lib64:... export DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH echo $DYLD_LIBRARY_PATH >>> /usr/local/cuda/lib:... #. Verify that :obj:`nvcc` is accessible from terminal: .. code-block:: none nvcc --version >>> 11.8 #. Ensure that :pytorch:`PyTorch` and system CUDA versions match: .. code-block:: none python -c "import torch; print(torch.version.cuda)" >>> 11.8 nvcc --version >>> 11.8 #. Install the relevant packages: .. code-block:: none pip install --verbose git+https://github.com/pyg-team/pyg-lib.git pip install --verbose torch_scatter pip install --verbose torch_sparse pip install --verbose torch_cluster In rare cases, CUDA or :python:`Python` path problems can prevent a successful installation. :obj:`pip` may even signal a successful installation, but execution simply crashes with :obj:`Segmentation fault (core dumped)`. We collected common installation errors in the `Frequently Asked Questions `__ subsection. In case the FAQ does not help you in solving your problem, please create an `issue `_. Before, please verify that your CUDA is set up correctly by following the official `installation guide `_. Installation via Anaconda ------------------------- .. warning:: Conda packages are no longer available since :pytorch:`PyTorch` :obj:`>2.5.0`. Please use :obj:`pip` instead. For earlier :pytorch:`PyTorch` versions (:obj:`torch<=2.5.0`), you can install :pyg:`PyG` via :conda:`null` `Anaconda `_ for all major OS, and CUDA combinations. If you have not yet installed :pytorch:`PyTorch`, install it via :conda:`null` :obj:`conda install` as described in its `official documentation `_. Given that you have :pytorch:`PyTorch` installed, run .. code-block:: none conda install pyg -c pyg If :conda:`null` :obj:`conda` does not pick up the correct CUDA version of :pyg:`PyG`, you can enforce it as follows: .. code-block:: none conda install pyg=*=*cu* -c pyg .. _install-cugraph: Accelerating PyG with NVIDIA cuGraph GNN ---------------------------------------- :pyg:`PyG` can optionally leverage NVIDIA's `cuGraph `_ to accelerate neighbor sampling and achieve better scalability for multi-GPU training on large-scale graphs (2x-8x data loading speedups on billion-edge graphs). NVIDIA currently recommends the `NVIDIA PyG Container `_ on NGC as the most reliable way to use cuGraph integration with :pyg:`PyG`. For other installation methods, refer to the `cuGraph GNN repository `_ and/or the `RAPIDS installation guide `_. .. note:: **cuGraph GNN is optional** — all :pyg:`PyG` functionality, including multi-GPU training, works without it. However, for users with NVIDIA GPUs, cuGraph can provide significant speedups and better scalability for neighbor sampling and data loading, especially on large-scale graphs. `cuGraph `_ is a collection of packages focused on GPU-accelerated graph analytics including support for property graphs and scaling up to thousands of GPUs. cuGraph supports the creation and manipulation of graphs followed by the execution of scalable fast graph algorithms. It is part of the `RAPIDS `_ accelerated data science framework. `cuGraph GNN `_ is a collection of GPU-accelerated plugins that support :pytorch:`PyTorch` and :pyg:`PyG` natively through the *cuGraph-PyG* and *WholeGraph* subprojects. cuGraph GNN is built on top of cuGraph, leveraging its low-level `pylibcugraph `_ API and C++ primitives for sampling and other GNN operations (`libcugraph `_). It also includes the :obj:`libwholegraph` and :obj:`pylibwholegraph` libraries for high-performance distributed edgelist and embedding storage. Users have the option of working with these lower-level libraries directly, or through the higher-level API in cuGraph-PyG that directly implements the :class:`~torch_geometric.data.GraphStore`, :class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.loader.NodeLoader`, and :class:`~torch_geometric.loader.LinkLoader` interfaces. Complete documentation on RAPIDS graph packages, including ``cugraph``, ``cugraph-pyg``, ``pylibwholegraph``, and ``pylibcugraph`` is available on the `RAPIDS docs pages `_. See `rapidsai/cugraph-gnn examples on GitHub `_ for fully scalable PyG example workflows. Frequently Asked Questions -------------------------- #. :obj:`undefined symbol: **make_function_schema**`: This issue signals (1) a **version conflict** between your installed :pytorch:`PyTorch` version and the :obj:`${TORCH}` version specified to install the extension packages, or (2) a version conflict between the installed CUDA version of :pytorch:`PyTorch` and the :obj:`${CUDA}` version specified to install the extension packages. Please verify that your :pytorch:`PyTorch` version and its CUDA version **match** with your installation command: .. code-block:: none python -c "import torch; print(torch.__version__)" python -c "import torch; print(torch.version.cuda)" nvcc --version For re-installation, ensure that you do not run into any caching issues by using the :obj:`pip --force-reinstall --no-cache-dir` flags. In addition, the :obj:`pip --verbose` option may help to track down any issues during installation. If you still do not find any success in installation, please try to install the extension packages `from source `__. ================================================ FILE: docs/source/install/quick-start.html ================================================
PyTorch
Your OS
Package
CUDA
Run:
================================================ FILE: docs/source/modules/contrib.rst ================================================ torch_geometric.contrib ======================= .. currentmodule:: torch_geometric.contrib :obj:`torch_geometric.contrib` is a staging area for early stage experimental code. Modules might be moved to the main library in the future. .. warning:: This module contains experimental code, which is not guaranteed to be stable. .. contents:: Contents :local: Convolutional Layers -------------------- .. currentmodule:: torch_geometric.contrib.nn.conv .. autosummary:: :nosignatures: {% for cls in torch_geometric.contrib.nn.conv.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.contrib.nn.conv :members: :undoc-members: :exclude-members: message, aggregate, message_and_aggregate, update, MessagePassing, training, initialize_parameters Models ------ .. currentmodule:: torch_geometric.contrib.nn.models .. autosummary:: :nosignatures: {% for cls in torch_geometric.contrib.nn.models.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.contrib.nn.models :members: :undoc-members: :exclude-members: message, aggregate, message_and_aggregate, update, MessagePassing, training, init_conv Datasets -------- .. currentmodule:: torch_geometric.contrib.datasets .. autosummary:: :nosignatures: {% for cls in torch_geometric.contrib.datasets.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.contrib.datasets :members: :exclude-members: download, process, processed_file_names, raw_file_names, num_classes, get Transforms ---------- .. currentmodule:: torch_geometric.contrib.transforms .. autosummary:: :nosignatures: {% for cls in torch_geometric.contrib.transforms.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.contrib.transforms :members: Explainer --------- .. currentmodule:: torch_geometric.contrib.explain .. autosummary:: :nosignatures: {% for cls in torch_geometric.contrib.explain.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.contrib.explain :members: ================================================ FILE: docs/source/modules/data.rst ================================================ torch_geometric.data ==================== .. contents:: Contents :local: Data Objects ------------ .. currentmodule:: torch_geometric.data .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/inherited_class.rst {% for name in torch_geometric.data.data_classes %} {{ name }} {% endfor %} Remote Backend Interfaces ------------------------- .. currentmodule:: torch_geometric.data .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.data.remote_backend_classes %} {{ name }} {% endfor %} Databases --------- .. currentmodule:: torch_geometric.data .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/inherited_class.rst {% for name in torch_geometric.data.database_classes %} {{ name }} {% endfor %} PyTorch Lightning Wrappers -------------------------- .. currentmodule:: torch_geometric.data.lightning .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.data.lightning.classes %} {{ name }} {% endfor %} Helper Functions ---------------- .. currentmodule:: torch_geometric.data .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.data.helper_functions %} {{ name }} {% endfor %} ================================================ FILE: docs/source/modules/datasets.rst ================================================ torch_geometric.datasets ======================== .. contents:: Contents :local: Homogeneous Datasets -------------------- .. currentmodule:: torch_geometric.datasets .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.datasets.homo_datasets %} {{ name }} {% endfor %} Heterogeneous Datasets ---------------------- .. currentmodule:: torch_geometric.datasets .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.datasets.hetero_datasets %} {{ name }} {% endfor %} Hypergraph Datasets ------------------- .. currentmodule:: torch_geometric.datasets .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.datasets.hyper_datasets %} {{ name }} {% endfor %} Synthetic Datasets ------------------ .. currentmodule:: torch_geometric.datasets .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.datasets.synthetic_datasets %} {{ name }} {% endfor %} Graph Generators ---------------- .. currentmodule:: torch_geometric.datasets.graph_generator .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.datasets.graph_generator.classes %} {{ name }} {% endfor %} Motif Generators ---------------- .. currentmodule:: torch_geometric.datasets.motif_generator .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.datasets.motif_generator.classes %} {{ name }} {% endfor %} ================================================ FILE: docs/source/modules/distributed.rst ================================================ torch_geometric.distributed =========================== .. warning:: ``torch_geometric.distributed`` has been deprecated since 2.7.0 and will no longer be maintained. For distributed training, refer to :ref:`our tutorials on distributed training ` or `cuGraph examples `_. .. currentmodule:: torch_geometric.distributed .. autosummary:: :nosignatures: {% for cls in torch_geometric.distributed.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.distributed :members: ================================================ FILE: docs/source/modules/explain.rst ================================================ torch_geometric.explain ======================= .. currentmodule:: torch_geometric.explain .. warning:: This module is in active development and may not be stable. Access requires installing :pyg:`PyG` from master. .. contents:: Contents :local: Philosophy ---------- This module provides a set of tools to explain the predictions of a PyG model or to explain the underlying phenomenon of a dataset (see the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" `_ paper for more details). We represent explanations using the :class:`torch_geometric.explain.Explanation` class, which is a :class:`~torch_geometric.data.Data` object containing masks for the nodes, edges, features and any attributes of the data. The :class:`torch_geometric.explain.Explainer` class is designed to handle all explainability parameters (see the :class:`torch_geometric.explain.config.ExplainerConfig` class for more details): - which algorithm from the :class:`torch_geometric.explain.algorithm` module to use (*e.g.*, :class:`~torch_geometric.explain.algorithm.GNNExplainer`) - the type of explanation to compute (*e.g.*, :obj:`explanation_type="phenomenon"` or :obj:`explanation_type="model"`) - the different type of masks for node and edges (*e.g.*, :obj:`mask="object"` or :obj:`mask="attributes"`) - any postprocessing of the masks (*e.g.*, :obj:`threshold_type="topk"` or :obj:`threshold_type="hard"`) This class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same. Explainer --------- .. autoclass:: torch_geometric.explain.Explainer :show-inheritance: :members: :special-members: __call__ .. autoclass:: torch_geometric.explain.config.ExplainerConfig :members: .. autoclass:: torch_geometric.explain.config.ModelConfig :members: .. autoclass:: torch_geometric.explain.config.ThresholdConfig :members: Explanations ------------ .. autoclass:: torch_geometric.explain.Explanation :show-inheritance: :members: .. autoclass:: torch_geometric.explain.HeteroExplanation :show-inheritance: :members: Explainer Algorithms -------------------- .. currentmodule:: torch_geometric.explain.algorithm .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.explain.algorithm.classes %} {{ name }} {% endfor %} Explanation Metrics ------------------- The quality of an explanation can be judged by a variety of different methods. PyG supports the following metrics out-of-the-box: .. currentmodule:: torch_geometric.explain.metric .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.explain.metric.classes %} {{ name }} {% endfor %} ================================================ FILE: docs/source/modules/graphgym.rst ================================================ torch_geometric.graphgym ======================== .. contents:: Contents :local: Workflow and Register Modules ----------------------------- .. currentmodule:: torch_geometric.graphgym .. autosummary:: :nosignatures: {% for cls in torch_geometric.graphgym.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.graphgym :members: :exclude-members: Model Modules ------------- .. currentmodule:: torch_geometric.graphgym.models .. autosummary:: :nosignatures: {% for cls in torch_geometric.graphgym.models.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.graphgym.models :members: :exclude-members: forward Utility Modules --------------- .. currentmodule:: torch_geometric.graphgym.utils .. autosummary:: :nosignatures: {% for cls in torch_geometric.graphgym.utils.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.graphgym.utils :members: :exclude-members: ================================================ FILE: docs/source/modules/llm.rst ================================================ torch_geometric.llm ======================= .. currentmodule:: torch_geometric.llm .. autosummary:: :nosignatures: {% for cls in torch_geometric.llm.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.llm :members: Models ---------------- .. currentmodule:: torch_geometric.llm.models .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.llm.models.classes %} {{ name }} {% endfor %} Utils ---------------- .. currentmodule:: torch_geometric.llm.utils .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.llm.utils.classes %} {{ name }} {% endfor %} ================================================ FILE: docs/source/modules/loader.rst ================================================ torch_geometric.loader ====================== .. currentmodule:: torch_geometric.loader .. autosummary:: :nosignatures: {% for cls in torch_geometric.loader.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.loader :members: ================================================ FILE: docs/source/modules/metrics.rst ================================================ torch_geometric.metrics ======================= .. contents:: Contents :local: Link Prediction Metrics ----------------------- .. currentmodule:: torch_geometric.metrics .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/metrics.rst {% for name in torch_geometric.metrics.link_pred_metrics %} {{ name }} {% endfor %} ================================================ FILE: docs/source/modules/nn.rst ================================================ torch_geometric.nn ================== .. contents:: Contents :local: .. autoclass:: torch_geometric.nn.sequential.Sequential .. currentmodule:: torch_geometric.nn.dense {% for name in torch_geometric.nn.dense.lin_classes %} .. autoclass:: {{ name }} :members: {% endfor %} Convolutional Layers -------------------- .. currentmodule:: torch_geometric.nn.conv .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/nn.rst {% for name in torch_geometric.nn.conv.classes %} {{ name }} {% endfor %} Aggregation Operators --------------------- .. currentmodule:: torch_geometric.nn.aggr Aggregation functions play an important role in the message passing framework and the readout functions of Graph Neural Networks. Specifically, many works in the literature (`Hamilton et al. (2017) `__, `Xu et al. (2018) `__, `Corso et al. (2020) `__, `Li et al. (2020) `__, `Tailor et al. (2021) `__) demonstrate that the choice of aggregation functions contributes significantly to the representational power and performance of the model. For example, **mean aggregation** captures the distribution (or proportions) of elements, **max aggregation** proves to be advantageous to identify representative elements, and **sum aggregation** enables the learning of structural graph properties (`Xu et al. (2018) `__). Recent works also show that using **multiple aggregations** (`Corso et al. (2020) `__, `Tailor et al. (2021) `__) and **learnable aggregations** (`Li et al. (2020) `__) can potentially provide substantial improvements. Another line of research studies optimization-based and implicitly-defined aggregations (`Bartunov et al. (2022) `__). Furthermore, an interesting discussion concerns the trade-off between representational power (usually gained through learnable functions implemented as neural networks) and the formal property of permutation invariance (`Buterez et al. (2022) `__). To facilitate further experimentation and unify the concepts of aggregation within GNNs across both :class:`~torch_geometric.nn.conv.MessagePassing` and global readouts, we have made the concept of :class:`~torch_geometric.nn.aggr.Aggregation` a first-class principle in :pyg:`PyG`. As of now, :pyg:`PyG` provides support for various aggregations --- from rather simple ones (*e.g.*, :obj:`mean`, :obj:`max`, :obj:`sum`), to advanced ones (*e.g.*, :obj:`median`, :obj:`var`, :obj:`std`), learnable ones (*e.g.*, :class:`~torch_geometric.nn.aggr.SoftmaxAggregation`, :class:`~torch_geometric.nn.aggr.PowerMeanAggregation`, :class:`~torch_geometric.nn.aggr.SetTransformerAggregation`), and exotic ones (*e.g.*, :class:`~torch_geometric.nn.aggr.MLPAggregation`, :class:`~torch_geometric.nn.aggr.LSTMAggregation`, :class:`~torch_geometric.nn.aggr.SortAggregation`, :class:`~torch_geometric.nn.aggr.EquilibriumAggregation`): .. code-block:: python from torch_geometric.nn import aggr # Simple aggregations: mean_aggr = aggr.MeanAggregation() max_aggr = aggr.MaxAggregation() # Advanced aggregations: median_aggr = aggr.MedianAggregation() # Learnable aggregations: softmax_aggr = aggr.SoftmaxAggregation(learn=True) powermean_aggr = aggr.PowerMeanAggregation(learn=True) # Exotic aggregations: lstm_aggr = aggr.LSTMAggregation(in_channels=..., out_channels=...) sort_aggr = aggr.SortAggregation(k=4) We can then easily apply these aggregations over a batch of sets of potentially varying size. For this, an :obj:`index` vector defines the mapping from input elements to their location in the output: .. code-block:: python # Feature matrix holding 1000 elements with 64 features each: x = torch.randn(1000, 64) # Randomly assign elements to 100 sets: index = torch.randint(0, 100, (1000, )) output = mean_aggr(x, index) # Output shape: [100, 64] Notably, all aggregations share the same set of forward arguments, as described in detail in the :class:`torch_geometric.nn.aggr.Aggregation` base class. Each of the provided aggregations can be used within :class:`~torch_geometric.nn.conv.MessagePassing` as well as for hierarchical/global pooling to obtain graph-level representations: .. code-block:: python import torch from torch_geometric.nn import MessagePassing class MyConv(MessagePassing): def __init__(self, ...): # Use a learnable softmax neighborhood aggregation: super().__init__(aggr=aggr.SoftmaxAggregation(learn=True)) def forward(self, x, edge_index): .... class MyGNN(torch.nn.Module) def __init__(self, ...): super().__init__() self.conv = MyConv(...) # Use a global sort aggregation: self.global_pool = aggr.SortAggregation(k=4) self.classifier = torch.nn.Linear(...) def forward(self, x, edge_index, batch): x = self.conv(x, edge_index).relu() x = self.global_pool(x, batch) x = self.classifier(x) return x In addition, the aggregation package of :pyg:`PyG` introduces two new concepts: First, aggregations can be **resolved from pure strings** via a lookup table, following the design principles of the `class-resolver `__ library, *e.g.*, by simply passing in :obj:`"median"` to the :class:`~torch_geometric.nn.conv.MessagePassing` module. This will automatically resolve to the :obj:`~torch_geometric.nn.aggr.MedianAggregation` class: .. code-block:: python class MyConv(MessagePassing): def __init__(self, ...): super().__init__(aggr="median") Secondly, **multiple aggregations** can be combined and stacked via the :class:`~torch_geometric.nn.aggr.MultiAggregation` module in order to enhance the representational power of GNNs (`Corso et al. (2020) `__, `Tailor et al. (2021) `__): .. code-block:: python class MyConv(MessagePassing): def __init__(self, ...): # Combines a set of aggregations and concatenates their results, # i.e. its output will be `[num_nodes, 3 * out_channels]` here. # Note that the interface also supports automatic resolution. super().__init__(aggr=aggr.MultiAggregation( ['mean', 'std', aggr.SoftmaxAggregation(learn=True)])) Importantly, :class:`~torch_geometric.nn.aggr.MultiAggregation` provides various options to combine the outputs of its underlying aggregations (*e.g.*, using concatenation, summation, attention, ...) via its :obj:`mode` argument. The default :obj:`mode` performs concatenation (:obj:`"cat"`). For combining via attention, we need to additionally specify the :obj:`in_channels` :obj:`out_channels`, and :obj:`num_heads`: .. code-block:: python multi_aggr = aggr.MultiAggregation( aggrs=['mean', 'std'], mode='attn', mode_kwargs=dict(in_channels=64, out_channels=64, num_heads=4), ) If aggregations are given as a list, they will be automatically resolved to a :class:`~torch_geometric.nn.aggr.MultiAggregation`, *e.g.*, :obj:`aggr=['mean', 'std', 'median']`. Finally, we added full support for customization of aggregations into the :class:`~torch_geometric.nn.conv.SAGEConv` layer --- simply override its :obj:`aggr` argument and **utilize the power of aggregation within your GNN**. .. note:: You can read more about the :class:`torch_geometric.nn.aggr` package in this `blog post `__. .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.aggr.classes %} {{ name }} {% endfor %} Attention --------- .. currentmodule:: torch_geometric.nn.attention .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/nn.rst {% for name in torch_geometric.nn.attention.classes %} {{ name }} {% endfor %} Normalization Layers -------------------- .. currentmodule:: torch_geometric.nn.norm .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.norm.classes %} {{ name }} {% endfor %} Pooling Layers -------------- .. currentmodule:: torch_geometric.nn.pool .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.pool.classes %} {{ name }} {% endfor %} Unpooling Layers ---------------- .. currentmodule:: torch_geometric.nn.unpool .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.unpool.classes %} {{ name }} {% endfor %} Models ------ .. currentmodule:: torch_geometric.nn.models .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/nn.rst {% for name in torch_geometric.nn.models.classes %} {{ name }} {% endfor %} KGE Models ---------- .. currentmodule:: torch_geometric.nn.kge .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.kge.classes %} {{ name }} {% endfor %} Encodings --------- .. currentmodule:: torch_geometric.nn.encoding .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.encoding.classes %} {{ name }} {% endfor %} Functional ---------- .. py:currentmodule:: torch_geometric.nn.functional .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.functional.classes %} {{ name }} {% endfor %} Dense Convolutional Layers -------------------------- .. currentmodule:: torch_geometric.nn.dense .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.dense.conv_classes %} {{ name }} {% endfor %} Dense Pooling Layers -------------------- .. currentmodule:: torch_geometric.nn.dense .. autosummary:: :nosignatures: :toctree: ../generated {% for name in torch_geometric.nn.dense.pool_classes %} {{ name }} {% endfor %} Model Transformations --------------------- .. autoclass:: torch_geometric.nn.fx.Transformer :members: :undoc-members: :exclude-members: graph, find_by_target, find_by_name .. autofunction:: torch_geometric.nn.to_hetero_transformer.to_hetero .. autofunction:: torch_geometric.nn.to_hetero_with_bases_transformer.to_hetero_with_bases DataParallel Layers ------------------- .. warning:: :class:`~torch_geometric.nn.data_parallel.DataParallel` is deprecated. Please use :class:`torch.nn.parallel.DistributedDataParallel` instead. .. automodule:: torch_geometric.nn.data_parallel :members: Model Hub --------- .. automodule:: torch_geometric.nn.model_hub :members: Model Summary ------------- .. automodule:: torch_geometric.nn.summary :members: ================================================ FILE: docs/source/modules/profile.rst ================================================ torch_geometric.profile ======================= .. currentmodule:: torch_geometric.profile .. autosummary:: :nosignatures: {% for cls in torch_geometric.profile.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.profile :members: :undoc-members: ================================================ FILE: docs/source/modules/root.rst ================================================ torch_geometric =============== Tensor Objects -------------- .. currentmodule:: torch_geometric .. autosummary:: :nosignatures: :toctree: ../generated Index EdgeIndex HashTensor Functions --------- .. automodule:: torch_geometric.seed :members: .. automodule:: torch_geometric.home :members: .. automodule:: torch_geometric._compile :members: :exclude-members: compile .. automodule:: torch_geometric.debug :members: .. automodule:: torch_geometric.experimental :members: ================================================ FILE: docs/source/modules/sampler.rst ================================================ torch_geometric.sampler ======================= .. currentmodule:: torch_geometric.sampler .. autosummary:: :nosignatures: {% for cls in torch_geometric.sampler.classes %} {{ cls }} {% endfor %} .. autoclass:: torch_geometric.sampler.BaseSampler :members: .. automodule:: torch_geometric.sampler :members: :exclude-members: sample_from_nodes, sample_from_edges, edge_permutation, BaseSampler ================================================ FILE: docs/source/modules/transforms.rst ================================================ torch_geometric.transforms ========================== .. contents:: Contents :local: Transforms are a general way to modify and customize :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects, either by implicitly passing them as an argument to a :class:`~torch_geometric.data.Dataset`, or by applying them explicitly to individual :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects: .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import TUDataset transform = T.Compose([T.ToUndirected(), T.AddSelfLoops()]) dataset = TUDataset(path, name='MUTAG', transform=transform) data = dataset[0] # Implicitly transform data on every access. data = TUDataset(path, name='MUTAG')[0] data = transform(data) # Explicitly transform data. General Transforms ------------------ .. currentmodule:: torch_geometric.transforms .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.transforms.general_transforms %} {{ name }} {% endfor %} Graph Transforms ---------------- .. currentmodule:: torch_geometric.transforms .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.transforms.graph_transforms %} {{ name }} {% endfor %} Vision Transforms ----------------- .. currentmodule:: torch_geometric.transforms .. autosummary:: :nosignatures: :toctree: ../generated :template: autosummary/only_class.rst {% for name in torch_geometric.transforms.vision_transforms %} {{ name }} {% endfor %} ================================================ FILE: docs/source/modules/utils.rst ================================================ torch_geometric.utils ===================== .. currentmodule:: torch_geometric.utils .. autosummary:: :nosignatures: {% for cls in torch_geometric.utils.classes %} {{ cls }} {% endfor %} .. automodule:: torch_geometric.utils :members: ================================================ FILE: docs/source/notes/batching.rst ================================================ :orphan: .. include:: ../advanced/batching.rst ================================================ FILE: docs/source/notes/cheatsheet.rst ================================================ :orphan: GNN Cheatsheet ============== * :class:`~torch_sparse.SparseTensor`: If checked (✓), supports message passing based on :class:`torch_sparse.SparseTensor`, *e.g.*, :obj:`GCNConv(...).forward(x, adj_t)`. See `here <../advanced/sparse_tensor.html>`__ for the accompanying tutorial. * :obj:`edge_weight`: If checked (✓), supports message passing with one-dimensional edge weight information, *e.g.*, :obj:`GraphConv(...).forward(x, edge_index, edge_weight)`. * :obj:`edge_attr`: If checked (✓), supports message passing with multi-dimensional edge feature information, *e.g.*, :obj:`GINEConv(...).forward(x, edge_index, edge_attr)`. * **bipartite**: If checked (✓), supports message passing in bipartite graphs with potentially different feature dimensionalities for source and destination nodes, *e.g.*, :obj:`SAGEConv(in_channels=(16, 32), out_channels=64)`. * **static**: If checked (✓), supports message passing in static graphs, *e.g.*, :obj:`GCNConv(...).forward(x, edge_index)` with :obj:`x` having shape :obj:`[batch_size, num_nodes, in_channels]`. * **lazy**: If checked (✓), supports lazy initialization of message passing layers, *e.g.*, :obj:`SAGEConv(in_channels=-1, out_channels=64)`. Graph Neural Network Operators ------------------------------ .. list-table:: :widths: 40 10 10 10 10 10 10 :header-rows: 1 * - Name - :class:`~torch_sparse.SparseTensor` - :obj:`edge_weight` - :obj:`edge_attr` - bipartite - static - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if not torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) and not torch_geometric.nn.conv.utils.processes_hypergraphs(cls) and not torch_geometric.nn.conv.utils.processes_point_clouds(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__) - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} Heterogeneous Graph Neural Network Operators -------------------------------------------- .. list-table:: :widths: 40 10 10 10 10 10 10 :header-rows: 1 * - Name - :class:`~torch_sparse.SparseTensor` - :obj:`edge_weight` - :obj:`edge_attr` - bipartite - static - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__) - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} Hypergraph Neural Network Operators ----------------------------------- .. list-table:: :widths: 40 10 10 10 10 10 10 :header-rows: 1 * - Name - :class:`~torch_sparse.SparseTensor` - :obj:`edge_weight` - :obj:`edge_attr` - bipartite - static - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if torch_geometric.nn.conv.utils.processes_hypergraphs(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__) - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} Point Cloud Neural Network Operators ------------------------------------ .. list-table:: :widths: 80 10 10 :header-rows: 1 * - Name - bipartite - lazy {% for cls in torch_geometric.nn.conv.classes[1:] %} {% if torch_geometric.nn.conv.utils.processes_point_clouds(cls) %} * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__) - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %} - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %} {% endif %} {% endfor %} ================================================ FILE: docs/source/notes/colabs.rst ================================================ :orphan: .. include:: ../get_started/colabs.rst ================================================ FILE: docs/source/notes/create_dataset.rst ================================================ :orphan: .. include:: ../tutorial/create_dataset.rst ================================================ FILE: docs/source/notes/create_gnn.rst ================================================ :orphan: .. include:: ../tutorial/create_gnn.rst ================================================ FILE: docs/source/notes/data_cheatsheet.rst ================================================ :orphan: Dataset Cheatsheet ================== .. note:: This dataset statistics table is a **work in progress**. Please consider helping us filling its content by providing statistics for individual datasets. See `here `__ and `here `__ for examples on how to do so. Homogeneous Datasets -------------------- .. list-table:: :widths: 50 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes/#tasks {% for cls in torch_geometric.datasets.homo_datasets %} * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %} - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }} {% for child in torch_geometric.datasets.utils.get_children(cls) %} * - └─ {{ child }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }} {% endfor %} {% endfor %} Heterogeneous Datasets ---------------------- .. list-table:: :widths: 50 30 10 10 :header-rows: 1 * - Name - #nodes/#edges - #features - #classes/#tasks {% for cls in torch_geometric.datasets.hetero_datasets %} * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %} - - - {% for child in torch_geometric.datasets.utils.get_children(cls) %} * - └─ **{{torch_geometric.datasets.utils.get_type(child)}} Type**: {{ child }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes/#edges', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }} {% endfor %} {% endfor %} Synthetic Datasets ------------------ .. list-table:: :widths: 50 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes/#tasks {% for cls in torch_geometric.datasets.synthetic_datasets %} * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %} - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }} {% for child in torch_geometric.datasets.utils.get_children(cls) %} * - └─ {{ child }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }} - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }} {% endfor %} {% endfor %} ================================================ FILE: docs/source/notes/explain.rst ================================================ :orphan: .. include:: ../tutorial/explain.rst ================================================ FILE: docs/source/notes/graphgym.rst ================================================ :orphan: .. include:: ../advanced/graphgym.rst ================================================ FILE: docs/source/notes/heterogeneous.rst ================================================ :orphan: .. include:: ../tutorial/heterogeneous.rst ================================================ FILE: docs/source/notes/installation.rst ================================================ :orphan: .. meta:: :http-equiv=refresh: 0; URL=../install/installation.html Installation ============ This page has moved to :doc:`/install/installation`. ================================================ FILE: docs/source/notes/introduction.rst ================================================ :orphan: .. include:: ../get_started/introduction.rst ================================================ FILE: docs/source/notes/jit.rst ================================================ :orphan: .. include:: ../advanced/jit.rst ================================================ FILE: docs/source/notes/load_csv.rst ================================================ :orphan: .. include:: ../tutorial/load_csv.rst ================================================ FILE: docs/source/notes/remote.rst ================================================ :orphan: .. include:: ../advanced/remote.rst ================================================ FILE: docs/source/notes/resources.rst ================================================ :orphan: .. include:: ../external/resources.rst ================================================ FILE: docs/source/notes/sparse_tensor.rst ================================================ :orphan: .. include:: ../advanced/sparse_tensor.rst ================================================ FILE: docs/source/tutorial/application.rst ================================================ Use-Cases & Applications ======================== .. nbgallery:: :name: rst-gallery neighbor_loader point_cloud explain shallow_node_embeddings graph_transformer ================================================ FILE: docs/source/tutorial/compile.rst ================================================ :orphan: .. include:: ../advanced/compile.rst ================================================ FILE: docs/source/tutorial/create_dataset.rst ================================================ Creating Graph Datasets ======================= Although :pyg:`PyG` already contains a lot of useful datasets, you may wish to create your own dataset with self-recorded or non-publicly available data. Implementing datasets by yourself is straightforward and you may want to take a look at the source code to find out how the various datasets are implemented. However, we give a brief introduction on what is needed to setup your own dataset. We provide two abstract classes for datasets: :class:`torch_geometric.data.Dataset` and :class:`torch_geometric.data.InMemoryDataset`. :class:`torch_geometric.data.InMemoryDataset` inherits from :class:`torch_geometric.data.Dataset` and should be used if the whole dataset fits into CPU memory. Following the :obj:`torchvision` convention, each dataset gets passed a root folder which indicates where the dataset should be stored. We split up the root folder into two folders: the :obj:`raw_dir`, where the dataset gets downloaded to, and the :obj:`processed_dir`, where the processed dataset is being saved. In addition, each dataset can be passed a :obj:`transform`, a :obj:`pre_transform` and a :obj:`pre_filter` function, which are :obj:`None` by default. The :obj:`transform` function dynamically transforms the data object before accessing (so it is best used for data augmentation). The :obj:`pre_transform` function applies the transformation before saving the data objects to disk (so it is best used for heavy precomputation which needs to be only done once). The :obj:`pre_filter` function can manually filter out data objects before saving. Use cases may involve the restriction of data objects being of a specific class. Creating "In Memory Datasets" ----------------------------- In order to create a :class:`torch_geometric.data.InMemoryDataset`, you need to implement four fundamental methods: * :func:`torch_geometric.data.InMemoryDataset.raw_file_names`: A list of files in the :obj:`raw_dir` which needs to be found in order to skip the download. * :func:`torch_geometric.data.InMemoryDataset.processed_file_names`: A list of files in the :obj:`processed_dir` which needs to be found in order to skip the processing. * :func:`torch_geometric.data.InMemoryDataset.download`: Downloads raw data into :obj:`raw_dir`. * :func:`torch_geometric.data.InMemoryDataset.process`: Processes raw data and saves it into the :obj:`processed_dir`. You can find helpful methods to download and extract data in :mod:`torch_geometric.data`. The real magic happens in the body of :meth:`~torch_geometric.data.InMemoryDataset.process`. Here, we need to read and create a list of :class:`~torch_geometric.data.Data` objects and save it into the :obj:`processed_dir`. Because saving a huge python list is quite slow, we collate the list into one huge :class:`~torch_geometric.data.Data` object via :meth:`torch_geometric.data.InMemoryDataset.collate` before saving. The collated data object concatenates all examples into one big data object and, in addition, returns a :obj:`slices` dictionary to reconstruct single examples from this object. Finally, we need to load these two objects in the constructor into the properties :obj:`self.data` and :obj:`self.slices`. .. note:: From :pyg:`null` **PyG >= 2.4**, the functionalities of :meth:`torch.save` and :meth:`torch_geometric.data.InMemoryDataset.collate` are unified and implemented behind :meth:`torch_geometric.data.InMemoryDataset.save`. Additionally, :obj:`self.data` and :obj:`self.slices` are implicitly loaded via :meth:`torch_geometric.data.InMemoryDataset.load`. Let's see this process in a simplified example: .. code-block:: python import torch from torch_geometric.data import InMemoryDataset, download_url class MyOwnDataset(InMemoryDataset): def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): super().__init__(root, transform, pre_transform, pre_filter) self.load(self.processed_paths[0]) # For PyG<2.4: # self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_file_names(self): return ['some_file_1', 'some_file_2', ...] @property def processed_file_names(self): return ['data.pt'] def download(self): # Download to `self.raw_dir`. download_url(url, self.raw_dir) ... def process(self): # Read data into huge `Data` list. data_list = [...] if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] self.save(data_list, self.processed_paths[0]) # For PyG<2.4: # torch.save(self.collate(data_list), self.processed_paths[0]) Creating "Larger" Datasets -------------------------- For creating datasets which do not fit into memory, the :class:`torch_geometric.data.Dataset` can be used, which closely follows the concepts of the :obj:`torchvision` datasets. It expects the following methods to be implemented in addition: * :func:`torch_geometric.data.Dataset.len`: Returns the number of examples in your dataset. * :func:`torch_geometric.data.Dataset.get`: Implements the logic to load a single graph. Internally, :meth:`torch_geometric.data.Dataset.__getitem__` gets data objects from :meth:`torch_geometric.data.Dataset.get` and optionally transforms them according to :obj:`transform`. Let's see this process in a simplified example: .. code-block:: python import os.path as osp import torch from torch_geometric.data import Dataset, download_url class MyOwnDataset(Dataset): def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): super().__init__(root, transform, pre_transform, pre_filter) @property def raw_file_names(self): return ['some_file_1', 'some_file_2', ...] @property def processed_file_names(self): return ['data_1.pt', 'data_2.pt', ...] def download(self): # Download to `self.raw_dir`. path = download_url(url, self.raw_dir) ... def process(self): idx = 0 for raw_path in self.raw_paths: # Read data from `raw_path`. data = Data(...) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt')) idx += 1 def len(self): return len(self.processed_file_names) def get(self, idx): data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt')) return data Here, each graph data object gets saved individually in :meth:`~torch_geometric.data.Dataset.process`, and is manually loaded in :meth:`~torch_geometric.data.Dataset.get`. Frequently Asked Questions -------------------------- #. **How can I skip the execution of** :meth:`download` **and/or** :meth:`process` **?** You can skip downloading and/or processing by just not overriding the :meth:`download` and :meth:`process` methods: .. code-block:: python class MyOwnDataset(Dataset): def __init__(self, transform=None, pre_transform=None): super().__init__(None, transform, pre_transform) #. **Do I really need to use these dataset interfaces?** No! Just as in regular :pytorch:`PyTorch`, you do not have to use datasets, *e.g.*, when you want to create synthetic data on the fly without saving them explicitly to disk. In this case, simply pass a regular python list holding :class:`torch_geometric.data.Data` objects and pass them to :class:`torch_geometric.loader.DataLoader`: .. code-block:: python from torch_geometric.data import Data from torch_geometric.loader import DataLoader data_list = [Data(...), ..., Data(...)] loader = DataLoader(data_list, batch_size=32) Exercises --------- Consider the following :class:`~torch_geometric.data.InMemoryDataset` constructed from a list of :obj:`~torch_geometric.data.Data` objects: .. code-block:: python class MyDataset(InMemoryDataset): def __init__(self, root, data_list, transform=None): self.data_list = data_list super().__init__(root, transform) self.load(self.processed_paths[0]) @property def processed_file_names(self): return 'data.pt' def process(self): self.save(self.data_list, self.processed_paths[0]) 1. What is the output of :obj:`self.processed_paths[0]`? 2. What does :meth:`~torch_geometric.data.InMemoryDataset.save` do? ================================================ FILE: docs/source/tutorial/create_gnn.rst ================================================ Creating Message Passing Networks ================================= Generalizing the convolution operator to irregular domains is typically expressed as a *neighborhood aggregation* or *message passing* scheme. With :math:`\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F` denoting node features of node :math:`i` in layer :math:`(k-1)` and :math:`\mathbf{e}_{j,i} \in \mathbb{R}^D` denoting (optional) edge features from node :math:`j` to node :math:`i`, message passing graph neural networks can be described as .. math:: \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), where :math:`\bigoplus` denotes a differentiable, permutation invariant function, *e.g.*, sum, mean or max, and :math:`\gamma` and :math:`\phi` denote differentiable functions such as MLPs (Multi Layer Perceptrons). .. contents:: :local: The "MessagePassing" Base Class ------------------------------- :pyg:`PyG` provides the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation. The user only has to define the functions :math:`\phi` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, and :math:`\gamma` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`, as well as the aggregation scheme to use, *i.e.* :obj:`aggr="add"`, :obj:`aggr="mean"` or :obj:`aggr="max"`. This is done with the help of the following methods: * :obj:`MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)`: Defines the aggregation scheme to use (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`) and the flow direction of message passing (either :obj:`"source_to_target"` or :obj:`"target_to_source"`). Furthermore, the :obj:`node_dim` attribute indicates along which axis to propagate. * :obj:`MessagePassing.propagate(edge_index, size=None, **kwargs)`: The initial call to start propagating messages. Takes in the edge indices and all additional data which is needed to construct messages and to update node embeddings. Note that :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate` is not limited to exchanging messages in square adjacency matrices of shape :obj:`[N, N]` only, but can also exchange messages in general sparse assignment matrices, *e.g.*, bipartite graphs, of shape :obj:`[N, M]` by passing :obj:`size=(N, M)` as an additional argument. If set to :obj:`None`, the assignment matrix is assumed to be a square matrix. For bipartite graphs with two independent sets of nodes and indices, and each set holding its own information, this split can be marked by passing the information as a tuple, *e.g.* :obj:`x=(x_N, x_M)`. * :obj:`MessagePassing.message(...)`: Constructs messages to node :math:`i` in analogy to :math:`\phi` for each edge :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`. Can take any argument which was initially passed to :meth:`propagate`. In addition, tensors passed to :meth:`propagate` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *e.g.* :obj:`x_i` and :obj:`x_j`. Note that we generally refer to :math:`i` as the central nodes that aggregates information, and refer to :math:`j` as the neighboring nodes, since this is the most common notation. * :obj:`MessagePassing.update(aggr_out, ...)`: Updates node embeddings in analogy to :math:`\gamma` for each node :math:`i \in \mathcal{V}`. Takes in the output of aggregation as first argument and any argument which was initially passed to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`. Let us verify this by re-implementing two popular GNN variants, the `GCN layer from Kipf and Welling `_ and the `EdgeConv layer from Wang et al. `_. Implementing the GCN Layer -------------------------- The `GCN layer `_ is mathematically defined as .. math:: \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b}, where neighboring node features are first transformed by a weight matrix :math:`\mathbf{W}`, normalized by their degree, and finally summed up. Lastly, we apply the bias vector :math:`\mathbf{b}` to the aggregated output. This formula can be divided into the following steps: 1. Add self-loops to the adjacency matrix. 2. Linearly transform node feature matrix. 3. Compute normalization coefficients. 4. Normalize node features in :math:`\phi`. 5. Sum up neighboring node features (:obj:`"add"` aggregation). 6. Apply a final bias vector. Steps 1-3 are typically computed before message passing takes place. Steps 4-5 can be easily processed using the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class. The full layer implementation is shown below: .. code-block:: python import torch from torch.nn import Linear, Parameter from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='add') # "Add" aggregation (Step 5). self.lin = Linear(in_channels, out_channels, bias=False) self.bias = Parameter(torch.empty(out_channels)) self.reset_parameters() def reset_parameters(self): self.lin.reset_parameters() self.bias.data.zero_() def forward(self, x, edge_index): # x has shape [N, in_channels] # edge_index has shape [2, E] # Step 1: Add self-loops to the adjacency matrix. edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # Step 2: Linearly transform node feature matrix. x = self.lin(x) # Step 3: Compute normalization. row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # Step 4-5: Start propagating messages. out = self.propagate(edge_index, x=x, norm=norm) # Step 6: Apply a final bias vector. out = out + self.bias return out def message(self, x_j, norm): # x_j has shape [E, out_channels] # Step 4: Normalize node features. return norm.view(-1, 1) * x_j :class:`~torch_geometric.nn.conv.GCNConv` inherits from :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` with :obj:`"add"` propagation. All the logic of the layer takes place in its :meth:`forward` method. Here, we first add self-loops to our edge indices using the :meth:`torch_geometric.utils.add_self_loops` function (step 1), as well as linearly transform node features by calling the :class:`torch.nn.Linear` instance (step 2). The normalization coefficients are derived by the node degrees :math:`\deg(i)` for each node :math:`i` which gets transformed to :math:`1/(\sqrt{\deg(i)} \cdot \sqrt{\deg(j)})` for each edge :math:`(j,i) \in \mathcal{E}`. The result is saved in the tensor :obj:`norm` of shape :obj:`[num_edges, ]` (step 3). We then call :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`, which internally calls :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.aggregate` and :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`. We pass the node embeddings :obj:`x` and the normalization coefficients :obj:`norm` as additional arguments for message propagation. In the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we need to normalize the neighboring node features :obj:`x_j` by :obj:`norm`. Here, :obj:`x_j` denotes a *lifted* tensor, which contains the source node features of each edge, *i.e.*, the neighbors of each node. Node features can be automatically lifted by appending :obj:`_i` or :obj:`_j` to the variable name. In fact, any tensor can be converted this way, as long as they hold source or destination node features. That is all that it takes to create a simple message passing layer. You can use this layer as a building block for deep architectures. Initializing and calling it is straightforward: .. code-block:: python conv = GCNConv(16, 32) x = conv(x, edge_index) Implementing the Edge Convolution --------------------------------- The `edge convolutional layer `_ processes graphs or point clouds and is mathematically defined as .. math:: \mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right), where :math:`h_{\mathbf{\Theta}}` denotes an MLP. In analogy to the GCN layer, we can use the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` class to implement this layer, this time using the :obj:`"max"` aggregation: .. code-block:: python import torch from torch.nn import Sequential as Seq, Linear, ReLU from torch_geometric.nn import MessagePassing class EdgeConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='max') # "Max" aggregation. self.mlp = Seq(Linear(2 * in_channels, out_channels), ReLU(), Linear(out_channels, out_channels)) def forward(self, x, edge_index): # x has shape [N, in_channels] # edge_index has shape [2, E] return self.propagate(edge_index, x=x) def message(self, x_i, x_j): # x_i has shape [E, in_channels] # x_j has shape [E, in_channels] tmp = torch.cat([x_i, x_j - x_i], dim=1) # tmp has shape [E, 2 * in_channels] return self.mlp(tmp) Inside the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we use :obj:`self.mlp` to transform both the target node features :obj:`x_i` and the relative source node features :obj:`x_j - x_i` for each edge :math:`(j,i) \in \mathcal{E}`. The edge convolution is actually a dynamic convolution, which recomputes the graph for each layer using nearest neighbors in the feature space. Luckily, :pyg:`PyG` comes with a GPU accelerated batch-wise k-NN graph generation method named :meth:`torch_geometric.nn.pool.knn_graph`: .. code-block:: python from torch_geometric.nn import knn_graph class DynamicEdgeConv(EdgeConv): def __init__(self, in_channels, out_channels, k=6): super().__init__(in_channels, out_channels) self.k = k def forward(self, x, batch=None): edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow) return super().forward(x, edge_index) Here, :meth:`~torch_geometric.nn.pool.knn_graph` computes a nearest neighbor graph, which is further used to call the :meth:`forward` method of :class:`~torch_geometric.nn.conv.EdgeConv`. This leaves us with a clean interface for initializing and calling this layer: .. code-block:: python conv = DynamicEdgeConv(3, 128, k=6) x = conv(x, batch) Exercises --------- Imagine we are given the following :obj:`~torch_geometric.data.Data` object: .. code-block:: python import torch from torch_geometric.data import Data edge_index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long) x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index.t().contiguous()) Try to answer the following questions related to :class:`~torch_geometric.nn.conv.GCNConv`: 1. What information does :obj:`row` and :obj:`col` hold? 2. What does :meth:`~torch_geometric.utils.degree` do? 3. Why do we use :obj:`degree(col, ...)` rather than :obj:`degree(row, ...)`? 4. What does :obj:`deg_inv_sqrt[col]` and :obj:`deg_inv_sqrt[row]` do? 5. What information does :obj:`x_j` hold in the :meth:`~torch_geometric.nn.conv.MessagePassing.message` function? If :obj:`self.lin` denotes the identity function, what is the exact content of :obj:`x_j`? 6. Add an :meth:`~torch_geometric.nn.conv.MessagePassing.update` function to :class:`~torch_geometric.nn.conv.GCNConv` that adds transformed central node features to the aggregated output. Try to answer the following questions related to :class:`~torch_geometric.nn.conv.EdgeConv`: 1. What is :obj:`x_i` and :obj:`x_j - x_i`? 2. What does :obj:`torch.cat([x_i, x_j - x_i], dim=1)` do? Why :obj:`dim = 1`? ================================================ FILE: docs/source/tutorial/dataset.rst ================================================ Working with Graph Datasets =========================== .. nbgallery:: :name: rst-gallery create_dataset load_csv dataset_splitting ================================================ FILE: docs/source/tutorial/dataset_splitting.rst ================================================ Dataset Splitting ================= Dataset splitting is a critical step in graph machine learning, where we divide our dataset into subsets for training, validation, and testing. It ensures that our models are evaluated properly, preventing overfitting, and enabling generalization. In this tutorial, we will explore the basics of dataset splitting, focusing on three fundamental tasks: node prediction, link prediction, and graph prediction. We will introduce commonly used techniques, including :class:`~torch_geometric.transforms.RandomNodeSplit` and :class:`~torch_geometric.transforms.RandomLinkSplit` transformations. Additionally, we will also cover how to create custom dataset splits beyond random ones. Node Prediction --------------- .. note:: In this section, we'll learn how to use :class:`~torch_geometric.transforms.RandomNodeSplit` of :pyg:`PyG` to randomly divide nodes into training, validation, and test sets. A fully working example on dataset :class:`~torch_geometric.datasets.Planetoid` is available in `examples/cora.py `_. The :class:`~torch_geometric.transforms.RandomNodeSplit` is initialized to split nodes for both a :pyg:`PyG` :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` object. * :obj:`split` defines the dataset's split type. * :obj:`num_splits` defines the number of splits to add. * :obj:`num_train_per_class` defines the number of training nodes per class. * :obj:`num_val` defines the number of validation nodes after data splitting. * :obj:`num_test` defines the number of test nodes after data splitting. * :obj:`key` defines the name of the ground-truth labels. .. code-block:: python import torch from torch_geometric.data import Data from torch_geometric.transforms import RandomNodeSplit x = torch.randn(8, 32) # Node features of shape [num_nodes, num_features] y = torch.randint(0, 4, (8, )) # Node labels of shape [num_nodes] edge_index = torch.tensor([ [2, 3, 3, 4, 5, 6, 7], [0, 0, 1, 1, 2, 3, 4]], ) # 0 1 # / \/ \ # 2 3 4 # | | | # 5 6 7 data = Data(x=x, y=y, edge_index=edge_index) node_transform = RandomNodeSplit(num_val=2, num_test=3) node_splits = node_transform(data) Here, we initialize a :class:`~torch_geometric.transforms.RandomNodeSplit` transformation to split the graph data by nodes. After the transformation, :obj:`train_mask`, :obj:`valid_mask` and :obj:`test_mask` will be attached to the graph data. .. code-block:: python node_splits.train_mask >>> tensor([ True, False, False, False, True, True, False, False]) node_splits.val_mask >>> tensor([False, False, False, False, False, False, True, True]) node_splits.test_mask >>> tensor([False, True, True, True, False, False, False, False]) In this example, there are 8 nodes, we want to sample 2 nodes for validation, 3 nodes for testing, and the rest for training. Finally, we got node :obj:`0, 4, 5` as training set, node :obj:`6, 7` as validation set, and node :obj:`1, 2, 3` as test set. Link Prediction --------------- .. note:: In this section, we'll learn how to use :class:`~torch_geometric.transforms.RandomLinkSplit` of :pyg:`PyG` to randomly divide edges into training, validation, and test sets. A fully working example on dataset :class:`~torch_geometric.datasets.Planetoid` is available in `examples/link_pred.py `_. The :class:`~torch_geometric.transforms.RandomLinkSplit` is initialized to split edges for both a :pyg:`PyG` :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` object. * :obj:`num_val` defines the number of validation edges after data splitting. * :obj:`num_test` defines the number of test edges after data splitting. * :obj:`is_undirected` defines whether the graph is assumed as undirected. .. code-block:: python import torch from torch_geometric.data import Data from torch_geometric.transforms import RandomLinkSplit x = torch.randn(8, 32) # Node features of shape [num_nodes, num_features] y = torch.randint(0, 4, (8, )) # Node labels of shape [num_nodes] edge_index = torch.tensor([ [2, 3, 3, 4, 5, 6, 7], [0, 0, 1, 1, 2, 3, 4]], ) edge_y = torch.tensor([0, 0, 0, 0, 1, 1, 1]) # 0 1 # / \/ \ # 2 3 4 # | | | # 5 6 7 data = Data(x=x, y=y, edge_index=edge_index, edge_y=edge_y) edge_transform = RandomLinkSplit(num_val=0.2, num_test=0.2, key='edge_y', is_undirected=False, add_negative_train_samples=False) train_data, val_data, test_data = edge_transform(data) Similar to node splitting, we initialize a :class:`~torch_geometric.transforms.RandomLinkSplit` transformation to split the graph data by edges. Below, we can see the splitting results. .. code-block:: python train_data >>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[5], edge_y_index=[2, 5]) val_data >>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[2], edge_y_index=[2, 2]) test_data >>> Data(x=[8, 32], edge_index=[2, 6], y=[8], edge_y=[2], edge_y_index=[2, 2]) :obj:`train_data.edge_index` and :obj:`val_data.edge_index` refers to the edges that are used for message passing. As such, during training and validation, we are allowed to propagate information based on the training edges. While during testing, we can propagate information based on the union of training and validation edges. For evaluation and testing, :obj:`val_data.edge_label_index` and :obj:`test_data.edge_label_index` hold a batch of positive and negative samples that should be used to evaluate and test our model on. Graph Prediction ---------------- .. note:: In this section, we'll learn how to randomly divide graphs into training, validation, and test sets. A fully working example on dataset :class:`~torch_geometric.datasets.PPI` is available in `examples/ppi.py `_. In graph prediction task, each graph is an independent sample. Usually we need to divide a graph dataset according to a certain ratio. :pyg:`PyG` has provided some datasets that already contain corresponding indexes for training, validation and test, such as :class:`~torch_geometric.datasets.PPI`. .. code-block:: python from torch_geometric.datasets import PPI path = './data/PPI' train_dataset = PPI(path, split='train') val_dataset = PPI(path, split='val') test_dataset = PPI(path, split='test') In addition, we can also use :obj:`scikit-learn` or :obj:`numpy` to randomly divide :pyg:`PyG` dataset. Creating Custom Splits ---------------------- If random splitting doesn't suit our specific use case, then we can create custom node splits. This requirement generally occurs in real business scenarios. For example, there are large-scale heterogeneous graphs in e-commerce scenarios, and nodes can be used to represent users, products, merchants, etc. We may divide new and old users to evaluate the performance of the model on new users. Therefore, we'll not post specific examples here for reference. ================================================ FILE: docs/source/tutorial/distributed.rst ================================================ .. _distributed_tutorials: Distributed Training ==================== .. nbgallery:: :name: rst-gallery multi_gpu_vanilla multi_node_multi_gpu_vanilla distributed_pyg ================================================ FILE: docs/source/tutorial/distributed_pyg.rst ================================================ Distributed Training in PyG =========================== .. warning:: ``torch_geometric.distributed`` has been deprecated and will no longer be maintained. For distributed training with cuGraph, refer to `cuGraph examples `_. .. figure:: ../_figures/intel_kumo.png :width: 400px .. note:: We are thrilled to announce the first **in-house distributed training solution** for :pyg:`PyG` via :class:`torch_geometric.distributed`, available from version 2.5 onwards. Developers and researchers can now take full advantage of distributed training on large-scale datasets which cannot be fully loaded in memory of one machine at the same time. This implementation doesn't require any additional packages to be installed on top of the default :pyg:`PyG` stack. In real life applications, graphs often consists of billions of nodes that cannot fit into a single system memory. This is when distributed training of Graph Neural Networks comes in handy. By allocating a number of partitions of the large graph into a cluster of CPUs, one can deploy synchronized model training on the whole dataset at once by making use of :pytorch:`PyTorch's` `Distributed Data Parallel (DDP) `_ capabilities. This architecture seamlessly distributes training of Graph Neural Networks across multiple nodes via `Remote Procedure Calls (RPCs) `_ for efficient sampling and retrieval of non-local features with traditional DDP for model training. This new technique in :pyg:`PyG` was produced by engineers from `Intel `_ and `Kumo AI `_. Key Advantages -------------- #. **Balanced graph partitioning** via METIS ensures minimal communication overhead when sampling subgraphs across compute nodes. #. Utilizing **DDP for model training** in conjunction with **RPC for remote sampling and feature fetching routines** (with TCP/IP protocol and `gloo `_ communication backend) allows for data parallelism with distinct data partitions at each node. #. The implementation via custom :class:`~torch_geometric.data.GraphStore` and :class:`~torch_geometric.data.FeatureStore` APIs provides a flexible and tailored interface for distributing large graph structure information and feature storage. #. **Distributed neighbor sampling** is capable of sampling in both local and remote partitions through RPC communication channels. All advanced functionality of single-node sampling are also applicable for distributed training, *e.g.*, heterogeneous sampling, link-level sampling, temporal sampling, *etc*. #. **Distributed data loaders** offer a high-level abstraction for managing sampler processes, ensuring simplicity and seamless integration with standard :pyg:`PyG` data loaders. #. Incorporating the Python `asyncio `_ library for asynchronous processing on top of :pytorch:`PyTorch`-based RPCs further enhances the system's responsiveness and overall performance. Architecture Components ----------------------- .. note:: The purpose of this tutorial is to guide you through the most important steps of deploying distributed training applications in :pyg:`PyG`. For code examples, please refer to `examples/distributed/pyg `_. Overall, :class:`torch_geometric.distributed` is divided into the following components: * :class:`~torch_geometric.distributed.Partitoner` partitions the graph into multiple parts, such that each node only needs to load its local data in memory. * :class:`~torch_geometric.distributed.LocalGraphStore` and :class:`~torch_geometric.distributed.LocalFeatureStore` store the graph topology and features per partition, respectively. In addition, they maintain a mapping between local and global IDs for efficient assignment of nodes and feature lookup. * :class:`~torch_geometric.distributed.DistNeighborSampler` implements the distributed sampling algorithm, which includes local+remote sampling and the final merge between local/remote sampling results based on :pytorch:`PyTorch's` RPC mechanisms. * :class:`~torch_geometric.distributed.DistNeighborLoader` manages the distributed neighbor sampling and feature fetching processes via multiple RPC workers. Finally, it takes care to form sampled nodes, edges, and their features into the classic :pyg:`PyG` data format. .. figure:: ../_figures/dist_proc.png :align: center :width: 100% Schematic breakdown of the main components of :class:`torch_geometric.distributed`. Graph Partitioning ~~~~~~~~~~~~~~~~~~ The first step for distributed training is to split the graph into multiple smaller portions, which can then be loaded locally into nodes of the cluster. Partitioning is built on top of :pyg:`null` :obj:`pyg-lib`'s `implementation `_ of the METIS algorithm, suitable to perform graph partitioning efficiently, even on large-scale graphs. Note that METIS requires undirected, homogeneous graphs as input. :class:`~torch_geometric.distributed.Partitoner` performs necessary processing steps to partition heterogeneous data objects with correct distribution and indexing. By default, METIS tries to balance the number of nodes of each type in each partition while minimizing the number of edges between partitions. This ensures that the resulting partitions provide maximal local access of neighbors, enabling samplers to perform local computations without the need for communication between different compute nodes. Through this partitioning approach, every node receives a distinct assignment, while "halo nodes" (1-hop neighbors that fall into a different partition) are replicated. Halo nodes ensure that neighbor sampling for a single node in a single layer stays purely local. .. figure:: ../_figures/dist_part.png :align: center :width: 100% Graph partitioning with halo nodes. In our distributed training example, we prepared the `partition_graph.py `_ script to demonstrate how to apply partitioning on a selected subset of both homogeneous and heterogeneous graphs. The :class:`~torch_geometric.distributed.Partitioner` can also preserve node features, edge features, and any temporal attributes at the level of nodes and edges. Later on, each node in the cluster then owns a single partition of this graph. .. warning:: Partitioning via METIS is non-deterministic and as such may differ between iterations. However, all compute nodes should access the same partition data. Therefore, generate the partitions on one node and copy the data to all members of the cluster, or place the folder into a shared location. The resulting structure of partitioning for a two-part split on the homogeneous :obj:`ogbn-products` is shown below: .. code-block:: none partitions └─ obgn-products ├─ ogbn-products-partitions │ ├─ part_0 │ ├─ part_1 │ ├─ META.json │ ├─ node_map.pt │ └─ edge_map.pt ├─ ogbn-products-label │ └─ label.pt ├─ ogbn-products-test-partitions │ ├─ partition0.pt │ └─ partition1.pt └─ ogbn-products-train-partitions ├─ partition0.pt └─ partition1.pt Distributed Data Storage ~~~~~~~~~~~~~~~~~~~~~~~~ To maintain distributed data partitions, we utilize instantiations of :pyg:`PyG's` :class:`~torch_geometric.data.GraphStore` and :class:`~torch_geometric.data.FeatureStore` remote interfaces. Together with an integrated API for sending and receiving RPC requests, they provide a powerful tool for inter-connected distributed data storage. Both stores can be filled with data in a number of ways, *e.g.*, from :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` objects or initialized directly from generated partition files. :class:`~torch_geometric.distributed.LocalGraphStore` is a class designed to act as a **container for graph topology information**. It holds the edge indices that define relationships between nodes in a graph. It offers methods that provide mapping information for nodes and edges to individual partitions and support both homogeneous and heterogeneous data formats. **Key Features:** * It only stores information about local graph connections and its halo nodes within a partition. * Remote connectivity: The affiliation information of individual nodes and edges to partitions (both local and global) can be retrieved through node and edge "partition books", *i.e.* mappings of partition IDs to global node/edge IDs. * It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions. :class:`~torch_geometric.distributed.LocalFeatureStore` is a class that serves as both a **node-level and edge-level feature storage**. It provides efficient :obj:`put` and :obj:`get` routines for attribute retrieval for both local and remote node/edge IDs. The :class:`~torch_geometric.distributed.LocalFeatureStore` is responsible for retrieving and updating features across different partitions and machines during the training process. **Key Features:** * It provides functionalities for storing, retrieving, and distributing node and edge features. Within the managed partition of a machine, node and edge features are stored locally. * Remote feature lookup: It implements mechanisms for looking up features in both local and remote nodes during distributed training processes through RPC requests. The class is designed to work seamlessly in distributed training scenarios, allowing for efficient feature handling across partitions. * It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions. Below is an example of how :class:`~torch_geometric.distributed.LocalFeatureStore` is used internally to retrieve both local+remote features: .. code-block:: python import torch from torch_geometric.distributed import LocalFeatureStore from torch_geometric.distributed.event_loop import to_asyncio_future feature_store = LocalFeatureStore(...) async def get_node_features(): # Create a `LocalFeatureStore` instance: # Retrieve node features for specific node IDs: node_id = torch.tensor([1]) future = feature_store.lookup_features(node_id) return await to_asyncio_future(future) Distributed Neighbor Sampling ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ :class:`~torch_geometric.distributed.DistNeighborSampler` is a class designed for efficient distributed training of Graph Neural Networks. It addresses the challenges of sampling neighbors in a distributed environment, whereby graph data is partitioned across multiple machines or devices. The sampler ensures that GNNs can effectively learn from large-scale graphs, maintaining scalability and performance. **Asynchronous Neighbor Sampling and Feature Collection:** Distributed neighbor sampling is implemented using asynchronous :class:`torch.distributed.rpc` calls. It allows machines to independently sample neighbors without strict synchronization. Each machine autonomously selects neighbors from its local graph partition, without waiting for others to complete their sampling processes. This approach enhances parallelism, as machines can progress asynchronously, and leads to faster training. In addition to asynchronous sampling, distributed neighbor sampling also provides asynchronous feature collection. **Customizable Sampling Strategies:** Users can customize neighbor sampling strategies based on their specific requirements. The :class:`~torch_geometric.distributed.DistNeighborSampler` class provides full flexibility in defining sampling techniques, such as: * Node sampling vs. edge sampling * Homogeneous vs. heterogeneous sampling * Temporal sampling vs. static sampling **Distributed Neighbor Sampling Workflow:** A batch of seed nodes follows three main steps before it is made available for the model's :meth:`forward` pass by the data loader: #. **Distributed node sampling:** While the underlying principles of neighbor sampling holds for the distributed case as well, the implementation slightly diverges from single-machine sampling. In distributed training, seed nodes can belong to different partitions, leading to simultaneous sampling on multiple machines for a single batch. Consequently, synchronization of sampling results across machines is necessary to obtain seed nodes for the subsequent layer, requiring modifications to the basic algorithm. For nodes within a local partition, the sampling occurs on the local machine. Conversely, for nodes associated with a remote partition, the neighbor sampling is conducted on the machine responsible for storing the respective partition. Sampling then happens layer-wise, where sampled nodes act as seed nodes in follow-up layers. #. **Distributed feature lookup:** Each partition stores an array of features of nodes and edges that are within that partition. Consequently, if the output of a sampler on a specific machine includes sampled nodes or edges which do not pertain in its partition, the machine initiates an RPC request to a remote server which these nodes (or edges) belong to. #. **Data conversion:** Based on the sampler output and the acquired node (or edge) features, a :pyg:`PyG` :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object is created. This object forms a batch used in subsequent computational operations of the model. .. figure:: ../_figures/dist_sampling.png :align: center :width: 450px Local and remote neighbor sampling. Distributed Data Loading ~~~~~~~~~~~~~~~~~~~~~~~~ Distributed data loaders such as :class:`~torch_geometric.distributed.DistNeighborLoader` and :class:`~torch_geometric.distributed.DistLinkNeighborLoader` provide a simple API for the sampling engine described above because they entirely wrap initialization and cleanup of sampler processes internally. Notably, the distributed data loaders inherit from the standard :pyg:`PyG` single-node :class:`~torch_geometric.loader.NodeLoader` and :class:`~torch_geometric.loader.LinkLoader` loaders, making their application inside training scripts nearly identically. Batch generation is slightly different from the single-node case in that the step of (local+remote) feature fetching happens within the sampler, rather than encapsulated into two separate steps (sampling->feature fetching). This allows limiting the amount of RPCs. Due to the asynchronous processing between all sampler sub-processes, the samplers then return their output to a :class:`torch.multiprocessing.Queue`. Setting up Communication using DDP & RPC ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In this distributed training implementation two :class:`torch.distributed` communication technologies are used: * :class:`torch.distributed.rpc` for remote sampling calls and distributed feature retrieval * :class:`torch.distributed.ddp` for data parallel model training Our solution opts for :class:`torch.distributed.rpc` over alternatives such as gRPC because :pytorch:`PyTorch` RPC inherently comprehends tensor-type data. Unlike other RPC methods, which require the serialization or digitization of JSON or other user data into tensor types, using this method helps avoid additional serialization and digitization overhead. The DDP group is initialzied in a standard way in the main training script: .. code-block:: python torch.distributed.init_process_group( backend='gloo', rank=current_ctx.rank, world_size=current_ctx.world_size, init_method=f'tcp://{master_addr}:{ddp_port}', ) .. note:: For CPU-based sampling we recommended the `gloo `_ communication backend. RPC group initialization is more complicated because it happens in each sampler subprocess, which is achieved via the :meth:`~torch.utils.data.DataLoader.worker_init_fn` of the data loader, which is called by :pytorch:`PyTorch` directly at the initialization step of worker processes. This function first defines a distributed context for each worker and assigns it a group and rank, subsequently initializes its own distributed neighbor sampler, and finally registers a new member in the RPC group. This RPC connection remains open as long as the subprocess exists. Additionally, we opted for the `atexit `_ module to register additional cleanup behaviors that are triggered when the process is terminated. Results and Performance ----------------------- We collected the benchmarking results on :pytorch:`PyTorch` 2.1 using the system configuration at the bottom of this blog. The below table shows the scaling performance on the :obj:`ogbn-products` dataset of a :class:`~torch_geometric.nn.models.GraphSAGE` model under different partition configurations (1/2/4/8/16). .. list-table:: :widths: 15 15 15 15 :header-rows: 1 * - #Partitions - :obj:`batch_size=1024` - :obj:`batch_size=4096` - :obj:`batch_size=8192` * - 1 - 98s - 47s - 38s * - 2 - 45s - 30s - 24s * - 4 - 38s - 21s - 16s * - 8 - 29s - 14s - 10s * - 16 - 22s - 13s - 9s * **Hardware:** 2x Intel(R) Xeon(R) Platinum 8360Y CPU @ 2.40GHz, 36 cores, HT On, Turbo On, NUMA 2, Integrated Accelerators Available [used]: DLB 0 [0], DSA 0 [0], IAA 0 [0], QAT 0 [0], Total Memory 256GB (16x16GB DDR4 3200 MT/s [3200 MT/s]), BIOS SE5C620.86B.01.01.0003.2104260124, microcode 0xd000389, 2x Ethernet Controller X710 for 10GbE SFP+, 1x MT28908 Family [ConnectX-6], 1x 894.3G INTEL SSDSC2KG96, Rocky Linux 8.8 (Green Obsidian), 4.18.0-477.21.1.el8_8.x86_64 * **Software:** :python:`Python` 3.9, :pytorch:`PyTorch` 2.1, :pyg:`PyG` 2.5, :pyg:`null` :obj:`pyg-lib` 0.4.0 ================================================ FILE: docs/source/tutorial/explain.rst ================================================ Explaining Graph Neural Networks ================================ Interpreting GNN models is crucial for many use cases. :pyg:`PyG` (2.3 and beyond) provides the :class:`torch_geometric.explain` package for first-class GNN explainability support that currently includes #. a flexible interface to generate a variety of explanations via the :class:`~torch_geometric.explain.Explainer` class, #. several underlying explanation algorithms including, *e.g.*, :class:`~torch_geometric.explain.algorithm.GNNExplainer`, :class:`~torch_geometric.explain.algorithm.PGExplainer` and :class:`~torch_geometric.explain.algorithm.CaptumExplainer`, #. support to visualize explanations via the :class:`~torch_geometric.explain.Explanation` or the :class:`~torch_geometric.explain.HeteroExplanation` class, #. and metrics to evaluate explanations via the :class:`~torch_geometric.explain.metric` package. .. warning:: The explanation APIs discussed here may change in the future as we continuously work to improve their ease-of-use and generalizability. Explainer Interface ------------------- The :class:`torch_geometric.explain.Explainer` class is designed to handle all explainability parameters (see the :class:`~torch_geometric.explain.config.ExplainerConfig` class for more details): #. which algorithm from the :class:`torch_geometric.explain.algorithm` module to use (*e.g.*, :class:`~torch_geometric.explain.algorithm.GNNExplainer`) #. the type of explanation to compute, *i.e.* :obj:`explanation_type="phenomenon"` to explain the underlying phenomenon of a dataset, and :obj:`explanation_type="model"` to explain the prediction of a GNN model (see the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" `_ paper for more details). #. the different type of masks for node and edges (*e.g.*, :obj:`mask="object"` or :obj:`mask="attributes"`) #. any postprocessing of the masks (*e.g.*, :obj:`threshold_type="topk"` or :obj:`threshold_type="hard"`) This class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same. The :class:`~torch_geometric.explain.Explainer` generates an :class:`~torch_geometric.explain.Explanation` or :class:`~torch_geometric.explain.HeteroExplanation` object which contains the final information about which nodes, edges and features are crucial to explain a GNN model. .. note:: You can read more about the :class:`torch_geometric.explain` package in this `blog post `__. Examples -------- In what follows, we discuss a few use-cases with corresponding code examples. Explaining node classification on a homogeneous graph ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Assume we have a GNN :obj:`model` that does node classification on a homogeneous graph. We can use the :class:`torch_geometric.explain.algorithm.GNNExplainer` algorithm to generate an :class:`~torch_geometric.explain.Explanation`. We configure the :class:`~torch_geometric.explain.Explainer` to use both a :obj:`node_mask_type` and an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.Explanation` object contains (1) a :obj:`node_mask` (indicating which nodes and features are crucial for prediction), and (2) an :obj:`edge_mask` (indicating which edges are crucial for prediction). .. code-block:: python from torch_geometric.data import Data from torch_geometric.explain import Explainer, GNNExplainer data = Data(...) # A homogeneous graph data object. explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', # Model returns log probabilities. ), ) # Generate explanation for the node at index `10`: explanation = explainer(data.x, data.edge_index, index=10) print(explanation.edge_mask) print(explanation.node_mask) Finally, we can visualize both feature importance and the crucial subgraph of the explanation: .. code-block:: python explanation.visualize_feature_importance(top_k=10) explanation.visualize_graph() To evaluate the explanation from the :class:`~torch_geometric.explain.algorithm.GNNExplainer`, we can utilize the :class:`torch_geometric.explain.metric` module. For example, to compute the :meth:`~torch_geometric.explain.metric.unfaithfulness` of an explanation, run: .. code-block:: python from torch_geometric.explain import unfaithfulness metric = unfaithfulness(explainer, explanation) print(metric) Explaining node classification on a heterogeneous graph ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Assume we have a heterogeneous GNN :obj:`model` that does node classification on a heterogeneous graph. We can use the :class:`IntegratedGradient` attribution method from :captum:`null` `Captum `__ via the :class:`torch_geometric.explain.algorithm.CaptumExplainer` algorithm to generate a :class:`~torch_geometric.explain.HeteroExplanation`. .. note:: :class:`~torch_geometric.explain.algorithm.CaptumExplainer` is a wrapper around the :captum:`null` `Captum `__ library with support for most of attribution methods to explain *any* homogeneous or heterogeneous :pyg:`PyG` model. We configure the :class:`~torch_geometric.explain.Explainer` to use both a :obj:`node_mask_type` and an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.HeteroExplanation` object contains (1) a :obj:`node_mask` for *each* node type (indicating which nodes and features for each node type are crucial for prediction), and (2) an :obj:`edge_mask` for *each* edge type (indicating which edges for each edge type are crucial for prediction). .. code-block:: python from torch_geometric.data import HeteroData from torch_geometric.explain import Explainer, CaptumExplainer hetero_data = HeteroData(...) # A heterogeneous graph data object. explainer = Explainer( model, # It is assumed that model outputs a single tensor. algorithm=CaptumExplainer('IntegratedGradients'), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', model_config = dict( mode='multiclass_classification', task_level=task_level, return_type='probs', # Model returns probabilities. ), ) # Generate batch-wise heterogeneous explanations for # the nodes at index `1` and `3`: hetero_explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, index=torch.tensor([1, 3]), ) print(hetero_explanation.edge_mask_dict) print(hetero_explanation.node_mask_dict) Explaining graph regression on a homogeneous graph ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Assume we have a GNN :obj:`model` that does graph regression on a homogeneous graph. We can use the :class:`torch_geometric.explain.algorithm.PGExplainer` algorithm to generate an :class:`~torch_geometric.explain.Explanation`. We configure the :class:`~torch_geometric.explain.Explainer` to use an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.Explanation` object contains an :obj:`edge_mask` (indicating which edges are crucial for prediction). Importantly, passing a :obj:`node_mask_type` to the :class:`~torch_geometric.explain.Explainer` will throw an error since :class:`~torch_geometric.explain.algorithm.PGExplainer` cannot explain the importance of nodes: .. code-block:: python from torch_geometric.data import Data from torch_geometric.explain import Explainer, PGExplainer dataset = ... loader = DataLoader(dataset, batch_size=1, shuffle=True) explainer = Explainer( model=model, algorithm=PGExplainer(epochs=30, lr=0.003), explanation_type='phenomenon', edge_mask_type='object', model_config=dict( mode='regression', task_level='graph', return_type='raw', ), # Include only the top 10 most important edges: threshold_config=dict(threshold_type='topk', value=10), ) # PGExplainer needs to be trained separately since it is a parametric # explainer i.e it uses a neural network to generate explanations: for epoch in range(30): for batch in loader: loss = explainer.algorithm.train( epoch, model, batch.x, batch.edge_index, target=batch.target) # Generate the explanation for a particular graph: explanation = explainer(dataset[0].x, dataset[0].edge_index) print(explanation.edge_mask) Since this feature is still undergoing heavy development, please feel free to reach out to the :pyg:`PyG` core team either on :github:`null` `GitHub `_ or :slack:`null` `Slack `_ if you have any questions, comments or concerns. ================================================ FILE: docs/source/tutorial/gnn_design.rst ================================================ Design of Graph Neural Networks =============================== .. nbgallery:: :name: rst-gallery create_gnn heterogeneous ================================================ FILE: docs/source/tutorial/graph_transformer.rst ================================================ Graph Transformer ================= `Transformer `_ is an effictive architecture in `natural language processing `_ and `computer vision `_. Recently, there have been some applications(`Grover `_, `GraphGPS `_, etc) that combine transformers on graphs. In this tutorial, we will present how to build a graph transformer model via :pyg:`PyG`. See `our webinar `_ for in-depth learning on this topic. .. note:: Click `here `_ to download the full example code Transformers on Graphs ------------------------------ Compared to Graph Transformers, MPNNs have several drawbacks: (1) WL test: 1-order MPNNs have limited expressivity; (2) Over-smoothing: the features tend to converge to the same value while increasing the number of GNN layers; (3) Over-squashing: Losing information when trying to aggregate messages from many neighbors into a single vector; (4) Cannot capture long-range dependencies. Feeding the whole graph into the Transformer also brings several pros and cons **Pros** * Computation graph structure is decoupled from the input graph structure. * Long-range connections can be handled as all nodes are connected to each other. **Cons** * Loss of inductive bias that enables GNNs to work so well on graphs with pronounced locality. Particularly in graphs where edges represent relatedness/closeness. * Language input is squential, but graphs are permutation invariant to node ordering. * Square computational complexity :math:`O(N^2)` in the number of nodes whereas message passing GNNs are linear in the number of edges :math:`O(E)`. Graphs are often sparse :math:`N \approx E`. Attention +++++++++ .. math:: Q = XW_Q, K = XW_K, V = XW_V .. math:: Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V In Transformer, attention can be multi-head, which consists of multiple attention weights. Positional and Structural Encodings +++++++++++++++++++++++++++++++++++ We organized PE/SE into 3 categories based on their locality: (1) Local, (2) Global, (3) Relative. Positional encodings (PE) provides an idea of the position in space of a given node within the graph. When two nodes are close to each other within a graph or subgraph, their PE should also be close. Structure encodings (SE) provides an embedding of the structure of graphs or subgraphs to help increasing the expressivity and the generalizability of GNNs. When two nodes share similar subgraphs, or when two graphs are similar, their SE should also be close. .. list-table:: :widths: 10 20 20 :header-rows: 1 * - Encoding type - Positional encodings (PE) - Structure encodings (SE) * - Local (node) - (1)Distance to cluster center; (2)Sum of non-diagonal elements in m-step random walk. - (1)Node degree; (2)Random walk diagonals; (3) Enumerate substructures(triangles, rings). * - Global (node) - (1)Eigenvectors of A/L or distance matrix; (2)Distance to graph's centroid; (3)Unique ID for each node. - (1)Eigenvalues of A/L; (2) Graph diameter, girth, degree, etc. * - Relative (edge) - (1)Pair-wise distance from: Heat Kernels, Random Walks, Graph geodesic, etc; (2)Gradient of eigenvectors - (1)Gradient of any Local SE; (2)Gradient of sub-structure enumeration GPS Layer and GraphGPS Model -------------------------------------- Firstly, we introduce the GPS layer, which is combined with local MPNN and global Transformer, and then followed by 2-layer MLP and skip-connecttions. Local MPNN can provide locality bias that is difficult or expensive to achieve in Transformer. In addition, features of edges can be updated and encoded into the node features(`GatedGCN `_, `GINE `_). Transformer can utilize positional and structural encodings. As we don't need to consider edge features, We can use the existing linear Transformer architecture to reduce the time complexity from :math:`O(N^2)` to :math:`O(N + E)`, like `Performer `_ and `BigBird `_. .. warning:: `BigBird `_ currently is not supported, will be added in the future. .. figure:: ../_figures/graphgps_layer.png :align: center :width: 100% The update function of each layer is described by the equations below. Local MPNN ++++++++++ .. math:: \hat{X}_M^{l + 1}, E^{l + 1} = MPNN_e^l(X^l, E^l, A) .. math:: X_M^{l + 1} = BatchNorm(Dropout(\hat{X}_M^{l + 1}) + X^l) .. code-block:: python h = self.conv(x, edge_index, **kwargs) h = F.dropout(h, p=self.dropout, training=self.training) h = h + x if self.norm1 is not None: if self.norm_with_batch: h = self.norm1(h, batch=batch) else: h = self.norm1(h) hs.append(h) Global Attention ++++++++++++++++ .. math:: \hat{X}_T^{l + 1} = GlobalAttn^l(X^l) .. math:: X_T^{l + 1} = BatchNorm(Dropout(\hat{X}_T^{l + 1}) + X^l) .. code-block:: python h, mask = to_dense_batch(x, batch) if isinstance(self.attn, torch.nn.MultiheadAttention): h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False) elif isinstance(self.attn, PerformerAttention): h = self.attn(h, mask=mask) h = h[mask] h = F.dropout(h, p=self.dropout, training=self.training) h = h + x # Residual connection. if self.norm2 is not None: if self.norm_with_batch: h = self.norm2(h, batch=batch) else: h = self.norm2(h) hs.append(h) Combine local and global outputs ++++++++++++++++++++++++++++++++ .. math:: X^{l + 1} = MLP^l(X_M^{l + 1} + X_T^{l + 1}) .. code-block:: python out = sum(hs) out = out + self.mlp(out) if self.norm3 is not None: if self.norm_with_batch: out = self.norm3(out, batch=batch) else: out = self.norm3(out) Next, we introduce GraphGPS architecture. The difference between `GraphGPS `_ and `GraphTrans `_ is the organization of MPNN and transformer. In GraphTrans, a few layers of MPNNs are comprised before the Transformer, which may be limited by problems of over-smoothing, over-squashing and low expressivity against the WL test. These layers could irreparably fail to keep some information in the early stage. The design of GraphGPS is a stacking of MPNN + transformer hybrid, which resolves the local expressivity bottlenecks by allowing information to spread across the graph via full-connectivity. Train GraphGPS on graph-structured data -------------------------------------------------- In this part, we'll show how to train a :class:`~torch_geometric.nn.GPSConv` GNN model on the :class:`~torch_geometric.datasets.ZINC` dataset. Load dataset ++++++++++++ .. code-block:: python transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe') train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform) val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform) test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=64) test_loader = DataLoader(test_dataset, batch_size=64) Define model ++++++++++++ .. code-block:: python class RedrawProjection: def __init__(self, model: torch.nn.Module, redraw_interval: Optional[int] = None): self.model = model self.redraw_interval = redraw_interval self.num_last_redraw = 0 def redraw_projections(self): if not self.model.training or self.redraw_interval is None: return if self.num_last_redraw >= self.redraw_interval: fast_attentions = [ module for module in self.model.modules() if isinstance(module, PerformerAttention) ] for fast_attention in fast_attentions: fast_attention.redraw_projection_matrix() self.num_last_redraw = 0 return self.num_last_redraw += 1 class GPS(torch.nn.Module): def __init__(self, channels: int, pe_dim: int, num_layers: int, attn_type: str, attn_kwargs: Dict[str, Any]): super().__init__() self.node_emb = Embedding(28, channels - pe_dim) self.pe_lin = Linear(20, pe_dim) self.pe_norm = BatchNorm1d(20) self.edge_emb = Embedding(4, channels) self.convs = ModuleList() for _ in range(num_layers): nn = Sequential( Linear(channels, channels), ReLU(), Linear(channels, channels), ) conv = GPSConv(channels, GINEConv(nn), heads=4, attn_type=attn_type, attn_kwargs=attn_kwargs) self.convs.append(conv) self.mlp = Sequential( Linear(channels, channels // 2), ReLU(), Linear(channels // 2, channels // 4), ReLU(), Linear(channels // 4, 1), ) self.redraw_projection = RedrawProjection( self.convs, redraw_interval=1000 if attn_type == 'performer' else None) def forward(self, x, pe, edge_index, edge_attr, batch): x_pe = self.pe_norm(pe) x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1) edge_attr = self.edge_emb(edge_attr) for conv in self.convs: x = conv(x, edge_index, batch, edge_attr=edge_attr) x = global_add_pool(x, batch) return self.mlp(x) Train and evaluate +++++++++++++++++++ .. code-block:: python device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') attn_kwargs = {'dropout': 0.5} model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type=args.attn_type, attn_kwargs=attn_kwargs).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=0.00001) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() model.redraw_projection.redraw_projections() out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch) loss = (out.squeeze() - data.y).abs().mean() loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() total_error = 0 for data in loader: data = data.to(device) out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch) total_error += (out.squeeze() - data.y).abs().sum().item() return total_error / len(loader.dataset) for epoch in range(1, 101): loss = train() val_mae = test(val_loader) test_mae = test(test_loader) scheduler.step(val_mae) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, ' f'Test: {test_mae:.4f}') .. code-block:: text Epoch: 01, Loss: 0.7216, Val: 0.5316, Test: 0.5454 Epoch: 02, Loss: 0.5519, Val: 0.5895, Test: 0.6288 Epoch: 03, Loss: 0.5009, Val: 0.5029, Test: 0.4924 Epoch: 04, Loss: 0.4751, Val: 0.4801, Test: 0.4786 Epoch: 05, Loss: 0.4363, Val: 0.4438, Test: 0.4352 Epoch: 06, Loss: 0.4276, Val: 0.4931, Test: 0.4994 Epoch: 07, Loss: 0.3956, Val: 0.3502, Test: 0.3439 Epoch: 08, Loss: 0.4021, Val: 0.3143, Test: 0.3296 Epoch: 09, Loss: 0.3761, Val: 0.4012, Test: 0.3858 Epoch: 10, Loss: 0.3739, Val: 0.3343, Test: 0.3032 Epoch: 11, Loss: 0.3532, Val: 0.3679, Test: 0.3334 Epoch: 12, Loss: 0.3683, Val: 0.3094, Test: 0.2754 Epoch: 13, Loss: 0.3457, Val: 0.4007, Test: 0.4023 Epoch: 14, Loss: 0.3460, Val: 0.3986, Test: 0.3589 Epoch: 15, Loss: 0.3369, Val: 0.3478, Test: 0.3124 Epoch: 16, Loss: 0.3222, Val: 0.3043, Test: 0.2651 Epoch: 17, Loss: 0.3190, Val: 0.4496, Test: 0.4070 Epoch: 18, Loss: 0.3317, Val: 0.3803, Test: 0.3450 Epoch: 19, Loss: 0.3179, Val: 0.2671, Test: 0.2408 Epoch: 20, Loss: 0.3143, Val: 0.4168, Test: 0.3901 Epoch: 21, Loss: 0.3238, Val: 0.3183, Test: 0.2926 Epoch: 22, Loss: 0.3132, Val: 0.9534, Test: 1.0879 Epoch: 23, Loss: 0.3088, Val: 0.3705, Test: 0.3360 Epoch: 24, Loss: 0.3032, Val: 0.3051, Test: 0.2692 Epoch: 25, Loss: 0.2968, Val: 0.2829, Test: 0.2571 Epoch: 26, Loss: 0.2915, Val: 0.3145, Test: 0.2820 Epoch: 27, Loss: 0.2871, Val: 0.3127, Test: 0.2965 Epoch: 28, Loss: 0.2953, Val: 0.4415, Test: 0.4144 Epoch: 29, Loss: 0.2916, Val: 0.3118, Test: 0.2733 Epoch: 30, Loss: 0.3074, Val: 0.4497, Test: 0.4418 ================================================ FILE: docs/source/tutorial/heterogeneous.rst ================================================ Heterogeneous Graph Learning ============================ A large set of real-world datasets are stored as heterogeneous graphs, motivating the introduction of specialized functionality for them in :pyg:`PyG`. For example, most graphs in the area of recommendation, such as social graphs, are heterogeneous, as they store information about different types of entities and their different types of relations. This tutorial introduces how heterogeneous graphs are mapped to :pyg:`PyG` and how they can be used as input to Graph Neural Network models. Heterogeneous graphs come with different types of information attached to nodes and edges. Thus, a single node or edge feature tensor cannot hold all node or edge features of the whole graph, due to differences in type and dimensionality. Instead, a set of types need to be specified for nodes and edges, respectively, each having its own data tensors. As a consequence of the different data structure, the message passing formulation changes accordingly, allowing the computation of message and update function conditioned on node or edge type. Example Graph ------------- As a guiding example, we take a look at the heterogeneous `ogbn-mag `__ network from the :ogb:`null` `dataset suite `_: .. image:: ../_figures/hg_example.svg :align: center :width: 500px The given heterogeneous graph has 1,939,743 nodes, split between the four node types **author**, **paper**, **institution** and **field of study**. It further has 21,111,007 edges, which also are of one of four types: * **writes**: An author *writes* a specific paper * **affiliated with**: An author is *affiliated with* a specific institution * **cites**: A paper *cites* another paper * **has topic**: A paper *has a topic* of a specific field of study The task for this graph is to infer the venue of each paper (conference or journal) given the information stored in the graph. Creating Heterogeneous Graphs ----------------------------- First, we can create a data object of type :class:`torch_geometric.data.HeteroData`, for which we define node feature tensors, edge index tensors and edge feature tensors individually for each type: .. code-block:: python from torch_geometric.data import HeteroData data = HeteroData() data['paper'].x = ... # [num_papers, num_features_paper] data['author'].x = ... # [num_authors, num_features_author] data['institution'].x = ... # [num_institutions, num_features_institution] data['field_of_study'].x = ... # [num_field, num_features_field] data['paper', 'cites', 'paper'].edge_index = ... # [2, num_edges_cites] data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges_writes] data['author', 'affiliated_with', 'institution'].edge_index = ... # [2, num_edges_affiliated] data['paper', 'has_topic', 'field_of_study'].edge_index = ... # [2, num_edges_topic] data['paper', 'cites', 'paper'].edge_attr = ... # [num_edges_cites, num_features_cites] data['author', 'writes', 'paper'].edge_attr = ... # [num_edges_writes, num_features_writes] data['author', 'affiliated_with', 'institution'].edge_attr = ... # [num_edges_affiliated, num_features_affiliated] data['paper', 'has_topic', 'field_of_study'].edge_attr = ... # [num_edges_topic, num_features_topic] Node or edge tensors will be automatically created upon first access and indexed by string keys. Node types are identified by a single string while edge types are identified by using a triplet :obj:`(source_node_type, edge_type, destination_node_type)` of strings: the edge type identifier and the two node types between which the edge type can exist. As such, the data object allows different feature dimensionalities for each type. Dictionaries containing the heterogeneous information grouped by attribute names rather than by node or edge type can directly be accessed via :obj:`data.{attribute_name}_dict` and serve as input to GNN models later: .. code-block:: python model = HeteroGNN(...) output = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) If the dataset exists in the `list of Pytorch Geometric datasets `_, it can directly be imported and used. In particular, it will be downloaded to :obj:`root` and processed automatically. .. code-block:: python from torch_geometric.datasets import OGB_MAG dataset = OGB_MAG(root='./data', preprocess='metapath2vec') data = dataset[0] The :obj:`data` object can be printed for verification. .. code-block:: text HeteroData( paper={ x=[736389, 128], y=[736389], train_mask=[736389], val_mask=[736389], test_mask=[736389] }, author={ x=[1134649, 128] }, institution={ x=[8740, 128] }, field_of_study={ x=[59965, 128] }, (author, affiliated_with, institution)={ edge_index=[2, 1043998] }, (author, writes, paper)={ edge_index=[2, 7145660] }, (paper, cites, paper)={ edge_index=[2, 5416271] }, (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] } ) .. note:: The original `ogbn-mag `__ network does only provide features for "paper" nodes. In :class:`~torch_geometric.datasets.OGB_MAG`, we provide the option to download a processed version of it in which structural features (obtained from either :obj:`"metapath2vec"` or :obj:`"TransE"`) are added to featureless nodes, as it is commonly done in the top ranked submissions to the `OGB leaderboards `_. Utility Functions ~~~~~~~~~~~~~~~~~ The :class:`torch_geometric.data.HeteroData` class provides a number of useful utility functions to modify and analyze the given graph. For example, single node or edge stores can be individually indexed: .. code-block:: python paper_node_data = data['paper'] cites_edge_data = data['paper', 'cites', 'paper'] In case the edge type can be uniquely identified by only the pair of source and destination node types or the edge type, the following operations work as well: .. code-block:: python cites_edge_data = data['paper', 'paper'] cites_edge_data = data['cites'] We can add new node types or tensors and remove them: .. code-block:: python data['paper'].year = ... # Setting a new paper attribute del data['field_of_study'] # Deleting 'field_of_study' node type del data['has_topic'] # Deleting 'has_topic' edge type We can access the meta-data of the :obj:`data` object, holding information of all present node and edge types: .. code-block:: python node_types, edge_types = data.metadata() print(node_types) ['paper', 'author', 'institution'] print(edge_types) [('paper', 'cites', 'paper'), ('author', 'writes', 'paper'), ('author', 'affiliated_with', 'institution')] The :obj:`data` object can be transferred between devices as usual: .. code-block:: python data = data.to('cuda:0') data = data.cpu() We further have access to additional helper functions to analyze the given graph .. code-block:: python data.has_isolated_nodes() data.has_self_loops() data.is_undirected() and can convert it to a homogeneous "typed" graph via :meth:`~torch_geometric.data.HeteroData.to_homogeneous` which is able to maintain features in case their dimensionalities match across different types: .. code-block:: python homogeneous_data = data.to_homogeneous() print(homogeneous_data) Data(x=[1879778, 128], edge_index=[2, 13605929], edge_type=[13605929]) Here, :obj:`homogeneous_data.edge_type` represents an edge-level vector that holds the edge type of each edge as an integer. Heterogeneous Graph Transformations ----------------------------------- Most `transformations `_ for preprocessing regular graphs work as well on the heterogeneous graph :obj:`data` object. .. code-block:: python import torch_geometric.transforms as T data = T.ToUndirected()(data) data = T.AddSelfLoops()(data) data = T.NormalizeFeatures()(data) Here, :meth:`~torch_geometric.transforms.ToUndirected` transforms a directed graph into (the :pyg:`PyG` representation of) an undirected graph, by adding reverse edges for all edges in the graph. Thus, future message passing is performed in both direction of all edges. The function may add reverse edge types to the heterogeneous graph, if necessary. For all nodes of type :obj:`'node_type'` and all existing edge types of the form :obj:`('node_type', 'edge_type', 'node_type')`, the function :meth:`~torch_geometric.transforms.AddSelfLoops` will add self-loop edges. As a result, each node might receive one or more (one per appropriate edge type) messages from itself during message passing. The transform :meth:`~torch_geometric.transforms.NormalizeFeatures` works like in the homogeneous case, and normalizes all specified features (of all types) to sum up to one. Creating Heterogeneous GNNs --------------------------- Standard Message Passing GNNs (MP-GNNs) can not trivially be applied to heterogeneous graph data, as node and edge features from different types can not be processed by the same functions due to differences in feature type. A natural way to circumvent this is to implement message and update functions individually for each edge type. During runtime, the MP-GNN algorithm would need to iterate over edge type dictionaries during message computation and over node type dictionaries during node updates. To avoid unnecessary runtime overheads and to make the creation of heterogeneous MP-GNNs as simple as possible, Pytorch Geometric provides three ways for the user to create models on heterogeneous graph data: #. Automatically convert a homogeneous GNN model to a heterogeneous GNN model by making use of :meth:`torch_geometric.nn.to_hetero` or :meth:`torch_geometric.nn.to_hetero_with_bases` #. Define individual functions for different types using :pyg:`PyG's` wrapper :class:`torch_geometric.nn.conv.HeteroConv` for heterogeneous convolution #. Deploy existing (or write your own) heterogeneous GNN operators In the following, each option is introduced in detail. Automatically Converting GNN Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pytorch Geometric allows to automatically convert any :pyg:`PyG` GNN model to a model for heterogeneous input graphs, using the built in functions :meth:`torch_geometric.nn.to_hetero` or :meth:`torch_geometric.nn.to_hetero_with_bases`. The following `example `__ shows how to apply it: .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.nn import SAGEConv, to_hetero dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected()) data = dataset[0] class GNN(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x model = GNN(hidden_channels=64, out_channels=dataset.num_classes) model = to_hetero(model, data.metadata(), aggr='sum') The process takes an existing GNN model and duplicates the message functions to work on each edge type individually, as detailed in the following figure. .. image:: ../_figures/to_hetero.svg :align: center :width: 90% As a result, the model now expects dictionaries with node and edge types as keys as input arguments, rather than single tensors utilized in homogeneous graphs. Note that we pass in a tuple of :obj:`in_channels` to :class:`~torch_geometric.nn.conv.SAGEConv` in order to allow for message passing in bipartite graphs. .. _lazyinit: .. note:: Since the number of input features and thus the size of tensors varies between different types, :pyg:`PyG` can make use of **lazy initialization** to initialize parameters in heterogeneous GNNs (as denoted by :obj:`-1` as the :obj:`in_channels` argument). This allows us to avoid calculating and keeping track of all tensor sizes of the computation graph. Lazy initialization is supported for all existing :pyg:`PyG` operators. We can initialize the model's parameters by calling it once: .. code-block:: python with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) Both :meth:`~torch_geometric.nn.to_hetero` and :meth:`~torch_geometric.nn.to_hetero_with_bases` are very flexible with respect to the homogeneous architectures that can be automatically converted to heterogeneous ones, *e.g.*, applying skip-connections, jumping knowledge or other techniques are supported out-of-the-box. For example, this is all it takes to implement a heterogeneous graph attention network with learnable skip-connections: .. code-block:: python from torch_geometric.nn import GATConv, Linear, to_hetero class GAT(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) self.lin1 = Linear(-1, hidden_channels) self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False) self.lin2 = Linear(-1, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index) + self.lin1(x) x = x.relu() x = self.conv2(x, edge_index) + self.lin2(x) return x model = GAT(hidden_channels=64, out_channels=dataset.num_classes) model = to_hetero(model, data.metadata(), aggr='sum') Note that we disable the creation of self loops via the :obj:`add_self_loops=False` argument. This is done because the concept of self-loops is not well-defined in bipartite graphs (message passing for an edge type with distinct source and destination node types), and we would mistakenly add the edges :obj:`[(0, 0), (1, 1), ...]` to the bipartite graph. To preserve central node information, we thus utilize a learnable skip-connection via :obj:`conv(x, edge_index) + lin(x)` instead, which will perform attention-based message passing from source to destination node features, and its output is then summed up to the existing destination node features. Afterwards, the created model can be trained as usual: .. _trainfunc: .. code-block:: python def train(): model.train() optimizer.zero_grad() out = model(data.x_dict, data.edge_index_dict) mask = data['paper'].train_mask loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask]) loss.backward() optimizer.step() return float(loss) Using the Heterogeneous Convolution Wrapper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The heterogeneous convolution wrapper :class:`torch_geometric.nn.conv.HeteroConv` allows to define custom heterogeneous message and update functions to build arbitrary MP-GNNs for heterogeneous graphs from scratch. While the automatic converter :meth:`~torch_geometric.nn.to_hetero` uses the same operator for all edge types, the wrapper allows to define different operators for different edge types. Here, :class:`~torch_geometric.nn.conv.HeteroConv` takes a dictionary of submodules as input, one for each edge type in the graph data. The following `example `__ shows how to apply it. .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected()) data = dataset[0] class HeteroGNN(torch.nn.Module): def __init__(self, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels), ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels), ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels, add_self_loops=False), }, aggr='sum') self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) x_dict = {key: x.relu() for key, x in x_dict.items()} return self.lin(x_dict['author']) model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes, num_layers=2) We can initialize the model by calling it once (see :ref:`here` for more details about lazy initialization) .. code-block:: python with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) and run the standard training procedure as outlined :ref:`here`. Deploy Existing Heterogeneous Operators ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ :pyg:`PyG` provides operators (*e.g.*, :class:`torch_geometric.nn.conv.HGTConv`), which are specifically designed for heterogeneous graphs. These operators can be directly used to build heterogeneous GNN models as can be seen in the following `example `__: .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.nn import HGTConv, Linear dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected()) data = dataset[0] class HGT(torch.nn.Module): def __init__(self, hidden_channels, out_channels, num_heads, num_layers): super().__init__() self.lin_dict = torch.nn.ModuleDict() for node_type in data.node_types: self.lin_dict[node_type] = Linear(-1, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum') self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): for node_type, x in x_dict.items(): x_dict[node_type] = self.lin_dict[node_type](x).relu_() for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) return self.lin(x_dict['author']) model = HGT(hidden_channels=64, out_channels=dataset.num_classes, num_heads=2, num_layers=2) We can initialize the model by calling it once (see :ref:`here` for more details about lazy initialization). .. code-block:: python with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) and run the standard training procedure as outlined :ref:`here`. Heterogeneous Graph Samplers ---------------------------- :pyg:`PyG` provides various functionalities for sampling heterogeneous graphs, *i.e.* in the standard :class:`torch_geometric.loader.NeighborLoader` class or in dedicated heterogeneous graph samplers such as :class:`torch_geometric.loader.HGTLoader`. This is especially useful for efficient representation learning on large heterogeneous graphs, where processing the full number of neighbors is too computationally expensive. Heterogeneous graph support for other samplers such as :class:`torch_geometric.loader.ClusterLoader` or :class:`torch_geometric.loader.GraphSAINTLoader` will be added soon. Overall, all heterogeneous graph loaders will produce a :class:`~torch_geometric.data.HeteroData` object as output, holding a subset of the original data, and mainly differ in the way their sampling procedures works. As such, only minimal code changes are required to convert the training procedure from :ref:`full-batch training` to mini-batch training. Performing neighbor sampling using :class:`~torch_geometric.loader.NeighborLoader` works as outlined in the following `example `__: .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import NeighborLoader transform = T.ToUndirected() # Add reverse edge types. data = OGB_MAG(root='./data', preprocess='metapath2vec', transform=transform)[0] train_loader = NeighborLoader( data, # Sample 15 neighbors for each node and each edge type for 2 iterations: num_neighbors=[15] * 2, # Use a batch size of 128 for sampling training nodes of type "paper": batch_size=128, input_nodes=('paper', data['paper'].train_mask), ) batch = next(iter(train_loader)) Notably, :class:`~torch_geometric.loader.NeighborLoader` works for both homogeneous *and* heterogeneous graphs. When operating in heterogeneous graphs, more fine-grained control over the amount of sampled neighbors of individual edge types is possible, but not necessary, *e.g.*: .. code-block:: python num_neighbors = {key: [15] * 2 for key in data.edge_types} Using the :obj:`input_nodes` argument, we further specify the type and indices of nodes from which we want to sample local neighborhoods, *i.e.* all the "paper" nodes marked as training nodes according to :obj:`data['paper'].train_mask`. Printing :obj:`batch` then yields the following output: .. code-block:: text HeteroData( paper={ x=[20799, 256], y=[20799], train_mask=[20799], val_mask=[20799], test_mask=[20799], batch_size=128 }, author={ x=[4419, 128] }, institution={ x=[302, 128] }, field_of_study={ x=[2605, 128] }, (author, affiliated_with, institution)={ edge_index=[2, 0] }, (author, writes, paper)={ edge_index=[2, 5927] }, (paper, cites, paper)={ edge_index=[2, 11829] }, (paper, has_topic, field_of_study)={ edge_index=[2, 10573] }, (institution, rev_affiliated_with, author)={ edge_index=[2, 829] }, (paper, rev_writes, author)={ edge_index=[2, 5512] }, (field_of_study, rev_has_topic, paper)={ edge_index=[2, 10499] } ) As such, :obj:`batch` holds a total of 28,187 nodes involved for computing the embeddings of 128 "paper" nodes. Sampled nodes are always sorted based on the order in which they were sampled. Thus, the first :obj:`batch['paper'].batch_size` nodes represent the set of original mini-batch nodes, making it easy to obtain the final output embeddings via slicing. Training our heterogeneous GNN model in mini-batch mode is then similar to training it in full-batch mode, except that we now iterate over the mini-batches produced by :obj:`train_loader` and optimize model parameters based on individual mini-batches: .. code-block:: python def train(): model.train() total_examples = total_loss = 0 for batch in train_loader: optimizer.zero_grad() batch = batch.to('cuda:0') batch_size = batch['paper'].batch_size out = model(batch.x_dict, batch.edge_index_dict) loss = F.cross_entropy(out['paper'][:batch_size], batch['paper'].y[:batch_size]) loss.backward() optimizer.step() total_examples += batch_size total_loss += float(loss) * batch_size return total_loss / total_examples Importantly, we only make use of the first 128 "paper" nodes during loss computation. We do so by slicing both "paper" labels :obj:`batch['paper'].y` and "paper" output predictions :obj:`out['paper']` based on :obj:`batch['paper'].batch_size`, representing the labels and final output predictions of original mini-batch nodes, respectively. ================================================ FILE: docs/source/tutorial/load_csv.rst ================================================ Loading Graphs from CSV ======================= In this example, we will show how to load a set of :obj:`*.csv` files as input and construct a **heterogeneous graph** from it, which can be used as input to a `heterogeneous graph model `__. This tutorial is also available as an executable `example script `_ in the :obj:`examples/hetero` directory. We are going to use the `MovieLens dataset `_ collected by the GroupLens research group. This toy dataset describes 5-star rating and tagging activity from MovieLens. The dataset contains approximately 100k ratings across more than 9k movies from more than 600 users. We are going to use this dataset to generate two node types holding data for **movies** and **users**, respectively, and one edge type connecting **users and movies**, representing the relation of how a user has rated a specific movie. First, we download the dataset to an arbitrary folder (in this case, the current directory): .. code-block:: python from torch_geometric.data import download_url, extract_zip url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip' extract_zip(download_url(url, '.'), '.') movie_path = './ml-latest-small/movies.csv' rating_path = './ml-latest-small/ratings.csv' Before we create the heterogeneous graph, let's take a look at the data. .. code-block:: python import pandas as pd print(pd.read_csv(movie_path).head()) print(pd.read_csv(rating_path).head()) .. list-table:: Head of :obj:`movies.csv` :widths: 5 40 60 :header-rows: 1 * - movieId - title - genres * - 1 - Toy Story (1995) - Adventure|Animation|Children|Comedy|Fantasy * - 2 - Jumanji (1995) - Adventure|Children|Fantasy * - 3 - Grumpier Old Men (1995) - Comedy|Romance * - 4 - Waiting to Exhale (1995) - Comedy|Drama|Romance * - 5 - Father of the Bride Part II (1995) - Comedy We see that the :obj:`movies.csv` file provides three columns: :obj:`movieId` assigns a unique identifier to each movie, while the :obj:`title` and :obj:`genres` columns represent title and genres of the given movie. We can make use of those two columns to define a feature representation that can be easily interpreted by machine learning models. .. list-table:: Head of :obj:`ratings.csv` :widths: 5 5 10 30 :header-rows: 1 * - userId - movieId - rating - timestamp * - 1 - 1 - 4.0 - 964982703 * - 1 - 3 - 4.0 - 964981247 * - 1 - 6 - 4.0 - 964982224 * - 1 - 47 - 5.0 - 964983815 * - 1 - 50 - 5.0 - 964982931 The :obj:`ratings.csv` data connects users (as given by :obj:`userId`) and movies (as given by :obj:`movieId`), and defines how a given user has rated a specific movie (:obj:`rating`). Due to simplicity, we do not make use of the additional :obj:`timestamp` information. For representing this data in the :pyg:`PyG` data format, we first define a method :meth:`load_node_csv` that reads in a :obj:`*.csv` file and returns a node-level feature representation :obj:`x` of shape :obj:`[num_nodes, num_features]`: .. code-block:: python import torch def load_node_csv(path, index_col, encoders=None, **kwargs): df = pd.read_csv(path, index_col=index_col, **kwargs) mapping = {index: i for i, index in enumerate(df.index.unique())} x = None if encoders is not None: xs = [encoder(df[col]) for col, encoder in encoders.items()] x = torch.cat(xs, dim=-1) return x, mapping Here, :meth:`load_node_csv` reads the :obj:`*.csv` file from :obj:`path`, and creates a dictionary :obj:`mapping` that maps its index column to a consecutive value in the range :obj:`{ 0, ..., num_rows - 1 }`. This is needed as we want our final data representation to be as compact as possible, *e.g.*, the representation of a movie in the first row should be accessible via :obj:`x[0]`. We further utilize the concept of encoders, which define how the values of specific columns should be encoded into a numerical feature representation. For example, we can define a sentence encoder that encodes raw column strings into low-dimensional embeddings. For this, we make use of the excellent `sentence-transformers `_ library which provides a large number of state-of-the-art pretrained NLP embedding models: .. code-block:: bash pip install sentence-transformers .. code-block:: python class SequenceEncoder: def __init__(self, model_name='all-MiniLM-L6-v2', device=None): self.device = device self.model = SentenceTransformer(model_name, device=device) @torch.no_grad() def __call__(self, df): x = self.model.encode(df.values, show_progress_bar=True, convert_to_tensor=True, device=self.device) return x.cpu() The :class:`SequenceEncoder` class loads a pre-trained NLP model as given by :obj:`model_name`, and uses it to encode a list of strings into a :pytorch:`PyTorch` tensor of shape :obj:`[num_strings, embedding_dim]`. We can use this :class:`SequenceEncoder` to encode the :obj:`title` of the :obj:`movies.csv` file. In a similar fashion, we can create another encoder that converts the genres of movies, *e.g.*, :obj:`Adventure|Children|Fantasy`, into categorical labels. For this, we first need to find all existing genres present in the data, create a feature representation :obj:`x` of shape :obj:`[num_movies, num_genres]`, and assign a :obj:`1` to :obj:`x[i, j]` in case the genre :obj:`j` is present in movie :obj:`i`: .. code-block:: python class GenresEncoder: def __init__(self, sep='|'): self.sep = sep def __call__(self, df): genres = set(g for col in df.values for g in col.split(self.sep)) mapping = {genre: i for i, genre in enumerate(genres)} x = torch.zeros(len(df), len(mapping)) for i, col in enumerate(df.values): for genre in col.split(self.sep): x[i, mapping[genre]] = 1 return x With this, we can obtain our final representation of movies via: .. code-block:: python movie_x, movie_mapping = load_node_csv( movie_path, index_col='movieId', encoders={ 'title': SequenceEncoder(), 'genres': GenresEncoder() }) Similarly, we can utilize :meth:`load_node_csv` for obtaining a user mapping from :obj:`userId` to consecutive values as well. However, there is no additional feature information for users present in this dataset. As such, we do not define any encoders: .. code-block:: python _, user_mapping = load_node_csv(rating_path, index_col='userId') With this, we are ready to initialize our :class:`~torch_geometric.data.HeteroData` object and pass two node types into it: .. code-block:: python from torch_geometric.data import HeteroData data = HeteroData() data['user'].num_nodes = len(user_mapping) # Users do not have any features. data['movie'].x = movie_x print(data) HeteroData( user={ num_nodes=610 }, movie={ x[9742, 404] } ) As users do not have any node-level information, we solely define its number of nodes. As a result, we likely need to learn distinct user embeddings via :class:`torch.nn.Embedding` in an end-to-end fashion during training of a heterogeneous graph model. Next, we take a look at connecting users with movies as defined by their ratings. For this, we define a method :meth:`load_edge_csv` that returns the final :obj:`edge_index` representation of shape :obj:`[2, num_ratings]` from :obj:`ratings.csv`, as well as any additional features present in the raw :obj:`*.csv` file: .. code-block:: python def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping, encoders=None, **kwargs): df = pd.read_csv(path, **kwargs) src = [src_mapping[index] for index in df[src_index_col]] dst = [dst_mapping[index] for index in df[dst_index_col]] edge_index = torch.tensor([src, dst]) edge_attr = None if encoders is not None: edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()] edge_attr = torch.cat(edge_attrs, dim=-1) return edge_index, edge_attr Here, :obj:`src_index_col` and :obj:`dst_index_col` define the index columns of source and destination nodes, respectively. We further make use of the node-level mappings :obj:`src_mapping` and :obj:`dst_mapping` to ensure that raw indices are mapped to the correct consecutive indices in our final representation. For every edge defined in the file, it looks up the forward indices in :obj:`src_mapping` and :obj:`dst_mapping`, and moves the data appropriately. Similarly to :meth:`load_node_csv`, encoders are used to return additional edge-level feature information. For example, for loading the ratings from the :obj:`rating` column in :obj:`ratings.csv`, we can define an :class:`IdentityEncoder` that simply converts a list of floating-point values into a :pytorch:`PyTorch` tensor: .. code-block:: python class IdentityEncoder: def __init__(self, dtype=None): self.dtype = dtype def __call__(self, df): return torch.from_numpy(df.values).view(-1, 1).to(self.dtype) With this, we are ready to finalize our :class:`~torch_geometric.data.HeteroData` object: .. code-block:: python edge_index, edge_label = load_edge_csv( rating_path, src_index_col='userId', src_mapping=user_mapping, dst_index_col='movieId', dst_mapping=movie_mapping, encoders={'rating': IdentityEncoder(dtype=torch.long)}, ) data['user', 'rates', 'movie'].edge_index = edge_index data['user', 'rates', 'movie'].edge_label = edge_label print(data) HeteroData( user={ num_nodes=610 }, movie={ x=[9742, 404] }, (user, rates, movie)={ edge_index=[2, 100836], edge_label=[100836, 1] } ) This :class:`~torch_geometric.data.HeteroData` object is the native format of heterogeneous graphs in :pyg:`PyG` and can be used as input for `heterogeneous graph models `__. .. note:: Click `here `_ to see the final example script. ================================================ FILE: docs/source/tutorial/multi_gpu_vanilla.rst ================================================ Multi-GPU Training in Pure PyTorch ================================== .. note:: For multi-GPU training with cuGraph, refer to `cuGraph examples `_. For many large scale, real-world datasets, it may be necessary to scale-up training across multiple GPUs. This tutorial goes over how to set up a multi-GPU training pipeline in :pyg:`PyG` with :pytorch:`PyTorch` via :class:`torch.nn.parallel.DistributedDataParallel`, without the need for any other third-party libraries (such as :lightning:`PyTorch Lightning`). Note that this approach is based on data-parallelism. This means that each GPU runs an identical copy of the model; you might want to look into `PyTorch FSDP `_ if you want to scale your model across devices. Data-parallelism allows you to increase the batch size of your model by aggregating gradients across GPUs and then sharing the same optimizer step within every model replica. This `DDP+MNIST-tutorial `_ by the Princeton University has some nice illustrations of the process. Specifically this tutorial shows how to train a :class:`~torch_geometric.nn.models.GraphSAGE` GNN model on the :class:`~torch_geometric.datasets.Reddit` dataset. For this, we will use :class:`torch.nn.parallel.DistributedDataParallel` to scale-up training across all available GPUs. We will do this by spawning multiple processes from our :python:`Python` code which will all execute the same function. Per process, we set up our model instance and feed data through it by utilizing the :class:`~torch_geometric.loader.NeighborLoader`. Gradients are synchronized by wrapping the model in :class:`torch.nn.parallel.DistributedDataParallel` (as described in its `official tutorial `_), which in turn relies on :obj:`torch.distributed`-IPC-facilities. .. note:: The complete script of this tutorial can be found at `examples/multi_gpu/distributed_sampling.py `_. Defining a Spawnable Runner ~~~~~~~~~~~~~~~~~~~~~~~~~~~ To create our training script, we use the :pytorch:`PyTorch`-provided wrapper of the vanilla :python:`Python` :class:`multiprocessing` module. Here, the :obj:`world_size` corresponds to the number of GPUs we will be using at once. :meth:`torch.multiprocessing.spawn` will take care of spawning :obj:`world_size` processes. Each process will load the same script as a module and subsequently execute the :meth:`run`-function: .. code-block:: python from torch_geometric.datasets import Reddit import torch.multiprocessing as mp def run(rank: int, world_size: int, dataset: Reddit): pass if __name__ == '__main__': dataset = Reddit('./data/Reddit') world_size = torch.cuda.device_count() mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True) Note that we initialize the dataset *before* spawning any processes. With this, we only initialize the dataset once, and any data inside it will be automatically moved to shared memory via :obj:`torch.multiprocessing` such that processes do not need to create their own replica of the data. In addition, note how the :meth:`run` function accepts :obj:`rank` as its first argument. This argument is not explicitly provided by us. It corresponds to the process ID (starting at :obj:`0`) injected by :pytorch:`PyTorch`. Later we will use this to select a unique GPU for every :obj:`rank`. With this, we can start to implement our spawnable runner function. The first step is to initialize a process group with :obj:`torch.distributed`. To this point, processes are not aware of each other and we set a hardcoded server-address for rendezvous using the :obj:`nccl` protocol. More details can be found in the `"Writing Distributed Applications with PyTorch" `_ tutorial: .. code-block:: python import os import torch.distributed as dist import torch def run(rank: int, world_size: int, dataset: Reddit): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12345' dist.init_process_group('nccl', rank=rank, world_size=world_size) Next, we split training indices into :obj:`world_size` many chunks for each GPU, and initialize the :class:`~torch_geometric.loader.NeighborLoader` class to only operate on its specific chunk of the training set: .. code-block:: python from torch_geometric.loader import NeighborLoader def run(rank: int, world_size: int, dataset: Reddit): ... data = dataset[0] train_index = data.train_mask.nonzero().view(-1) train_index = train_index.split(train_index.size(0) // world_size)[rank] train_loader = NeighborLoader( data, input_nodes=train_index, num_neighbors=[25, 10], batch_size=1024, num_workers=4, shuffle=True, ) Note that our :meth:`run` function is called for each rank, which means that each rank holds a separate :class:`~torch_geometric.loader.NeighborLoader` instance. Similarly, we create a :class:`~torch_geometric.loader.NeighborLoader` instance for evaluation. For simplicity, we only do this on rank :obj:`0` such that computation of metrics does not need to communicate across different processes. We recommend taking a look at the `torchmetrics `_ package for distributed computation of metrics. .. code-block:: python def run(rank: int, world_size: int, dataset: Reddit): ... if rank == 0: val_index = data.val_mask.nonzero().view(-1) val_loader = NeighborLoader( data, input_nodes=val_index, num_neighbors=[25, 10], batch_size=1024, num_workers=4, shuffle=False, ) Now that we have our data loaders defined, we initialize our :class:`~torch_geometric.nn.GraphSAGE` model and wrap it inside :class:`torch.nn.parallel.DistributedDataParallel`. We also move the model to its exclusive GPU using the :obj:`rank` as a shortcut for the full device identifier. The wrapper on our model manages communication between each rank and synchronizes gradients across all ranks before updating the model parameters across all ranks: .. code-block:: python from torch.nn.parallel import DistributedDataParallel from torch_geometric.nn import GraphSAGE def run(rank: int, world_size: int, dataset: Reddit): ... torch.manual_seed(12345) model = GraphSAGE( in_channels=dataset.num_features, hidden_channels=256, num_layers=2, out_channels=dataset.num_classes, ).to(rank) model = DistributedDataParallel(model, device_ids=[rank]) Finally, we can set up our optimizer and define our training loop, which follows a similar flow as usual single GPU training loops - the actual magic of gradient and model weight synchronization across different processes will happen behind the scenes within :class:`~torch.nn.parallel.DistributedDataParallel`: .. code-block:: python import torch.nn.functional as F def run(rank: int, world_size: int, dataset: Reddit): ... optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(1, 11): model.train() for batch in train_loader: batch = batch.to(rank) optimizer.zero_grad() out = model(batch.x, batch.edge_index)[:batch.batch_size] loss = F.cross_entropy(out, batch.y[:batch.batch_size]) loss.backward() optimizer.step() After each training epoch, we evaluate and report validation metrics. As previously mentioned, we do this on a single GPU only. To synchronize all processes and to ensure that the model weights have been updated, we need to call :meth:`torch.distributed.barrier`: .. code-block:: python dist.barrier() if rank == 0: print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') if rank == 0: model.eval() count = correct = 0 with torch.no_grad(): for batch in val_loader: batch = batch.to(rank) out = model(batch.x, batch.edge_index)[:batch.batch_size] pred = out.argmax(dim=-1) correct += (pred == batch.y[:batch.batch_size]).sum() count += batch.batch_size print(f'Validation Accuracy: {correct/count:.4f}') dist.barrier() After finishing training, we can clean up processes and destroy the process group via: .. code-block:: python dist.destroy_process_group() And that's it. Putting it all together gives a working multi-GPU example that follows a training flow that is similar to single GPU training. You can run the shown tutorial by yourself by looking at `examples/multi_gpu/distributed_sampling.py `_. ================================================ FILE: docs/source/tutorial/multi_node_multi_gpu_vanilla.rst ================================================ Multi-Node Training using SLURM =============================== .. note:: For multi-GPU training with cuGraph, refer to `cuGraph examples `_. This tutorial introduces a skeleton on how to perform distributed training on multiple GPUs over multiple nodes using the `SLURM workload manager `_ available at many supercomputing centers. The code is based on `our tutorial on single-node multi-GPU training `_. Please go there first to understand the basics if you are unfamiliar with the concepts of distributed training in :pytorch:`PyTorch`. .. note:: The complete script of this tutorial can be found at `examples/multi_gpu/distributed_sampling_multinode.py `_. You can find the example :obj:`*.sbatch` file `next to it `_ and tune it to your needs. A submission script to manage startup ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ As we are now running on multiple nodes, we can no longer use our :obj:`__main__` entrypoint and start processes from there. This is where the workload manager comes in as it allows us to prepare a special :obj:`*.sbatch` file. This file is a standard bash script with instructions on how to setup the processes and your environment. Our example starts with the usual shebang :obj:`#!/bin/bash` and special comments instructing which resources the SLURM system should reserve for our training run. Configuration of the specifics usually depends on your site (and your usage limits!). The following is a minimal example which works with a quite unrestricted configuration available to us: .. code-block:: bash #!/bin/bash #SBATCH --job-name=pyg-multinode-tutorial # identifier for the job listings #SBATCH --output=pyg-multinode.log # outputfile #SBATCH --partition=gpucloud # ADJUST this to your system #SBATCH -N 2 # number of nodes you want to use #SBATCH --ntasks=4 # number of processes to be run #SBATCH --gpus-per-task=1 # every process wants one GPU! #SBATCH --gpu-bind=none # NCCL can't deal with task-binding... This example will create two processes each on two nodes with each process having a single GPU reserved. In the following part, we have to set up some environment variables for :obj:`torch.distributed` to properly do the rendezvous procedure. In theory you could also set those inside the :python:`Python` process: .. code-block:: bash export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) echo "MASTER_ADDR:MASTER_PORT="${MASTER_ADDR}:${MASTER_PORT} If you do not want to let your script randomly open a port and listen for incoming connections, you can also use a file on your shared filesystem. Now the only thing left to add is the execution of the training script: .. code-block:: console srun python distributed_sampling_multinode.py Note how the :obj:`python` call is prefixed with the :obj:`srun` command and thus :obj:`--ntasks` replicas will be started. Finally, to submit the :obj:`*.sbatch` file itself into the work queue, use the :obj:`sbatch` utility in your shell: .. code-block:: console sbatch distributed_sampling_multinode.sbatch Using a cluster configured with pyxis-containers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If your cluster supports the :obj:`pyxis` plugin developed by NVIDIA, you can use a ready-to-use :pyg:`PyG` container that is updated each month with the latest from NVIDIA and :pyg:`PyG`, see `here `_ for more information. The container sets up all necessary environment variables from which you can now directly run the example using :obj:`srun` from your command prompt: .. code-block:: console srun --partition= -N --ntasks= --gpus-per-task=1 --gpu-bind=none --container-name=pyg-test --container-image= --container-mounts='.:/workspace' python3 distributed_sampling_multinode.py Note that :obj:`--container-mounts='.:/workspace'` makes the current folder (which should include the example code) available in the default startup folder :obj:`workspace` of the container. If you want to eventually customize packages in the container without having access to :obj:`docker` (very likely on a public HPC), you can create your own image by following `this tutorial `_. Modifying the training script ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ As SLURM now takes care of creating multiple :python:`Python` processes and we can not share any data (each process will have the full dataset loaded!), our :obj:`__main__` section now has to query the environment for the process setup generated by SLURM or the :obj:`pyxis` container: .. code-block:: python # Get the world size from the WORLD_SIZE variable or directly from SLURM: world_size = int(os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS'))) # Likewise for RANK and LOCAL_RANK: rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID'))) local_rank = int(os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID'))) run(world_size, rank, local_rank) The :meth:`torch.distributed.init_process_group` function will now pick up the :obj:`MASTER_ADDR` from the environment: .. code-block:: python def run(world_size: int, rank: int, local_rank: int): dist.init_process_group('nccl', world_size=world_size, rank=rank) We also have to replace the usage of :obj:`rank` depending on whether we want to use it for node-local purposes like selecting a GPU or global tasks such as data splitting .. code-block:: python train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) train_idx = train_idx.split(train_idx.size(0) // world_size)[rank] while we need to assign the model to a node-local GPU and thus use :obj:`local_rank`: .. code-block:: python model = SAGE(dataset.num_features, 256, dataset.num_classes).to(local_rank) model = DistributedDataParallel(model, device_ids=[local_rank]) ================================================ FILE: docs/source/tutorial/neighbor_loader.rst ================================================ Scaling GNNs via Neighbor Sampling ================================== One of the challenges of Graph Neural Networks is to scale them to large graphs, *e.g.*, in industrial and social applications. Traditional deep neural networks are known to scale well to large amounts of data by decomposing the training loss into individual samples (called a *mini-batch*) and approximating exact gradients stochastically. In contrast, applying stochastic mini-batch training in GNNs is challenging since the embedding of a given node depends recursively on all its neighbor’s embeddings, leading to high inter-dependency between nodes that grows exponentially with respect to the number of layers. This phenomenon is often referred to as *neighbor explosion*. As a simple workaround, GNNs are typically executed in a full-batch fashion (see `here `_ for an example), where the GNN has access to all hidden node representations in all its layers. However, this is not feasible in large-scale graphs due to memory limitations and slow convergence. Scalability techniques are indispensable for applying GNNs to large-scale graphs in order to alleviate the neighbor explosion problem induced by mini-batch training, *i.e.* **node-wise**, **layer-wise** or **subgraph-wise** sampling techniques, or to **decouple propagations from predictions**. In this tutorial, we take a closer look at the most common node-wise sampling approach, originally introduced in the `"Inductive Representation Learning on Large Graphs" `_ paper. Neighbor Sampling ----------------- :pyg:`PyG` implements neighbor sampling via its :class:`torch_geometric.loader.NeighborLoader` class. Neighbor sampling works by recursively sampling a fixed number of at most :math:`k` neighbors for a node :math:`v \in \mathcal{V}`, *i.e.* :math:`\tilde{\mathcal{N}}(v) \subset \mathcal{N}(v)` with :math:`|\tilde{\mathcal{N}}| \le k`, leading to an overall bounded :math:`L`-hop neighborhood size of :math:`\mathcal{O}(k^L)`. That is, starting from a set of seed nodes :math:`\mathcal{B} \subset \mathcal{V}`, we sample at most :math:`k` neighbors for every node in :math:`v \in \mathcal{B}`, and then proceed to sample neighbors for every sampled node in the previous hop, and so on. The resulting graph structure holds a **directed** :math:`L`-hop subgraph around every node in :math:`v \in \mathcal{B}`, for which it is guaranteed that every node has at least one path of at most length :math:`L` to at least one of the seed nodes in :math:`\mathcal{B}`. As such, a message passing GNN with :math:`L` layers will incorporate the full set of sampled nodes in its computation graph. .. figure:: ../_static/thumbnails/neighbor_loader.png :align: center :width: 40% | It is important to note that neighbor sampling can only mitigate the neighbor explosion problem to some extend since the overall neighborhood size still increases exponentially with the number of layers. As a result, sampling for more than two or three iterations is generally not feasible. Often times, the number of sampled hops and the number of message passing layers is kept in sync. Specifically, it is very wasteful to sample for more hops than there exist message passing layers since the GNN will never be able to incorporate the features of the nodes sampled in later hops into the final node representation of its seed nodes. However, it is nonetheless possible to utilize deeper GNNs, but one needs to be careful to convert the sampled subgraph into a bidirectional variant to ensure correct message passing flow. :pyg:`PyG` provides support for this via an additional argument in :class:`~torch_geometric.loader.NeighborLoader`, while other mini-batch techniques are designed for this use-case out-of-the-box, *e.g.*, :class:`~torch_geometric.loader.ClusterLoader`, :class:`~torch_geometric.loader.GraphSAINTSampler` and :class:`~torch_geometric.loader.ShaDowKHopSampler`. Basic Usage ----------- .. note:: In this section of the tutorial, we will learn how to utilize the :class:`~torch_geometric.nn.models.Node2Vec` class of :pyg:`PyG` to train GNNs on single graphs in a mini-batch fashion. A fully working example on large-scale real-world data is available in `examples/reddit.py `_. The :class:`~torch_geometric.loader.NeighborLoader` is initialized from a :pyg:`PyG` :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and defines how sampling should be performed: * :obj:`input_nodes` defines the set of seed nodes from which we want to start sampling from. * :obj:`num_neighbors` defines the number of neighbors to sample for each node in each hop. * :obj:`batch_size` defines the size of seed nodes we want to consider at once. * :obj:`replace` defines whether to sample with or without replacement. * :obj:`shuffle` defines whether seed nodes should be shuffled at every epoch. .. code-block:: python import torch from torch_geometric.data import Data from torch_geometric.loader import NeighborLoader x = torch.randn(8, 32) # Node features of shape [num_nodes, num_features] y = torch.randint(0, 4, (8, )) # Node labels of shape [num_nodes] edge_index = torch.tensor([ [2, 3, 3, 4, 5, 6, 7], [0, 0, 1, 1, 2, 3, 4]], ) # 0 1 # / \/ \ # 2 3 4 # | | | # 5 6 7 data = Data(x=x, y=y, edge_index=edge_index) loader = NeighborLoader( data, input_nodes=torch.tensor([0, 1]), num_neighbors=[2, 1], batch_size=1, replace=False, shuffle=False, ) Here, we initialize the :class:`~torch_geometric.loader.NeigborLoader` to sample subgraphs for the first two nodes, where we want to sample 2 neighbors in the first hop, and 1 neighbor in the second hop. Our :obj:`batch_size` is set to :obj:`1`, such that :obj:`input_nodes` will be split into chunks of size :obj:`1`. In the execution of :class:`~torch_geometric.loader.NeighborLoader`, we expect that the seed node :obj:`0` samples nodes :obj:`2` and :obj:`3` in the first hop. In the second hop, node :obj:`2` samples node :obj:`5`, and node :obj:`3` samples node :obj:`6`. Let's confirm by looking at the output of the :obj:`loader`: .. code-block:: python batch = next(iter(loader)) batch.edge_index >>> tensor([[1, 2, 3, 4], [0, 0, 1, 2]]) batch.n_id >>> tensor([0, 2, 3, 5, 6]) batch.batch_size >>> 1 The :class:`~torch_geometric.loader.NeighborLoader` will return a :class:`~torch_geometric.data.Data` object, which contains the following attributes: * :obj:`batch.edge_index` contain the edge indices of the subgraph. * :obj:`batch.n_id` contains the original node indices of all the sampled nodes. * :obj:`batch.batch_size` contains the number of seed nodes/the batch size. In addition, node and edge features will be filtered to only contain the features of sampled nodes/edges, respectively. Importantly, :obj:`batch.edge_index` contains the sampled subgraph with relabeled node indices, such that its indices range from :obj:`0` to :obj:`batch.num_nodes - 1`. If you want to reconstruct the original node indices of :obj:`batch.edge_index`, do: .. code-block:: python batch.n_id[batch.edge_index] >>> tensor([[2, 3, 5, 6], [0, 0, 2, 3]]) Furthermore, while :class:`~torch_geometric.loader.NeighborLoader` starts sampling *from* seed nodes, the resulting subgraph will hold edges that point *to* the seed nodes. This aligns well with the default :pyg:`PyG` message passing flow from source to destination nodes. Lastly, nodes in the output of :class:`~torch_geometric.loader.NeighborLoader` are guaranteed to be sorted. In particular, the first :obj:`batch_size` sampled nodes will exactly match with the seed nodes that were used for sampling: .. code-block:: python batch.n_id[:batch.batch_size] >>> tensor([0]) Afterwards, we can use :class:`~torch_geometric.loader.NeighborLoader` as a data loading routine to train GNNs on large-scale graphs in mini-batch fashion. For this, let's create a simple two-layer :class:`~torch_geometric.nn.models.GraphSAGE` model: .. code-block:: python from torch_geometric.nn import GraphSAGE device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GraphSAGE( in_channels=32, hidden_channels=64, out_channels=4, num_layers=2 ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) We can now combine the :obj:`loader` and :obj:`model` to define our training routine: .. code-block:: python import torch.nn.functional as F for batch in loader: optimizer.zero_grad() batch = batch.to(device) out = model(batch.x, batch.edge_index) # NOTE Only consider predictions and labels of seed nodes: y = batch.y[:batch.batch_size] out = out[:batch.batch_size] loss = F.cross_entropy(out, y) loss.backward() optimizer.step() The training loop follows a similar design to any other :pytorch:`PyTorch` training loop. The only important difference is that by default the model will output a matrix of shape :obj:`[batch.num_nodes, *]`, while we are only interested in the predictions of the seed nodes. As such, we can use efficient slicing both on the node predictions and the ground-truth information :obj:`batch.y` to only obtain predictions and ground-truth information of actual seed nodes. This ensures that we are only making use of the first :obj:`batch_size` many nodes for loss and metric computation. Hierarchical Extension ---------------------- A drawback of :class:`~torch_geometric.loader.Neighborloader` is that it computes a representations for *all* sampled nodes at *all* depths of the network. However, nodes sampled in later hops no longer contribute to the node representations of seed nodes in later GNN layers, thus performing useless computation. :class:`~torch_geometric.loader.NeighborLoader` will be marginally slower since we are computing node embeddings for nodes we no longer need. This is a trade-off we make to obtain a clean, modular and experimental-friendly GNN design, which does not tie the definition of the model to its utilized data loader routine. The `Hierarchical Neighborhood Sampling <../advanced/hgam.html>`__ tutorial shows how to eliminate this overhead and speed up training and inference in mini-batch GNNs further. Advanced Options ---------------- :class:`~torch_geometric.loader.NeighborLoader` provides many more features for advanced usage. In particular, * :class:`~torch_geometric.loader.NeighborLoader` supports both sampling on homogeneous and heterogeneous graphs out-of-the-box. For sampling on heterogeneous graphs, simply initialize it with a :class:`~torch_geometric.data.HeteroData` object. Sampling on heterogeneous graphs via :class:`~torch_geometric.loader.NeighborLoader` allows for fine-granular control of sampling parameters, *e.g.*, it allows to specify the number of neighbors to sample for each edge type individually. Take a look at the `Heterogeneous Graph Learning <../advanced/heterogeneous.html>`__ tutorial for additional information. * By default, :class:`~torch_geometric.loader.NeighborLoader` fuses sampled nodes across different seed nodes into a single subgraph. This way, shared neighbors of seed nodes will not be duplicated in the resulting subgraph and hence save memory. You can disable this behavior by passing the :obj:`disjoint=True` option to the :class:`~torch_geometric.loader.NeighborLoader`. * By default, the subgraphs returned from :class:`~torch_geometric.loader.NeighborLoader` will be **directed**, which restricts its use to GNNs with equal depth to the number of sampling hops. If you want to utilize deeper GNNs, specify the :obj:`subgraph_type` option. If set to :obj:`"bidirectional"`, sampled edges are converted to bidirectional edges. If set to :obj:`"induced"`, the returned subgraph will contain the induced subgraph of all sampled nodes. * :class:`~torch_geometric.loader.NeighborLoader` is designed to perform sampling from individual seed nodes. As such, it is not directly applicable in a link prediction scenario. For this use-cases, we developed the :class:`~torch_geometric.loader.LinkNeighborLoader`, which expects a set of input edges, and will return subgraphs that were created via neighbor sampling from both source and destination nodes. ================================================ FILE: docs/source/tutorial/point_cloud.rst ================================================ Point Cloud Processing ====================== This tutorial explains how to leverage Graph Neural Networks (GNNs) for operating and training on point cloud data. Although point clouds do not come with a graph structure by default, we can utilize :pyg:`PyG` transformations to make them applicable for the full suite of GNNs available in :pyg:`PyG`. The key idea is to create a synthetic graph from point clouds, from which we can learn meaningful local geometric structures via a GNN's message passing scheme. These point representations can then be used to, *e.g.*, perform point cloud classification or segmentation. 3D Point Cloud Datasets ----------------------- :pyg:`PyG` provides several point cloud datasets, such as the :class:`~torch_geometric.datasets.PCPNetDataset`, :class:`~torch_geometric.datasets.S3DIS` and :class:`~torch_geometric.datasets.ShapeNet` datasets. To get started, we also provide the :class:`~torch_geometric.datasets.GeometricShapes` dataset, which is a toy dataset that contains various geometric shapes such cubes, spheres or pyramids. Notably, the :class:`~torch_geometric.datasets.GeometricShapes` dataset contains meshes instead of point clouds by default, represented via :obj:`pos` and :obj:`face` attributes, which hold the information of vertices and their triangular connectivity, respectively: .. code-block:: python from torch_geometric.datasets import GeometricShapes dataset = GeometricShapes(root='data/GeometricShapes') print(dataset) >>> GeometricShapes(40) data = dataset[0] print(data) >>> Data(pos=[32, 3], face=[3, 30], y=[1]) When visualizing the first mesh in the dataset, we can see that it represents a circle: .. figure:: ../_figures/point_cloud1.png :align: center :width: 40% | Since we are interested in point clouds, we can transform our meshes into points via the usage of :class:`torch_geometric.transforms`. In particular, :pyg:`PyG` provides the :class:`~torch_geometric.transforms.SamplePoints` transformation, which will uniformly sample a fixed number of points on the mesh faces according to their face area. We can add this transformation to the dataset by simply setting it via :obj:`dataset.transform = SamplePoints(num=...)`. Each time an example is accessed from the dataset, the transformation procedure will get called, converting our mesh into a point cloud. Note that sampling points is stochastic, and so you will receive a new point cloud upon every access: .. code-block:: python import torch_geometric.transforms as T dataset.transform = T.SamplePoints(num=256) data = dataset[0] print(data) >>> Data(pos=[256, 3], y=[1]) Note that we now have :obj:`256` points in our example, and the triangular connectivity stored in :obj:`face` has been removed. Visualizing the points now shows that we have correctly sampled points on the surface of the initial mesh: .. figure:: ../_figures/point_cloud2.png :align: center :width: 40% | Finally, let's convert our point cloud into a graph. Since we are interested in learning local geometric structures, we want to construct a graph in such a way that nearby points are connected. Typically, this is either done via :math:`k`-nearest neighbor search or via ball queries (which connect all points that are within a certain radius to the query point). :pyg:`PyG` provides utilities for such graph generation via the :class:`~torch_geometric.transforms.KNNGraph` and :class:`~torch_geometric.transforms.RadiusGraph` transformations, respectively. .. code-block:: python from torch_geometric.transforms import SamplePoints, KNNGraph dataset.transform = T.Compose([SamplePoints(num=256), KNNGraph(k=6)]) data = dataset[0] print(data) >>> Data(pos=[256, 3], edge_index=[2, 1536], y=[1]) You can see that the :obj:`data` object now also contains an :obj:`edge_index` representation, holding :obj:`1536` edges in total, 6 edges for every of the 256 points. We can confirm that our graph looks good via the following visualization: .. figure:: ../_figures/point_cloud3.png :align: center :width: 40% | PointNet++ Implementation ------------------------- `PointNet++ `_ is a pioneering work that proposes a Graph Neural Network architecture for point cloud classification and segmentation. PointNet++ processes point clouds iteratively by following a simple grouping, neighborhood aggregation and downsampling scheme: .. figure:: ../_figures/point_cloud4.png :align: center :width: 100% | 1. The **grouping phase** constructs a graph :math:`k`-nearest neighbor search or via ball queries as described above. 2. The **neighborhood aggregation** phase executes a GNN layer that, for each point, aggregates information from its direct neighbors (given by the graph constructed in the previous phase). This allows PointNet++ to capture local context at different scales. 3. The **downsampling phase** implements a pooling scheme suitable for point clouds with potentially different sizes. Due to simplicity, we will ignore this phase for now. We recommend to take a look at `examples/pointnet2_classification.py `_ on guidance to how to implement this step. Neighborhood Aggregation ~~~~~~~~~~~~~~~~~~~~~~~~ The PointNet++ layer follows a simple neural message passing scheme defined via .. math:: \mathbf{h}^{(\ell + 1)}_i = \max_{j \in \mathcal{N}(i)} \textrm{MLP} \left( \mathbf{h}_j^{(\ell)}, \mathbf{p}_j - \mathbf{p}_i \right) where * :math:`\mathbf{h}_i^{(\ell)} \in \mathbb{R}^d` denotes the hidden features of point :math:`i` in layer :math:`\ell`, and * :math:`\mathbf{p}_i \in \mathbf{R}^3$` denotes the position of point :math:`i`. We can make use of the :class:`~torch_geometric.nn.conv.MessagePassing` interface in :pyg:`PyG` to implement this layer from scratch. The :class:`~torch_geometric.nn.conv.MessagePassing` interface helps us in **creating message passing graph neural networks** by automatically taking care of message propagation. Here, we only need to define its :meth:`~torch_geometric.nn.conv.MessagePassing.message` function and which aggregation scheme we want to use, *e.g.*, :obj:`aggr="max"` (see `here `_ for the accompanying tutorial): .. code-block:: python from torch import Tensor from torch.nn import Sequential, Linear, ReLU from torch_geometric.nn import MessagePassing class PointNetLayer(MessagePassing): def __init__(self, in_channels: int, out_channels: int): # Message passing with "max" aggregation. super().__init__(aggr='max') # Initialization of the MLP: # Here, the number of input features correspond to the hidden # node dimensionality plus point dimensionality (=3). self.mlp = Sequential( Linear(in_channels + 3, out_channels), ReLU(), Linear(out_channels, out_channels), ) def forward(self, h: Tensor, pos: Tensor, edge_index: Tensor, ) -> Tensor: # Start propagating messages. return self.propagate(edge_index, h=h, pos=pos) def message(self, h_j: Tensor, pos_j: Tensor, pos_i: Tensor, ) -> Tensor: # h_j: The features of neighbors as shape [num_edges, in_channels] # pos_j: The position of neighbors as shape [num_edges, 3] # pos_i: The central node position as shape [num_edges, 3] edge_feat = torch.cat([h_j, pos_j - pos_i], dim=-1) return self.mlp(edge_feat) As one can see, implementing the PointNet++ layer is quite straightforward in :pyg:`PyG`. In the :meth:`__init__` function, we first define that we want to apply **max aggregation**, and afterwards initialize an MLP that takes care of transforming node features of neighbors and the spatial relation between source and destination nodes to a (trainable) message. In the :meth:`forward` function, we can start **propagating messages** based on :obj:`edge_index`, and pass in everything needed in order to create messages. In the :meth:`message` function, we can now access neighbor and central node information via :obj:`*_j` and :obj:`*_i` suffixes, respectively, and return a message for each edge. Network Architecture ~~~~~~~~~~~~~~~~~~~~ We can make use of above :class:`PointNetLayer` to define our network architecture (or use its equivalent :class:`torch_geometric.nn.conv.PointNetConv` directly integrated in :pyg:`PyG`). With this, our overall :class:`PointNet` architecture looks as follows: .. code-block:: python from torch_geometric.nn import global_max_pool class PointNet(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = PointNetLayer(3, 32) self.conv2 = PointNetLayer(32, 32) self.classifier = Linear(32, dataset.num_classes) def forward(self, pos: Tensor, edge_index: Tensor, batch: Tensor, ) -> Tensor: # Perform two-layers of message passing: h = self.conv1(h=pos, pos=pos, edge_index=edge_index) h = h.relu() h = self.conv2(h=h, pos=pos, edge_index=edge_index) h = h.relu() # Global Pooling: h = global_max_pool(h, batch) # [num_examples, hidden_channels] # Classifier: return self.classifier(h) model = PointNet() If we inspect the model, we can see the everything is initialized correctly: .. code-block:: python print(model) >>> PointNet( ... (conv1): PointNetLayer() ... (conv2): PointNetLayer() ... (classifier): Linear(in_features=32, out_features=40, bias=True) ... ) Here, we create our network architecture by inheriting from :class:`torch.nn.Module` and initialize **two** :class:`PointNetLayer` **modules** and a **final linear classifier** in its constructor. In the :meth:`forward` method, we apply two graph-based convolutional operators and enhance them by ReLU non-linearities. The first operator takes in 3 input features (the positions of nodes) and maps them to 32 output features. After that, each point holds information about its 2-hop neighborhood, and should already be able to distinguish between simple local shapes. Next, we apply a global graph readout function, *i.e.*, :meth:`~torch_geometric.nn.pool.global_max_pool`, which takes the maximum value along the node dimension for each example. In order to map the different nodes to their corresponding examples, we use the :obj:`batch` vector which will be automatically created for use when using the mini-batch :class:`torch_geometric.loader.DataLoader`. Last, we apply a linear classifier to map the global 32 features per point cloud to one of the 40 classes. Training Procedure ~~~~~~~~~~~~~~~~~~ We are now ready to write two simple procedures to train and test our model on the training and test datasets, respectively. If you are not new to :pytorch:`PyTorch`, this scheme should appear familiar to you. Otherwise, the :pytorch:`PyTorch` documentation provide a `good introduction on how to train a neural network in PyTorch `_: .. code-block:: python from torch_geometric.loader import DataLoader train_dataset = GeometricShapes(root='data/GeometricShapes', train=True) train_dataset.transform = T.Compose([SamplePoints(num=256), KNNGraph(k=6)]) test_dataset = GeometricShapes(root='data/GeometricShapes', train=False) test_dataset.transform = T.Compose([SamplePoints(num=256), KNNGraph(k=6)]) train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=10) model = PointNet() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) criterion = torch.nn.CrossEntropyLoss() def train(): model.train() total_loss = 0 for data in train_loader: optimizer.zero_grad() logits = model(data.pos, data.edge_index, data.batch) loss = criterion(logits, data.y) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs return total_loss / len(train_loader.dataset) @torch.no_grad() def test(): model.eval() total_correct = 0 for data in test_loader: logits = model(data.pos, data.edge_index, data.batch) pred = logits.argmax(dim=-1) total_correct += int((pred == data.y).sum()) return total_correct / len(test_loader.dataset) for epoch in range(1, 51): loss = train() test_acc = test() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}') Using this setup, you should get around **75%-80% test set accuracy**, even when training only on a single example per class. ================================================ FILE: docs/source/tutorial/shallow_node_embeddings.rst ================================================ Shallow Node Embeddings ======================= In this tutorial, we will take a closer look at how to learn *shallow node embeddings* in an unsupervised fashion via :pyg:`PyG`. Introduction ------------ The key difference between *shallow* node embeddings (*e.g.,* :class:`~torch_geometric.nn.models.Node2Vec`) and *deep* node embeddings (*e.g.,* GNNs) is the choice of the encoder :math:`\textrm{ENC}(v, \mathcal{G}) = \mathbf{z}_v \in \mathbb{R}^d`. Specifically, shallow node embedding techniques rely on embedding nodes into low-dimensional vectorial representations :math:`\mathbf{z}_v` via a *shallow embedding lookup table* such that the likelihood of preserving neighborhoods is maximized, *i.e.* nearby nodes should receive similar embeddings while distant nodes should receive distinct embedding. These techniques generalize the famous `SkipGram `_ model for obtaining low-dimensional word embeddings, in which sequences of words are now interpreted as sequences of nodes, *e.g.*, given via randomly-generated walks: .. figure:: ../_figures/shallow_node_embeddings.png :align: center :width: 100% | Specifically, given a *random walk* :math:`\mathcal{W} = (v_{\pi(1)}, \ldots, v_{\pi_(k)})` of length :math:`k` starting at node :math:`v \in \mathcal{V}`, the objective is to maximize the likelihood of observing node :math:`v_{\pi(i)}` given node :math:`v`. This objective can be efficiently trained via stochastic gradient descent in a contrastive learning scenario .. math:: \mathcal{L} = \sum_{w \in \mathcal{W}} - \log \left(\sigma(\mathbf{z}_v^{\top} \mathbf{z}_w) \right) + \sum_{w \sim \mathcal{V} \setminus \mathcal{W}} - \log \left( 1 - \sigma(\mathbf{z}_v^{\top} \mathbf{z}_w) \right), in which non-existent walks (so called *negative examples*) are sampled and trained jointly, and :math:`\sigma` denotes the :math:`\textrm{sigmoid}` function. Noteworthy, the dot-product :math:`\mathbf{z}_v^{\top} \mathbf{z}_w` between the embeddings is usually used to measure similarity, but other similarity measures are applicable as well. Importantly, shallow node embeddings are trained in an unsupervised fashion, and can eventually be used as input for a given down-stream task, *e.g.*, in node-level tasks :math:`\mathbf{z}_v` can directly be used as input to a final classifier. For edge-level tasks, edge-level representations can be obtained via averaging :math:`\frac{1}{2} (\mathbf{z}_v + \mathbf{z}_w)` or via the Hadamard product :math:`\mathbf{z}_v \odot \mathbf{z}_w`. Despite the simplicity of node embedding techniques, they are also subject to certain shortcomings. In particular, they fail to incorporate rich feature information attached to nodes and edges, and cannot be trivially applied to unseen graphs as learnable parameters are fixed to the nodes of a particular graph (making this approach transductive by nature and hard-to-scale due to the :math:`\mathcal{O}(|\mathcal{V}| \cdot d)` parameter complexity). However, it is still a commonly used technique to preserve structural graph information into fixed-size vectors, and is often times also used to generate inputs to GNNs for further processing in case the initial set of node features is not rich. Node2Vec -------- .. note:: In this section of the tutorial, we will learn node embeddings for **homogenous graphs** using the :class:`~torch_geometric.nn.models.Node2Vec` module of :pyg:`PyG`. The code is available in `examples/node2vec.py `_ and as a `Google Colab tutorial notebook `_. :class:`~torch_geometric.nn.models.Node2Vec` is a method for learning shallow node embeddings, which allows for flexible control of random walk procedures based on breadth-first or depth-first samplers. In particular, its parameter :obj:`p` dictates the likelihood of immediately revisiting a node in the walk, while its parameter :obj:`q` interpolates between breadth-first and depth-first strategies. To begin the example, let us load in the needed packages and the data that we will be working with: .. code-block:: python from torch_geometric.nn import Node2Vec data = Planetoid('./data/Planetoid', name='Cora')[0] We are now ready to initialize our :class:`~torch_geometric.nn.module.Node2Vec` module: .. code-block:: python import torch from torch_geometric.nn import Node2Vec device = 'cuda' if torch.cuda.is_available() else 'cpu' model = Node2Vec( data.edge_index, embedding_dim=128, walks_per_node=10, walk_length=20, context_size=10, p=1.0, q=1.0, num_negative_samples=1, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) :class:`~torch_geometric.nn.models.Node2Vec` takes the graph structure :obj:`edge_index` as input (but none of its feature information), the :obj:`embedding_dim` of the shallow embeddings, and additional parameters to control the random walk and negative sampling procedures. In particular, :obj:`walks_per_node` and :obj:`walk_length` specify the number of walks to perform for each node and their length, respectively. The :obj:`context_size` then denotes how many nodes in the walk are actually used for gradient optimization, *i.e* :class:`~torch_geometric.nn.models.Node2Vec` slides over each sampled walk and splits them into windows of size :obj:`context_size`. As previously mentioned, :obj:`p` and :obj:`q` denote how random walks are generated. Finally, :obj:`num_negative_samples` specifies how many negative walks we want to generate for each positive walk. After initializing, we can go ahead and train our :class:`~torch_geometric.nn.models.Node2Vec` model right away. We start this by creating a data loader that will generate positive and negative random walks for us: .. code-block:: python loader = model.loader(batch_size=128, shuffle=True, num_workers=4) To generate random walks, we can simply iterate over the data loader, *e.g.*: .. code-block:: python pos_rw, neg_rw = next(iter(loader)) Here, :obj:`pos_rw` will contain the node indices of positive random walks and :obj:`neg_rw` will contain the node indices of negative walks. In particular, :obj:`pos_rw` is a two-dimensional matrix of shape :obj:`[batch_size * walks_per_node * (2 + walk_length - context_size), context_size]`, and :obj:`neg_rw` is a two-dimensional matrix of shape :obj:`[num_negative_samples * pos_rw.size(0), context_size]`. Using this :obj:`loader` and the built-in constrastive :meth:`~torch_geometric.nn.models.Node2Vec.loss` function, we can define our :meth:`train` function as follows: .. code-block:: python def train(): model.train() total_loss = 0 for pos_rw, neg_rw in loader: optimizer.zero_grad() loss = model.loss(pos_rw.to(device), neg_rw.to(device)) loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(loader) After finishing training, we can obtain the final node embeddings from the model as follows: .. code-block:: python z = model() # Full node-level embeddings. z = model(torch.tensor([0, 1, 2])) # Embeddings of first three nodes. MetaPath2Vec ------------ .. note:: In this section of the tutorial, we will learn node embeddings for **heterogenous graphs** using the :class:`~torch_geometric.nn.models.MetaPath2Vec` module of :pyg:`PyG`. The code is available as `examples/hetero/metapath2vec.py `_ and as a `Google Colab tutorial notebook `_. An extension of :class:`~torch_geometric.nn.models.Node2Vec` to *heterogeneous graphs* is the :class:`~torch_geometric.nn.models.MetaPath2Vec` model. :class:`~torch_geometric.nn.models.MetaPath2Vec` works similar to :class:`~torch_geometric.nn.models.Node2Vec` but expects a dictionary of edge indices as input (holding the :obj:`edge_index` for each edge type in the graph), and samples random walks based on a given :obj:`metapath` formulation, *e.g.*, .. code-block:: python metapath = [ ('author', 'writes', 'paper'), ('paper', 'published_in', 'venue'), ('venue', 'publishes', 'paper'), ('paper', 'written_by', 'author'), ] denotes that random walk sampling is performed from author nodes to paper nodes to venue nodes back to paper nodes and author nodes. Otherwise, initialization and training of the model stays the same as in the :class:`~torch_geometric.nn.models.Node2Vec` case. ================================================ FILE: examples/README.md ================================================ # Examples This folder contains a plethora of examples covering different GNN use-cases. This readme highlights some key examples. > [!NOTE] > We recommend the [NVIDIA PyG Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg/tags) for best results and easiest setup with NVIDIA GPUs. See the [cuGraph installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html#accelerating-pyg-with-nvidia-cugraph-gnn) for details. A great and simple example to start with is [`gcn.py`](./gcn.py), showing a user how to train a [`GCN`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GCN.html) model for node-level prediction on small-scale homogeneous data. For a simple GNN based link prediction example, see [`link_pred.py`](./link_pred.py). For an improved GNN based link prediction approach using Attract-Repel embeddings that can significantly boost accuracy (up to 23% improvement in AUC), see [`ar_link_pred.py`](./ar_link_pred.py). This approach is based on [Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs](https://arxiv.org/abs/2106.09671). To see an example for doing link prediction with an advanced Graph Transformer called [`LPFormer`](https://arxiv.org/abs/2310.11009), see \[`lpformer.py`\]. For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see the `ogbn_*.py` examples: - [`ogbn_train.py`](./ogbn_train.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges or the medium scale `ogbn-products` dataset, ~62M edges. - Uses SGFormer (a kind of GraphTransformer) by default. - [SGFormer Paper](https://arxiv.org/pdf/2306.10759) - [Polynormer](https://arxiv.org/pdf/2403.01232) - [Kumo.ai x NVIDIA x Stanford Graph Transformer Webinar](https://www.youtube.com/watch?v=wAYryx3GjLw) - [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset. - [`ogbn_train_perforatedai.py`](https://github.com/PerforatedAI/PerforatedAI-Examples/tree/master/otherExamples/torch_geometric/OGBNProducts) shows how to optimize the `ogbn_train.py` workflow using [Perforated AI](https://github.com/PerforatedAI/PerforatedAI-API). Perforated AI provides a PyTorch add-on which increases network accuracy by empowering each artificial neuron with artificial dendrites. For an example on [Relational Deep Learning](https://arxiv.org/abs/2312.04615) with the [RelBench datasets](https://relbench.stanford.edu/), see [`rdl.py`](./rdl.py). For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile). For examples on scaling PyG up via multi-GPUs, see the examples under [`examples/multi_gpu`](./multi_gpu). For examples on working with heterogeneous data, see the examples under [`examples/hetero`](./hetero). For examples on co-training LLMs with GNNs, see the examples under [`examples/llm`](./llm). - [Stanford GNN+LLM Talk](https://www.nvidia.com/en-us/on-demand/session/other25-nv-0003/) We recommend looking into [PyTorch documentation](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) for examples on setting up model parralel GNNs. ### Scale to Trillions of Edges with cuGraph [cuGraph](https://github.com/rapidsai/cugraph) is a collection of packages focused on GPU-accelerated graph analytics including support for property graphs and scaling up to thousands of GPUs. cuGraph supports the creation and manipulation of graphs followed by the execution of scalable fast graph algorithms. It is part of the [RAPIDS](https://rapids.ai) accelerated data science framework. [cuGraph GNN](https://github.com/rapidsai/cugraph-gnn) is a collection of GPU-accelerated plugins that support PyTorch and PyG natively through the _cuGraph-PyG_ and _WholeGraph_ subprojects. cuGraph GNN is built on top of cuGraph, leveraging its low-level [pylibcugraph](https://github.com/rapidsai/cugraph/python/pylibcugraph) API and C++ primitives for sampling and other GNN operations ([libcugraph](https://github.com/rapidai/cugraph/python/libcugraph)). It also includes the `libwholegraph` and `pylibwholegraph` libraries for high-performance distributed edgelist and embedding storage. Users have the option of working with these lower-level libraries directly, or through the higher-level API in cuGraph-PyG that directly implements the `GraphStore`, `FeatureStore`, `NodeLoader`, and `LinkLoader` interfaces. Complete documentation on RAPIDS graph packages, including `cugraph`, `cugraph-pyg`, `pylibwholegraph`, and `pylibcugraph` is available on the [RAPIDS docs pages](https://docs.rapids.ai/api/cugraph/nightly/graph_support). See [`rapidsai/cugraph-gnn/tree/branch-25.12/python/cugraph-pyg/cugraph_pyg/examples` on GitHub](https://github.com/rapidsai/cugraph-gnn/tree/branch-25.12/python/cugraph-pyg/cugraph_pyg/examples) for fully scalable PyG example workflows. ================================================ FILE: examples/agnn.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import AGNNConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = torch.nn.Linear(dataset.num_features, 16) self.prop1 = AGNNConv(requires_grad=False) self.prop2 = AGNNConv(requires_grad=True) self.lin2 = torch.nn.Linear(16, dataset.num_classes) def forward(self): x = F.dropout(data.x, training=self.training) x = F.relu(self.lin1(x)) x = self.prop1(x, data.edge_index) x = self.prop2(x, data.edge_index) x = F.dropout(x, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() @torch.no_grad() def test(): model.eval() out, accs = model(), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs best_val_acc = test_acc = 0 for epoch in range(1, 201): train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/ar_link_pred.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv from torch_geometric.utils import negative_sampling, train_test_split_edges class GCNEncoder(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) class LinkPredictor(torch.nn.Module): def __init__(self, in_channels, hidden_channels): super().__init__() self.lin1 = torch.nn.Linear(in_channels * 2, hidden_channels) self.lin2 = torch.nn.Linear(hidden_channels, 1) def forward(self, z_i, z_j): x = torch.cat([z_i, z_j], dim=1) x = self.lin1(x).relu() x = self.lin2(x) return x.view(-1) class ARLinkPredictor(torch.nn.Module): def __init__(self, in_channels): super().__init__() # Split dimensions between attract and repel self.attract_dim = in_channels // 2 self.repel_dim = in_channels - self.attract_dim def forward(self, z_i, z_j): # Split into attract and repel parts z_i_attr = z_i[:, :self.attract_dim] z_i_repel = z_i[:, self.attract_dim:] z_j_attr = z_j[:, :self.attract_dim] z_j_repel = z_j[:, self.attract_dim:] # Calculate AR score attract_score = (z_i_attr * z_j_attr).sum(dim=1) repel_score = (z_i_repel * z_j_repel).sum(dim=1) return attract_score - repel_score def train(encoder, predictor, data, optimizer): encoder.train() predictor.train() # Forward pass and calculate loss optimizer.zero_grad() z = encoder(data.x, data.train_pos_edge_index) # Positive edges pos_out = predictor(z[data.train_pos_edge_index[0]], z[data.train_pos_edge_index[1]]) # Sample and predict on negative edges neg_edge_index = negative_sampling( edge_index=data.train_pos_edge_index, num_nodes=data.num_nodes, num_neg_samples=data.train_pos_edge_index.size(1), ) neg_out = predictor(z[neg_edge_index[0]], z[neg_edge_index[1]]) # Calculate loss pos_loss = F.binary_cross_entropy_with_logits(pos_out, torch.ones_like(pos_out)) neg_loss = F.binary_cross_entropy_with_logits(neg_out, torch.zeros_like(neg_out)) loss = pos_loss + neg_loss loss.backward() optimizer.step() return loss.item() @torch.no_grad() def test(encoder, predictor, data): encoder.eval() predictor.eval() z = encoder(data.x, data.train_pos_edge_index) pos_val_out = predictor(z[data.val_pos_edge_index[0]], z[data.val_pos_edge_index[1]]) neg_val_out = predictor(z[data.val_neg_edge_index[0]], z[data.val_neg_edge_index[1]]) pos_test_out = predictor(z[data.test_pos_edge_index[0]], z[data.test_pos_edge_index[1]]) neg_test_out = predictor(z[data.test_neg_edge_index[0]], z[data.test_neg_edge_index[1]]) val_auc = compute_auc(pos_val_out, neg_val_out) test_auc = compute_auc(pos_test_out, neg_test_out) return val_auc, test_auc def compute_auc(pos_out, neg_out): pos_out = torch.sigmoid(pos_out).cpu().numpy() neg_out = torch.sigmoid(neg_out).cpu().numpy() # Simple AUC calculation from sklearn.metrics import roc_auc_score y_true = torch.cat( [torch.ones(pos_out.shape[0]), torch.zeros(neg_out.shape[0])]) y_score = torch.cat([torch.tensor(pos_out), torch.tensor(neg_out)]) return roc_auc_score(y_true, y_score) def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='Cora', choices=['Cora', 'CiteSeer', 'PubMed']) parser.add_argument('--hidden_channels', type=int, default=128) parser.add_argument('--out_channels', type=int, default=64) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--use_ar', action='store_true', help='Use Attract-Repel embeddings') parser.add_argument('--lr', type=float, default=0.01) args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load dataset transform = T.Compose([ T.NormalizeFeatures(), T.ToDevice(device), ]) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', args.dataset) dataset = Planetoid(path, args.dataset, transform=transform) data = dataset[0] # Process data for link prediction data = train_test_split_edges(data) # Initialize encoder encoder = GCNEncoder( in_channels=dataset.num_features, hidden_channels=args.hidden_channels, out_channels=args.out_channels, ).to(device) # Choose predictor based on args if args.use_ar: predictor = ARLinkPredictor(in_channels=args.out_channels).to(device) print(f"Running link prediction on {args.dataset}" f"with Attract-Repel embeddings") else: predictor = LinkPredictor( in_channels=args.out_channels, hidden_channels=args.hidden_channels).to(device) print(f"Running link prediction on {args.dataset}" f"with Traditional embeddings") optimizer = torch.optim.Adam( list(encoder.parameters()) + list(predictor.parameters()), lr=args.lr) best_val_auc = 0 final_test_auc = 0 for epoch in range(1, args.epochs + 1): loss = train(encoder, predictor, data, optimizer) val_auc, test_auc = test(encoder, predictor, data) if val_auc > best_val_auc: best_val_auc = val_auc final_test_auc = test_auc if epoch % 10 == 0: print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' f'Val AUC: {val_auc:.4f}, ' f'Test AUC: {test_auc:.4f}') print(f'Final results - Val AUC: {best_val_auc:.4f}, ' f'Test AUC: {final_test_auc:.4f}') # Calculate R-fraction if using AR if args.use_ar: with torch.no_grad(): z = encoder(data.x, data.train_pos_edge_index) attr_dim = args.out_channels // 2 z_attr = z[:, :attr_dim] z_repel = z[:, attr_dim:] attract_norm_squared = torch.sum(z_attr**2) repel_norm_squared = torch.sum(z_repel**2) r_fraction = repel_norm_squared / (attract_norm_squared + repel_norm_squared) print(f"R-fraction: {r_fraction.item():.4f}") if __name__ == '__main__': main() ================================================ FILE: examples/argva_node_clustering.py ================================================ import os.path as osp import matplotlib.pyplot as plt import torch from sklearn.cluster import KMeans from sklearn.manifold import TSNE from sklearn.metrics.cluster import ( completeness_score, homogeneity_score, v_measure_score, ) from torch.nn import Linear import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import ARGVA, GCNConv if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') transform = T.Compose([ T.ToDevice(device), T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False), ]) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora', transform=transform) train_data, val_data, test_data = dataset[0] class Encoder(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv_mu = GCNConv(hidden_channels, out_channels) self.conv_logstd = GCNConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index) class Discriminator(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.lin1 = Linear(in_channels, hidden_channels) self.lin2 = Linear(hidden_channels, hidden_channels) self.lin3 = Linear(hidden_channels, out_channels) def forward(self, x): x = self.lin1(x).relu() x = self.lin2(x).relu() return self.lin3(x) encoder = Encoder(train_data.num_features, hidden_channels=32, out_channels=32) discriminator = Discriminator(in_channels=32, hidden_channels=64, out_channels=32) model = ARGVA(encoder, discriminator).to(device) encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005) discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001) def train(): model.train() encoder_optimizer.zero_grad() z = model.encode(train_data.x, train_data.edge_index) # We optimize the discriminator more frequently than the encoder. for _ in range(5): discriminator_optimizer.zero_grad() discriminator_loss = model.discriminator_loss(z) discriminator_loss.backward() discriminator_optimizer.step() loss = model.recon_loss(z, train_data.pos_edge_label_index) loss = loss + model.reg_loss(z) loss = loss + (1 / train_data.num_nodes) * model.kl_loss() loss.backward() encoder_optimizer.step() return float(loss.detach()) @torch.no_grad() def test(data): model.eval() z = model.encode(data.x, data.edge_index) # Cluster embedded values using k-means. kmeans_input = z.cpu().numpy() kmeans = KMeans(n_clusters=7, random_state=0, n_init='auto').fit(kmeans_input) pred = kmeans.predict(kmeans_input) labels = data.y.cpu().numpy() completeness = completeness_score(labels, pred) hm = homogeneity_score(labels, pred) nmi = v_measure_score(labels, pred) auc, ap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index) return auc, ap, completeness, hm, nmi for epoch in range(1, 151): loss = train() auc, ap, completeness, hm, nmi = test(test_data) print(f'Epoch: {epoch:03d}, Loss: {loss:.3f}, AUC: {auc:.3f}, ' f'AP: {ap:.3f}, Completeness: {completeness:.3f}, ' f'Homogeneity: {hm:.3f}, NMI: {nmi:.3f}') @torch.no_grad() def plot_points(data, colors): model.eval() z = model.encode(data.x, data.edge_index) z = TSNE(n_components=2).fit_transform(z.cpu().numpy()) y = data.y.cpu().numpy() plt.figure(figsize=(8, 8)) for i in range(dataset.num_classes): plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i]) plt.axis('off') plt.show() colors = [ '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700' ] plot_points(test_data, colors) ================================================ FILE: examples/arma.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import ARMAConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) data = dataset[0] class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = ARMAConv(in_channels, hidden_channels, num_stacks=3, num_layers=2, shared_weights=True, dropout=0.25) self.conv2 = ARMAConv(hidden_channels, out_channels, num_stacks=3, num_layers=2, shared_weights=True, dropout=0.25, act=lambda x: x) def forward(self, x, edge_index): x = F.dropout(x, training=self.training) x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model, data = Net(dataset.num_features, 16, dataset.num_classes).to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() def test(): model.eval() out, accs = model(data.x, data.edge_index), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs best_val_acc = test_acc = 0 for epoch in range(1, 401): train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/attentive_fp.py ================================================ import os.path as osp from math import sqrt import torch import torch.nn.functional as F from rdkit import Chem from torch_geometric.datasets import MoleculeNet from torch_geometric.loader import DataLoader from torch_geometric.nn.models import AttentiveFP class GenFeatures: def __init__(self): self.symbols = [ 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At', 'other' ] self.hybridizations = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, 'other', ] self.stereos = [ Chem.rdchem.BondStereo.STEREONONE, Chem.rdchem.BondStereo.STEREOANY, Chem.rdchem.BondStereo.STEREOZ, Chem.rdchem.BondStereo.STEREOE, ] def __call__(self, data): # Generate AttentiveFP features according to Table 1. mol = Chem.MolFromSmiles(data.smiles) xs = [] for atom in mol.GetAtoms(): symbol = [0.] * len(self.symbols) symbol[self.symbols.index(atom.GetSymbol())] = 1. degree = [0.] * 6 degree[atom.GetDegree()] = 1. formal_charge = atom.GetFormalCharge() radical_electrons = atom.GetNumRadicalElectrons() hybridization = [0.] * len(self.hybridizations) hybridization[self.hybridizations.index( atom.GetHybridization())] = 1. aromaticity = 1. if atom.GetIsAromatic() else 0. hydrogens = [0.] * 5 hydrogens[atom.GetTotalNumHs()] = 1. chirality = 1. if atom.HasProp('_ChiralityPossible') else 0. chirality_type = [0.] * 2 if atom.HasProp('_CIPCode'): chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1. x = torch.tensor(symbol + degree + [formal_charge] + [radical_electrons] + hybridization + [aromaticity] + hydrogens + [chirality] + chirality_type) xs.append(x) data.x = torch.stack(xs, dim=0) edge_indices = [] edge_attrs = [] for bond in mol.GetBonds(): edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]] edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]] bond_type = bond.GetBondType() single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0. double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0. triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0. aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0. conjugation = 1. if bond.GetIsConjugated() else 0. ring = 1. if bond.IsInRing() else 0. stereo = [0.] * 4 stereo[self.stereos.index(bond.GetStereo())] = 1. edge_attr = torch.tensor( [single, double, triple, aromatic, conjugation, ring] + stereo) edge_attrs += [edge_attr, edge_attr] if len(edge_attrs) == 0: data.edge_index = torch.zeros((2, 0), dtype=torch.long) data.edge_attr = torch.zeros((0, 10), dtype=torch.float) else: data.edge_index = torch.tensor(edge_indices).t().contiguous() data.edge_attr = torch.stack(edge_attrs, dim=0) return data path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'AFP_Mol') dataset = MoleculeNet(path, name='ESOL', pre_transform=GenFeatures()).shuffle() N = len(dataset) // 10 val_dataset = dataset[:N] test_dataset = dataset[N:2 * N] train_dataset = dataset[2 * N:] train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=200) test_loader = DataLoader(test_dataset, batch_size=200) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = AttentiveFP(in_channels=39, hidden_channels=200, out_channels=1, edge_dim=10, num_layers=2, num_timesteps=2, dropout=0.2).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=10**-2.5, weight_decay=10**-5) def train(): total_loss = total_examples = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.edge_attr, data.batch) loss = F.mse_loss(out, data.y) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs total_examples += data.num_graphs return sqrt(total_loss / total_examples) @torch.no_grad() def test(loader): mse = [] for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.edge_attr, data.batch) mse.append(F.mse_loss(out, data.y, reduction='none').cpu()) return float(torch.cat(mse, dim=0).mean().sqrt()) for epoch in range(1, 201): train_rmse = train() val_rmse = test(val_loader) test_rmse = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} ' f'Test: {test_rmse:.4f}') ================================================ FILE: examples/autoencoder.py ================================================ import argparse import os.path as osp import time import torch import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GAE, VGAE, GCNConv parser = argparse.ArgumentParser() parser.add_argument('--variational', action='store_true') parser.add_argument('--linear', action='store_true') parser.add_argument('--dataset', type=str, default='Cora', choices=['Cora', 'CiteSeer', 'PubMed']) parser.add_argument('--epochs', type=int, default=400) args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') transform = T.Compose([ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False), ]) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, args.dataset, transform=transform) train_data, val_data, test_data = dataset[0] class GCNEncoder(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, 2 * out_channels) self.conv2 = GCNConv(2 * out_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) class VariationalGCNEncoder(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, 2 * out_channels) self.conv_mu = GCNConv(2 * out_channels, out_channels) self.conv_logstd = GCNConv(2 * out_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index) class LinearEncoder(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = GCNConv(in_channels, out_channels) def forward(self, x, edge_index): return self.conv(x, edge_index) class VariationalLinearEncoder(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv_mu = GCNConv(in_channels, out_channels) self.conv_logstd = GCNConv(in_channels, out_channels) def forward(self, x, edge_index): return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index) in_channels, out_channels = dataset.num_features, 16 if not args.variational and not args.linear: model = GAE(GCNEncoder(in_channels, out_channels)) elif not args.variational and args.linear: model = GAE(LinearEncoder(in_channels, out_channels)) elif args.variational and not args.linear: model = VGAE(VariationalGCNEncoder(in_channels, out_channels)) elif args.variational and args.linear: model = VGAE(VariationalLinearEncoder(in_channels, out_channels)) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() optimizer.zero_grad() z = model.encode(train_data.x, train_data.edge_index) loss = model.recon_loss(z, train_data.pos_edge_label_index) if args.variational: loss = loss + (1 / train_data.num_nodes) * model.kl_loss() loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(data): model.eval() z = model.encode(data.x, data.edge_index) return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index) times = [] for epoch in range(1, args.epochs + 1): start = time.time() loss = train() auc, ap = test(test_data) print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/cluster_gcn_ppi.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from sklearn.metrics import f1_score from torch_geometric.data import Batch from torch_geometric.datasets import PPI from torch_geometric.loader import ClusterData, ClusterLoader, DataLoader from torch_geometric.nn import BatchNorm, SAGEConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI') train_dataset = PPI(path, split='train') val_dataset = PPI(path, split='val') test_dataset = PPI(path, split='test') train_data = Batch.from_data_list(train_dataset) cluster_data = ClusterData(train_data, num_parts=50, recursive=False, save_dir=train_dataset.processed_dir) train_loader = ClusterLoader(cluster_data, batch_size=1, shuffle=True, num_workers=12) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) self.batch_norms.append(BatchNorm(hidden_channels)) for _ in range(num_layers - 2): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.batch_norms.append(BatchNorm(hidden_channels)) self.convs.append(SAGEConv(hidden_channels, out_channels)) def forward(self, x, edge_index): for conv, batch_norm in zip(self.convs[:-1], self.batch_norms): x = conv(x, edge_index) x = batch_norm(x) x = F.relu(x) x = F.dropout(x, p=0.2, training=self.training) return self.convs[-1](x, edge_index) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(in_channels=train_dataset.num_features, hidden_channels=1024, out_channels=train_dataset.num_classes, num_layers=6).to(device) loss_op = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() loss = loss_op(model(data.x, data.edge_index), data.y) loss.backward() optimizer.step() total_loss += loss.item() * data.num_nodes return total_loss / train_data.num_nodes @torch.no_grad() def test(loader): model.eval() ys, preds = [], [] for data in loader: ys.append(data.y) out = model(data.x.to(device), data.edge_index.to(device)) preds.append((out > 0).float().cpu()) y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 times = [] for epoch in range(1, 201): start = time.time() loss = train() val_f1 = test(val_loader) test_f1 = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, ' f'Test: {test_f1:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/cluster_gcn_reddit.py ================================================ import time import torch import torch.nn.functional as F from torch.nn import ModuleList from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import ClusterData, ClusterLoader, NeighborLoader from torch_geometric.nn import SAGEConv dataset = Reddit('../data/Reddit') data = dataset[0] cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir) train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=12) subgraph_loader = NeighborLoader(data, num_neighbors=[-1], batch_size=1024, shuffle=False, num_workers=12) class Net(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.convs = ModuleList( [SAGEConv(in_channels, 128), SAGEConv(128, out_channels)]) def forward(self, x, edge_index): for i, conv in enumerate(self.convs): x = conv(x, edge_index) if i != len(self.convs) - 1: x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) return F.log_softmax(x, dim=-1) def inference(self, x_all): pbar = tqdm(total=x_all.size(0) * len(self.convs)) pbar.set_description('Evaluating') # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch. for i, conv in enumerate(self.convs): xs = [] for batch in subgraph_loader: edge_index = batch.edge_index.to(device) x = x_all[batch.n_id].to(device) x_target = x[:batch.batch_size] x = conv((x, x_target), edge_index) if i != len(self.convs) - 1: x = F.relu(x) xs.append(x.cpu()) pbar.update(batch.batch_size) x_all = torch.cat(xs, dim=0) pbar.close() return x_all device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(dataset.num_features, dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005) def train(): model.train() total_loss = total_nodes = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() out = model(batch.x, batch.edge_index) loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask]) loss.backward() optimizer.step() nodes = batch.train_mask.sum().item() total_loss += loss.item() * nodes total_nodes += nodes return total_loss / total_nodes @torch.no_grad() def test(): # Inference should be performed on the full graph. model.eval() out = model.inference(data.x) y_pred = out.argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: correct = y_pred[mask].eq(data.y[mask]).sum().item() accs.append(correct / mask.sum().item()) return accs times = [] for epoch in range(1, 31): start = time.time() loss = train() if epoch % 5 == 0: train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, test: {test_acc:.4f}') else: print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/colors_topk_pool.py ================================================ import copy import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GINConv, TopKPooling, global_add_pool from torch_geometric.utils import scatter class HandleNodeAttention: def __call__(self, data): data = copy.copy(data) data.attn = torch.softmax(data.x[:, 0], dim=0) data.x = data.x[:, 1:] return data path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'COLORS-3') dataset = TUDataset(path, 'COLORS-3', use_node_attr=True, transform=HandleNodeAttention()) train_loader = DataLoader(dataset[:500], batch_size=60, shuffle=True) val_loader = DataLoader(dataset[500:3000], batch_size=60) test_loader = DataLoader(dataset[3000:], batch_size=60) class Net(torch.nn.Module): def __init__(self, in_channels): super().__init__() self.conv1 = GINConv(Seq(Lin(in_channels, 64), ReLU(), Lin(64, 64))) self.pool1 = TopKPooling(in_channels, min_score=0.05) self.conv2 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64))) self.lin = torch.nn.Linear(64, 1) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch out = F.relu(self.conv1(x, edge_index)) out, edge_index, _, batch, perm, score = self.pool1( out, edge_index, None, batch, attn=x) ratio = out.size(0) / x.size(0) out = F.relu(self.conv2(out, edge_index)) out = global_add_pool(out, batch) out = self.lin(out).view(-1) attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm], reduction='none') attn_loss = scatter(attn_loss, batch, reduce='mean') return out, attn_loss, ratio device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(dataset.num_features).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Initialize to optimal attention weights: # model.pool1.weight.data = torch.tensor([0., 1., 0., 0.]).view(1,4).to(device) def train(epoch): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out, attn_loss, _ = model(data) loss = ((out - data.y).pow(2) + 100 * attn_loss).mean() loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_loader.dataset) def test(loader): model.eval() corrects, total_ratio = [], 0 for data in loader: data = data.to(device) out, _, ratio = model(data) pred = out.round().to(torch.long) corrects.append(pred.eq(data.y.to(torch.long))) total_ratio += ratio return torch.cat(corrects, dim=0), total_ratio / len(loader) for epoch in range(1, 301): loss = train(epoch) train_correct, train_ratio = test(train_loader) val_correct, val_ratio = test(val_loader) test_correct, test_ratio = test(test_loader) train_acc = train_correct.sum().item() / train_correct.size(0) val_acc = val_correct.sum().item() / val_correct.size(0) test_acc1 = test_correct[:2500].sum().item() / 2500 test_acc2 = test_correct[2500:5000].sum().item() / 2500 test_acc3 = test_correct[5000:].sum().item() / 2500 print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.3f}, ' f'Val: {val_acc:.3f}, Test Orig: {test_acc1:.3f}, ' f'Test Large: {test_acc2:.3f}, Test LargeC: {test_acc3:.3f}, ' f'Train/Val/Test Ratio=' f'{train_ratio:.3f}/{val_ratio:.3f}/{test_ratio:.3f}') ================================================ FILE: examples/compile/gcn.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', '..', 'data', 'Planetoid') dataset = Planetoid( path, name='Cora', transform=T.Compose([ T.NormalizeFeatures(), T.GCNNorm(), ])) data = dataset[0].to(device) class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() # Pre-process normalization to avoid CPU communication/graph breaks: self.conv1 = GCNConv(in_channels, hidden_channels, normalize=False) self.conv2 = GCNConv(hidden_channels, out_channels, normalize=False) def forward(self, x, edge_index, edge_weight): x = F.dropout(x, p=0.5, training=self.training) x = self.conv1(x, edge_index, edge_weight).relu() x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return x model = GCN( in_channels=dataset.num_features, hidden_channels=16, out_channels=dataset.num_classes, ).to(device) # Compile the model into an optimized version: model = torch.compile(model, dynamic=False) optimizer = torch.optim.Adam([ dict(params=model.conv1.parameters(), weight_decay=5e-4), dict(params=model.conv2.parameters(), weight_decay=0) ], lr=0.01) # Only perform weight-decay on first convolution. def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index, data.edge_weight) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss.detach()) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index, data.edge_weight).argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs times = [] for epoch in range(1, 201): start = time.time() loss = train() train_acc, val_acc, test_acc = test() times.append(time.time() - start) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') print(f'Median time per epoch: {torch.tensor(times).median():.4f}s') ================================================ FILE: examples/compile/gin.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F import torch_geometric from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, GINConv, global_add_pool if not torch_geometric.typing.WITH_PT21: quit('Dynamic shape compilation requires PyTorch >= 2.1.0') if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # MPS is currently slower than CPU due to missing int64 min/max ops device = torch.device('cpu') else: device = torch.device('cpu') path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', '..', 'data', 'TU') dataset = TUDataset(path, name='MUTAG').shuffle() train_loader = DataLoader(dataset[:0.9], batch_size=128, shuffle=True) test_loader = DataLoader(dataset[0.9:], batch_size=128) class GIN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): mlp = MLP([in_channels, hidden_channels, hidden_channels]) self.convs.append(GINConv(nn=mlp, train_eps=False)) in_channels = hidden_channels self.mlp = MLP([hidden_channels, hidden_channels, out_channels], norm=None, dropout=0.5) def forward(self, x, edge_index, batch, batch_size): for conv in self.convs: x = conv(x, edge_index).relu() # Pass the batch size to avoid CPU communication/graph breaks: x = global_add_pool(x, batch, size=batch_size) return self.mlp(x) model = GIN( in_channels=dataset.num_features, hidden_channels=32, out_channels=dataset.num_classes, num_layers=5, ).to(device) # Compile the model into an optimized version: model = torch.compile(model, dynamic=True) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch, data.batch_size) loss = F.cross_entropy(out, data.y) loss.backward() optimizer.step() total_loss += float(loss.detach()) * data.num_graphs return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() total_correct = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch, data.batch_size) pred = out.argmax(dim=-1) total_correct += int((pred == data.y).sum()) return total_correct / len(loader.dataset) times = [] for epoch in range(1, 101): start = time.time() loss = train() train_acc = test(train_loader) test_acc = test(test_loader) times.append(time.time() - start) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Test: {test_acc:.4f}') print(f'Median time per epoch: {torch.tensor(times).median():.4f}s') ================================================ FILE: examples/contrib/README.md ================================================ # Examples for External Contributions This directory contains examples demonstrating functionality included in the `torch_geometric.contrib` package. The `contrib` package of PyG is a staging area for early-stage, experimental code. Modules included here might be moved to the main library in the future. | Example | Description | | ---------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | | [`rbcd_attack.py`](./rbcd_attack.py) | An example of the RBCD (Resource-based Critical Data) attack | | [`rbcd_attack_poisoning.py`](./rbcd_attack_poisoning.py) | An example of the RBCD (Resource-Based Critical Data) attack with data poisoning strategies | | [`pgm_explainer_node_classification.py`](./pgm_explainer_node_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for node classification | | [`pgm_explainer_graph_classification.py`](./pgm_explainer_graph_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for graph classification | ================================================ FILE: examples/contrib/pgm_explainer_graph_classification.py ================================================ """This is an example of using the PGM explainer algorithm on a graph classification task. """ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear, ReLU, Sequential import torch_geometric.transforms as T from torch_geometric.contrib.explain import PGMExplainer from torch_geometric.datasets import MNISTSuperpixels from torch_geometric.explain import Explainer from torch_geometric.loader import DataLoader from torch_geometric.nn import ( NNConv, global_mean_pool, graclus, max_pool, max_pool_x, ) from torch_geometric.utils import normalized_cut path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') transform = T.Cartesian(cat=False) train_dataset = MNISTSuperpixels(path, True, transform=transform) test_dataset = MNISTSuperpixels(path, False, transform=transform) train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) d = train_dataset def normalized_cut_2d(edge_index, pos): row, col = edge_index edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) class Net(torch.nn.Module): def __init__(self): super().__init__() nn1 = Sequential( Linear(2, 25), ReLU(), Linear(25, d.num_features * 32), ) self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean') nn2 = Sequential( Linear(2, 25), ReLU(), Linear(25, 32 * 64), ) self.conv2 = NNConv(32, 64, nn2, aggr='mean') self.fc1 = torch.nn.Linear(64, 128) self.fc2 = torch.nn.Linear(128, d.num_classes) def forward(self, x, edge_index, **kwargs): data = kwargs.get('data') data = data.detach().clone() x = F.elu(self.conv1(x, edge_index, data.edge_attr)) weight = normalized_cut_2d(edge_index, data.pos) cluster = graclus(edge_index, weight, x.size(0)) data.edge_attr = None data.x = x data.edge_index = edge_index data = max_pool(cluster, data, transform=transform) data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = max_pool_x(cluster, data.x, data.batch) x = global_mean_pool(x, batch) x = F.elu(self.fc1(x)) x = F.dropout(x, training=self.training) return F.log_softmax(self.fc2(x), dim=1) def train(model, dataloader): model.train() for data in dataloader: data = data.to(device) optimizer.zero_grad() F.nll_loss(model(data.x, data), data.y).backward() optimizer.step() if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'current device: {device}') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for _ in range(2): train(model, train_loader) explainer = Explainer( model=model, algorithm=PGMExplainer(perturb_feature_list=[0], perturbation_mode="mean"), explanation_type='phenomenon', node_mask_type="object", model_config=dict(mode="multiclass_classification", task_level="graph", return_type="raw")) i = 0 for explain_dataset in test_loader: explain_dataset.to(device) explanation = explainer(x=explain_dataset.x, edge_index=explain_dataset.edge_index, target=explain_dataset.y, edge_attr=explain_dataset.edge_attr, data=explain_dataset) for k in explanation.available_explanations: print(explanation[k]) i += 1 if i > 2: break ================================================ FILE: examples/contrib/pgm_explainer_node_classification.py ================================================ """This is an example of using the PGM explainer algorithm on a node classification task. """ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.contrib.explain import PGMExplainer from torch_geometric.datasets import Planetoid from torch_geometric.explain import Explainer, ModelConfig from torch_geometric.nn import GCNConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') transform = T.Compose([T.GCNNorm(), T.NormalizeFeatures()]) dataset = Planetoid(path, dataset, transform=transform) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16, normalize=False) self.conv2 = GCNConv(16, dataset.num_classes, normalize=False) def forward(self, x, edge_index, edge_weight): x = F.relu(self.conv1(x, edge_index, edge_weight)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index, edge_weight) return F.log_softmax(x, dim=1) if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) x, edge_index, edge_weight, target = \ data.x, data.edge_index, data.edge_weight, data.y model.train() for _ in range(1, 500): optimizer.zero_grad() log_logits = model(x, edge_index, edge_weight) loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() model.eval() log_logits = model(x, edge_index, edge_weight) predicted_target = log_logits.argmax(dim=1) explainer = Explainer( model=model, algorithm=PGMExplainer(), node_mask_type='attributes', explanation_type='phenomenon', model_config=ModelConfig(mode='multiclass_classification', task_level='node', return_type='raw')) node_idx = 100 explanation = explainer(x=data.x, edge_index=edge_index, index=node_idx, target=predicted_target, edge_weight=edge_weight) print(f'Significance of relevant neighbors: {explanation.pgm_stats}') ================================================ FILE: examples/contrib/rbcd_attack.py ================================================ import copy import os.path as osp from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.optim import Adam import torch_geometric.transforms as T from torch_geometric.contrib.nn import GRBCDAttack, PRBCDAttack from torch_geometric.datasets import Planetoid from torch_geometric.nn import GATConv, GCNConv from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.utils import softmax path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid') class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.norm = gcn_norm self.conv1 = GCNConv(in_channels, hidden_channels, normalize=False) self.conv2 = GCNConv(hidden_channels, out_channels, normalize=False) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, edge_index, edge_weight=None, **kwargs): # Normalize edge indices only once: if not kwargs.get('skip_norm', False): edge_index, edge_weight = self.norm( edge_index, edge_weight, num_nodes=x.size(0), add_self_loops=True, ) x = self.conv1(x, edge_index, edge_weight).relu() x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index, edge_weight) return x class WeightedGATConv(GATConv): """Extended GAT to allow for weighted edges.""" def edge_update(self, alpha_j: Tensor, alpha_i: Optional[Tensor], edge_attr: Optional[Tensor], index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor: # Given edge-level attention coefficients for source and target nodes, # we simply need to sum them up to "emulate" concatenation: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i alpha = F.leaky_relu(alpha, self.negative_slope) if edge_attr is not None: assert edge_attr.dim() == 1, 'Only scalar edge weights supported' edge_attr = edge_attr.view(-1, 1) # `alpha` unchanged if edge_attr == 1 and -Inf if edge_attr == 0; # We choose log to counteract underflow in subsequent exp/softmax alpha = alpha + torch.log2(edge_attr) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha class GAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() # Initialize edge weights of self-loops with 1: self.conv1 = WeightedGATConv(in_channels, hidden_channels, fill_value=1.) self.conv2 = WeightedGATConv(hidden_channels, out_channels, fill_value=1.) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, edge_index, edge_weight=None): x = self.conv1(x, edge_index, edge_weight).relu() x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index, edge_weight) return x def train(model, data, epochs=200, lr=0.01, weight_decay=5e-4): model.train() optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for _ in range(epochs): optimizer.zero_grad() pred = model(data.x, data.edge_index, data.edge_weight) loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() def accuracy(pred, y, mask): return (pred.argmax(-1)[mask] == y[mask]).float().mean() @torch.no_grad() def test(model, data): model.eval() pred = model(data.x, data.edge_index, data.edge_weight) return float(accuracy(pred, data.y, data.test_mask)) # The metric in PRBCD is assumed to be best if lower (like a loss). def metric(*args, **kwargs): return -accuracy(*args, **kwargs) if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures()) data = dataset[0].to(device) gcn = GCN(dataset.num_features, 16, dataset.num_classes).to(device) gat = GAT(dataset.num_features, 16, dataset.num_classes).to(device) train(gcn, data) gcn.eval() train(gat, data) gat.eval() node_idx = 42 local_budget = 2 # Degree of (training) node 42 is 2. # Perturb 5% of edges: global_budget = int(0.05 * data.edge_index.size(1) / 2) print('------------- GAT: Local Evasion -------------') # Note: GRBCD is faster than PRBCD for small budgets but not as consistent grbcd = GRBCDAttack(gat, block_size=250_000) # The learning rate is one of the most important parameters for PRBCD and a # good heuristic is to choose it s.t. the budget is exhausted within a few # steps. Moreover, a high learning rate mitigates the impact of the # relaxation gap ({0, 1} -> [0, 1]) of the edge weights. See poisoning # example for a debug plot. prbcd = PRBCDAttack(gat, block_size=250_000, metric=metric, lr=2_000) clean_acc = test(gat, data) print(f'Clean accuracy: {clean_acc:.3f}') # GRBCD: Attack a single node: pert_edge_index, perts = grbcd.attack( data.x, data.edge_index, data.y, budget=local_budget, idx_attack=[node_idx], ) clean_margin = -PRBCDAttack._probability_margin_loss( gat(data.x, data.edge_index), data.y, [node_idx]) pert_margin = -PRBCDAttack._probability_margin_loss( gat(data.x, pert_edge_index), data.y, [node_idx]) print(f'GRBCD: Confidence margin of target to best non-target dropped ' f'from {clean_margin:.3f} to {pert_margin:.3f}') adv_edges = ', '.join(str((u, v)) for u, v in perts.T.tolist()) print(f'Adv. edges: {adv_edges}') # PRBCD: Attack single node: pert_edge_index, perts = prbcd.attack( data.x, data.edge_index, data.y, budget=local_budget, idx_attack=[node_idx], ) clean_margin = -PRBCDAttack._probability_margin_loss( gat(data.x, data.edge_index), data.y, [node_idx]) pert_margin = -PRBCDAttack._probability_margin_loss( gat(data.x, pert_edge_index), data.y, [node_idx]) print(f'PRBCD: Confidence margin of target to best non-target dropped ' f'from {clean_margin:.3f} to {pert_margin:.3f}') adv_edges = ', '.join(str((u, v)) for u, v in perts.T.tolist()) print(f'Adv. edges: {adv_edges}\n') print('------------- GCN: Global Evasion -------------') grbcd = GRBCDAttack(gcn, block_size=250_000) prbcd = PRBCDAttack(gcn, block_size=250_000, metric=metric, lr=2_000) clean_acc = test(gcn, data) # GRBCD: Attack test set: pert_edge_index, perts = grbcd.attack( data.x, data.edge_index, data.y, budget=global_budget, idx_attack=data.test_mask, ) pert_data = copy.copy(data) pert_data.edge_index = pert_edge_index pert_acc = test(gcn, pert_data) print(f'GRBCD: Accuracy dropped from {clean_acc:.3f} to {pert_acc:.3f}') # PRBCD: Attack test set: pert_edge_index, perts = prbcd.attack( data.x, data.edge_index, data.y, budget=global_budget, idx_attack=data.test_mask, ) pert_data = copy.copy(data) pert_data.edge_index = pert_edge_index pert_acc = test(gcn, pert_data) print(f'PRBCD: Accuracy dropped from {clean_acc:.3f} to {pert_acc:.3f}') ================================================ FILE: examples/contrib/rbcd_attack_poisoning.py ================================================ import copy import os.path as osp import sys from typing import Optional, Tuple import matplotlib.pyplot as plt import torch import torch.nn.functional as F from rbcd_attack import GCN, metric, test, train from torch import Tensor from torch.optim import Adam import torch_geometric.transforms as T from torch_geometric.contrib.nn import PRBCDAttack from torch_geometric.datasets import Planetoid try: import higher except ImportError: sys.exit('Install `higher` via `pip install higher` for poisoning example') path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # IMPORTANT: Edge weights are being ignored later and most adjacency matrix # preprocessing should be part of the model (part of backpropagation): dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures()) data = dataset[0].to(device) gcn = GCN(dataset.num_features, 16, dataset.num_classes).to(device) train(gcn, data) print('------------- GCN: Global Poisoning -------------') clean_acc = test(gcn, data) print(f'Clean accuracy: {clean_acc:.3f}') n_epochs = 50 lr = 0.04 weight_decay = 5e-4 class PoisoningPRBCDAttack(PRBCDAttack): def _forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, **kwargs) -> Tensor: """Forward model.""" self.model.reset_parameters() with torch.enable_grad(): ped = copy.copy(data) ped.x, ped.edge_index, ped.edge_weight = x, edge_index, edge_weight train(self.model, ped, n_epochs, lr, weight_decay) self.model.eval() return self.model(x, edge_index, edge_weight) def _forward_and_gradient(self, x: Tensor, labels: Tensor, idx_attack: Optional[Tensor] = None, **kwargs) -> Tuple[Tensor, Tensor]: """Forward and update edge weights.""" self.block_edge_weight.requires_grad = True self.model.reset_parameters() self.model.train() opt = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) with higher.innerloop_ctx(self.model, opt) as (fmodel, diffopt): edge_index, edge_weight = self._get_modified_adj( self.edge_index, self.edge_weight, self.block_edge_index, self.block_edge_weight) # Normalize only once (only relevant if model normalizes adj) if hasattr(fmodel, 'norm'): edge_index, edge_weight = fmodel.norm( edge_index, edge_weight, num_nodes=x.size(0), add_self_loops=True, ) for _ in range(n_epochs): pred = fmodel.forward(x, edge_index, edge_weight, skip_norm=True) loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask]) diffopt.step(loss) pred = fmodel(x, edge_index, edge_weight) loss = self.loss(pred, labels, idx_attack) gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] # Clip gradient for stability: clip_norm = 0.5 grad_len_sq = gradient.square().sum() if grad_len_sq > clip_norm: gradient *= clip_norm / grad_len_sq.sqrt() self.model.eval() return loss, gradient prbcd = PoisoningPRBCDAttack(gcn, block_size=250_000, metric=metric, lr=100) # PRBCD: Attack test set: global_budget = int(0.05 * data.edge_index.size(1) / 2) # Perturb 5% of edges pert_edge_index, perts = prbcd.attack( data.x, data.edge_index, data.y, budget=global_budget, idx_attack=data.test_mask, ) gcn.reset_parameters() pert_data = copy.copy(data) pert_data.edge_index = pert_edge_index train(gcn, pert_data) pert_acc = test(gcn, pert_data) # Note that the values here a bit more noisy than in the evasion case: print(f'PRBCD: Accuracy dropped from {clean_acc:.3f} to {pert_acc:.3f}') fig, ax1 = plt.subplots() plt.title('Global Poisoning GCN') color = 'tab:red' ax1.plot(prbcd.attack_statistics['loss'], color=color, label='Loss') ax1.tick_params(axis='y', labelcolor=color) ax1.set_ylabel('Loss') ax1.set_xlabel('Steps') # It is best practice choosing the learning rate s.t. the budget is exhausted: ax2 = ax1.twinx() color = 'tab:blue' ax2.plot(prbcd.attack_statistics['prob_mass_after_update'], color=color, linestyle='--', label='Before projection') ax2.plot(prbcd.attack_statistics['prob_mass_after_projection'], color=color, label='After projection') ax2.tick_params(axis='y', labelcolor=color) ax2.set_ylabel('Used budget') plt.legend() fig.show() ================================================ FILE: examples/cora.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import SplineConv from torch_geometric.typing import WITH_SPLINE if not WITH_SPLINE: quit("This example requires 'pyg-lib>=0.6.0'") dataset = 'Cora' transform = T.Compose([ T.RandomNodeSplit(num_val=500, num_test=500), T.TargetIndegree(), ]) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=transform) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SplineConv(dataset.num_features, 16, dim=1, kernel_size=2) self.conv2 = SplineConv(16, dataset.num_classes, dim=1, kernel_size=2) def forward(self): x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr x = F.dropout(x, training=self.training) x = F.elu(self.conv1(x, edge_index, edge_attr)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index, edge_attr) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3) def train(): model.train() optimizer.zero_grad() F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() @torch.no_grad() def test(): model.eval() log_probs, accs = model(), [] for _, mask in data('train_mask', 'test_mask'): pred = log_probs[mask].max(1)[1] acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs for epoch in range(1, 201): train() train_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/correct_and_smooth.py ================================================ import os.path as osp import torch from ogb.nodeproppred import Evaluator, PygNodePropPredDataset import torch_geometric.transforms as T from torch_geometric.nn import MLP, CorrectAndSmooth from torch_geometric.typing import WITH_TORCH_SPARSE if not WITH_TORCH_SPARSE: quit("This example requires 'torch-sparse'") root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB') dataset = PygNodePropPredDataset('ogbn-products', root, transform=T.ToSparseTensor()) evaluator = Evaluator(name='ogbn-products') split_idx = dataset.get_idx_split() data = dataset[0] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = MLP([dataset.num_features, 200, 200, dataset.num_classes], dropout=0.5, norm="batch_norm", act_first=True).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) criterion = torch.nn.CrossEntropyLoss() x, y = data.x.to(device), data.y.to(device) train_idx = split_idx['train'].to(device) val_idx = split_idx['valid'].to(device) test_idx = split_idx['test'].to(device) x_train, y_train = x[train_idx], y[train_idx] def train(): model.train() optimizer.zero_grad() out = model(x_train) loss = criterion(out, y_train.view(-1)) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(out=None): model.eval() out = model(x) if out is None else out pred = out.argmax(dim=-1, keepdim=True) train_acc = evaluator.eval({ 'y_true': y[train_idx], 'y_pred': pred[train_idx] })['acc'] val_acc = evaluator.eval({ 'y_true': y[val_idx], 'y_pred': pred[val_idx] })['acc'] test_acc = evaluator.eval({ 'y_true': y[test_idx], 'y_pred': pred[test_idx] })['acc'] return train_acc, val_acc, test_acc, out best_val_acc = 0 for epoch in range(1, 301): loss = train() train_acc, val_acc, test_acc, out = test() if val_acc > best_val_acc: best_val_acc = val_acc y_soft = out.softmax(dim=-1) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}') adj_t = data.adj_t.to(device) deg = adj_t.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 DAD = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) DA = deg_inv_sqrt.view(-1, 1) * deg_inv_sqrt.view(-1, 1) * adj_t post = CorrectAndSmooth(num_correction_layers=50, correction_alpha=1.0, num_smoothing_layers=50, smoothing_alpha=0.8, autoscale=False, scale=20.) print('Correct and smooth...') y_soft = post.correct(y_soft, y_train, train_idx, DAD) y_soft = post.smooth(y_soft, y_train, train_idx, DA) print('Done!') train_acc, val_acc, test_acc, _ = test(y_soft) print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/cpp/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10) project(hello-world) # The first thing do is to tell cmake to find the TorchScatter # and TorchSparse libraries. The package pulls in all the necessary # torch libraries, so there is no need to add `find_package(Torch)`. find_package(TorchScatter REQUIRED) find_package(TorchSparse REQUIRED) find_package(Python3 COMPONENTS Development) add_executable(hello-world main.cpp) # We now need to link the TorchScatter and TorchSparse libraries # to our executable. We can do that by using the # TorchScatter::TorchScatter and TorchSparse::TorchSparse targets, # which also adds all the necessary torch dependencies. target_compile_features(hello-world PUBLIC cxx_range_for) target_link_libraries(hello-world TorchScatter::TorchScatter) target_link_libraries(hello-world TorchSparse::TorchSparse) target_link_libraries(hello-world ${CUDA_cusparse_LIBRARY}) set_property(TARGET hello-world PROPERTY CXX_STANDARD 14) ================================================ FILE: examples/cpp/README.md ================================================ # PyG in C++ This is a minimal example of getting PyG to work in C++ with CMake. In order to successfully compile this example, make sure you have both the C++ APIs of [`TorchScatter`](https://github.com/rusty1s/pytorch_scatter#c-api) and [`TorchSparse`](https://github.com/rusty1s/pytorch_sparse/#c-api) installed. For this, we need to add `TorchLib` to the `-DCMAKE_PREFIX_PATH` (run `import torch; print(torch.utils.cmake_prefix_path)` to obtain it). Then, *e.g.*, to install `TorchScatter`, run: ``` git clone https://github.com/rusty1s/pytorch_scatter.git cd pytorch_scatter mkdir build && cd build cmake -DWITH_CUDA=on -DCMAKE_PREFIX_PATH="..." .. make (sudo) make install ``` Once both dependencies are sorted, we can start the CMake fun: 1. Run `save_model.py` to create and save a PyG GNN model. 1. Create a `build` directory inside the current one. 1. From within the `build` directory, run the following commands: - `cmake -DCMAKE_PREFIX_PATH=";;" ..` - `cmake --build .` That's it! You should now have a `hello-world` executable in your `build` folder. Run it via: ``` ./hello-world ../model.pt ``` ================================================ FILE: examples/cpp/main.cpp ================================================ #include #include #include #include int main(int argc, const char *argv[]) { if (argc != 2) { std::cerr << "usage: hello-world \n"; return -1; } torch::jit::script::Module model; try { model = torch::jit::load(argv[1]); } catch (const c10::Error &e) { std::cerr << "error loading the model\n"; return -1; } auto x = torch::randn({5, 32}); auto edge_index = torch::tensor({ {0, 1, 1, 2, 2, 3, 3, 4}, {1, 0, 2, 1, 3, 2, 4, 3}, }); std::vector inputs; inputs.push_back(x); inputs.push_back(edge_index); auto out = model.forward(inputs).toTensor(); std::cout << "output tensor shape: " << out.sizes() << std::endl; } ================================================ FILE: examples/cpp/save_model.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import BatchNorm1d, Linear, ReLU, Sequential from torch_geometric.nn import GINConv, global_mean_pool class GIN(torch.nn.Module): def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): mlp = Sequential( Linear(in_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels), ) self.convs.append(GINConv(mlp)) in_channels = hidden_channels self.lin1 = Linear(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) def forward( self, x: Tensor, edge_index: Tensor, batch: Optional[Tensor] = None, ) -> Tensor: for conv in self.convs: x = conv(x, edge_index).relu() x = global_mean_pool(x, batch) x = self.lin1(x).relu() x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x model = GIN(32, 64, 16, num_layers=3) model = torch.jit.script(model) model.save('model.pt') ================================================ FILE: examples/datapipe.py ================================================ # In this example, you will find data loading implementations using PyTorch # DataPipes (https://pytorch.org/data/) across various tasks: # (1) molecular graph data loading pipe # (2) mesh/point cloud data loading pipe # In particular, we make use of PyG's built-in DataPipes, e.g., for batching # multiple PyG data objects together or for converting SMILES strings into # molecular graph representations. We also showcase how to write your own # DataPipe (i.e. for loading and parsing mesh data into PyG data objects). import argparse import csv import os.path as osp import time from itertools import chain, tee import torch from torch.utils.data import IterDataPipe from torch.utils.data.datapipes.iter import ( FileLister, FileOpener, IterableWrapper, ) from torch_geometric.data import Data, download_url, extract_zip def molecule_datapipe() -> IterDataPipe: # Download HIV dataset from MoleculeNet: url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets' root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data') path = download_url(f'{url}/HIV.csv', root_dir) datapipe = FileOpener([path], mode="rt") # Convert CSV rows into dictionaries, skipping the header row datapipe = datapipe.map(lambda file: ( dict(zip(["smiles", "activity", "HIV_active"], row)) for i, row in enumerate(csv.reader(file[1])) if i > 0 and row)) datapipe = IterableWrapper(chain.from_iterable(datapipe)) datapipe = datapipe.parse_smiles(target_key='HIV_active') datapipe, = tee(datapipe, 1) return IterableWrapper(datapipe) @torch.utils.data.functional_datapipe('read_mesh') class MeshOpener(IterDataPipe): # A custom DataPipe to load and parse mesh data into PyG data objects. def __init__(self, dp: IterDataPipe) -> None: try: import meshio # noqa: F401 import torch_cluster # noqa: F401 except ImportError as e: raise ImportError( "To run this example, please install required packages:\n" "pip install meshio torch-cluster") from e super().__init__() self.dp = dp def __iter__(self): import meshio for path in self.dp: category = osp.basename(path).split('_')[0] try: mesh = meshio.read(path) except UnicodeDecodeError: # Failed to read the file because it is not in the expected OFF # format. continue pos = torch.from_numpy(mesh.points).to(torch.float) face = torch.from_numpy(mesh.cells[0].data).t().contiguous() yield Data(pos=pos, face=face, category=category) def mesh_datapipe() -> IterDataPipe: # Download ModelNet10 dataset from Princeton: url = 'http://vision.princeton.edu/projects/2014/3DShapeNets' root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data') path = download_url(f'{url}/ModelNet10.zip', root_dir) root_dir = osp.join(root_dir, 'ModelNet10') if not osp.exists(root_dir): extract_zip(path, root_dir) def is_train(path: str) -> bool: return 'train' in path datapipe = FileLister([root_dir], masks='*.off', recursive=True) datapipe = datapipe.filter(is_train) datapipe = datapipe.read_mesh() datapipe, = tee(datapipe, 1) datapipe = IterableWrapper(datapipe) datapipe = datapipe.sample_points(1024) # Use PyG transforms from here. datapipe = datapipe.knn_graph(k=8) return datapipe DATAPIPES = { 'molecule': molecule_datapipe, 'mesh': mesh_datapipe, } if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--task', default='molecule', choices=DATAPIPES.keys()) args = parser.parse_args() datapipe = DATAPIPES[args.task]() print('Example output:') print(next(iter(datapipe))) # Shuffling + Batching support: datapipe = datapipe.shuffle() datapipe = datapipe.batch_graphs(batch_size=32) # The first epoch will take longer than the remaining ones... print('Iterating over all data...') t = time.perf_counter() for _ in datapipe: pass print(f'Done! [{time.perf_counter() - t:.2f}s]') print('Iterating over all data a second time...') t = time.perf_counter() for _ in datapipe: pass print(f'Done! [{time.perf_counter() - t:.2f}s]') ================================================ FILE: examples/dgcnn_classification.py ================================================ import argparse import os.path as osp import random import torch import torch.nn.functional as F from torch.nn import Linear import torch_geometric.transforms as T from torch_geometric.datasets import MedShapeNet, ModelNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, DynamicEdgeConv, global_max_pool parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( '--dataset', type=str, default='modelnet10', choices=['modelnet10', 'modelnet40', 'medshapenet'], help='Dataset name.', ) parser.add_argument( '--dataset_dir', type=str, default='./data', help='Root directory of dataset.', ) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--num_workers', type=int, default=6) parser.add_argument('--epochs', type=int, default=201) args = parser.parse_args() num_epochs = args.epochs num_workers = args.num_workers batch_size = args.batch_size root = osp.join(args.dataset_dir, args.dataset) print('The root is: ', root) pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) print('The Dataset is: ', args.dataset) if args.dataset == 'modelnet40': print('Loading training data') train_dataset = ModelNet(root, '40', True, transform, pre_transform) print('Loading test data') test_dataset = ModelNet(root, '40', False, transform, pre_transform) elif args.dataset == 'medshapenet': print('Loading dataset') dataset = MedShapeNet(root=root, size=50, pre_transform=pre_transform, transform=transform, force_reload=False) random.seed(42) train_indices = [] test_indices = [] for label in range(dataset.num_classes): by_class = [ i for i, data in enumerate(dataset) if int(data.y) == label ] random.shuffle(by_class) split_point = int(0.7 * len(by_class)) train_indices.extend(by_class[:split_point]) test_indices.extend(by_class[split_point:]) train_dataset = dataset[train_indices] test_dataset = dataset[test_indices] elif args.dataset == 'modelnet10': print('Loading training data') train_dataset = ModelNet(root, '10', True, transform, pre_transform) print('Loading test data') test_dataset = ModelNet(root, '10', False, transform, pre_transform) else: raise ValueError( f"Unknown dataset name '{args.dataset}'. " f"Available options: 'modelnet10', 'modelnet40', 'medshapenet'.") train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) print('Running model') class Net(torch.nn.Module): def __init__(self, out_channels, k=20, aggr='max'): super().__init__() self.conv1 = DynamicEdgeConv(MLP([2 * 3, 64, 64, 64]), k, aggr) self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128]), k, aggr) self.lin1 = Linear(128 + 64, 1024) self.mlp = MLP([1024, 512, 256, out_channels], dropout=0.5, norm=None) def forward(self, data): pos, batch = data.pos, data.batch x1 = self.conv1(pos, batch) x2 = self.conv2(x1, batch) out = self.lin1(torch.cat([x1, x2], dim=1)) out = global_max_pool(out, batch) out = self.mlp(out) return F.log_softmax(out, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(train_dataset.num_classes, k=20).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data) loss = F.nll_loss(out, data.y) loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_dataset) def test(loader): model.eval() correct = 0 for data in loader: data = data.to(device) with torch.no_grad(): pred = model(data).max(dim=1)[1] correct += pred.eq(data.y).sum().item() return correct / len(loader.dataset) for epoch in range(1, num_epochs): loss = train() test_acc = test(test_loader) print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Test: {test_acc:.4f}') scheduler.step() ================================================ FILE: examples/dgcnn_segmentation.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torchmetrics.functional import jaccard_index import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, DynamicEdgeConv from torch_geometric.utils import scatter category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') transform = T.Compose([ T.RandomJitter(0.01), T.RandomRotate(15, axis=0), T.RandomRotate(15, axis=1), T.RandomRotate(15, axis=2) ]) pre_transform = T.NormalizeScale() train_dataset = ShapeNet(path, category, split='trainval', transform=transform, pre_transform=pre_transform) test_dataset = ShapeNet(path, category, split='test', pre_transform=pre_transform) train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=6) test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=6) class Net(torch.nn.Module): def __init__(self, out_channels, k=30, aggr='max'): super().__init__() self.conv1 = DynamicEdgeConv(MLP([2 * 6, 64, 64]), k, aggr) self.conv2 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr) self.conv3 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr) self.mlp = MLP([3 * 64, 1024, 256, 128, out_channels], dropout=0.5, norm=None) def forward(self, data): x, pos, batch = data.x, data.pos, data.batch x0 = torch.cat([x, pos], dim=-1) x1 = self.conv1(x0, batch) x2 = self.conv2(x1, batch) x3 = self.conv3(x2, batch) out = self.mlp(torch.cat([x1, x2, x3], dim=1)) return F.log_softmax(out, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(train_dataset.num_classes, k=30).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8) def train(): model.train() total_loss = correct_nodes = total_nodes = 0 for i, data in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() out = model(data) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() total_loss += loss.item() correct_nodes += out.argmax(dim=1).eq(data.y).sum().item() total_nodes += data.num_nodes if (i + 1) % 10 == 0: print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} ' f'Train Acc: {correct_nodes / total_nodes:.4f}') total_loss = correct_nodes = total_nodes = 0 @torch.no_grad() def test(loader): model.eval() ious, categories = [], [] y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) outs = model(data) sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() for out, y, category in zip(outs.split(sizes), data.y.split(sizes), data.category.tolist()): category = list(ShapeNet.seg_classes.keys())[category] part = ShapeNet.seg_classes[category] part = torch.tensor(part, device=device) y_map[part] = torch.arange(part.size(0), device=device) iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y], num_classes=part.size(0), absent_score=1.0) ious.append(iou) categories.append(data.category) iou = torch.tensor(ious, device=device) category = torch.cat(categories, dim=0) mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. return float(mean_iou.mean()) # Global IoU. for epoch in range(1, 31): train() iou = test(test_loader) print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}') ================================================ FILE: examples/dir_gnn.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import WikipediaNetwork from torch_geometric.nn import DirGNNConv, GCNConv, SAGEConv parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='chameleon') parser.add_argument('--hidden_channels', type=int, default=128) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--alpha', type=float, default=1) parser.add_argument('--conv', type=str, default='gcn') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Wikipedia') dataset = WikipediaNetwork( root=path, name=args.dataset, transform=T.NormalizeFeatures(), ) data = dataset[0].to(device) data.train_mask = data.train_mask[:, 0] data.val_mask = data.val_mask[:, 0] data.test_mask = data.test_mask[:, 0] if args.conv == 'gcn': Conv = GCNConv elif args.conv == 'sage': Conv = SAGEConv else: raise NotImplementedError class DirGNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, alpha): super().__init__() self.conv1 = Conv(in_channels, hidden_channels) self.conv1 = DirGNNConv(self.conv1, alpha, root_weight=False) self.conv2 = Conv(hidden_channels, out_channels) self.conv2 = DirGNNConv(self.conv2, alpha, root_weight=False) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x model = DirGNN( dataset.num_features, args.hidden_channels, dataset.num_classes, alpha=args.alpha, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index).argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs best_val_acc = final_test_acc = 0 for epoch in range(1, args.epochs + 1): loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/distributed/README.md ================================================ # Examples for Distributed Graph Learning This directory contains examples for distributed graph learning. The examples are organized into two subdirectories: 1. [`graphlearn_for_pytorch`](./graphlearn_for_pytorch): Distributed training via the external [GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch) package. 1. [`kuzu`](./kuzu): Remote backend via the [Kùzu](https://kuzudb.com/) graph database. ================================================ FILE: examples/distributed/graphlearn_for_pytorch/README.md ================================================ # Using GraphLearn-for-PyTorch (GLT) for Distributed Training with PyG **[GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch)** is a graph learning library for PyTorch that makes distributed GNN training easy and efficient. GLT leverages GPUs to accelerate graph sampling and utilizes UVA and GPU caches to reduce the data conversion and transferring costs during graph sampling and model training. Most of the APIs of GLT are compatible with PyG, so PyG users only need to modify a few lines of their PyG code to train their model with GLT. ## Requirements - `python >= 3.6` - `torch >= 1.12` - `graphlearn-torch` ## Distributed (Multi-Node) Example This example shows how to leverage [GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch) to train PyG models in a distributed scenario with GPUs. The dataset in this example is `ogbn-products` from the [Open Graph Benchmark](https://ogb.stanford.edu/), but you can also train on `ogbn-papers100M` with only minor modifications. To run this example, you can run the example as described below or directly make use of our [`launch.py`](launch.py) script. The training results will be generated and saved in `dist_sage_sup.txt`. ### Running the Example #### Step 1: Prepare and partition the data Here, we use `ogbn-products` and partition it into two partitions: ```bash python partition_ogbn_dataset.py --dataset=ogbn-products --root_dir=../../../data/ogbn-products --num_partitions=2 ``` #### Step 2: Run the example in each training node For example, running the example in two nodes each with two GPUs: ```bash # Node 0: CUDA_VISIBLE_DEVICES=0,1 python dist_train_sage_supervised.py \ --num_nodes=2 --node_rank=0 --master_addr=localhost \ --dataset=ogbn-products --dataset_root_dir=../../../data/ogbn-products \ --in_channel=100 --out_channel=47 # Node 1: CUDA_VISIBLE_DEVICES=2,3 python dist_train_sage_supervised.py \ --num_nodes=2 --node_rank=1 --master_addr=localhost \ --dataset=ogbn-products --dataset_root_dir=../../../data/ogbn-products \ --in_channel=100 --out_channel=47 ``` **Notes:** 1. You should change the `master_addr` to the IP of `node#0`. 1. Since there is randomness during data partitioning, please ensure all nodes are using the same partitioned data when running `dist_train_sage_supervised.py`. ### Using the `launch.py` Script #### Step 1: Setup a distributed file system **Note**: You may skip this step if you already set up folder(s) synchronized across machines. To perform distributed sampling, files and codes need to be accessed across multiple machines. A distributed file system (*i.e.*, [NFS](https://wiki.archlinux.org/index.php/NFS), [SSHFS](https://www.digitalocean.com/community/tutorials/how-to-use-sshfs-to-mount-remote-file-systems-over-ssh), [Ceph](https://docs.ceph.com/en/latest/install), ...) exempts you from synchnonizing files such as partition information. #### Step 2: Prepare and partition the data In distributed training (under the worker mode), each node in the cluster holds a partition of the graph. Thus, before the training starts, we partition the `ogbn-products` dataset into multiple partitions, each of which corresponds to a specific training worker. The partitioning occurs in three steps: 1. Run the partition algorithm to assign nodes to partitions. 1. Construct the partitioned graph structure based on the node assignment. 1. Split the node features and edge features into partitions. GLT supports caching graph topology and frequently accessed features in GPU to accelerate GPU sampling and feature collection. For feature caching, we adopt a pre-sampling-based approach to determine the hotness of nodes, and cache features for nodes with higher hotness while loading the graph. The uncached features are stored in pinned memory for efficient access via UVA. For further information about partitioning, please refer to the [official tutorial](https://github.com/alibaba/graphlearn-for-pytorch/blob/main/docs/tutorial/dist.md). Here, we use `ogbn-products` and partition it into two partitions: ```bash python partition_ogbn_dataset.py --dataset=ogbn-products --root_dir=../../../data/ogbn-products --num_partitions=2 ``` #### Step 3: Set up the configure file An example configuration file in given via [`dist_train_sage_sup_config.yml`](dist_train_sage_sup_config.yml). #### Step 4: Launch the distributed training ```bash pip install paramiko pip install click apt install tmux python launch.py --config=dist_train_sage_sup_config.yml --master_addr=0.0.0.0 --master_port=11234 ``` Here, `master_addr` is for the master RPC address, and `master_port` is for PyTorch's process group initialization across training processes. Note that you should change the `master_addr` to the IP of `node#0`. ================================================ FILE: examples/distributed/graphlearn_for_pytorch/dist_train_sage_sup_config.yml ================================================ # IP addresses for all nodes. # Note: The first 3 params are expected to form usernames@nodes:ports. nodes: - 0.0.0.0 - 1.1.1.1 # SSH ports for each node: ports: [22, 22] # Username for remote IPs: usernames: - your_username_for_node_0 - your_username_for_node_1 # Path to Python with GLT environment for each node: python_bins: - /path/to/python - /path/to/python # The dataset name, e.g., ogbn-products, ogbn-papers100M. # Note: make sure the name of dataset_root_dir is the same as the dataset name. dataset: ogbn-products # `in_channel` and `out_channel` of the dataset, e.g.,: # - ogbn-products: in_channel=100, out_channel=47 # - ogbn-papers100M: in_channel=128, out_channel=172 in_channel: 100 out_channel: 47 # Path to the pytorch_geometric directory: dst_paths: - /path/to/pytorch_geometric - /path/to/pytorch_geometric # Setup visible CUDA devices for each node: visible_devices: - 0,1,2,3 - 0,1,2,3 ================================================ FILE: examples/distributed/graphlearn_for_pytorch/dist_train_sage_supervised.py ================================================ import argparse import os.path as osp import time import graphlearn_torch as glt import torch import torch.distributed import torch.nn.functional as F from ogb.nodeproppred import Evaluator from torch import Tensor from torch.nn.parallel import DistributedDataParallel from torch_geometric.io import fs from torch_geometric.nn import GraphSAGE @torch.no_grad() def test(model, test_loader, dataset_name): evaluator = Evaluator(name=dataset_name) model.eval() xs = [] y_true = [] for i, batch in enumerate(test_loader): if i == 0: device = batch.x.device x = model(batch.x, batch.edge_index)[:batch.batch_size] xs.append(x.cpu()) y_true.append(batch.y[:batch.batch_size].cpu()) xs = [t.to(device) for t in xs] y_true = [t.to(device) for t in y_true] y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True) y_true = torch.cat(y_true, dim=0).unsqueeze(-1) test_acc = evaluator.eval({ 'y_true': y_true, 'y_pred': y_pred, })['acc'] return test_acc def run_training_proc( local_proc_rank: int, num_nodes: int, node_rank: int, num_training_procs_per_node: int, dataset_name: str, in_channels: int, out_channels: int, dataset: glt.distributed.DistDataset, train_idx: Tensor, test_idx: Tensor, epochs: int, batch_size: int, master_addr: str, training_pg_master_port: int, train_loader_master_port: int, test_loader_master_port: int, ): # Initialize graphlearn_torch distributed worker group context: glt.distributed.init_worker_group( world_size=num_nodes * num_training_procs_per_node, rank=node_rank * num_training_procs_per_node + local_proc_rank, group_name='distributed-sage-supervised-trainer') current_ctx = glt.distributed.get_context() current_device = torch.device(local_proc_rank % torch.cuda.device_count()) # Initialize training process group of PyTorch: torch.distributed.init_process_group( backend='nccl', # or choose 'gloo' if 'nccl' is not supported. rank=current_ctx.rank, world_size=current_ctx.world_size, init_method=f'tcp://{master_addr}:{training_pg_master_port}', ) # Create distributed neighbor loader for training. # We replace PyG's NeighborLoader with GLT's DistNeighborLoader. # GLT parameters for sampling are quite similar to PyG. # We only need to configure additional network and device parameters: train_idx = train_idx.split( train_idx.size(0) // num_training_procs_per_node)[local_proc_rank] train_loader = glt.distributed.DistNeighborLoader( data=dataset, num_neighbors=[15, 10, 5], input_nodes=train_idx, batch_size=batch_size, shuffle=True, collect_features=True, to_device=current_device, worker_options=glt.distributed.MpDistSamplingWorkerOptions( num_workers=1, worker_devices=[current_device], worker_concurrency=4, master_addr=master_addr, master_port=train_loader_master_port, channel_size='1GB', pin_memory=True, ), ) # Create distributed neighbor loader for testing. test_idx = test_idx.split(test_idx.size(0) // num_training_procs_per_node)[local_proc_rank] test_loader = glt.distributed.DistNeighborLoader( data=dataset, num_neighbors=[15, 10, 5], input_nodes=test_idx, batch_size=batch_size, shuffle=False, collect_features=True, to_device=current_device, worker_options=glt.distributed.MpDistSamplingWorkerOptions( num_workers=2, worker_devices=[ torch.device('cuda', i % torch.cuda.device_count()) for i in range(2) ], worker_concurrency=4, master_addr=master_addr, master_port=test_loader_master_port, channel_size='2GB', pin_memory=True, ), ) # Define the model and optimizer. torch.cuda.set_device(current_device) model = GraphSAGE( in_channels=in_channels, hidden_channels=256, num_layers=3, out_channels=out_channels, ).to(current_device) model = DistributedDataParallel(model, device_ids=[current_device.index]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Train and test: f = open('dist_sage_sup.txt', 'a+') for epoch in range(0, epochs): model.train() start = time.time() for batch in train_loader: optimizer.zero_grad() out = model(batch.x, batch.edge_index)[:batch.batch_size] loss = F.cross_entropy(out, batch.y[:batch.batch_size].long()) loss.backward() optimizer.step() f.write(f'-- [Trainer {current_ctx.rank}] Epoch: {epoch:03d}, ' f'Loss: {loss:.4f}, Epoch Time: {time.time() - start}\n') torch.cuda.synchronize() torch.distributed.barrier() if epoch == 0 or epoch > (epochs // 2): test_acc = test(model, test_loader, dataset_name) f.write(f'-- [Trainer {current_ctx.rank}] ' f'Test Acc: {test_acc:.4f}\n') torch.cuda.synchronize() torch.distributed.barrier() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--dataset', type=str, default='ogbn-products', help='The name of the dataset', ) parser.add_argument( '--in_channel', type=int, default=100, help='Number of input features of the dataset', ) parser.add_argument( '--out_channel', type=int, default=47, help='Number of classes of the dataset', ) parser.add_argument( '--num_dataset_partitions', type=int, default=2, help='The number of partitions', ) parser.add_argument( '--dataset_root_dir', type=str, default='../../../data/products', help='The root directory (relative path) of the partitioned dataset', ) parser.add_argument( '--num_nodes', type=int, default=2, help='Number of distributed nodes', ) parser.add_argument( '--node_rank', type=int, default=0, help='The current node rank', ) parser.add_argument( '--num_training_procs', type=int, default=2, help='The number of training processes per node', ) parser.add_argument( '--epochs', type=int, default=10, help='The number of training epochs', ) parser.add_argument( '--batch_size', type=int, default=512, help='The batch size for the training and testing data loaders', ) parser.add_argument( '--master_addr', type=str, default='localhost', help='The master address for RPC initialization', ) parser.add_argument( '--training_pg_master_port', type=int, default=11111, help="The port used for PyTorch's process group initialization", ) parser.add_argument( '--train_loader_master_port', type=int, default=11112, help='The port used for RPC initialization for training', ) parser.add_argument( '--test_loader_master_port', type=int, default=11113, help='The port used for RPC initialization for testing', ) args = parser.parse_args() # Record configuration information for debugging f = open('dist_sage_sup.txt', 'a+') f.write('--- Distributed training example of supervised SAGE ---\n') f.write(f'* dataset: {args.dataset}\n') f.write(f'* dataset root dir: {args.dataset_root_dir}\n') f.write(f'* number of dataset partitions: {args.num_dataset_partitions}\n') f.write(f'* total nodes: {args.num_nodes}\n') f.write(f'* node rank: {args.node_rank}\n') f.write(f'* number of training processes per node: ' f'{args.num_training_procs}\n') f.write(f'* epochs: {args.epochs}\n') f.write(f'* batch size: {args.batch_size}\n') f.write(f'* master addr: {args.master_addr}\n') f.write(f'* training process group master port: ' f'{args.training_pg_master_port}\n') f.write(f'* training loader master port: ' f'{args.train_loader_master_port}\n') f.write(f'* testing loader master port: {args.test_loader_master_port}\n') f.write('--- Loading data partition ...\n') root_dir = osp.join(osp.dirname(osp.realpath(__file__)), args.dataset_root_dir) data_pidx = args.node_rank % args.num_dataset_partitions dataset = glt.distributed.DistDataset() label_file = osp.join(root_dir, f'{args.dataset}-label', 'label.pt') dataset.load( root_dir=osp.join(root_dir, f'{args.dataset}-partitions'), partition_idx=data_pidx, graph_mode='ZERO_COPY', whole_node_label_file=label_file, ) train_file = osp.join(root_dir, f'{args.dataset}-train-partitions', f'partition{data_pidx}.pt') train_idx = fs.torch_load(train_file) test_file = osp.join(root_dir, f'{args.dataset}-test-partitions', f'partition{data_pidx}.pt') test_idx = fs.torch_load(test_file) train_idx.share_memory_() test_idx.share_memory_() f.write('--- Launching training processes ...\n') torch.multiprocessing.spawn( run_training_proc, args=( args.num_nodes, args.node_rank, args.num_training_procs, args.dataset, args.in_channel, args.out_channel, dataset, train_idx, test_idx, args.epochs, args.batch_size, args.master_addr, args.training_pg_master_port, args.train_loader_master_port, args.test_loader_master_port, ), nprocs=args.num_training_procs, join=True, ) ================================================ FILE: examples/distributed/graphlearn_for_pytorch/launch.py ================================================ import argparse import click import paramiko import yaml if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--config', type=str, default='dist_train_sage_sup_config.yml', help='The path to the configuration file', ) parser.add_argument( '--epochs', type=int, default=10, help='The number of training epochs', ) parser.add_argument( '--batch_size', type=int, default=512, help='The batch size for the training and testing data loaders', ) parser.add_argument( '--master_addr', type=str, default='0.0.0.0', help='Master IP address for synchronization across all training nodes', ) parser.add_argument( '--master_port', type=str, default='11345', help='The port for synchronization across all training nodes', ) args = parser.parse_args() config = open(args.config) config = yaml.safe_load(config) dataset = config['dataset'] ip_list = config['nodes'] port_list = config['ports'] username_list = config['usernames'] dst_path_list = config['dst_paths'] node_ranks = list(range(len(ip_list))) num_nodes = len(node_ranks) visible_devices = config['visible_devices'] python_bins = config['python_bins'] num_cores = len(str(visible_devices[0]).split(',')) in_channel = str(config['in_channel']) out_channel = str(config['out_channel']) dataset_path = '../../../data/' passwd_dict = {} for username, ip in zip(username_list, ip_list): passwd_dict[ip + username] = click.prompt( f'Password for {username}@{ip}', hide_input=True) for username, ip, port, dst, noderk, device, pythonbin in zip( username_list, ip_list, port_list, dst_path_list, node_ranks, visible_devices, python_bins, ): trans = paramiko.Transport((ip, port)) trans.connect(username=username, password=passwd_dict[ip + username]) ssh = paramiko.SSHClient() ssh._transport = trans to_dist_dir = 'cd ' + dst + \ '/examples/distributed/graphlearn_for_pytorch/ ' exec_example = "tmux new -d 'CUDA_VISIBLE_DEVICES=" + str(device) + \ " " + pythonbin + " dist_train_sage_supervised.py --dataset=" + \ dataset + " --dataset_root_dir=" + dataset_path + dataset + \ " --in_channel=" + in_channel + " --out_channel=" + out_channel + \ " --node_rank=" + str(noderk) + " --num_dataset_partitions=" + \ str(num_nodes) + " --num_nodes=" + str(num_nodes) + \ " --num_training_procs=" + str(num_cores) + " --master_addr=" + \ args.master_addr + " --training_pg_master_port=" + \ args.master_port + " --train_loader_master_port=" + \ str(int(args.master_port) + 1) + " --test_loader_master_port=" + \ str(int(args.master_port) + 2) + " --batch_size=" + \ str(args.batch_size) + " --epochs=" + str(args.epochs) print(to_dist_dir + ' && ' + exec_example + " '") stdin, stdout, stderr = ssh.exec_command( to_dist_dir + ' && ' + exec_example + " '", bufsize=1) print(stdout.read().decode()) print(stderr.read().decode()) ssh.close() ================================================ FILE: examples/distributed/graphlearn_for_pytorch/partition_ogbn_dataset.py ================================================ import argparse import ast import os.path as osp import graphlearn_torch as glt import torch from ogb.nodeproppred import PygNodePropPredDataset def partition_dataset( ogbn_dataset: str, root_dir: str, num_partitions: int, num_nbrs: glt.NumNeighbors, chunk_size: int, cache_ratio: float, ): ########################################################################### # In distributed training (under the worker mode), each node in the cluster # holds a partition of the graph. Thus before the training starts, we # partition the dataset into multiple partitions, each of which corresponds # to a specific training worker. # The partitioning occurs in three steps: # 1. Run a partition algorithm to assign nodes to partitions. # 2. Construct partition graph structure based on the node assignment. # 3. Split the node features and edge features based on the partition # result. ########################################################################### print(f'-- Loading {ogbn_dataset} ...') dataset = PygNodePropPredDataset(ogbn_dataset, root_dir) data = dataset[0] print(f'* node count: {data.num_nodes}') print(f'* edge count: {data.num_edges}') split_idx = dataset.get_idx_split() print('-- Saving label ...') label_dir = osp.join(root_dir, f'{ogbn_dataset}-label') glt.utils.ensure_dir(label_dir) torch.save(data.y.squeeze(), osp.join(label_dir, 'label.pt')) print('-- Partitioning training idx ...') train_idx = split_idx['train'] train_idx = train_idx.split(train_idx.size(0) // num_partitions) train_idx_partitions_dir = osp.join( root_dir, f'{ogbn_dataset}-train-partitions', ) glt.utils.ensure_dir(train_idx_partitions_dir) for pidx in range(num_partitions): torch.save( train_idx[pidx], osp.join(train_idx_partitions_dir, f'partition{pidx}.pt'), ) print('-- Partitioning test idx ...') test_idx = split_idx['test'] test_idx = test_idx.split(test_idx.size(0) // num_partitions) test_idx_partitions_dir = osp.join( root_dir, f'{ogbn_dataset}-test-partitions', ) glt.utils.ensure_dir(test_idx_partitions_dir) for pidx in range(num_partitions): torch.save( test_idx[pidx], osp.join(test_idx_partitions_dir, f'partition{pidx}.pt'), ) print('-- Initializing graph ...') csr_topo = glt.data.Topology(edge_index=data.edge_index, input_layout='COO') graph = glt.data.Graph(csr_topo, mode='ZERO_COPY') print('-- Sampling hotness ...') glt_sampler = glt.sampler.NeighborSampler(graph, num_nbrs) node_probs = [] for pidx in range(num_partitions): seeds = train_idx[pidx] prob = glt_sampler.sample_prob(seeds, data.num_nodes) node_probs.append(prob.cpu()) print('-- Partitioning graph and features ...') partitions_dir = osp.join(root_dir, f'{ogbn_dataset}-partitions') freq_partitioner = glt.partition.FrequencyPartitioner( output_dir=partitions_dir, num_parts=num_partitions, num_nodes=data.num_nodes, edge_index=data.edge_index, probs=node_probs, node_feat=data.x, chunk_size=chunk_size, cache_ratio=cache_ratio, ) freq_partitioner.partition() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--dataset', type=str, default='ogbn-products', help='The name of the dataset', ) parser.add_argument( '--num_partitions', type=int, default=2, help='The Number of partitions', ) parser.add_argument( '--root_dir', type=str, default='../../../data/ogbn-products', help='The root directory (relative path) of the partitioned dataset', ) parser.add_argument( '--num_nbrs', type=ast.literal_eval, default='[15,10,5]', help='The number of neighbors to sample hotness for feature caching', ) parser.add_argument( '--chunk_size', type=int, default=10000, help='The chunk size for feature partitioning', ) parser.add_argument( '--cache_ratio', type=float, default=0.2, help='The proportion to cache features per partition', ) args = parser.parse_args() partition_dataset( ogbn_dataset=args.dataset, root_dir=osp.join(osp.dirname(osp.realpath(__file__)), args.root_dir), num_partitions=args.num_partitions, num_nbrs=args.num_nbrs, chunk_size=args.chunk_size, cache_ratio=args.cache_ratio, ) ================================================ FILE: examples/distributed/kuzu/README.md ================================================ # Using Kùzu as a Remote Backend for PyG [Kùzu](https://kuzudb.com/) is an in-process property graph database management system built for query speed and scalability. It provides an integration with PyG via the [remote backend interface](https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html) of PyG. The Python API of Kùzu outputs a [`torch_geometric.data.FeatureStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.FeatureStore.html) and a [`torch_geometric.data.GraphStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.GraphStore.html) that can be plugged directly into existing familiar PyG interfaces such as [`NeighborLoader`](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/neighbor_loader.html) and enables training GNNs directly on graphs stored in Kùzu. This is particularly useful if you would like to train graphs that don't fit on your CPU's memory. ## Installation You can install Kùzu as follows: ```bash pip install kuzu ``` ## Usage The API and design documentation of Kùzu can be found at [https://kuzudb.com/docs/](https://kuzudb.com/docs/). ## Examples We provide the following examples to showcase the usage of Kùzu remote backend within PyG: ### PubMed Open In Colab The PubMed example is hosted on [Google Colab](https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6). In this example, we work on a small dataset for demonstrative purposes. The [PubMed](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset consists of 19,717 papers as nodes and 88,648 citation relationships between them. ### `papers_100M` This example shows how to use the remote backend feature of Kùzu to work with a large graph of papers and citations on a single machine. The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/). The dataset contains approximately 111 million nodes and 1.6 billion edges. ================================================ FILE: examples/distributed/kuzu/papers_100M/README.md ================================================ # `papers_100M` Example This example shows how to use the remote backend feature of [Kùzu](https://kuzudb.com) to work with a large graph of papers and citations on a single machine. The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/). The dataset contains approximately 100 million nodes and 1.6 billion edges. ## Prepare the data 1. Download the dataset from [`http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip`](http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip) and put the `*.zip` file into this directory. 1. Run `python prepare_data.py`. The script will automatically extract the data and convert it to the format that Kùzu can read. A Kùzu database instance is then created under `papers_100M` and the data is loaded into the it. ## Train a Model Afterwards, run `python train.py` to train a three-layer [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) model on this dataset. ================================================ FILE: examples/distributed/kuzu/papers_100M/prepare_data.py ================================================ from multiprocessing import cpu_count from os import path from zipfile import ZipFile import kuzu import numpy as np from tqdm import tqdm with ZipFile("papers100M-bin.zip", 'r') as papers100M_zip: print('Extracting papers100M-bin.zip...') papers100M_zip.extractall() with ZipFile("papers100M-bin/raw/data.npz", 'r') as data_zip: print('Extracting data.npz...') data_zip.extractall() with ZipFile("papers100M-bin/raw/node-label.npz", 'r') as node_label_zip: print('Extracting node-label.npz...') node_label_zip.extractall() print("Converting edge_index to CSV...") edge_index = np.load('edge_index.npy', mmap_mode='r') csvfile = open('edge_index.csv', 'w') csvfile.write('src,dst\n') for i in tqdm(range(edge_index.shape[1])): csvfile.write(str(edge_index[0, i]) + ',' + str(edge_index[1, i]) + '\n') csvfile.close() print("Generating IDs for nodes...") node_year = np.load('node_year.npy', mmap_mode='r') length = node_year.shape[0] ids = np.arange(length) np.save('ids.npy', ids) ids_path = path.abspath(path.join('.', 'ids.npy')) edge_index_path = path.abspath(path.join('.', 'edge_index.csv')) node_label_path = path.abspath(path.join('.', 'node_label.npy')) node_feature_path = path.abspath(path.join('.', 'node_feat.npy')) node_year_path = path.abspath(path.join('.', 'node_year.npy')) print("Creating Kùzu database...") db = kuzu.Database('papers100M') conn = kuzu.Connection(db, num_threads=cpu_count()) print("Creating Kùzu tables...") conn.execute( "CREATE NODE TABLE paper(id INT64, x FLOAT[128], year INT64, y FLOAT, " "PRIMARY KEY (id));") conn.execute("CREATE REL TABLE cites(FROM paper TO paper, MANY_MANY);") print("Copying nodes to Kùzu tables...") conn.execute('COPY paper FROM ("%s", "%s", "%s", "%s") BY COLUMN;' % (ids_path, node_feature_path, node_year_path, node_label_path)) print("Copying edges to Kùzu tables...") conn.execute('COPY cites FROM "%s";' % (edge_index_path)) print("All done!") ================================================ FILE: examples/distributed/kuzu/papers_100M/train.py ================================================ import multiprocessing as mp import os.path as osp import kuzu import pandas as pd import torch import torch.nn.functional as F from tqdm import tqdm from torch_geometric.loader import NeighborLoader from torch_geometric.nn import MLP, BatchNorm, SAGEConv NUM_EPOCHS = 1 LOADER_BATCH_SIZE = 1024 print('Batch size:', LOADER_BATCH_SIZE) print('Number of epochs:', NUM_EPOCHS) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Using device:', device) # Load the train set: train_path = osp.join('.', 'papers100M-bin', 'split', 'time', 'train.csv.gz') train_df = pd.read_csv( osp.abspath(train_path), compression='gzip', header=None, ) input_nodes = torch.tensor(train_df[0].values, dtype=torch.long) ######################################################################## # The below code sets up the remote backend of Kùzu for PyG. # Please refer to: https://kuzudb.com/docs/client-apis/python-api/overview.html # for how to use the Python API of Kùzu. ######################################################################## # The buffer pool size of Kùzu is set to 40GB. You can change it to a smaller # value if you have less memory. KUZU_BM_SIZE = 40 * 1024**3 # Create Kùzu database: db = kuzu.Database(osp.abspath(osp.join('.', 'papers100M')), KUZU_BM_SIZE) # Get remote backend for PyG: feature_store, graph_store = db.get_torch_geometric_remote_backend( mp.cpu_count()) # Plug the graph store and feature store into the `NeighborLoader`. # Note that `filter_per_worker` is set to `False`. This is because the Kùzu # database is already using multi-threading to scan the features in parallel # and the database object is not fork-safe. loader = NeighborLoader( data=(feature_store, graph_store), num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]}, batch_size=LOADER_BATCH_SIZE, input_nodes=('paper', input_nodes), num_workers=4, filter_per_worker=False, ) class GraphSAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.2): super().__init__() self.convs = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) self.norms.append(BatchNorm(hidden_channels)) for _ in range(1, num_layers): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.norms.append(BatchNorm(hidden_channels)) self.mlp = MLP( in_channels=in_channels + num_layers * hidden_channels, hidden_channels=2 * out_channels, out_channels=out_channels, num_layers=2, norm='batch_norm', act='leaky_relu', ) self.dropout = dropout def forward(self, x, edge_index): x = F.dropout(x, p=self.dropout, training=self.training) xs = [x] for conv, norm in zip(self.convs, self.norms): x = conv(x, edge_index) x = norm(x) x = x.relu() x = F.dropout(x, p=self.dropout, training=self.training) xs.append(x) return self.mlp(torch.cat(xs, dim=-1)) model = GraphSAGE(in_channels=128, hidden_channels=1024, out_channels=172, num_layers=3).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) for epoch in range(1, NUM_EPOCHS + 1): total_loss = total_examples = 0 for batch in tqdm(loader): batch = batch.to(device) batch_size = batch['paper'].batch_size optimizer.zero_grad() out = model( batch['paper'].x, batch['paper', 'cites', 'paper'].edge_index, )[:batch_size] y = batch['paper'].y[:batch_size].long().view(-1) loss = F.cross_entropy(out, y) loss.backward() optimizer.step() total_loss += float(loss) * y.numel() total_examples += y.numel() print(f'Epoch: {epoch:02d}, Loss: {total_loss / total_examples:.4f}') ================================================ FILE: examples/distributed/pyg/README.md ================================================ # Distributed Training with PyG > **Deprecated:** `torch_geometric.distributed` is deprecated. > Please refer to [NVIDIA cuGraph-GNN](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html#accelerating-pyg-with-nvidia-cugraph-gnn) for scalable distributed GNN training with NVIDIA GPUs. ================================================ FILE: examples/dna.py ================================================ import os.path as osp import torch import torch.nn.functional as F from sklearn.model_selection import StratifiedKFold from torch_geometric.datasets import Planetoid from torch_geometric.nn import DNAConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset) data = dataset[0] data.train_mask = data.val_mask = data.test_mask = None def gen_uniform_20_20_60_split(data): skf = StratifiedKFold(5, shuffle=True, random_state=55) idx = [torch.from_numpy(i) for _, i in skf.split(data.y, data.y)] data.train_idx = idx[0].to(torch.long) data.val_idx = idx[1].to(torch.long) data.test_idx = torch.cat(idx[2:], dim=0).to(torch.long) return data data = gen_uniform_20_20_60_split(data) class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, heads=1, groups=1): super().__init__() self.hidden_channels = hidden_channels self.lin1 = torch.nn.Linear(in_channels, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): self.convs.append( DNAConv(hidden_channels, heads, groups, dropout=0.8)) self.lin2 = torch.nn.Linear(hidden_channels, out_channels) def reset_parameters(self): self.lin1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, x, edge_index): x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x_all = x.view(-1, 1, self.hidden_channels) for conv in self.convs: x = F.relu(conv(x_all, edge_index)) x = x.view(-1, 1, self.hidden_channels) x_all = torch.cat([x_all, x], dim=1) x = x_all[:, -1] x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return torch.log_softmax(x, dim=1) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model = Net(in_channels=dataset.num_features, hidden_channels=128, out_channels=dataset.num_classes, num_layers=5, heads=8, groups=16) model, data = model.to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_idx], data.y[data.train_idx]) loss.backward() optimizer.step() @torch.no_grad() def test(): model.eval() out, accs = model(data.x, data.edge_index), [] for _, idx in data('train_idx', 'val_idx', 'test_idx'): pred = out[idx].argmax(1) acc = pred.eq(data.y[idx]).sum().item() / idx.numel() accs.append(acc) return accs best_val_acc = test_acc = 0 for epoch in range(1, 201): train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/egc.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F from ogb.graphproppred import Evaluator from ogb.graphproppred import PygGraphPropPredDataset as OGBG from ogb.graphproppred.mol_encoder import AtomEncoder from torch.nn import BatchNorm1d, Linear, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau import torch_geometric.transforms as T from torch_geometric.loader import DataLoader from torch_geometric.nn import EGConv, global_mean_pool from torch_geometric.typing import WITH_TORCH_SPARSE if not WITH_TORCH_SPARSE: quit("This example requires 'torch-sparse'") parser = argparse.ArgumentParser() parser.add_argument('--use_multi_aggregators', action='store_true', help='Switch between EGC-S and EGC-M') args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB') dataset = OGBG('ogbg-molhiv', path, pre_transform=T.ToSparseTensor()) evaluator = Evaluator('ogbg-molhiv') split_idx = dataset.get_idx_split() train_dataset = dataset[split_idx['train']] val_dataset = dataset[split_idx['valid']] test_dataset = dataset[split_idx['test']] train_loader = DataLoader(train_dataset, batch_size=32, num_workers=4, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=256) test_loader = DataLoader(test_dataset, batch_size=256) class Net(torch.nn.Module): def __init__(self, hidden_channels, num_layers, num_heads, num_bases): super().__init__() if args.use_multi_aggregators: aggregators = ['sum', 'mean', 'max'] else: aggregators = ['symnorm'] self.encoder = AtomEncoder(hidden_channels) self.convs = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() for _ in range(num_layers): self.convs.append( EGConv(hidden_channels, hidden_channels, aggregators, num_heads, num_bases)) self.norms.append(BatchNorm1d(hidden_channels)) self.mlp = Sequential( Linear(hidden_channels, hidden_channels // 2, bias=False), BatchNorm1d(hidden_channels // 2), ReLU(inplace=True), Linear(hidden_channels // 2, hidden_channels // 4, bias=False), BatchNorm1d(hidden_channels // 4), ReLU(inplace=True), Linear(hidden_channels // 4, 1), ) def forward(self, x, adj_t, batch): adj_t = adj_t.set_value(None) # EGConv works without any edge features x = self.encoder(x) for conv, norm in zip(self.convs, self.norms): h = conv(x, adj_t) h = norm(h) h = h.relu_() x = x + h x = global_mean_pool(x, batch) return self.mlp(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(hidden_channels=236, num_layers=4, num_heads=4, num_bases=4).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, min_lr=1e-5) def train(): model.train() total_loss = total_examples = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.adj_t, data.batch) loss = F.binary_cross_entropy_with_logits(out, data.y.to(torch.float)) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs total_examples += data.num_graphs return total_loss / total_examples @torch.no_grad() def evaluate(loader): model.eval() y_pred, y_true = [], [] for data in loader: data = data.to(device) pred = model(data.x, data.adj_t, data.batch) y_pred.append(pred.cpu()) y_true.append(data.y.cpu()) y_true = torch.cat(y_true, dim=0) y_pred = torch.cat(y_pred, dim=0) return evaluator.eval({'y_true': y_true, 'y_pred': y_pred})['rocauc'] for epoch in range(1, 31): loss = train() val_rocauc = evaluate(val_loader) test_rocauc = evaluate(test_loader) scheduler.step(val_rocauc) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_rocauc:.4f}, ' f'Test: {test_rocauc:.4f}') ================================================ FILE: examples/equilibrium_median.py ================================================ r"""Replicates the experiment from `"Deep Graph Infomax" `_ to try and teach `EquilibriumAggregation` to learn to take the median of a set of numbers. This example converges slowly to being able to predict the median similar to what is observed in the paper. """ import numpy as np import torch from torch_geometric.nn import EquilibriumAggregation input_size = 100 steps = 10000000 embedding_size = 10 eval_each = 1000 model = EquilibriumAggregation(1, 10, [256, 256], 1) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) norm = torch.distributions.normal.Normal(0.5, 0.4) gamma = torch.distributions.gamma.Gamma(0.2, 0.5) uniform = torch.distributions.uniform.Uniform(0, 1) total_loss = 0 n_loss = 0 for i in range(1, steps + 1): optimizer.zero_grad() dist = np.random.choice([norm, gamma, uniform]) x = dist.sample((input_size, 1)) y = model(x) loss = (y - x.median()).norm(2) / input_size loss.backward() optimizer.step() total_loss += loss n_loss += 1 if i % eval_each == 0: print(f"Epoch: {i}, Loss {total_loss / n_loss:.6f}") ================================================ FILE: examples/explain/README.md ================================================ # Examples for Generating Explanations of Graph Neural Networks This directory contains examples demonstrating the use of the `torch_geometric.explain` package. The `explain` package of PyG provides a set of tools to explain the predictions of a GNN model or to explain the underlying phenomenon of a dataset. | Example | Description | | ---------------------------------------------------------------------- | ------------------------------------------------------- | | [`gnn_explainer.py`](./gnn_explainer.py) | `GNNExplainer` for node classification | | [`gnn_explainer_link_pred.py`](./gnn_explainer_link_pred.py) | `GNNExplainer` for link prediction | | [`gnn_explainer_ba_shapes.py`](./gnn_explainer_ba_shapes.py) | `GNNExplainer` applied on the `BAShapes` dataset | | [`captum_explainer.py`](./captum_explainer.py) | Captum-based explainer for node classification | | [`captum_explainer_hetero_link.py`](./captum_explainer_hetero_link.py) | Captum-based explainer for heterogenous link prediction | | [`graphmask_explainer.py`](./graphmask_explainer.py) | `GraphMaskExplainer` for node classification | ================================================ FILE: examples/explain/captum_explainer.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.explain import CaptumExplainer, Explainer from torch_geometric.nn import GCNConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, dataset) data = dataset[0] class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GCN().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) for _ in range(1, 201): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() explainer = Explainer( model=model, algorithm=CaptumExplainer('IntegratedGradients'), explanation_type='model', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), node_mask_type='attributes', edge_mask_type='object', threshold_config=dict( threshold_type='topk', value=200, ), ) node_index = 10 explanation = explainer(data.x, data.edge_index, index=node_index) print(f'Generated explanations in {explanation.available_explanations}') path = 'feature_importance.png' explanation.visualize_feature_importance(path, top_k=10) print(f"Feature importance plot has been saved to '{path}'") path = 'subgraph.pdf' explanation.visualize_graph(path) print(f"Subgraph plot has been saved to '{path}'") ================================================ FILE: examples/explain/captum_explainer_hetero_link.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear import torch_geometric.transforms as T from torch_geometric.datasets import MovieLens from torch_geometric.explain import CaptumExplainer, Explainer from torch_geometric.nn import SAGEConv, to_hetero device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') dataset = MovieLens(path, model_name='all-MiniLM-L6-v2') data = dataset[0].to(device) # Add user node features for message passing: data['user'].x = torch.eye(data['user'].num_nodes, device=device) del data['user'].num_nodes # Add a reverse ('movie', 'rev_rates', 'user') relation for message passing: data = T.ToUndirected()(data) data['user', 'movie'].edge_label = data['user', 'movie'].edge_label.to(torch.float) del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label. # Perform a link-level split into training, validation, and test edges: data, _, _ = T.RandomLinkSplit( num_val=0.1, num_test=0.1, neg_sampling_ratio=0.0, edge_types=[('user', 'rates', 'movie')], rev_edge_types=[('movie', 'rev_rates', 'user')], )(data) class GNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x class EdgeDecoder(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.lin1 = Linear(2 * hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, 1) def forward(self, z_dict, edge_label_index): row, col = edge_label_index z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1) z = self.lin1(z).relu() z = self.lin2(z) return z.view(-1) class Model(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.encoder = GNNEncoder(hidden_channels, hidden_channels) self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum') self.decoder = EdgeDecoder(hidden_channels) def forward(self, x_dict, edge_index_dict, edge_label_index): z_dict = self.encoder(x_dict, edge_index_dict) return self.decoder(z_dict, edge_label_index) model = Model(hidden_channels=32).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for _ in range(1, 10): model.train() optimizer.zero_grad() pred = model( data.x_dict, data.edge_index_dict, data['user', 'movie'].edge_label_index, ) loss = F.mse_loss(pred, data['user', 'movie'].edge_label) loss.backward() optimizer.step() explainer = Explainer( model=model, algorithm=CaptumExplainer('IntegratedGradients'), explanation_type='model', model_config=dict( mode='regression', task_level='edge', return_type='raw', ), node_mask_type='attributes', edge_mask_type='object', threshold_config=dict( threshold_type='topk', value=200, ), ) index = torch.tensor([2, 10]) # Explain edge labels with index 2 and 10. explanation = explainer( data.x_dict, data.edge_index_dict, index=index, edge_label_index=data['user', 'movie'].edge_label_index, ) print(f'Generated explanations in {explanation.available_explanations}') path = 'feature_importance.png' explanation.visualize_feature_importance(path, top_k=10) print(f"Feature importance plot has been saved to '{path}'") ================================================ FILE: examples/explain/gnn_explainer.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.explain import Explainer, GNNExplainer from torch_geometric.nn import GCNConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, dataset) data = dataset[0] class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GCN().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) for _ in range(1, 201): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), ) node_index = 10 explanation = explainer(data.x, data.edge_index, index=node_index) print(f'Generated explanations in {explanation.available_explanations}') path = 'feature_importance.png' explanation.visualize_feature_importance(path, top_k=10) print(f"Feature importance plot has been saved to '{path}'") path = 'subgraph.pdf' explanation.visualize_graph(path) print(f"Subgraph visualization plot has been saved to '{path}'") ================================================ FILE: examples/explain/gnn_explainer_ba_shapes.py ================================================ import torch import torch.nn.functional as F from sklearn.metrics import roc_auc_score from sklearn.model_selection import train_test_split from tqdm import tqdm import torch_geometric.transforms as T from torch_geometric.datasets import ExplainerDataset from torch_geometric.datasets.graph_generator import BAGraph from torch_geometric.explain import Explainer, GNNExplainer from torch_geometric.nn import GCN from torch_geometric.utils import k_hop_subgraph dataset = ExplainerDataset( graph_generator=BAGraph(num_nodes=300, num_edges=5), motif_generator='house', num_motifs=80, transform=T.Constant(), ) data = dataset[0] idx = torch.arange(data.num_nodes) train_idx, test_idx = train_test_split(idx, train_size=0.8, stratify=data.y) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = data.to(device) model = GCN(data.num_node_features, hidden_channels=20, num_layers=3, out_channels=dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[train_idx], data.y[train_idx]) torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index).argmax(dim=-1) train_correct = int((pred[train_idx] == data.y[train_idx]).sum()) train_acc = train_correct / train_idx.size(0) test_correct = int((pred[test_idx] == data.y[test_idx]).sum()) test_acc = test_correct / test_idx.size(0) return train_acc, test_acc pbar = tqdm(range(1, 2001)) for epoch in pbar: loss = train() if epoch == 1 or epoch % 200 == 0: train_acc, test_acc = test() pbar.set_description(f'Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Test: {test_acc:.4f}') pbar.close() model.eval() for explanation_type in ['phenomenon', 'model']: explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=300), explanation_type=explanation_type, node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='raw', ), ) # Explanation ROC AUC over all test nodes: targets, preds = [], [] node_indices = range(400, data.num_nodes, 5) for node_index in tqdm(node_indices, leave=False, desc='Train Explainer'): target = data.y if explanation_type == 'phenomenon' else None explanation = explainer(data.x, data.edge_index, index=node_index, target=target) _, _, _, hard_edge_mask = k_hop_subgraph(node_index, num_hops=3, edge_index=data.edge_index) targets.append(data.edge_mask[hard_edge_mask].cpu()) preds.append(explanation.edge_mask[hard_edge_mask].cpu()) auc = roc_auc_score(torch.cat(targets), torch.cat(preds)) print(f'Mean ROC AUC (explanation type {explanation_type:10}): {auc:.4f}') ================================================ FILE: examples/explain/gnn_explainer_link_pred.py ================================================ import os.path as osp import torch import torch.nn.functional as F from sklearn.metrics import roc_auc_score import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig from torch_geometric.nn import GCNConv if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') transform = T.Compose([ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True), ]) dataset = Planetoid(path, dataset, transform=transform) train_data, val_data, test_data = dataset[0] class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def encode(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x def decode(self, z, edge_label_index): src, dst = edge_label_index return (z[src] * z[dst]).sum(dim=-1) def forward(self, x, edge_index, edge_label_index): z = model.encode(x, edge_index) return model.decode(z, edge_label_index).view(-1) model = GCN(dataset.num_features, 128, 64).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) def train(): model.train() optimizer.zero_grad() out = model(train_data.x, train_data.edge_index, train_data.edge_label_index) loss = F.binary_cross_entropy_with_logits(out, train_data.edge_label) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(data): model.eval() out = model(data.x, data.edge_index, data.edge_label_index).sigmoid() return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) for epoch in range(1, 201): loss = train() if epoch % 20 == 0: val_auc = test(val_data) test_auc = test(test_data) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ' f'Test: {test_auc:.4f}') model_config = ModelConfig( mode='binary_classification', task_level='edge', return_type='raw', ) # Explain model output for a single edge: edge_label_index = val_data.edge_label_index[:, 0] explainer = Explainer( model=model, explanation_type='model', algorithm=GNNExplainer(epochs=200), node_mask_type='attributes', edge_mask_type='object', model_config=model_config, ) explanation = explainer( x=train_data.x, edge_index=train_data.edge_index, edge_label_index=edge_label_index, ) print(f'Generated model explanations in {explanation.available_explanations}') # Explain a selected target (phenomenon) for a single edge: edge_label_index = val_data.edge_label_index[:, 0] target = val_data.edge_label[0].unsqueeze(dim=0).long() explainer = Explainer( model=model, explanation_type='phenomenon', algorithm=GNNExplainer(epochs=200), node_mask_type='attributes', edge_mask_type='object', model_config=model_config, ) explanation = explainer( x=train_data.x, edge_index=train_data.edge_index, target=target, edge_label_index=edge_label_index, ) available_explanations = explanation.available_explanations print(f'Generated phenomenon explanations in {available_explanations}') ================================================ FILE: examples/explain/graphmask_explainer.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.explain import Explainer, GraphMaskExplainer from torch_geometric.nn import GATConv, GCNConv device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid') dataset = Planetoid(path, name='Cora') data = dataset[0].to(device) # GCN Node Classification ===================================================== class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) model = GCN().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) for _ in range(1, 201): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() explainer = Explainer( model=model, algorithm=GraphMaskExplainer(2, epochs=5), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), ) node_index = 10 explanation = explainer(data.x, data.edge_index, index=node_index) print(f'Generated explanations in {explanation.available_explanations}') # GAT Node Classification ===================================================== class GAT(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GATConv(dataset.num_features, 8, heads=8) self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) model = GAT().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) for _ in range(1, 201): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() explainer = Explainer( model=model, algorithm=GraphMaskExplainer(2, epochs=5), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), ) node_index = torch.tensor([10, 20]) explanation = explainer(data.x, data.edge_index, index=node_index) print(f'Generated explanations in {explanation.available_explanations}') ================================================ FILE: examples/faust.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import FAUST from torch_geometric.loader import DataLoader from torch_geometric.nn import SplineConv from torch_geometric.typing import WITH_SPLINE if not WITH_SPLINE: quit("This example requires 'pyg-lib>=0.6.0'") path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FAUST') pre_transform = T.Compose([T.FaceToEdge(), T.Constant(value=1)]) train_dataset = FAUST(path, True, T.Cartesian(), pre_transform) test_dataset = FAUST(path, False, T.Cartesian(), pre_transform) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=1) d = train_dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add') self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add') self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add') self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add') self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add') self.conv6 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add') self.lin1 = torch.nn.Linear(64, 256) self.lin2 = torch.nn.Linear(256, d.num_nodes) def forward(self, data): x, edge_index, pseudo = data.x, data.edge_index, data.edge_attr x = F.elu(self.conv1(x, edge_index, pseudo)) x = F.elu(self.conv2(x, edge_index, pseudo)) x = F.elu(self.conv3(x, edge_index, pseudo)) x = F.elu(self.conv4(x, edge_index, pseudo)) x = F.elu(self.conv5(x, edge_index, pseudo)) x = F.elu(self.conv6(x, edge_index, pseudo)) x = F.elu(self.lin1(x)) x = F.dropout(x, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) target = torch.arange(d.num_nodes, dtype=torch.long, device=device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(epoch): model.train() if epoch == 61: for param_group in optimizer.param_groups: param_group['lr'] = 0.001 for data in train_loader: optimizer.zero_grad() F.nll_loss(model(data.to(device)), target).backward() optimizer.step() def test(): model.eval() correct = 0 for data in test_loader: pred = model(data.to(device)).max(1)[1] correct += pred.eq(target).sum().item() return correct / (len(test_dataset) * d.num_nodes) for epoch in range(1, 101): train(epoch) test_acc = test() print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}') ================================================ FILE: examples/film.py ================================================ import os.path as osp import torch import torch.nn.functional as F from sklearn.metrics import f1_score from torch.nn import BatchNorm1d from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader from torch_geometric.nn import FiLMConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI') train_dataset = PPI(path, split='train') val_dataset = PPI(path, split='val') test_dataset = PPI(path, split='test') train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.0): super().__init__() self.dropout = dropout self.convs = torch.nn.ModuleList() self.convs.append(FiLMConv(in_channels, hidden_channels)) for _ in range(num_layers - 2): self.convs.append(FiLMConv(hidden_channels, hidden_channels)) self.convs.append(FiLMConv(hidden_channels, out_channels, act=None)) self.norms = torch.nn.ModuleList() for _ in range(num_layers - 1): self.norms.append(BatchNorm1d(hidden_channels)) def forward(self, x, edge_index): for conv, norm in zip(self.convs[:-1], self.norms): x = norm(conv(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.convs[-1](x, edge_index) return x if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model = Net(in_channels=train_dataset.num_features, hidden_channels=320, out_channels=train_dataset.num_classes, num_layers=4, dropout=0.1).to(device) criterion = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() loss = criterion(model(data.x, data.edge_index), data.y) total_loss += loss.item() * data.num_graphs loss.backward() optimizer.step() return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() ys, preds = [], [] for data in loader: ys.append(data.y) out = model(data.x.to(device), data.edge_index.to(device)) preds.append((out > 0).float().cpu()) y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 for epoch in range(1, 501): loss = train() val_f1 = test(val_loader) test_f1 = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, ' f'Test: {test_f1:.4f}') ================================================ FILE: examples/gat.py ================================================ import argparse import os.path as osp import time import torch import torch.nn.functional as F import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.logging import init_wandb, log from torch_geometric.nn import GATConv parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='Cora') parser.add_argument('--hidden_channels', type=int, default=8) parser.add_argument('--heads', type=int, default=8) parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--wandb', action='store_true', help='Track experiment') args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs, hidden_channels=args.hidden_channels, lr=args.lr, device=device) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures()) data = dataset[0].to(device) class GAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, heads): super().__init__() self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6) # On the Pubmed dataset, use `heads` output heads in `conv2`. self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6) def forward(self, x, edge_index): x = F.dropout(x, p=0.6, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x, edge_index) return x model = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes, args.heads).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss.detach()) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index).argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs times = [] best_val_acc = final_test_acc = 0 for epoch in range(1, args.epochs + 1): start = time.time() loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc) times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/gcn.py ================================================ import argparse import os.path as osp import time import torch import torch.nn.functional as F import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.logging import init_wandb, log from torch_geometric.nn import GCNConv parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='Cora') parser.add_argument('--hidden_channels', type=int, default=16) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--use_gdc', action='store_true', help='Use GDC') parser.add_argument('--wandb', action='store_true', help='Track experiment') args = parser.parse_args() device = torch_geometric.device('auto') init_wandb( name=f'GCN-{args.dataset}', lr=args.lr, epochs=args.epochs, hidden_channels=args.hidden_channels, device=device, ) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures()) data = dataset[0].to(device) if args.use_gdc: transform = T.GDC( self_loop_weight=1, normalization_in='sym', normalization_out='col', diffusion_kwargs=dict(method='ppr', alpha=0.05), sparsification_kwargs=dict(method='topk', k=128, dim=0), exact=True, ) data = transform(data) class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels, normalize=not args.use_gdc) self.conv2 = GCNConv(hidden_channels, out_channels, normalize=not args.use_gdc) def forward(self, x, edge_index, edge_weight=None): x = F.dropout(x, p=0.5, training=self.training) x = self.conv1(x, edge_index, edge_weight).relu() x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index, edge_weight) return x model = GCN( in_channels=dataset.num_features, hidden_channels=args.hidden_channels, out_channels=dataset.num_classes, ).to(device) optimizer = torch.optim.Adam([ dict(params=model.conv1.parameters(), weight_decay=5e-4), dict(params=model.conv2.parameters(), weight_decay=0) ], lr=args.lr) # Only perform weight-decay on first convolution. def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index, data.edge_attr) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss.detach()) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index, data.edge_attr).argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs best_val_acc = test_acc = 0 times = [] for epoch in range(1, args.epochs + 1): start = time.time() loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc) times.append(time.time() - start) print(f'Median time per epoch: {torch.tensor(times).median():.4f}s') ================================================ FILE: examples/gcn2_cora.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from torch.nn import Linear import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCN2Conv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) transform = T.Compose([T.NormalizeFeatures(), T.GCNNorm(), T.ToSparseTensor()]) dataset = Planetoid(path, dataset, transform=transform) data = dataset[0] class Net(torch.nn.Module): def __init__(self, hidden_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0): super().__init__() self.lins = torch.nn.ModuleList() self.lins.append(Linear(dataset.num_features, hidden_channels)) self.lins.append(Linear(hidden_channels, dataset.num_classes)) self.convs = torch.nn.ModuleList() for layer in range(num_layers): self.convs.append( GCN2Conv(hidden_channels, alpha, theta, layer + 1, shared_weights, normalize=False)) self.dropout = dropout def forward(self, x, adj_t): x = F.dropout(x, self.dropout, training=self.training) x = x_0 = self.lins[0](x).relu() for conv in self.convs: x = F.dropout(x, self.dropout, training=self.training) x = conv(x, x_0, adj_t) x = x.relu() x = F.dropout(x, self.dropout, training=self.training) x = self.lins[1](x) return x.log_softmax(dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(hidden_channels=64, num_layers=64, alpha=0.1, theta=0.5, shared_weights=True, dropout=0.6).to(device) data = data.to(device) optimizer = torch.optim.Adam([ dict(params=model.convs.parameters(), weight_decay=0.01), dict(params=model.lins.parameters(), weight_decay=5e-4) ], lr=0.01) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.adj_t) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() pred, accs = model(data.x, data.adj_t).argmax(dim=-1), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs best_val_acc = test_acc = 0 times = [] for epoch in range(1, 1001): start = time.time() loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, ' f'Final Test: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/gcn2_ppi.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from sklearn.metrics import f1_score from torch.nn import Linear import torch_geometric.transforms as T from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader from torch_geometric.nn import GCN2Conv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'GCN2_PPI') pre_transform = T.Compose([T.GCNNorm(), T.ToSparseTensor()]) train_dataset = PPI(path, split='train', pre_transform=pre_transform) val_dataset = PPI(path, split='val', pre_transform=pre_transform) test_dataset = PPI(path, split='test', pre_transform=pre_transform) train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) class Net(torch.nn.Module): def __init__(self, hidden_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0): super().__init__() self.lins = torch.nn.ModuleList() self.lins.append(Linear(train_dataset.num_features, hidden_channels)) self.lins.append(Linear(hidden_channels, train_dataset.num_classes)) self.convs = torch.nn.ModuleList() for layer in range(num_layers): self.convs.append( GCN2Conv(hidden_channels, alpha, theta, layer + 1, shared_weights, normalize=False)) self.dropout = dropout def forward(self, x, adj_t): x = F.dropout(x, self.dropout, training=self.training) x = x_0 = self.lins[0](x).relu() for conv in self.convs: h = F.dropout(x, self.dropout, training=self.training) h = conv(h, x_0, adj_t) x = h + x x = x.relu() x = F.dropout(x, self.dropout, training=self.training) x = self.lins[1](x) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(hidden_channels=2048, num_layers=9, alpha=0.5, theta=1.0, shared_weights=False, dropout=0.2).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = torch.nn.BCEWithLogitsLoss() def train(): model.train() total_loss = total_examples = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() loss = criterion(model(data.x, data.adj_t), data.y) loss.backward() optimizer.step() total_loss += loss.item() * data.num_nodes total_examples += data.num_nodes return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() ys, preds = [], [] for data in loader: ys.append(data.y) out = model(data.x.to(device), data.adj_t.to(device)) preds.append((out > 0).float().cpu()) y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 times = [] for epoch in range(1, 2001): start = time.time() loss = train() val_f1 = test(val_loader) test_f1 = test(test_loader) print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, ' f'Test: {test_f1:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/geniepath.py ================================================ import argparse import os.path as osp import time import torch from sklearn.metrics import f1_score from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader from torch_geometric.nn import GATConv parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='GeniePathLazy') args = parser.parse_args() assert args.model in ['GeniePath', 'GeniePathLazy'] path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'PPI') train_dataset = PPI(path, split='train') val_dataset = PPI(path, split='val') test_dataset = PPI(path, split='test') train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) dim = 256 lstm_hidden = 256 layer_num = 4 class Breadth(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.gatconv = GATConv(in_dim, out_dim, heads=1) def forward(self, x, edge_index): x = torch.tanh(self.gatconv(x, edge_index)) return x class Depth(torch.nn.Module): def __init__(self, in_dim, hidden): super().__init__() self.lstm = torch.nn.LSTM(in_dim, hidden, 1, bias=False) def forward(self, x, h, c): x, (h, c) = self.lstm(x, (h, c)) return x, (h, c) class GeniePathLayer(torch.nn.Module): def __init__(self, in_dim): super().__init__() self.breadth_func = Breadth(in_dim, dim) self.depth_func = Depth(dim, lstm_hidden) def forward(self, x, edge_index, h, c): x = self.breadth_func(x, edge_index) x = x[None, :] x, (h, c) = self.depth_func(x, h, c) x = x[0] return x, (h, c) class GeniePath(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.lin1 = torch.nn.Linear(in_dim, dim) self.gplayers = torch.nn.ModuleList( [GeniePathLayer(dim) for i in range(layer_num)]) self.lin2 = torch.nn.Linear(dim, out_dim) def forward(self, x, edge_index): x = self.lin1(x) h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device) c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device) for i, _ in enumerate(self.gplayers): x, (h, c) = self.gplayers[i](x, edge_index, h, c) x = self.lin2(x) return x class GeniePathLazy(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.lin1 = torch.nn.Linear(in_dim, dim) self.breadths = torch.nn.ModuleList( [Breadth(dim, dim) for i in range(layer_num)]) self.depths = torch.nn.ModuleList( [Depth(dim * 2, lstm_hidden) for i in range(layer_num)]) self.lin2 = torch.nn.Linear(dim, out_dim) def forward(self, x, edge_index): x = self.lin1(x) h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device) c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device) h_tmps = [] for i, _ in enumerate(self.breadths): h_tmps.append(self.breadths[i](x, edge_index)) x = x[None, :] for i, _ in enumerate(self.depths): in_cat = torch.cat((h_tmps[i][None, :], x), -1) x, (h, c) = self.depths[i](in_cat, h, c) x = self.lin2(x[0]) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') kwargs = {'GeniePath': GeniePath, 'GeniePathLazy': GeniePathLazy} model = kwargs[args.model](train_dataset.num_features, train_dataset.num_classes).to(device) loss_op = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.005) def train(): model.train() total_loss = 0 for data in train_loader: num_graphs = data.num_graphs data.batch = None data = data.to(device) optimizer.zero_grad() loss = loss_op(model(data.x, data.edge_index), data.y) total_loss += loss.item() * num_graphs loss.backward() optimizer.step() return total_loss / len(train_loader.dataset) def test(loader): model.eval() ys, preds = [], [] for data in loader: ys.append(data.y) with torch.no_grad(): out = model(data.x.to(device), data.edge_index.to(device)) preds.append((out > 0).float().cpu()) y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 times = [] for epoch in range(1, 101): start = time.time() loss = train() val_f1 = test(val_loader) test_f1 = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, ' f'Test: {test_f1:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/glnn.py ================================================ # Implementation of: # Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation import argparse import os.path as osp import time import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCN, MLP parser = argparse.ArgumentParser() parser.add_argument('--lamb', type=float, default=0.0, help='Balances loss from hard labels and teacher outputs') args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures()) data = dataset[0].to(device) gnn = GCN(dataset.num_node_features, hidden_channels=16, out_channels=dataset.num_classes, num_layers=2).to(device) mlp = MLP([dataset.num_node_features, 64, dataset.num_classes], dropout=0.5, norm=None).to(device) gnn_optimizer = torch.optim.Adam(gnn.parameters(), lr=0.01, weight_decay=5e-4) mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4) def train_teacher(): gnn.train() gnn_optimizer.zero_grad() out = gnn(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() gnn_optimizer.step() return float(loss) @torch.no_grad() def test_teacher(): gnn.eval() pred = gnn(data.x, data.edge_index).argmax(dim=-1) accs = [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs times = [] print('Training Teacher GNN:') for epoch in range(1, 201): start = time.time() loss = train_teacher() if epoch % 20 == 0: train_acc, val_acc, test_acc = test_teacher() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') times.append(time.time() - start) start = time.time() print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") with torch.no_grad(): # Obtain soft labels from the GNN: y_soft = gnn(data.x, data.edge_index).log_softmax(dim=-1) def train_student(): mlp.train() mlp_optimizer.zero_grad() out = mlp(data.x) loss1 = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss2 = F.kl_div(out.log_softmax(dim=-1), y_soft, reduction='batchmean', log_target=True) loss = args.lamb * loss1 + (1 - args.lamb) * loss2 loss.backward() mlp_optimizer.step() return float(loss) @torch.no_grad() def test_student(): mlp.eval() pred = mlp(data.x).argmax(dim=-1) accs = [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs times = [] print('Training Student MLP:') for epoch in range(1, 501): start = time.time() loss = train_student() if epoch % 20 == 0: train_acc, val_acc, test_acc = test_student() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') times.append(time.time() - start) start = time.time() print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/gpse.py ================================================ import argparse import os.path as osp import time import torch import torch.nn.functional as F from torch_geometric.datasets import ZINC from torch_geometric.graphgym.models.encoder import AtomEncoder from torch_geometric.loader import DataLoader from torch_geometric.logging import log from torch_geometric.nn import ( GPSE, MLP, GCNConv, GINConv, GPSENodeEncoder, Linear, global_mean_pool, ) from torch_geometric.nn.models.gpse import precompute_GPSE from torch_geometric.transforms import AddGPSE def load_ZINC(args): """Load the ZINC dataset, and generate GPSE encodings for the graphs if args.gpse is not None. """ path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC_subset') gpse_model = GPSE.from_pretrained( name=args.gpse, root=osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'GPSE_pretrained')) if args.gpse else None if args.gpse and args.as_transform: # WARNING: Using a pre_transform will save the encodings to disk, # meaning any future runs will use the saved encodings. This is useful # for speeding up computation, but may not be desirable, e.g. when # experimenting with different pre-trained GPSE models. Alternatively, # AddGPSE can be used as a regular transform, which will compute the # encodings on-the-fly, but this will slow down the data loading # process. train_dataset = ZINC( path, subset=True, split='train', pre_transform=AddGPSE(gpse_model, use_vn=True, rand_type='NormalSE')) test_dataset = ZINC( path, subset=True, split='val', pre_transform=AddGPSE(gpse_model, use_vn=True, rand_type='NormalSE')) else: train_dataset = ZINC(path, subset=True, split='train') test_dataset = ZINC(path, subset=True, split='val') if args.gpse: precompute_GPSE(gpse_model, train_dataset) precompute_GPSE(gpse_model, test_dataset) train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=256) return train_loader, test_loader class IdentityNodeEncoder(torch.nn.Module): def __init__(self, emb_dim): super().__init__() def forward(self, batch): return batch class LinearNodeEncoder(torch.nn.Module): def __init__(self, emb_dim, emb_pe_out, bias=True): super().__init__() self.encoder = Linear(emb_dim - emb_pe_out, emb_dim, bias=bias) def forward(self, batch): batch.x = self.encoder(batch.x) return batch class TypeDictNodeEncoder(torch.nn.Module): def __init__(self, emb_dim, num_types=28): super().__init__() if num_types < 1: raise ValueError(f"Invalid 'node_encoder_num_types': {num_types}") self.encoder = torch.nn.Embedding(num_embeddings=num_types, embedding_dim=emb_dim) def forward(self, batch): # Encode just the first dimension if more exist batch.x = self.encoder(batch.x[:, 0]) return batch class GNNStackStage(torch.nn.Module): """Simple Staging mechanism that stacks an arbitrary number of GNN layers with skip connections and L2 normalization. Args: dim_in (int): Input dimension dim_out (int): Output dimension num_layers (int): Number of GNN layers conv_type (str): Type of graph convolution in GNN stage_type (str): Type of skip connections. Options: 'skipsum' or 'skipconcat', any other value means no skip connections. l2norm (bool): Whether to apply L2 normalization to outputs """ def __init__(self, dim_in, dim_out, num_layers, conv_type='gcn', stage_type='skipsum', l2norm=True): super().__init__() self.num_layers = num_layers self.stage_type = stage_type self.l2norm = l2norm conv_dict = {'gcn': GCNConv, 'gin': GINConv} for i in range(num_layers): if stage_type == 'skipconcat': d_in = dim_in if i == 0 else dim_in + i * dim_out else: d_in = dim_in if i == 0 else dim_out layer = conv_dict[conv_type](d_in, dim_out) self.add_module(f'layer{i}', layer) def forward(self, batch): for i, layer in enumerate(self.children()): x = batch.x batch.x = layer(batch.x, batch.edge_index) if self.stage_type == 'skipsum': batch.x = x + batch.x elif self.stage_type == 'skipconcat' and \ i < self.num_layers - 1: batch.x = torch.cat([x, batch.x], dim=1) if self.l2norm: batch.x = F.normalize(batch.x, p=2, dim=-1) return batch class GPSEPlusGNN(torch.nn.Module): """A GPSE encoder paired with a GNN module. Consists of: - encoder1: An optional encoder that is used to encode raw node features, common practice for biochemistry datasets. ZINC uses :class:`TypeDictNodeEncoder`, while ogbg-mol* datasets typically use :class:`~torch_geometric.graphgym.models.encoder.AtomEncoder`. If 'none', an :class`IdentityNodeEncoder` is passed that returns the inputs as-is. - encoder2: GPSE encoder that adds precomputed GPSE encodings in the dataset to node features if :obj:`gpse` is :obj:`True`. Otherwise is replaced by a linear layer that maps the :obj:`encoder1` outputs to the correct dimension. - premp: 2-layer MLP before message-passing. - gnn: Stacked message-passing layers of :obj:`conv_type`. - postmp: 1-layer MLP after message-passing to map GNN node states to a single output (for ZINC regression task). For classification tasks, :obj:`num_classes` outputs with softmax activation would be required. Args: dim_emb (int): Dimension of embedding outputs. Equals dimension of :obj:`encoder1` outputs (dim_emb - dim_pe_out) and :class:`~torch_geometric.nn.GPSENodeEncoder` outputs (dim_pe_out). dim_conv (int): Dimension of GNN message-passing layers. conv_type (str): Type of graph convolution in GNN. num_layers (int): Number of GNN layers. dim_pe_in (int): Original dimension of posenc_GPSE, i.e. the precomputed GPSE encodings. dim_pe_out (int): Desired dimension of GPSE-derived node features, mapped from the original GPSE encodings via GPSENodeEncoder. encoder (str): Encoding applied to raw node features. gpse (bool): Whether to use GPSE encodings. """ def __init__(self, dim_emb, dim_conv, conv_type, num_layers, dim_pe_in, dim_pe_out, encoder='none', gpse=True): super().__init__() encoder_dict = { 'none': IdentityNodeEncoder, 'Atom': AtomEncoder, 'TypeDict': TypeDictNodeEncoder } self.encoder1 = encoder_dict[encoder](dim_emb - dim_pe_out) self.encoder2 = GPSENodeEncoder( dim_emb, dim_pe_in, dim_pe_out, expand_x=False) if gpse else ( LinearNodeEncoder(dim_emb, dim_pe_out, bias=True)) self.premp = MLP([dim_emb, dim_emb, dim_conv]) self.gnn = GNNStackStage(dim_conv, dim_conv, num_layers, conv_type) self.postmp = MLP([dim_conv, 1]) def forward(self, batch): batch = self.encoder1(batch) batch.x = self.encoder2(batch.x, batch.pestat_GPSE) batch.x = self.premp(batch.x) batch = self.gnn(batch) batch = global_mean_pool(batch.x, batch.batch) batch = F.dropout(batch, p=0.5, training=self.training) batch = self.postmp(batch) return batch def train(loader): model.train() total_loss = 0 for data in loader: data = data.to(device) optimizer.zero_grad() out = model(data) pred = out.squeeze(-1) if out.ndim > 1 else out true = data.y.squeeze(-1) if data.y.ndim > 1 else data.y loss = F.mse_loss(pred, true) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() total_loss = 0 for data in loader: data = data.to(device) out = model(data) pred = out.squeeze(-1) if out.ndim > 1 else out true = data.y.squeeze(-1) if data.y.ndim > 1 else data.y loss = F.mse_loss(pred, true) total_loss += float(loss) * data.num_graphs return total_loss / len(loader.dataset) if __name__ == '__main__': parser = argparse.ArgumentParser(description='GPSE Example') parser.add_argument( '--gpse', type=str, default=None, const='molpcba', nargs='?', choices=['molpcba', 'zinc', 'pcqm4mv2', 'geom', 'chembl'], help='which model weights to use ' '(default: %(default)s)') parser.add_argument( '--as_transform', action='store_true', help='Whether to apply GPSE as a pre_transform to the ' 'dataset or not') args = parser.parse_args() train_loader, test_loader = load_ZINC(args) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GPSEPlusGNN(dim_emb=64, dim_conv=128, conv_type='gcn', num_layers=8, dim_pe_in=512, dim_pe_out=32, encoder='TypeDict', gpse=args.gpse).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) num_epochs = 100 times = [] for epoch in range(1, num_epochs + 1): start = time.time() loss = train(train_loader) train_acc = test(train_loader) test_acc = test(test_loader) log(Epoch=epoch, Loss=loss, Train=train_acc, Test=test_acc) times.append(time.time() - start) print(f'Median time per epoch: {torch.tensor(times).median():.4f}s') ================================================ FILE: examples/graph_gps.py ================================================ import argparse import os.path as osp from typing import Any, Dict, Optional import torch from torch.nn import ( BatchNorm1d, Embedding, Linear, ModuleList, ReLU, Sequential, ) from torch.optim.lr_scheduler import ReduceLROnPlateau import torch_geometric.transforms as T from torch_geometric.datasets import ZINC from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv, GPSConv, global_add_pool from torch_geometric.nn.attention import PerformerAttention path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC-PE') transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe') train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform) val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform) test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=64) test_loader = DataLoader(test_dataset, batch_size=64) parser = argparse.ArgumentParser() parser.add_argument( '--attn_type', default='multihead', help="Global attention type such as 'multihead' or 'performer'.") args = parser.parse_args() class GPS(torch.nn.Module): def __init__(self, channels: int, pe_dim: int, num_layers: int, attn_type: str, attn_kwargs: Dict[str, Any]): super().__init__() self.node_emb = Embedding(28, channels - pe_dim) self.pe_lin = Linear(20, pe_dim) self.pe_norm = BatchNorm1d(20) self.edge_emb = Embedding(4, channels) self.convs = ModuleList() for _ in range(num_layers): nn = Sequential( Linear(channels, channels), ReLU(), Linear(channels, channels), ) conv = GPSConv(channels, GINEConv(nn), heads=4, attn_type=attn_type, attn_kwargs=attn_kwargs) self.convs.append(conv) self.mlp = Sequential( Linear(channels, channels // 2), ReLU(), Linear(channels // 2, channels // 4), ReLU(), Linear(channels // 4, 1), ) self.redraw_projection = RedrawProjection( self.convs, redraw_interval=1000 if attn_type == 'performer' else None) def forward(self, x, pe, edge_index, edge_attr, batch): x_pe = self.pe_norm(pe) x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1) edge_attr = self.edge_emb(edge_attr) for conv in self.convs: x = conv(x, edge_index, batch, edge_attr=edge_attr) x = global_add_pool(x, batch) return self.mlp(x) class RedrawProjection: def __init__(self, model: torch.nn.Module, redraw_interval: Optional[int] = None): self.model = model self.redraw_interval = redraw_interval self.num_last_redraw = 0 def redraw_projections(self): if not self.model.training or self.redraw_interval is None: return if self.num_last_redraw >= self.redraw_interval: fast_attentions = [ module for module in self.model.modules() if isinstance(module, PerformerAttention) ] for fast_attention in fast_attentions: fast_attention.redraw_projection_matrix() self.num_last_redraw = 0 return self.num_last_redraw += 1 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') attn_kwargs = {'dropout': 0.5} model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type=args.attn_type, attn_kwargs=attn_kwargs).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=0.00001) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() model.redraw_projection.redraw_projections() out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch) loss = (out.squeeze() - data.y).abs().mean() loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() total_error = 0 for data in loader: data = data.to(device) out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch) total_error += (out.squeeze() - data.y).abs().sum().item() return total_error / len(loader.dataset) for epoch in range(1, 101): loss = train() val_mae = test(val_loader) test_mae = test(test_loader) scheduler.step(val_mae) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, ' f'Test: {test_mae:.4f}') ================================================ FILE: examples/graph_sage_unsup.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from sklearn.linear_model import LogisticRegression import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.loader import LinkNeighborLoader from torch_geometric.nn import GraphSAGE dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) data = dataset[0] train_loader = LinkNeighborLoader( data, batch_size=256, shuffle=True, neg_sampling_ratio=1.0, num_neighbors=[10, 10], ) if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') data = data.to(device, 'x', 'edge_index') model = GraphSAGE( data.num_node_features, hidden_channels=64, num_layers=2, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() h = model(batch.x, batch.edge_index) h_src = h[batch.edge_label_index[0]] h_dst = h[batch.edge_label_index[1]] pred = (h_src * h_dst).sum(dim=-1) loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label) loss.backward() optimizer.step() total_loss += float(loss) * pred.size(0) return total_loss / data.num_nodes @torch.no_grad() def test(): model.eval() out = model(data.x, data.edge_index).cpu() clf = LogisticRegression() clf.fit(out[data.train_mask], data.y[data.train_mask]) val_acc = clf.score(out[data.val_mask], data.y[data.val_mask]) test_acc = clf.score(out[data.test_mask], data.y[data.test_mask]) return val_acc, test_acc times = [] for epoch in range(1, 51): start = time.time() loss = train() val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/graph_sage_unsup_ppi.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F import tqdm from sklearn.linear_model import SGDClassifier from sklearn.metrics import f1_score from sklearn.multioutput import MultiOutputClassifier import torch_geometric from torch_geometric.data import Batch from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader, LinkNeighborLoader from torch_geometric.nn import GraphSAGE path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI') train_dataset = PPI(path, split='train') val_dataset = PPI(path, split='val') test_dataset = PPI(path, split='test') # Group all training graphs into a single graph to perform sampling: train_data = Batch.from_data_list(train_dataset) loader = LinkNeighborLoader(train_data, batch_size=2048, shuffle=True, neg_sampling_ratio=1.0, num_neighbors=[10, 10], num_workers=6, persistent_workers=True) # Evaluation loaders (one datapoint corresponds to a graph) train_loader = DataLoader(train_dataset, batch_size=2) val_loader = DataLoader(val_dataset, batch_size=2) test_loader = DataLoader(test_dataset, batch_size=2) if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') model = GraphSAGE( in_channels=train_dataset.num_features, hidden_channels=64, num_layers=2, out_channels=64, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005) def train(): model.train() total_loss = total_examples = 0 for data in tqdm.tqdm(loader): data = data.to(device) optimizer.zero_grad() h = model(data.x, data.edge_index) h_src = h[data.edge_label_index[0]] h_dst = h[data.edge_label_index[1]] link_pred = (h_src * h_dst).sum(dim=-1) # Inner product. loss = F.binary_cross_entropy_with_logits(link_pred, data.edge_label) loss.backward() optimizer.step() total_loss += float(loss) * link_pred.numel() total_examples += link_pred.numel() return total_loss / total_examples @torch.no_grad() def encode(loader): model.eval() xs, ys = [], [] for data in loader: data = data.to(device) xs.append(model(data.x, data.edge_index).cpu()) ys.append(data.y.cpu()) return torch.cat(xs, dim=0), torch.cat(ys, dim=0) @torch.no_grad() def test(): # Train classifier on training set: x, y = encode(train_loader) clf = MultiOutputClassifier(SGDClassifier(loss='log_loss', penalty='l2')) clf.fit(x, y) train_f1 = f1_score(y, clf.predict(x), average='micro') # Evaluate on validation set: x, y = encode(val_loader) val_f1 = f1_score(y, clf.predict(x), average='micro') # Evaluate on test set: x, y = encode(test_loader) test_f1 = f1_score(y, clf.predict(x), average='micro') return train_f1, val_f1, test_f1 times = [] for epoch in range(1, 6): start = time.time() loss = train() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') train_f1, val_f1, test_f1 = test() print(f'Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}, ' f'Test F1: {test_f1:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/graph_saint.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Flickr from torch_geometric.loader import GraphSAINTRandomWalkSampler from torch_geometric.nn import GraphConv from torch_geometric.typing import WITH_TORCH_SPARSE from torch_geometric.utils import degree if not WITH_TORCH_SPARSE: quit("This example requires 'torch-sparse'") path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr') dataset = Flickr(path) data = dataset[0] row, col = data.edge_index data.edge_weight = 1. / degree(col, data.num_nodes)[col] # Norm by in-degree. parser = argparse.ArgumentParser() parser.add_argument('--use_normalization', action='store_true') args = parser.parse_args() loader = GraphSAINTRandomWalkSampler(data, batch_size=6000, walk_length=2, num_steps=5, sample_coverage=100, save_dir=dataset.processed_dir, num_workers=4) class Net(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() in_channels = dataset.num_node_features out_channels = dataset.num_classes self.conv1 = GraphConv(in_channels, hidden_channels) self.conv2 = GraphConv(hidden_channels, hidden_channels) self.conv3 = GraphConv(hidden_channels, hidden_channels) self.lin = torch.nn.Linear(3 * hidden_channels, out_channels) def set_aggr(self, aggr): self.conv1.aggr = aggr self.conv2.aggr = aggr self.conv3.aggr = aggr def forward(self, x0, edge_index, edge_weight=None): x1 = F.relu(self.conv1(x0, edge_index, edge_weight)) x1 = F.dropout(x1, p=0.2, training=self.training) x2 = F.relu(self.conv2(x1, edge_index, edge_weight)) x2 = F.dropout(x2, p=0.2, training=self.training) x3 = F.relu(self.conv3(x2, edge_index, edge_weight)) x3 = F.dropout(x3, p=0.2, training=self.training) x = torch.cat([x1, x2, x3], dim=-1) x = self.lin(x) return x.log_softmax(dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(hidden_channels=256).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() model.set_aggr('add' if args.use_normalization else 'mean') total_loss = total_examples = 0 for data in loader: data = data.to(device) optimizer.zero_grad() if args.use_normalization: edge_weight = data.edge_norm * data.edge_weight out = model(data.x, data.edge_index, edge_weight) loss = F.nll_loss(out, data.y, reduction='none') loss = (loss * data.node_norm)[data.train_mask].sum() else: out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() total_loss += loss.item() * data.num_nodes total_examples += data.num_nodes return total_loss / total_examples @torch.no_grad() def test(): model.eval() model.set_aggr('mean') out = model(data.x.to(device), data.edge_index.to(device)) pred = out.argmax(dim=-1) correct = pred.eq(data.y.to(device)) accs = [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): accs.append(correct[mask].sum().item() / mask.sum().item()) return accs for epoch in range(1, 51): loss = train() accs = test() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, ' f'Val: {accs[1]:.4f}, Test: {accs[2]:.4f}') ================================================ FILE: examples/graph_unet.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.nn import GraphUNet from torch_geometric.utils import dropout_edge dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() pool_ratios = [2000 / data.num_nodes, 0.5] self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes, depth=3, pool_ratios=pool_ratios) def forward(self): edge_index, _ = dropout_edge(data.edge_index, p=0.2, force_undirected=True, training=self.training) x = F.dropout(data.x, p=0.92, training=self.training) x = self.unet(x, edge_index) return F.log_softmax(x, dim=1) device = 'cuda' if torch.cuda.is_available() else 'cpu' model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001) def train(): model.train() optimizer.zero_grad() F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() @torch.no_grad() def test(): model.eval() out, accs = model(), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs best_val_acc = test_acc = 0 for epoch in range(1, 201): train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/hetero/README.md ================================================ # Examples for Heterogeneous Data | Example | Description | | ------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------- | | [`hetero_conv_dblp.py`](./hetero_conv_dblp.py) | Shows how to use the `HeteroConv(...)` wrapper; Trains it for node classification on the `DBLP` dataset. | | [`to_hetero_mag.py`](./to_hetero_mag.py) | Shows how to use the `to_hetero(...)` functionality; Trains it for node classification on the `ogb-mag` dataset. | | [`hetero_link_pred.py`](./hetero_link_pred.py) | Shows how to use the `to_hetero(...)` functionality; Trains it for link prediction on the `MovieLens` dataset. | | [`hgt_dblp.py`](./hgt_dblp.py) | Trains a Heterogeneous Graph Transformer (HGT) model for node classification on the `DBLP` dataset. | | [`hierarchical_sage.py`](./hierarchical_sage.py) | Shows how to perform hierarchical sampling; Trains a heterogeneous `GraphSAGE` model for node classification on the `ogb-mag` dataset. | | [`load_csv.py`](./load_csv.py) | Shows how to create heterogeneous graphs from raw `*.csv` data. | | [`metapath2vec.py`](./metapath2vec.py) | Train an unsupervised `MetaPath2Vec` model; Tests embeddings for node classification on the `AMiner` dataset. | | [`temporal_link_pred.py`](./temporal_link_pred.py) | Trains a heterogeneous `GraphSAGE` model for temporal link prediction on the `MovieLens` dataset. | | [`bipartite_sage.py`](./bipartite_sage.py) | Trains a GNN via metapaths for link prediction on the `MovieLens` dataset. | | [`bipartite_sage_unsup.py`](./bipartite_sage_unsup.py) | Trains a GNN via metapaths for link prediction on the large-scale `TaoBao` dataset. | | [`dmgi_unsup.py`](./dmgi_unsup.py) | Shows how to learn embeddings on the `IMDB` dataset using the `DMGI` model. | | [`han_imdb.py`](./han_imdb.py) | Shows how to train a heterogeneous Graph Attention Network (HAN) for node classification on the `IMDB` dataset. | | [`recommender_system.py`](./recommender_system.py) | Shows how to train a temporal GNN-based recommender system on the `MovieLens` dataset. | ================================================ FILE: examples/hetero/bipartite_sage.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Embedding, Linear import torch_geometric.transforms as T from torch_geometric.datasets import MovieLens from torch_geometric.nn import SAGEConv from torch_geometric.nn.conv.gcn_conv import gcn_norm path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') dataset = MovieLens(path, model_name='all-MiniLM-L6-v2') data = dataset[0] data['user'].x = torch.arange(data['user'].num_nodes) data['user', 'movie'].edge_label = data['user', 'movie'].edge_label.float() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = data.to(device) # Add a reverse ('movie', 'rev_rates', 'user') relation for message passing: data = T.ToUndirected()(data) del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label. # Perform a link-level split into training, validation, and test edges: train_data, val_data, test_data = T.RandomLinkSplit( num_val=0.1, num_test=0.1, neg_sampling_ratio=0.0, edge_types=[('user', 'rates', 'movie')], rev_edge_types=[('movie', 'rev_rates', 'user')], )(data) # Generate the co-occurrence matrix of movies<>movies: metapath = [('movie', 'rev_rates', 'user'), ('user', 'rates', 'movie')] train_data = T.AddMetaPaths(metapaths=[metapath])(train_data) # Apply normalization to filter the metapath: _, edge_weight = gcn_norm( train_data['movie', 'movie'].edge_index, num_nodes=train_data['movie'].num_nodes, add_self_loops=False, ) edge_index = train_data['movie', 'movie'].edge_index[:, edge_weight > 0.002] train_data['movie', 'metapath_0', 'movie'].edge_index = edge_index val_data['movie', 'metapath_0', 'movie'].edge_index = edge_index test_data['movie', 'metapath_0', 'movie'].edge_index = edge_index class MovieGNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv(-1, hidden_channels) self.conv2 = SAGEConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return self.lin(x) class UserGNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), hidden_channels) self.conv3 = SAGEConv((-1, -1), hidden_channels) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): movie_x = self.conv1( x_dict['movie'], edge_index_dict[('movie', 'metapath_0', 'movie')], ).relu() user_x = self.conv2( (x_dict['movie'], x_dict['user']), edge_index_dict[('movie', 'rev_rates', 'user')], ).relu() user_x = self.conv3( (movie_x, user_x), edge_index_dict[('movie', 'rev_rates', 'user')], ).relu() return self.lin(user_x) class EdgeDecoder(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.lin1 = Linear(2 * hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, 1) def forward(self, z_src, z_dst, edge_label_index): row, col = edge_label_index z = torch.cat([z_src[row], z_dst[col]], dim=-1) z = self.lin1(z).relu() z = self.lin2(z) return z.view(-1) class Model(torch.nn.Module): def __init__(self, num_users, hidden_channels, out_channels): super().__init__() self.user_emb = Embedding(num_users, hidden_channels) self.user_encoder = UserGNNEncoder(hidden_channels, out_channels) self.movie_encoder = MovieGNNEncoder(hidden_channels, out_channels) self.decoder = EdgeDecoder(out_channels) def forward(self, x_dict, edge_index_dict, edge_label_index): z_dict = {} x_dict['user'] = self.user_emb(x_dict['user']) z_dict['user'] = self.user_encoder(x_dict, edge_index_dict) z_dict['movie'] = self.movie_encoder( x_dict['movie'], edge_index_dict[('movie', 'metapath_0', 'movie')], ) return self.decoder(z_dict['user'], z_dict['movie'], edge_label_index) model = Model(data['user'].num_nodes, hidden_channels=64, out_channels=64) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0003) def train(): model.train() optimizer.zero_grad() out = model( train_data.x_dict, train_data.edge_index_dict, train_data['user', 'movie'].edge_label_index, ) loss = F.mse_loss(out, train_data['user', 'movie'].edge_label) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(data): model.eval() out = model( data.x_dict, data.edge_index_dict, data['user', 'movie'].edge_label_index, ).clamp(min=0, max=5) rmse = F.mse_loss(out, data['user', 'movie'].edge_label).sqrt() return float(rmse) for epoch in range(1, 701): loss = train() train_rmse = test(train_data) val_rmse = test(val_data) test_rmse = test(test_data) print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, ' f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}') ================================================ FILE: examples/hetero/bipartite_sage_unsup.py ================================================ # An implementation of unsupervised bipartite GraphSAGE using the Alibaba # Taobao dataset. import os.path as osp import torch import torch.nn.functional as F import tqdm from sklearn.metrics import roc_auc_score from torch.nn import Embedding, Linear import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import Taobao from torch_geometric.loader import LinkNeighborLoader from torch_geometric.nn import SAGEConv from torch_geometric.utils.convert import to_scipy_sparse_matrix if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/Taobao') dataset = Taobao(path) data = dataset[0] data['user'].x = torch.arange(0, data['user'].num_nodes) data['item'].x = torch.arange(0, data['item'].num_nodes) # Only consider user<>item relationships for simplicity: del data['category'] del data['item', 'category'] del data['user', 'item'].time del data['user', 'item'].behavior # Add a reverse ('item', 'rev_to', 'user') relation for message passing: data = T.ToUndirected()(data) # Perform a link-level split into training, validation, and test edges: print('Computing data splits...') train_data, val_data, test_data = T.RandomLinkSplit( num_val=0.1, num_test=0.1, neg_sampling_ratio=1.0, add_negative_train_samples=False, edge_types=[('user', 'to', 'item')], rev_edge_types=[('item', 'rev_to', 'user')], )(data) print('Done!') # Compute sparsified item<>item relationships through users: print('Computing item<>item relationships...') mat = to_scipy_sparse_matrix(data['user', 'item'].edge_index).tocsr() mat = mat[:data['user'].num_nodes, :data['item'].num_nodes] comat = mat.T @ mat comat.setdiag(0) comat = comat >= 3. comat = comat.tocoo() row = torch.from_numpy(comat.row).to(torch.long) col = torch.from_numpy(comat.col).to(torch.long) item_to_item_edge_index = torch.stack([row, col], dim=0) # Add the generated item<>item relationships for high-order information: train_data['item', 'item'].edge_index = item_to_item_edge_index val_data['item', 'item'].edge_index = item_to_item_edge_index test_data['item', 'item'].edge_index = item_to_item_edge_index print('Done!') train_loader = LinkNeighborLoader( data=train_data, num_neighbors=[8, 4], edge_label_index=('user', 'to', 'item'), neg_sampling='binary', batch_size=2048, shuffle=True, num_workers=16, drop_last=True, ) val_loader = LinkNeighborLoader( data=val_data, num_neighbors=[8, 4], edge_label_index=( ('user', 'to', 'item'), val_data[('user', 'to', 'item')].edge_label_index, ), edge_label=val_data[('user', 'to', 'item')].edge_label, batch_size=2048, shuffle=False, num_workers=16, ) test_loader = LinkNeighborLoader( data=test_data, num_neighbors=[8, 4], edge_label_index=( ('user', 'to', 'item'), test_data[('user', 'to', 'item')].edge_label_index, ), edge_label=test_data[('user', 'to', 'item')].edge_label, batch_size=2048, shuffle=False, num_workers=16, ) class ItemGNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv(-1, hidden_channels) self.conv2 = SAGEConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return self.lin(x) class UserGNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), hidden_channels) self.conv3 = SAGEConv((-1, -1), hidden_channels) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): item_x = self.conv1( x_dict['item'], edge_index_dict[('item', 'to', 'item')], ).relu() user_x = self.conv2( (x_dict['item'], x_dict['user']), edge_index_dict[('item', 'rev_to', 'user')], ).relu() user_x = self.conv3( (item_x, user_x), edge_index_dict[('item', 'rev_to', 'user')], ).relu() return self.lin(user_x) class EdgeDecoder(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.lin1 = Linear(2 * hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, 1) def forward(self, z_src, z_dst, edge_label_index): row, col = edge_label_index z = torch.cat([z_src[row], z_dst[col]], dim=-1) z = self.lin1(z).relu() z = self.lin2(z) return z.view(-1) class Model(torch.nn.Module): def __init__(self, num_users, num_items, hidden_channels, out_channels): super().__init__() self.user_emb = Embedding(num_users, hidden_channels, device=device) self.item_emb = Embedding(num_items, hidden_channels, device=device) self.item_encoder = ItemGNNEncoder(hidden_channels, out_channels) self.user_encoder = UserGNNEncoder(hidden_channels, out_channels) self.decoder = EdgeDecoder(out_channels) def forward(self, x_dict, edge_index_dict, edge_label_index): z_dict = {} x_dict['user'] = self.user_emb(x_dict['user']) x_dict['item'] = self.item_emb(x_dict['item']) z_dict['item'] = self.item_encoder( x_dict['item'], edge_index_dict[('item', 'to', 'item')], ) z_dict['user'] = self.user_encoder(x_dict, edge_index_dict) return self.decoder(z_dict['user'], z_dict['item'], edge_label_index) model = Model( num_users=data['user'].num_nodes, num_items=data['item'].num_nodes, hidden_channels=64, out_channels=64, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() total_loss = total_examples = 0 for batch in tqdm.tqdm(train_loader): batch = batch.to(device) optimizer.zero_grad() pred = model( batch.x_dict, batch.edge_index_dict, batch['user', 'item'].edge_label_index, ) loss = F.binary_cross_entropy_with_logits( pred, batch['user', 'item'].edge_label) loss.backward() optimizer.step() total_loss += float(loss) total_examples += pred.numel() return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() preds, targets = [], [] for batch in tqdm.tqdm(loader): batch = batch.to(device) pred = model( batch.x_dict, batch.edge_index_dict, batch['user', 'item'].edge_label_index, ).sigmoid().view(-1).cpu() target = batch['user', 'item'].edge_label.long().cpu() preds.append(pred) targets.append(target) pred = torch.cat(preds, dim=0).numpy() target = torch.cat(targets, dim=0).numpy() return roc_auc_score(target, pred) for epoch in range(1, 21): loss = train() val_auc = test(val_loader) test_auc = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:4f}, Val: {val_auc:.4f}, ' f'Test: {test_auc:.4f}') ================================================ FILE: examples/hetero/dmgi_unsup.py ================================================ # An implementation of "Unsupervised Attributed Multiplex Network # Embedding" (DMGI) for unsupervised learning on heterogeneous graphs: # * Paper: (AAAI 2020) import os.path as osp import torch import torch.nn.functional as F from sklearn.linear_model import LogisticRegression from torch.optim import Adam import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import IMDB from torch_geometric.nn import GCNConv path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/IMDB') dataset = IMDB(path) metapaths = [ [('movie', 'actor'), ('actor', 'movie')], # MAM [('movie', 'director'), ('director', 'movie')], # MDM ] data = T.AddMetaPaths(metapaths, drop_orig_edge_types=True)(dataset[0]) class DMGI(torch.nn.Module): def __init__(self, num_nodes, in_channels, out_channels, num_relations): super().__init__() self.convs = torch.nn.ModuleList( [GCNConv(in_channels, out_channels) for _ in range(num_relations)]) self.M = torch.nn.Bilinear(out_channels, out_channels, 1) self.Z = torch.nn.Parameter(torch.empty(num_nodes, out_channels)) self.reset_parameters() def reset_parameters(self): for conv in self.convs: conv.reset_parameters() torch.nn.init.xavier_uniform_(self.M.weight) self.M.bias.data.zero_() torch.nn.init.xavier_uniform_(self.Z) def forward(self, x, edge_indices): pos_hs, neg_hs, summaries = [], [], [] for conv, edge_index in zip(self.convs, edge_indices): pos_h = F.dropout(x, p=0.5, training=self.training) pos_h = conv(pos_h, edge_index).relu() pos_hs.append(pos_h) neg_h = F.dropout(x, p=0.5, training=self.training) neg_h = neg_h[torch.randperm(neg_h.size(0), device=neg_h.device)] neg_h = conv(neg_h, edge_index).relu() neg_hs.append(neg_h) summaries.append(pos_h.mean(dim=0, keepdim=True)) return pos_hs, neg_hs, summaries def loss(self, pos_hs, neg_hs, summaries): loss = 0. for pos_h, neg_h, s in zip(pos_hs, neg_hs, summaries): s = s.expand_as(pos_h) loss += -torch.log(self.M(pos_h, s).sigmoid() + 1e-15).mean() loss += -torch.log(1 - self.M(neg_h, s).sigmoid() + 1e-15).mean() pos_mean = torch.stack(pos_hs, dim=0).mean(dim=0) neg_mean = torch.stack(neg_hs, dim=0).mean(dim=0) pos_reg_loss = (self.Z - pos_mean).pow(2).sum() neg_reg_loss = (self.Z - neg_mean).pow(2).sum() loss += 0.001 * (pos_reg_loss - neg_reg_loss) return loss model = DMGI(data['movie'].num_nodes, data['movie'].x.size(-1), out_channels=64, num_relations=len(data.edge_types)) if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') data, model = data.to(device), model.to(device) optimizer = Adam(model.parameters(), lr=0.0005, weight_decay=0.0001) def train(): model.train() optimizer.zero_grad() x = data['movie'].x edge_indices = data.edge_index_dict.values() pos_hs, neg_hs, summaries = model(x, edge_indices) loss = model.loss(pos_hs, neg_hs, summaries) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): train_emb = model.Z[data['movie'].train_mask].cpu() val_emb = model.Z[data['movie'].val_mask].cpu() test_emb = model.Z[data['movie'].test_mask].cpu() train_y = data['movie'].y[data['movie'].train_mask].cpu() val_y = data['movie'].y[data['movie'].val_mask].cpu() test_y = data['movie'].y[data['movie'].test_mask].cpu() clf = LogisticRegression().fit(train_emb, train_y) return clf.score(val_emb, val_y), clf.score(test_emb, test_y) for epoch in range(1, 1001): loss = train() if epoch % 50 == 0: val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/hetero/han_imdb.py ================================================ import os.path as osp from typing import Dict, List, Union import torch import torch.nn.functional as F from torch import nn import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import IMDB from torch_geometric.nn import HANConv path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/IMDB') metapaths = [[('movie', 'actor'), ('actor', 'movie')], [('movie', 'director'), ('director', 'movie')]] transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True, drop_unconnected_node_types=True) dataset = IMDB(path, transform=transform) data = dataset[0] class HAN(nn.Module): def __init__(self, in_channels: Union[int, Dict[str, int]], out_channels: int, hidden_channels=128, heads=8): super().__init__() self.han_conv = HANConv(in_channels, hidden_channels, heads=heads, dropout=0.6, metadata=data.metadata()) self.lin = nn.Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): out = self.han_conv(x_dict, edge_index_dict) out = self.lin(out['movie']) return out model = HAN(in_channels=-1, out_channels=3) if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') data, model = data.to(device), model.to(device) with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001) def train() -> float: model.train() optimizer.zero_grad() out = model(data.x_dict, data.edge_index_dict) mask = data['movie'].train_mask loss = F.cross_entropy(out[mask], data['movie'].y[mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test() -> List[float]: model.eval() pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1) accs = [] for split in ['train_mask', 'val_mask', 'test_mask']: mask = data['movie'][split] acc = (pred[mask] == data['movie'].y[mask]).sum() / mask.sum() accs.append(float(acc)) return accs best_val_acc = 0 start_patience = patience = 100 for epoch in range(1, 200): loss = train() train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') if best_val_acc <= val_acc: patience = start_patience best_val_acc = val_acc else: patience -= 1 if patience <= 0: print('Stopping training as validation accuracy did not improve ' f'for {start_patience} epochs') break ================================================ FILE: examples/hetero/hetero_conv_dblp.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import DBLP from torch_geometric.nn import HeteroConv, Linear, SAGEConv path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP') # We initialize conference node features with a single one-vector as feature: dataset = DBLP(path, transform=T.Constant(node_types='conference')) data = dataset[0] class HeteroGNN(torch.nn.Module): def __init__(self, metadata, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv({ edge_type: SAGEConv((-1, -1), hidden_channels) for edge_type in metadata[1] }) self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} return self.lin(x_dict['author']) model = HeteroGNN(data.metadata(), hidden_channels=64, out_channels=4, num_layers=2) if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') data, model = data.to(device), model.to(device) with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001) def train(): model.train() optimizer.zero_grad() out = model(data.x_dict, data.edge_index_dict) mask = data['author'].train_mask loss = F.cross_entropy(out[mask], data['author'].y[mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1) accs = [] for split in ['train_mask', 'val_mask', 'test_mask']: mask = data['author'][split] acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum() accs.append(float(acc)) return accs for epoch in range(1, 101): loss = train() train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/hetero/hetero_link_pred.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import MovieLens from torch_geometric.nn import SAGEConv, to_hetero parser = argparse.ArgumentParser() parser.add_argument('--use_weighted_loss', action='store_true', help='Whether to use weighted MSE loss.') args = parser.parse_args() device = torch_geometric.device('auto') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') dataset = MovieLens(path, model_name='all-MiniLM-L6-v2') data = dataset[0].to(device) # Add user node features for message passing: data['user'].x = torch.eye(data['user'].num_nodes, device=device) del data['user'].num_nodes # Add a reverse ('movie', 'rev_rates', 'user') relation for message passing: data = T.ToUndirected()(data) del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label. # Perform a link-level split into training, validation, and test edges: train_data, val_data, test_data = T.RandomLinkSplit( num_val=0.1, num_test=0.1, neg_sampling_ratio=0.0, edge_types=[('user', 'rates', 'movie')], rev_edge_types=[('movie', 'rev_rates', 'user')], )(data) # We have an unbalanced dataset with many labels for rating 3 and 4, and very # few for 0 and 1. Therefore we use a weighted MSE loss. if args.use_weighted_loss: weight = torch.bincount(train_data['user', 'movie'].edge_label) weight = weight.max() / weight else: weight = None def weighted_mse_loss(pred, target, weight=None): weight = 1. if weight is None else weight[target].to(pred.dtype) return (weight * (pred - target.to(pred.dtype)).pow(2)).mean() class GNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x class EdgeDecoder(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.lin1 = Linear(2 * hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, 1) def forward(self, z_dict, edge_label_index): row, col = edge_label_index z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1) z = self.lin1(z).relu() z = self.lin2(z) return z.view(-1) class Model(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.encoder = GNNEncoder(hidden_channels, hidden_channels) self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum') self.decoder = EdgeDecoder(hidden_channels) def forward(self, x_dict, edge_index_dict, edge_label_index): z_dict = self.encoder(x_dict, edge_index_dict) return self.decoder(z_dict, edge_label_index) model = Model(hidden_channels=32).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() optimizer.zero_grad() pred = model(train_data.x_dict, train_data.edge_index_dict, train_data['user', 'movie'].edge_label_index) target = train_data['user', 'movie'].edge_label loss = weighted_mse_loss(pred, target, weight) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(data): model.eval() pred = model(data.x_dict, data.edge_index_dict, data['user', 'movie'].edge_label_index) pred = pred.clamp(min=0, max=5) target = data['user', 'movie'].edge_label.float() rmse = F.mse_loss(pred, target).sqrt() return float(rmse) for epoch in range(1, 301): loss = train() train_rmse = test(train_data) val_rmse = test(val_data) test_rmse = test(test_data) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, ' f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}') ================================================ FILE: examples/hetero/hgt_dblp.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import DBLP from torch_geometric.nn import HGTConv, Linear path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP') # We initialize conference node features with a single one-vector as feature: dataset = DBLP(path, transform=T.Constant(node_types='conference')) data = dataset[0] class HGT(torch.nn.Module): def __init__(self, hidden_channels, out_channels, num_heads, num_layers): super().__init__() self.lin_dict = torch.nn.ModuleDict() for node_type in data.node_types: self.lin_dict[node_type] = Linear(-1, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads) self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): x_dict = { node_type: self.lin_dict[node_type](x).relu_() for node_type, x in x_dict.items() } for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) return self.lin(x_dict['author']) model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1) if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') data, model = data.to(device), model.to(device) with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001) def train(): model.train() optimizer.zero_grad() out = model(data.x_dict, data.edge_index_dict) mask = data['author'].train_mask loss = F.cross_entropy(out[mask], data['author'].y[mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1) accs = [] for split in ['train_mask', 'val_mask', 'test_mask']: mask = data['author'][split] acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum() accs.append(float(acc)) return accs for epoch in range(1, 101): loss = train() train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/hetero/hierarchical_sage.py ================================================ import argparse import torch import torch.nn.functional as F from tqdm import tqdm import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import NeighborLoader from torch_geometric.nn import HeteroConv, Linear, SAGEConv from torch_geometric.utils import trim_to_layer parser = argparse.ArgumentParser() parser.add_argument('--use-sparse-tensor', action='store_true') args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') transforms = [T.ToUndirected(merge=True)] if args.use_sparse_tensor: transforms.append(T.ToSparseTensor()) dataset = OGB_MAG(root='../../data', preprocess='metapath2vec', transform=T.Compose(transforms)) data = dataset[0].to(device, 'x', 'y') class HierarchicalHeteroGraphSage(torch.nn.Module): def __init__(self, edge_types, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv( { edge_type: SAGEConv((-1, -1), hidden_channels) for edge_type in edge_types }, aggr='sum') self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict, num_sampled_nodes_dict): for i, conv in enumerate(self.convs): x_dict, edge_index_dict, _ = trim_to_layer( layer=i, num_sampled_nodes_per_hop=num_sampled_nodes_dict, num_sampled_edges_per_hop=num_sampled_edges_dict, x=x_dict, edge_index=edge_index_dict, ) x_dict = conv(x_dict, edge_index_dict) x_dict = {key: x.relu() for key, x in x_dict.items()} return self.lin(x_dict['paper']) model = HierarchicalHeteroGraphSage( edge_types=data.edge_types, hidden_channels=64, out_channels=dataset.num_classes, num_layers=2, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) kwargs = {'batch_size': 1024, 'num_workers': 0} train_loader = NeighborLoader( data, num_neighbors=[10] * 2, shuffle=True, input_nodes=('paper', data['paper'].train_mask), **kwargs, ) val_loader = NeighborLoader( data, num_neighbors=[10] * 2, shuffle=False, input_nodes=('paper', data['paper'].val_mask), **kwargs, ) def train(): model.train() total_examples = total_loss = 0 for batch in tqdm(train_loader): batch = batch.to(device) optimizer.zero_grad() out = model( batch.x_dict, batch.adj_t_dict if args.use_sparse_tensor else batch.edge_index_dict, num_sampled_nodes_dict=batch.num_sampled_nodes_dict, num_sampled_edges_dict=batch.num_sampled_edges_dict, ) batch_size = batch['paper'].batch_size loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size]) loss.backward() optimizer.step() total_examples += batch_size total_loss += float(loss) * batch_size return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() total_examples = total_correct = 0 for batch in tqdm(loader): batch = batch.to(device) out = model( batch.x_dict, batch.adj_t_dict if args.use_sparse_tensor else batch.edge_index_dict, num_sampled_nodes_dict=batch.num_sampled_nodes_dict, num_sampled_edges_dict=batch.num_sampled_edges_dict, ) batch_size = batch['paper'].batch_size pred = out[:batch_size].argmax(dim=-1) total_examples += batch_size total_correct += int((pred == batch['paper'].y[:batch_size]).sum()) return total_correct / total_examples for epoch in range(1, 6): loss = train() val_acc = test(val_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}') ================================================ FILE: examples/hetero/load_csv.py ================================================ import os.path as osp import pandas as pd import torch from sentence_transformers import SentenceTransformer from torch_geometric.data import HeteroData, download_url, extract_zip from torch_geometric.transforms import RandomLinkSplit, ToUndirected url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip' root = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') extract_zip(download_url(url, root), root) movie_path = osp.join(root, 'ml-latest-small', 'movies.csv') rating_path = osp.join(root, 'ml-latest-small', 'ratings.csv') def load_node_csv(path, index_col, encoders=None, **kwargs): df = pd.read_csv(path, index_col=index_col, **kwargs) mapping = {index: i for i, index in enumerate(df.index.unique())} x = None if encoders is not None: xs = [encoder(df[col]) for col, encoder in encoders.items()] x = torch.cat(xs, dim=-1) return x, mapping def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping, encoders=None, **kwargs): df = pd.read_csv(path, **kwargs) src = [src_mapping[index] for index in df[src_index_col]] dst = [dst_mapping[index] for index in df[dst_index_col]] edge_index = torch.tensor([src, dst]) edge_attr = None if encoders is not None: edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()] edge_attr = torch.cat(edge_attrs, dim=-1) return edge_index, edge_attr class SequenceEncoder: # The 'SequenceEncoder' encodes raw column strings into embeddings. def __init__(self, model_name='all-MiniLM-L6-v2', device=None): self.device = device self.model = SentenceTransformer(model_name, device=device) @torch.no_grad() def __call__(self, df): x = self.model.encode(df.values, show_progress_bar=True, convert_to_tensor=True, device=self.device) return x.cpu() class GenresEncoder: # The 'GenreEncoder' splits the raw column strings by 'sep' and converts # individual elements to categorical labels. def __init__(self, sep='|'): self.sep = sep def __call__(self, df): genres = {g for col in df.values for g in col.split(self.sep)} mapping = {genre: i for i, genre in enumerate(genres)} x = torch.zeros(len(df), len(mapping)) for i, col in enumerate(df.values): for genre in col.split(self.sep): x[i, mapping[genre]] = 1 return x class IdentityEncoder: # The 'IdentityEncoder' takes the raw column values and converts them to # PyTorch tensors. def __init__(self, dtype=None): self.dtype = dtype def __call__(self, df): return torch.from_numpy(df.values).view(-1, 1).to(self.dtype) user_x, user_mapping = load_node_csv(rating_path, index_col='userId') movie_x, movie_mapping = load_node_csv( movie_path, index_col='movieId', encoders={ 'title': SequenceEncoder(), 'genres': GenresEncoder() }) edge_index, edge_label = load_edge_csv( rating_path, src_index_col='userId', src_mapping=user_mapping, dst_index_col='movieId', dst_mapping=movie_mapping, encoders={'rating': IdentityEncoder(dtype=torch.long)}, ) data = HeteroData() data['user'].num_nodes = len(user_mapping) # Users do not have any features. data['movie'].x = movie_x data['user', 'rates', 'movie'].edge_index = edge_index data['user', 'rates', 'movie'].edge_label = edge_label print(data) # We can now convert `data` into an appropriate format for training a # graph-based machine learning model: # 1. Add a reverse ('movie', 'rev_rates', 'user') relation for message passing. data = ToUndirected()(data) del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label. # 2. Perform a link-level split into training, validation, and test edges. transform = RandomLinkSplit( num_val=0.05, num_test=0.1, neg_sampling_ratio=0.0, edge_types=[('user', 'rates', 'movie')], rev_edge_types=[('movie', 'rev_rates', 'user')], ) train_data, val_data, test_data = transform(data) print(train_data) print(val_data) print(test_data) ================================================ FILE: examples/hetero/metapath2vec.py ================================================ # Reaches around 91.8% Micro-F1 after 5 epochs. import os.path as osp import torch import torch_geometric from torch_geometric.datasets import AMiner from torch_geometric.nn import MetaPath2Vec path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/AMiner') dataset = AMiner(path) data = dataset[0] metapath = [ ('author', 'writes', 'paper'), ('paper', 'published_in', 'venue'), ('venue', 'publishes', 'paper'), ('paper', 'written_by', 'author'), ] if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') model = MetaPath2Vec(data.edge_index_dict, embedding_dim=128, metapath=metapath, walk_length=50, context_size=7, walks_per_node=5, num_negative_samples=5, sparse=True).to(device) loader = model.loader(batch_size=128, shuffle=True, num_workers=6) optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01) def train(epoch, log_steps=100, eval_steps=2000): model.train() total_loss = 0 for i, (pos_rw, neg_rw) in enumerate(loader): optimizer.zero_grad() loss = model.loss(pos_rw.to(device), neg_rw.to(device)) loss.backward() optimizer.step() total_loss += loss.item() if (i + 1) % log_steps == 0: print(f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, ' f'Loss: {total_loss / log_steps:.4f}') total_loss = 0 if (i + 1) % eval_steps == 0: acc = test() print(f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, ' f'Acc: {acc:.4f}') @torch.no_grad() def test(train_ratio=0.1): model.eval() z = model('author', batch=data['author'].y_index.to(device)) y = data['author'].y perm = torch.randperm(z.size(0)) train_perm = perm[:int(z.size(0) * train_ratio)] test_perm = perm[int(z.size(0) * train_ratio):] return model.test(z[train_perm], y[train_perm], z[test_perm], y[test_perm], max_iter=150) for epoch in range(1, 6): train(epoch) acc = test() print(f'Epoch: {epoch}, Accuracy: {acc:.4f}') ================================================ FILE: examples/hetero/recommender_system.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F from tqdm import tqdm import torch_geometric.transforms as T from torch_geometric import EdgeIndex from torch_geometric.datasets import MovieLens from torch_geometric.loader import LinkNeighborLoader, NeighborLoader from torch_geometric.metrics import ( LinkPredMAP, LinkPredPrecision, LinkPredRecall, ) from torch_geometric.nn import MIPSKNNIndex, SAGEConv, to_hetero parser = argparse.ArgumentParser() parser.add_argument('--k', type=int, default=20, help='Number of predictions') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') data = MovieLens(path, model_name='all-MiniLM-L6-v2')[0] # Add user node features for message passing: data['user'].x = torch.eye(data['user'].num_nodes) del data['user'].num_nodes # Only use edges with high ratings (>= 4): mask = data['user', 'rates', 'movie'].edge_label >= 4 data['user', 'movie'].edge_index = data['user', 'movie'].edge_index[:, mask] data['user', 'movie'].time = data['user', 'movie'].time[mask] del data['user', 'movie'].edge_label # Drop rating information from graph. # Add a reverse ('movie', 'rev_rates', 'user') relation for message passing: data = T.ToUndirected()(data) # Perform a temporal link-level split into training and test edges: edge_label_index = data['user', 'movie'].edge_index time = data['user', 'movie'].time perm = time.argsort() train_index = perm[:int(0.8 * perm.numel())] test_index = perm[int(0.8 * perm.numel()):] kwargs = dict( # Shared data loader arguments: data=data, num_neighbors=[5, 5, 5], batch_size=256, time_attr='time', num_workers=4, persistent_workers=True, temporal_strategy='last', ) train_loader = LinkNeighborLoader( edge_label_index=(('user', 'movie'), edge_label_index[:, train_index]), edge_label_time=time[train_index] - 1, # No leakage. neg_sampling=dict(mode='binary', amount=2), shuffle=True, **kwargs, ) # During testing, we sample node-level subgraphs from both endpoints to # retrieve their embeddings. # This allows us to do efficient k-NN search on top of embeddings: src_loader = NeighborLoader( input_nodes='user', input_time=(time[test_index].min() - 1).repeat(data['user'].num_nodes), **kwargs, ) dst_loader = NeighborLoader( input_nodes='movie', input_time=(time[test_index].min() - 1).repeat(data['movie'].num_nodes), **kwargs, ) # Save test edges and the edges we want to exclude when evaluating: sparse_size = (data['user'].num_nodes, data['movie'].num_nodes) test_edge_label_index = EdgeIndex( edge_label_index[:, test_index].to(device), sparse_size=sparse_size, ).sort_by('row')[0] test_exclude_links = EdgeIndex( edge_label_index[:, train_index].to(device), sparse_size=sparse_size, ).sort_by('row')[0] class GNN(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), hidden_channels) self.conv3 = SAGEConv((-1, -1), hidden_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) return x class InnerProductDecoder(torch.nn.Module): def forward(self, x_dict, edge_label_index): x_src = x_dict['user'][edge_label_index[0]] x_dst = x_dict['movie'][edge_label_index[1]] return (x_src * x_dst).sum(dim=-1) class Model(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.encoder = GNN(hidden_channels) self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum') self.decoder = InnerProductDecoder() def forward(self, x_dict, edge_index_dict, edge_label_index): x_dict = self.encoder(x_dict, edge_index_dict) return self.decoder(x_dict, edge_label_index) model = Model(hidden_channels=64).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() total_loss = total_examples = 0 for batch in tqdm(train_loader): batch = batch.to(device) optimizer.zero_grad() out = model( batch.x_dict, batch.edge_index_dict, batch['user', 'movie'].edge_label_index, ) y = batch['user', 'movie'].edge_label loss = F.binary_cross_entropy_with_logits(out, y) loss.backward() optimizer.step() total_loss += float(loss) * y.numel() total_examples += y.numel() return total_loss / total_examples @torch.no_grad() def test(edge_label_index, exclude_links): model.eval() dst_embs = [] for batch in dst_loader: # Collect destination node/movie embeddings: batch = batch.to(device) emb = model.encoder(batch.x_dict, batch.edge_index_dict)['movie'] emb = emb[:batch['movie'].batch_size] dst_embs.append(emb) dst_emb = torch.cat(dst_embs, dim=0) del dst_embs # Instantiate k-NN index based on maximum inner product search (MIPS): mips = MIPSKNNIndex(dst_emb) # Initialize metrics: map_metric = LinkPredMAP(k=args.k).to(device) precision_metric = LinkPredPrecision(k=args.k).to(device) recall_metric = LinkPredRecall(k=args.k).to(device) num_processed = 0 for batch in src_loader: # Collect source node/user embeddings: batch = batch.to(device) # Compute user embeddings: emb = model.encoder(batch.x_dict, batch.edge_index_dict)['user'] emb = emb[:batch['user'].batch_size] # Filter labels/exclusion by current batch: _edge_label_index = edge_label_index.sparse_narrow( dim=0, start=num_processed, length=emb.size(0), ) _exclude_links = exclude_links.sparse_narrow( dim=0, start=num_processed, length=emb.size(0), ) num_processed += emb.size(0) # Perform MIPS search: _, pred_index_mat = mips.search(emb, args.k, _exclude_links) # Update retrieval metrics: map_metric.update(pred_index_mat, _edge_label_index) precision_metric.update(pred_index_mat, _edge_label_index) recall_metric.update(pred_index_mat, _edge_label_index) return ( float(map_metric.compute()), float(precision_metric.compute()), float(recall_metric.compute()), ) for epoch in range(1, 16): train_loss = train() print(f'Epoch: {epoch:02d}, Loss: {train_loss:.4f}') val_map, val_precision, val_recall = test( test_edge_label_index, test_exclude_links, ) print(f'Test MAP@{args.k}: {val_map:.4f}, ' f'Test Precision@{args.k}: {val_precision:.4f}, ' f'Test Recall@{args.k}: {val_recall:.4f}') ================================================ FILE: examples/hetero/temporal_link_pred.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear import torch_geometric.transforms as T from torch_geometric.datasets import MovieLens from torch_geometric.loader import LinkNeighborLoader from torch_geometric.nn import SAGEConv, to_hetero device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') dataset = MovieLens(path, model_name='all-MiniLM-L6-v2') data = dataset[0] # Add user node features for message passing: data['user'].x = torch.eye(data['user'].num_nodes, device=device) del data['user'].num_nodes # Add a reverse ('movie', 'rev_rates', 'user') relation for message passing: data = T.ToUndirected()(data) del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label. # Perform a 80/10/10 temporal link-level split: perm = torch.argsort(data['user', 'movie'].time) train_idx = perm[:int(0.8 * perm.size(0))] val_idx = perm[int(0.8 * perm.size(0)):int(0.9 * perm.size(0))] test_idx = perm[int(0.9 * perm.size(0)):] edge_index = data['user', 'movie'].edge_index kwargs = dict( data=data, num_neighbors=[20, 10], batch_size=1024, time_attr='time', temporal_strategy='last', num_workers=4, persistent_workers=True, ) train_loader = LinkNeighborLoader( edge_label_index=(('user', 'movie'), edge_index[:, train_idx]), edge_label=data['user', 'movie'].edge_label[train_idx], edge_label_time=data['user', 'movie'].time[train_idx] - 1, shuffle=True, **kwargs, ) val_loader = LinkNeighborLoader( edge_label_index=(('user', 'movie'), edge_index[:, val_idx]), edge_label=data['user', 'movie'].edge_label[val_idx], edge_label_time=data['user', 'movie'].time[val_idx] - 1, **kwargs, ) test_loader = LinkNeighborLoader( edge_label_index=(('user', 'movie'), edge_index[:, test_idx]), edge_label=data['user', 'movie'].edge_label[test_idx], edge_label_time=data['user', 'movie'].time[test_idx] - 1, **kwargs, ) class GNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x class EdgeDecoder(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.lin1 = Linear(2 * hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, 1) def forward(self, z_dict, edge_label_index): row, col = edge_label_index z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1) z = self.lin1(z).relu() z = self.lin2(z) return z.view(-1) class Model(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.encoder = GNNEncoder(hidden_channels, hidden_channels) self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum') self.decoder = EdgeDecoder(hidden_channels) def forward(self, x_dict, edge_index_dict, edge_label_index): z_dict = self.encoder(x_dict, edge_index_dict) return self.decoder(z_dict, edge_label_index) model = Model(hidden_channels=32).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() total_loss = total_examples = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() pred = model( batch.x_dict, batch.edge_index_dict, batch['user', 'movie'].edge_label_index, ) target = batch['user', 'movie'].edge_label.float() loss = F.mse_loss(pred, target) loss.backward() optimizer.step() total_loss += float(loss * pred.size(0)) total_examples += pred.size(0) return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() preds, targets = [], [] for batch in loader: batch = batch.to(device) pred = model( batch.x_dict, batch.edge_index_dict, batch['user', 'movie'].edge_label_index, ).clamp(min=0, max=5) preds.append(pred) targets.append(batch['user', 'movie'].edge_label.float()) pred = torch.cat(preds, dim=0) target = torch.cat(targets, dim=0) rmse = (pred - target).pow(2).mean().sqrt() return float(rmse) for epoch in range(1, 11): loss = train() val_rmse = test(val_loader) test_rmse = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val RMSE: {val_rmse:.4f}, ' f'Test RMSE: {test_rmse:.4f}') ================================================ FILE: examples/hetero/to_hetero_mag.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F from torch.nn import ReLU from tqdm import tqdm import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import HGTLoader, NeighborLoader from torch_geometric.nn import Linear, SAGEConv, Sequential, to_hetero parser = argparse.ArgumentParser() parser.add_argument('--use_hgt_loader', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/OGB') transform = T.ToUndirected(merge=True) dataset = OGB_MAG(path, preprocess='metapath2vec', transform=transform) # Already send node features/labels to GPU for faster access during sampling: data = dataset[0].to(device, 'x', 'y') train_input_nodes = ('paper', data['paper'].train_mask) val_input_nodes = ('paper', data['paper'].val_mask) kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True} if not args.use_hgt_loader: train_loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True, input_nodes=train_input_nodes, **kwargs) val_loader = NeighborLoader(data, num_neighbors=[10] * 2, input_nodes=val_input_nodes, **kwargs) else: train_loader = HGTLoader(data, num_samples=[1024] * 4, shuffle=True, input_nodes=train_input_nodes, **kwargs) val_loader = HGTLoader(data, num_samples=[1024] * 4, input_nodes=val_input_nodes, **kwargs) model = Sequential('x, edge_index', [ (SAGEConv((-1, -1), 64), 'x, edge_index -> x'), ReLU(inplace=True), (SAGEConv((-1, -1), 64), 'x, edge_index -> x'), ReLU(inplace=True), (Linear(-1, dataset.num_classes), 'x -> x'), ]) model = to_hetero(model, data.metadata(), aggr='sum').to(device) @torch.no_grad() def init_params(): # Initialize lazy parameters via forwarding a single batch to the model: batch = next(iter(train_loader)) batch = batch.to(device, 'edge_index') model(batch.x_dict, batch.edge_index_dict) def train(): model.train() total_examples = total_loss = 0 for batch in tqdm(train_loader): optimizer.zero_grad() batch = batch.to(device, 'edge_index') batch_size = batch['paper'].batch_size out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size] loss = F.cross_entropy(out, batch['paper'].y[:batch_size]) loss.backward() optimizer.step() total_examples += batch_size total_loss += float(loss) * batch_size return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() total_examples = total_correct = 0 for batch in tqdm(loader): batch = batch.to(device, 'edge_index') batch_size = batch['paper'].batch_size out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size] pred = out.argmax(dim=-1) total_examples += batch_size total_correct += int((pred == batch['paper'].y[:batch_size]).sum()) return total_correct / total_examples init_params() # Initialize parameters. optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(1, 21): loss = train() val_acc = test(val_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}') ================================================ FILE: examples/hierarchical_sampling.py ================================================ import os.path as osp import torch import torch.nn.functional as F from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader from torch_geometric.nn.models.basic_gnn import GraphSAGE device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit') dataset = Reddit(path) # Already send node features/labels to GPU for faster access during sampling: data = dataset[0].to(device, 'x', 'y') kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True} loader = NeighborLoader(data, input_nodes=data.train_mask, num_neighbors=[20, 10, 5], shuffle=True, **kwargs) model = GraphSAGE( dataset.num_features, hidden_channels=64, out_channels=dataset.num_classes, num_layers=3, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(trim=False): for batch in tqdm(loader): optimizer.zero_grad() batch = batch.to(device) if not trim: out = model(batch.x, batch.edge_index) else: out = model( batch.x, batch.edge_index, num_sampled_nodes_per_hop=batch.num_sampled_nodes, num_sampled_edges_per_hop=batch.num_sampled_edges, ) out = out[:batch.batch_size] y = batch.y[:batch.batch_size] loss = F.cross_entropy(out, y) loss.backward() optimizer.step() print('One epoch training without Hierarchical Graph Sampling:') train(trim=False) print('One epoch training with Hierarchical Graph Sampling:') train(trim=True) ================================================ FILE: examples/infomax_inductive.py ================================================ import os.path as osp import torch from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader from torch_geometric.nn import DeepGraphInfomax, SAGEConv device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit') dataset = Reddit(path) data = dataset[0].to(device, 'x', 'edge_index') train_loader = NeighborLoader(data, num_neighbors=[10, 10, 25], batch_size=256, shuffle=True, num_workers=12) test_loader = NeighborLoader(data, num_neighbors=[10, 10, 25], batch_size=256, num_workers=12) class Encoder(torch.nn.Module): def __init__(self, in_channels, hidden_channels): super().__init__() self.convs = torch.nn.ModuleList([ SAGEConv(in_channels, hidden_channels), SAGEConv(hidden_channels, hidden_channels), SAGEConv(hidden_channels, hidden_channels) ]) self.activations = torch.nn.ModuleList() self.activations.extend([ torch.nn.PReLU(hidden_channels), torch.nn.PReLU(hidden_channels), torch.nn.PReLU(hidden_channels) ]) def forward(self, x, edge_index, batch_size): for conv, act in zip(self.convs, self.activations): x = conv(x, edge_index) x = act(x) return x[:batch_size] def corruption(x, edge_index, batch_size): return x[torch.randperm(x.size(0))], edge_index, batch_size model = DeepGraphInfomax( hidden_channels=512, encoder=Encoder(dataset.num_features, 512), summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)), corruption=corruption).to(device) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) def train(epoch): model.train() total_loss = total_examples = 0 for batch in tqdm(train_loader, desc=f'Epoch {epoch:02d}'): optimizer.zero_grad() pos_z, neg_z, summary = model(batch.x, batch.edge_index, batch.batch_size) loss = model.loss(pos_z, neg_z, summary) loss.backward() optimizer.step() total_loss += float(loss) * pos_z.size(0) total_examples += pos_z.size(0) return total_loss / total_examples @torch.no_grad() def test(): model.eval() zs = [] for batch in tqdm(test_loader, desc='Evaluating'): pos_z, _, _ = model(batch.x, batch.edge_index, batch.batch_size) zs.append(pos_z.cpu()) z = torch.cat(zs, dim=0) train_val_mask = data.train_mask | data.val_mask acc = model.test(z[train_val_mask], data.y[train_val_mask], z[data.test_mask], data.y[data.test_mask], max_iter=10000) return acc for epoch in range(1, 31): loss = train(epoch) print(f'Epoch {epoch:02d}, Loss: {loss:.4f}') test_acc = test() print(f'Test Accuracy: {test_acc:.4f}') ================================================ FILE: examples/infomax_transductive.py ================================================ import os.path as osp import torch from torch_geometric.datasets import Planetoid from torch_geometric.nn import DeepGraphInfomax, GCNConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset) class Encoder(torch.nn.Module): def __init__(self, in_channels, hidden_channels): super().__init__() self.conv = GCNConv(in_channels, hidden_channels) self.prelu = torch.nn.PReLU(hidden_channels) def forward(self, x, edge_index): x = self.conv(x, edge_index) x = self.prelu(x) return x def corruption(x, edge_index): return x[torch.randperm(x.size(0), device=x.device)], edge_index if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model = DeepGraphInfomax( hidden_channels=512, encoder=Encoder(dataset.num_features, 512), summary=lambda z, *args, **kwargs: z.mean(dim=0).sigmoid(), corruption=corruption, ).to(device) data = dataset[0].to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() optimizer.zero_grad() pos_z, neg_z, summary = model(data.x, data.edge_index) loss = model.loss(pos_z, neg_z, summary) loss.backward() optimizer.step() return loss.item() def test(): model.eval() z, _, _ = model(data.x, data.edge_index) acc = model.test(z[data.train_mask], data.y[data.train_mask], z[data.test_mask], data.y[data.test_mask], max_iter=150) return acc for epoch in range(1, 301): loss = train() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}') acc = test() print(f'Accuracy: {acc:.4f}') ================================================ FILE: examples/jit/README.md ================================================ # JIT Examples This directory contains examples demonstrating the use of Just-In-Time (JIT) compilation in different GNN models. | Example | Description | | ---------------------- | ----------------------------------------------------------------- | | [`gcn.py`](./gcn.py) | JIT compilation in `GCN` | | [`gat.py`](./gat.py) | JIT compilation in `GAT` | | [`gin.py`](./gin.py) | JIT compilation in `GIN` | | [`film.py`](./film.py) | JIT compilation in [`GNN-FiLM`](https://arxiv.org/abs/1906.12192) | ================================================ FILE: examples/jit/film.py ================================================ import os.path as osp import torch import torch.nn.functional as F from sklearn.metrics import f1_score from torch import Tensor from torch.nn import BatchNorm1d from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader from torch_geometric.nn import FiLMConv path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'PPI') train_dataset = PPI(path, split='train') val_dataset = PPI(path, split='val') test_dataset = PPI(path, split='test') train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) class FiLM(torch.nn.Module): def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, dropout: float = 0.0): super().__init__() self.dropout = dropout self.convs = torch.nn.ModuleList() self.convs.append(FiLMConv(in_channels, hidden_channels)) for _ in range(num_layers - 2): conv = FiLMConv(hidden_channels, hidden_channels) self.convs.append(conv) self.last_conv = FiLMConv(hidden_channels, out_channels, act=None) self.norms = torch.nn.ModuleList() for _ in range(num_layers - 1): self.norms.append(BatchNorm1d(hidden_channels)) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for conv, norm in zip(self.convs, self.norms): x = norm(conv(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.last_conv(x, edge_index) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = FiLM(train_dataset.num_features, 320, train_dataset.num_classes, num_layers=4, dropout=0.1) model = torch.jit.script(model).to(device) criterion = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() loss = criterion(model(data.x, data.edge_index), data.y) total_loss += loss.item() * data.num_graphs loss.backward() optimizer.step() return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() ys, preds = [], [] for data in loader: ys.append(data.y) out = model(data.x.to(device), data.edge_index.to(device)) preds.append((out > 0).float().cpu()) y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 for epoch in range(1, 501): loss = train() val_f1 = test(val_loader) test_f1 = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, ' f'Test: {test_f1:.4f}') ================================================ FILE: examples/jit/gat.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GATConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) class GAT(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6) self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=True, dropout=0.6) def forward(self, x, edge_index): x = F.dropout(x, p=0.6, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, data = GAT().to(device), dataset[0].to(device) model = torch.jit.script(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return loss.item() @torch.no_grad() def test(): model.eval() out, accs = model(data.x, data.edge_index), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs for epoch in range(1, 201): loss = train() train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/jit/gcn.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch import Tensor import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid') dataset = Planetoid(path, 'Cora', transform=T.NormalizeFeatures()) data = dataset[0] class GCN(torch.nn.Module): def __init__(self, in_channels: int, hidden_channels: int, out_channels: int): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = F.dropout(x, p=0.5, training=self.training) x = self.conv1(x, edge_index).relu() x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return x model = GCN(dataset.num_features, 16, dataset.num_classes) model = torch.jit.script(model).to(device) data = data.to(device) optimizer = torch.optim.Adam([ dict(params=model.conv1.parameters(), weight_decay=5e-4), dict(params=model.conv2.parameters(), weight_decay=0) ], lr=0.01) # Only perform weight-decay on first convolution. def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index).argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs best_val_acc = final_test_acc = 0 for epoch in range(1, 201): loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/jit/gin.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import BatchNorm1d as BatchNorm from torch.nn import Linear, ReLU, Sequential from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GINConv, global_add_pool from torch_geometric.transforms import OneHotDegree path = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data', 'TU') dataset = TUDataset(path, name='IMDB-BINARY', transform=OneHotDegree(135)) dataset = dataset.shuffle() test_dataset = dataset[:len(dataset) // 10] train_dataset = dataset[len(dataset) // 10:] test_loader = DataLoader(test_dataset, batch_size=128) train_loader = DataLoader(train_dataset, batch_size=128) class GIN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for _ in range(num_layers): mlp = Sequential( Linear(in_channels, 2 * hidden_channels), BatchNorm(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), ) conv = GINConv(mlp, train_eps=True) self.convs.append(conv) self.batch_norms.append(BatchNorm(hidden_channels)) in_channels = hidden_channels self.lin1 = Linear(hidden_channels, hidden_channels) self.batch_norm1 = BatchNorm(hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) def forward(self, x, edge_index, batch): for conv, batch_norm in zip(self.convs, self.batch_norms): x = F.relu(batch_norm(conv(x, edge_index))) x = global_add_pool(x, batch) x = F.relu(self.batch_norm1(self.lin1(x))) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GIN(dataset.num_features, 64, dataset.num_classes, num_layers=3) model = model.to(device) model = torch.jit.script(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = 0. for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = F.nll_loss(out, data.y) loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_dataset) @torch.no_grad() def test(loader): model.eval() total_correct = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch) pred = out.max(dim=1)[1] total_correct += pred.eq(data.y).sum().item() return total_correct / len(loader.dataset) for epoch in range(1, 101): loss = train() train_acc = test(train_loader) test_acc = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' f'Train: {train_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/kge_fb15k_237.py ================================================ import argparse import os.path as osp import torch import torch.optim as optim from torch_geometric.datasets import FB15k_237 from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE model_map = { 'transe': TransE, 'complex': ComplEx, 'distmult': DistMult, 'rotate': RotatE, } parser = argparse.ArgumentParser() parser.add_argument('--model', choices=model_map.keys(), type=str.lower, required=True) args = parser.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FB15k') train_data = FB15k_237(path, split='train')[0].to(device) val_data = FB15k_237(path, split='val')[0].to(device) test_data = FB15k_237(path, split='test')[0].to(device) model_arg_map = {'rotate': {'margin': 9.0}} model = model_map[args.model]( num_nodes=train_data.num_nodes, num_relations=train_data.num_edge_types, hidden_channels=50, **model_arg_map.get(args.model, {}), ).to(device) loader = model.loader( head_index=train_data.edge_index[0], rel_type=train_data.edge_type, tail_index=train_data.edge_index[1], batch_size=1000, shuffle=True, ) optimizer_map = { 'transe': optim.Adam(model.parameters(), lr=0.01), 'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6), 'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6), 'rotate': optim.Adam(model.parameters(), lr=1e-3), } optimizer = optimizer_map[args.model] def train(): model.train() total_loss = total_examples = 0 for head_index, rel_type, tail_index in loader: optimizer.zero_grad() loss = model.loss(head_index, rel_type, tail_index) loss.backward() optimizer.step() total_loss += float(loss) * head_index.numel() total_examples += head_index.numel() return total_loss / total_examples @torch.no_grad() def test(data): model.eval() return model.test( head_index=data.edge_index[0], rel_type=data.edge_type, tail_index=data.edge_index[1], batch_size=20000, k=10, ) for epoch in range(1, 501): loss = train() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}') if epoch % 25 == 0: rank, mrr, hits = test(val_data) print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, ' f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}') rank, mrr, hits_at_10 = test(test_data) print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, ' f'Test Hits@10: {hits_at_10:.4f}') ================================================ FILE: examples/label_prop.py ================================================ import os.path as osp from ogb.nodeproppred import Evaluator, PygNodePropPredDataset import torch_geometric.transforms as T from torch_geometric.nn import LabelPropagation root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB') dataset = PygNodePropPredDataset( 'ogbn-arxiv', root, transform=T.Compose([ T.ToUndirected(), T.ToSparseTensor(), ])) split_idx = dataset.get_idx_split() evaluator = Evaluator(name='ogbn-arxiv') data = dataset[0] model = LabelPropagation(num_layers=3, alpha=0.9) out = model(data.y, data.adj_t, mask=split_idx['train']) y_pred = out.argmax(dim=-1, keepdim=True) val_acc = evaluator.eval({ 'y_true': data.y[split_idx['valid']], 'y_pred': y_pred[split_idx['valid']], })['acc'] test_acc = evaluator.eval({ 'y_true': data.y[split_idx['test']], 'y_pred': y_pred[split_idx['test']], })['acc'] print(f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/lcm_aggr_2nd_min.py ================================================ # Final validation accuracy: ~95% import argparse import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.data import Data, InMemoryDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import LCMAggregation from torch_geometric.transforms import BaseTransform parser = argparse.ArgumentParser() parser.add_argument('--num_bits', type=int, default=8) args = parser.parse_args() class RandomPermutation(BaseTransform): def forward(self, data: Data) -> Data: data.x = torch.x[torch.randperm(data.x.size(0))] return data class Random2ndMinimumDataset(InMemoryDataset): r""""A labeled dataset, where each sample is a multiset of integers encoded as bit-vectors, and the label is the second smallest integer in the multiset. """ def __init__( self, num_examples: int, num_bits: int, min_num_elems: int, max_num_elems: int, ): super().__init__(transform=RandomPermutation()) self.data, self.slices = self.collate([ self.get_data(num_bits, min_num_elems, max_num_elems) for _ in range(num_examples) ]) def get_data( self, num_bits: int, min_num_elems: int, max_num_elems: int, ) -> Data: num_elems = int(torch.randint(min_num_elems, max_num_elems + 1, (1, ))) x = torch.randint(0, 2, (num_elems, num_bits)) power = torch.pow(2, torch.arange(num_bits)).flip([0]) ints = (x * power.view(1, -1)).sum(dim=-1) y = x[ints.topk(k=2, largest=False).indices[-1:]].to(torch.float) return Data(x=x, y=y) train_dataset = Random2ndMinimumDataset( num_examples=2**16, # 65,536 num_bits=args.num_bits, min_num_elems=2, max_num_elems=16, ) # Validate on multi sets of size 32, larger than observed during training: val_dataset = Random2ndMinimumDataset( num_examples=2**10, # 1024 num_bits=args.num_bits, min_num_elems=32, max_num_elems=32, ) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=128) class BitwiseEmbedding(torch.nn.Module): def __init__(self, emb_dim: int): super().__init__() self.embs = torch.nn.ModuleList( [torch.nn.Embedding(2, emb_dim) for _ in range(args.num_bits)]) def forward(self, x: Tensor) -> Tensor: xs = [emb(b) for emb, b in zip(self.embs, x.t())] return torch.stack(xs, dim=0).sum(0) class LCM(torch.nn.Module): def __init__(self, emb_dim: int, dropout: float = 0.25): super().__init__() self.encoder = torch.nn.Sequential( BitwiseEmbedding(emb_dim), torch.nn.Linear(emb_dim, emb_dim), torch.nn.Dropout(), torch.nn.GELU(), ) self.aggr = LCMAggregation(emb_dim, emb_dim, project=False) self.decoder = torch.nn.Sequential( torch.nn.Linear(emb_dim, emb_dim), torch.nn.Dropout(dropout), torch.nn.GELU(), torch.nn.Linear(emb_dim, args.num_bits), ) def forward(self, x: Tensor, batch: Tensor) -> Tensor: x = self.encoder(x) x = self.aggr(x, batch) x = self.decoder(x) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = LCM(emb_dim=128).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) def train(): total_loss = total_examples = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() out = model(batch.x, batch.batch) loss = F.binary_cross_entropy_with_logits(out, batch.y) loss.backward() optimizer.step() total_loss += batch.num_graphs * float(loss) total_examples += batch.num_graphs return total_loss / total_examples @torch.no_grad() def test(loader): total_correct = total_examples = 0 for batch in loader: batch = batch.to(device) pred = model(batch.x, batch.batch).sigmoid().round() num_mistakes = (pred != batch.y).sum(dim=-1) total_correct += int((num_mistakes == 0).sum()) total_examples += batch.num_graphs return total_correct / total_examples for epoch in range(1, 1001): loss = train() val_acc = test(val_loader) print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}') ================================================ FILE: examples/lightgcn.py ================================================ import os.path as osp import torch from tqdm import tqdm from torch_geometric.datasets import AmazonBook from torch_geometric.nn import LightGCN from torch_geometric.utils import degree device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Amazon') dataset = AmazonBook(path) data = dataset[0] num_users, num_books = data['user'].num_nodes, data['book'].num_nodes data = data.to_homogeneous().to(device) # Use all message passing edges as training labels: batch_size = 8192 mask = data.edge_index[0] < data.edge_index[1] train_edge_label_index = data.edge_index[:, mask] train_loader = torch.utils.data.DataLoader( range(train_edge_label_index.size(1)), shuffle=True, batch_size=batch_size, ) model = LightGCN( num_nodes=data.num_nodes, embedding_dim=64, num_layers=2, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): total_loss = total_examples = 0 for index in tqdm(train_loader): # Sample positive and negative labels. pos_edge_label_index = train_edge_label_index[:, index] neg_edge_label_index = torch.stack([ pos_edge_label_index[0], torch.randint(num_users, num_users + num_books, (index.numel(), ), device=device) ], dim=0) edge_label_index = torch.cat([ pos_edge_label_index, neg_edge_label_index, ], dim=1) optimizer.zero_grad() pos_rank, neg_rank = model(data.edge_index, edge_label_index).chunk(2) loss = model.recommendation_loss( pos_rank, neg_rank, node_id=edge_label_index.unique(), ) loss.backward() optimizer.step() total_loss += float(loss) * pos_rank.numel() total_examples += pos_rank.numel() return total_loss / total_examples @torch.no_grad() def test(k: int): emb = model.get_embedding(data.edge_index) user_emb, book_emb = emb[:num_users], emb[num_users:] precision = recall = total_examples = 0 for start in range(0, num_users, batch_size): end = start + batch_size logits = user_emb[start:end] @ book_emb.t() # Exclude training edges: mask = ((train_edge_label_index[0] >= start) & (train_edge_label_index[0] < end)) logits[train_edge_label_index[0, mask] - start, train_edge_label_index[1, mask] - num_users] = float('-inf') # Computing precision and recall: ground_truth = torch.zeros_like(logits, dtype=torch.bool) mask = ((data.edge_label_index[0] >= start) & (data.edge_label_index[0] < end)) ground_truth[data.edge_label_index[0, mask] - start, data.edge_label_index[1, mask] - num_users] = True node_count = degree(data.edge_label_index[0, mask] - start, num_nodes=logits.size(0)) topk_index = logits.topk(k, dim=-1).indices isin_mat = ground_truth.gather(1, topk_index) precision += float((isin_mat.sum(dim=-1) / k).sum()) recall += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum()) total_examples += int((node_count > 0).sum()) return precision / total_examples, recall / total_examples for epoch in range(1, 101): loss = train() precision, recall = test(k=20) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Precision@20: ' f'{precision:.4f}, Recall@20: {recall:.4f}') ================================================ FILE: examples/link_pred.py ================================================ import os.path as osp import torch from sklearn.metrics import roc_auc_score import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv from torch_geometric.utils import negative_sampling if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') transform = T.Compose([ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False), ]) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora', transform=transform) # After applying the `RandomLinkSplit` transform, the data is transformed from # a data object to a list of tuples (train_data, val_data, test_data), with # each element representing the corresponding split. train_data, val_data, test_data = dataset[0] class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def encode(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) def decode(self, z, edge_label_index): return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1) def decode_all(self, z): prob_adj = z @ z.t() return (prob_adj > 0).nonzero(as_tuple=False).t() model = Net(dataset.num_features, 128, 64).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) criterion = torch.nn.BCEWithLogitsLoss() def train(): model.train() optimizer.zero_grad() z = model.encode(train_data.x, train_data.edge_index) # We perform a new round of negative sampling for every training epoch: neg_edge_index = negative_sampling( edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, num_neg_samples=train_data.edge_label_index.size(1), method='sparse') edge_label_index = torch.cat( [train_data.edge_label_index, neg_edge_index], dim=-1, ) edge_label = torch.cat([ train_data.edge_label, train_data.edge_label.new_zeros(neg_edge_index.size(1)) ], dim=0) out = model.decode(z, edge_label_index).view(-1) loss = criterion(out, edge_label) loss.backward() optimizer.step() return loss @torch.no_grad() def test(data): model.eval() z = model.encode(data.x, data.edge_index) out = model.decode(z, data.edge_label_index).view(-1).sigmoid() return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) best_val_auc = final_test_auc = 0 for epoch in range(1, 101): loss = train() val_auc = test(val_data) test_auc = test(test_data) if val_auc > best_val_auc: best_val_auc = val_auc final_test_auc = test_auc print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ' f'Test: {test_auc:.4f}') print(f'Final Test: {final_test_auc:.4f}') z = model.encode(test_data.x, test_data.edge_index) final_edge_index = model.decode_all(z) ================================================ FILE: examples/linkx.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import LINKXDataset from torch_geometric.nn import LINKX if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'LINKX') dataset = LINKXDataset(path, name='Penn94') data = dataset[0].to(device) model = LINKX(data.num_nodes, data.num_features, hidden_channels=32, out_channels=dataset.num_classes, num_layers=1, num_edge_layers=1, num_node_layers=1, dropout=0.5).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) mask = data.train_mask[:, 0] # Use the first set of the five masks. loss = F.cross_entropy(out[mask], data.y[mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): accs = [] model.eval() pred = model(data.x, data.edge_index).argmax(dim=-1) for _, mask in data('train_mask', 'val_mask', 'test_mask'): mask = mask[:, 0] # Use the first set of the five masks. accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs for epoch in range(1, 201): loss = train() train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/llm/README.md ================================================ # Examples for Co-training LLMs and GNNs | Example | Description | | -------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | [`g_retriever.py`](./g_retriever.py) | Example helper functions for using the [G-retriever](https://arxiv.org/abs/2402.07630) GNN+LLM module in PyG. Includes an [example repo](https://github.com/neo4j-product-examples/neo4j-gnn-llm-example) for [Neo4j](https://neo4j.com) integration with an associated [blog post](https://developer.nvidia.com/blog/boosting-qa-accuracy-with-graphrag-using-pyg-and-graph-databases/) demonstrating 2x accuracy gains over LLMs on real medical data. For a complete end-to-end pipeline (KG Creation, Subgraph Retrieval, GNN+LLM Finetuning, Testing, LLM Judge Eval), see [`txt2kg_rag.py`](./txt2kg_rag.py). For a native PyG implementation without external graph databases, see [gretriever-stark-prime](https://github.com/puririshi98/gretriever-stark-prime/tree/main). | | [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset | | [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | | [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | | [`protein_mpnn.py`](./protein_mpnn.py) | Example for [Robust deep learning--based protein sequence design using ProteinMPNN](https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1) | | [`txt2kg_rag.py`](./txt2kg_rag.py) | Full end 2 end RAG pipeline using TXT2KG and Vector and Graph RAG with a GNN to achieve state of the art results. Uses the [techQA dataset](https://paperswithcode.com/dataset/techqa) but can be extended to handle any RAG dataset with a corpus of documents and an associated set of Q+A pairs to be split for train/eval/test. See [Stanford GNN+LLM Talk](https://www.nvidia.com/en-us/on-demand/session/other25-nv-0003/) for more details. Note that the TechQA data requires only a single document to answer each question so it can be viewed as a toy example. To see significant accuracy boosts from GNN+LLM TXT2KG based RAG, use data that requires multiple text chunks to answer a question. In cases where single document can answer, basic RAG should be sufficient. | ================================================ FILE: examples/llm/g_retriever.py ================================================ """This example provides helper functions for using the G-retriever model (https://arxiv.org/abs/2402.07630) in PyG. Requirements: `pip install datasets transformers pcst_fast sentencepiece accelerate` Example blog showing 2x accuracy over agentic graphRAG on real medical data (integration with Neo4j Graph DB): https://developer.nvidia.com/blog/boosting-qa-accuracy-with-graphrag-using-pyg-and-graph-databases/ https://github.com/neo4j-product-examples/neo4j-gnn-llm-example See examples/llm/txt2kg_rag.py for e2e pipeline in PyG including: - KG Creation - Subgraph Retrieval - GNN+LLM Finetuning - Testing - LLM Judge Eval """ import math import torch from torch import Tensor def adjust_learning_rate(param_group: dict, LR: float, epoch: int, num_epochs: int): """Decay learning rate with half-cycle cosine after warmup. Args: param_group (dict): Parameter group. LR (float): Learning rate. epoch (int): current epoch num_epochs (int): total epochs Returns: float: Adjusted learning rate. """ min_lr = 5e-6 warmup_epochs = 1 if epoch < warmup_epochs: lr = LR else: lr = min_lr + (LR - min_lr) * 0.5 * ( 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / (num_epochs - warmup_epochs))) param_group['lr'] = lr return lr def save_params_dict(model, save_path): """Saves a model's parameters, excluding non-trainable weights. Args: model (torch.nn.Module): The model to save parameters from. save_path (str): The path to save the parameters to. """ # Get the model's state dictionary, which contains all its parameters state_dict = model.state_dict() # Create a dictionary mapping parameter names to their requires_grad status param_grad_dict = { k: v.requires_grad for (k, v) in model.named_parameters() } # Remove non-trainable parameters from the state dictionary for k in list(state_dict.keys()): if k in param_grad_dict.keys() and not param_grad_dict[k]: del state_dict[k] # Delete parameters that do not require gradient # Save the filtered state dictionary to the specified path torch.save(state_dict, save_path) def load_params_dict(model, save_path): # Load the saved model parameters from the specified file path state_dict = torch.load(save_path) # Update the model's parameters with the loaded state dictionary model.load_state_dict(state_dict) # Return the model with updated parameters return model def normalize_batch_dtype(batch): batch.x = batch.x.float() if hasattr(batch, "edge_attr") and batch.edge_attr is not None: batch.edge_attr = batch.edge_attr.float() def get_loss(model, batch, model_save_name="gnn+llm") -> Tensor: """Compute the loss for a given model and batch of data. Args: model: The model to compute the loss for. batch: The batch of data to compute the loss for. model_save_name: The name of the model being used (e.g. 'llm'). Returns: Tensor: The computed loss. """ # Check the type of model being used to determine the input arguments if model_save_name == 'llm': # For LLM models return model(batch.question, batch.label, batch.desc) else: # (GNN+LLM) normalize_batch_dtype(batch) return model( batch.question, # ["list", "of", "questions", "here"] batch.x, # [num_nodes, num_features] batch.edge_index, # [2, num_edges] batch.batch, # which node belongs to which batch index batch.label, # list answers (labels) batch.edge_attr, # edge attributes batch.desc # list of text graph descriptions ) def inference_step(model, batch, model_save_name="gnn+llm", max_out_tokens=128): """Performs inference on a given batch of data using the provided model. Args: model (nn.Module): The model to use for inference. batch: The batch of data to process. model_save_name (str): The name of the model (e.g. 'llm'). max_out_tokens (int): The maximum number of tokens for our model to output. Returns: The output of the inference step. """ # Check the type of model being used to determine the input arguments if model_save_name == 'llm': # Perform inference on the question and textual graph description return model.inference(batch.question, batch.desc, max_out_tokens=max_out_tokens) else: # (GNN+LLM) normalize_batch_dtype(batch) return model.inference(batch.question, batch.x, batch.edge_index, batch.batch, batch.edge_attr, batch.desc, max_out_tokens=max_out_tokens) ================================================ FILE: examples/llm/git_mol.py ================================================ """This example implements the GIT-Mol model (https://arxiv.org/abs/2308.06911) using PyG. """ import argparse import os.path as osp import torch from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs from torch.optim.lr_scheduler import StepLR from tqdm import tqdm from torch_geometric import seed_everything from torch_geometric.datasets import GitMolDataset from torch_geometric.llm.models import GITMol from torch_geometric.loader import DataLoader @torch.no_grad() def eval(model, data_loader): model.eval() loss = 0 for batch in data_loader: batch_loss = model(batch.x, batch.edge_index, batch.batch, batch.edge_attr, batch.smiles, batch.image, batch.caption) loss += batch_loss.item() / len(data_loader) return loss def train( num_epochs: int, lr: float, weight_decay: float, batch_size: int, checkpointing: bool, ): # Load dataset ================================================ path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', '..', 'data', 'GITMol') train_dataset = GitMolDataset(path, split=0) val_dataset = GitMolDataset(path, split=1) test_dataset = GitMolDataset(path, split=2) seed_everything(42) train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, pin_memory=True, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=False, pin_memory=True, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=False, pin_memory=True, shuffle=False) # Create model =============================================== ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) device = accelerator.device model = GITMol().to(device) optimizer = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=weight_decay) scheduler = StepLR(optimizer, step_size=1, gamma=0.1) model, optimizer, train_loader, scheduler = accelerator.prepare( model, optimizer, train_loader, scheduler) val_loader = accelerator.prepare_data_loader(val_loader, device_placement=True) test_loader = accelerator.prepare_data_loader(test_loader, device_placement=True) # Train and eval ============================================ best_epoch = 0 best_val_loss = float('inf') for epoch in range(num_epochs): # Train model.train() epoch_loss = 0 if epoch == 0: print("Training beginning...") epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' for batch in tqdm(train_loader, desc=epoch_str): optimizer.zero_grad() loss = model(batch.x, batch.edge_index, batch.batch, batch.edge_attr, batch.smiles, batch.image, batch.caption) accelerator.backward(loss) optimizer.step() epoch_loss += loss.item() train_loss = epoch_loss / len(train_loader) # Eval val_loss = eval(model, val_loader) print( f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501 ) if checkpointing and val_loss < best_val_loss: best_val_loss = val_loss best_epoch = epoch torch.save( { 'model_state_dict': accelerator.unwrap_model(model).state_dict(), 'best_loss': best_val_loss }, f'gitmol_pretrain_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 ) torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # Test test_loss = eval(model, test_loader) print(f'Test loss: {test_loss:4f}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=3) parser.add_argument('--lr', type=float, default=1e-5) parser.add_argument('--batch_size', type=int, default=4) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument('--checkpointing', type=bool, default=True) args = parser.parse_args() train( args.epochs, args.lr, args.weight_decay, args.batch_size, args.checkpointing, ) ================================================ FILE: examples/llm/glem.py ================================================ """This example run GLEM model using PyG. Original Paper: https://arxiv.org/abs/2210.14709 “Learning on Large-scale Text-attributed Graphs via Variational Inference“. Requirements on top of basic PyG: `pip install ogb transformers peft tqdm`. GLEM is a data augmentation co-training strategy for LM and GNN, our implementation extended original implementation from LM to LLM and opt for LoRA from peft. ``note:: use additional trick, please add your external prediction by assigning `ext_pred_path` and combine it into pretraining phase and node features """ import argparse import os import os.path as osp import time import psutil import torch from ogb.nodeproppred import Evaluator, PygNodePropPredDataset from torch_geometric import seed_everything from torch_geometric.data import download_google_url from torch_geometric.datasets import TAGDataset from torch_geometric.llm import GLEM from torch_geometric.loader import DataLoader, NeighborLoader from torch_geometric.nn.models import GAT, GCN, GraphSAGE def get_n_params(model): pp = 0 for p in list(model.parameters()): nn = 1 for s in list(p.size()): nn = nn * s pp += nn return pp def main(args): gpu = args.gpu dataset_name = args.dataset text_type = args.text_type if args.dataset == 'arxiv' else 'raw_text' root = osp.join('data', 'ogb') hf_model = args.hf_model pl_ratio = args.pl_ratio gnn_lr = args.gnn_lr lm_lr = args.lm_lr em_order = args.em_order gnn_epochs = args.gnn_epochs lm_epochs = args.lm_epochs patience = args.patience verbose = args.verbose out_dir = args.out_dir lm_batch_size = args.lm_batch_size gnn_batch_size = args.gnn_batch_size lm_use_lora = args.lm_use_lora token_on_disk = args.token_on_disk num_em_iters = args.num_em_iters start_time = time.time() train_with_ext_pred = not args.train_without_ext_pred and \ dataset_name == 'products' ext_pred = None pretrain_augmented = False ext_pseudo_labels = None device = torch.device( f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') print(f'Running on: {torch.cuda.get_device_name({gpu})}') torch.cuda.empty_cache() if train_with_ext_pred: ext_pred_path = download_google_url( id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY', folder='data/ogb/ogbn_products/ext_preds', filename='giant_sagn_scr.pt', log=True) ext_pred = torch.load(ext_pred_path, map_location=device) ext_pseudo_labels = ext_pred.argmax(dim=-1) pretrain_augmented = True seed_everything(42) dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) split_idx = dataset.get_idx_split() data = dataset._data tag_dataset = TAGDataset(root, dataset, hf_model, token_on_disk=token_on_disk) text_dataset = tag_dataset.to_text_dataset(text_type) print(tag_dataset.num_classes, tag_dataset.raw_file_names) num_classes = tag_dataset.num_classes num_features = data.num_features # =========================== LM Data split =============================== split_idx = tag_dataset.get_idx_split() # GLEM train with augmented data, mark original train data as gold data, gold_idx = split_idx['train'] split_idx['valid'] test_idx = split_idx['test'] # random sample pseudo labels nodes, generate their index num_pseudo_labels = int(gold_idx.numel() * pl_ratio) idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels] pseudo_labels_idx = test_idx[idx_to_select] train_idx = torch.cat( (gold_idx, pseudo_labels_idx)) # augmented train_indx print(f'train_idx: {train_idx.size(0)}, ' f'gold_idx: {gold_idx.size(0)}, ' f'pseudo labels ratio: {pl_ratio}, ' f'{train_idx.size(0)/gold_idx.size(0) - 1.0}') gold_dataset = torch.utils.data.Subset(dataset=text_dataset, indices=gold_idx) train_dataset = torch.utils.data.Subset(dataset=text_dataset, indices=train_idx) # ========================== LM Data Loader =============================== print('Building language model dataloader...', end='-->') # if set train_without_ext_pred == True, use this for pretrain text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size, drop_last=False, pin_memory=True, shuffle=True) # training with augmented data, text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, drop_last=False, pin_memory=True, shuffle=True) text_test_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4, drop_last=False, pin_memory=True, shuffle=False) print('done') # =========================== GNN Data Loader ============================= initial_memory = torch.cuda.memory_allocated() data = data.to(device) if ext_pred is not None: data.x = torch.cat((data.x, ext_pred), dim=1) num_features += ext_pred.size(1) current_memory_1 = torch.cuda.max_memory_allocated() # 1 GB = 1073741824 Byte gpu_usage = float(current_memory_1 - initial_memory) / 1073741824 # Print the maximum memory usage after running the model print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB') print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->') # train on gold data w/o pseudo labels graph_pretrain_loader = NeighborLoader( data, input_nodes=gold_idx, num_neighbors=[15, 10, 5], batch_size=gnn_batch_size, shuffle=True, num_workers=12, persistent_workers=True, ) # graph data loader w/ pseudo labels in M-step graph_train_loader = NeighborLoader( data, input_nodes=train_idx, num_neighbors=[15, 10, 5], batch_size=gnn_batch_size, shuffle=True, num_workers=12, persistent_workers=True, ) # for gnn inference subgraph_loader = NeighborLoader( data, input_nodes=None, num_neighbors=[-1], batch_size=gnn_batch_size * 4, num_workers=12, persistent_workers=True, ) # =========================== internal function =========================== evaluator = Evaluator(name=f'ogbn-{dataset_name}') def evaluate(out, split): y_true = data.y.cpu() y_pred = out.argmax(dim=-1, keepdim=True) train_acc, val_acc, test_acc = None, None, None if 'train' in split: train_acc = evaluator.eval({ 'y_true': y_true[split_idx['train']], 'y_pred': y_pred[split_idx['train']], })['acc'] if 'valid' in split: val_acc = evaluator.eval({ 'y_true': y_true[split_idx['valid']], 'y_pred': y_pred[split_idx['valid']], })['acc'] if 'test' in split: test_acc = evaluator.eval({ 'y_true': y_true[split_idx['test']], 'y_pred': y_pred[split_idx['test']], })['acc'] return train_acc, val_acc, test_acc # =========================== Build GNN Model ============================= gnn = None if args.gnn_model == 'SAGE': gnn = GraphSAGE( in_channels=num_features, hidden_channels=args.gnn_hidden_channels, num_layers=args.gnn_num_layers, out_channels=dataset.num_classes, ) elif args.gnn_model == 'GAT': gnn = GAT(in_channels=num_features, hidden_channels=args.gnn_hidden_channels, num_layers=args.gnn_num_layers, out_channels=dataset.num_classes, heads=args.gat_heads) else: gnn = GCN( in_channels=num_features, hidden_channels=args.gnn_hidden_channels, num_layers=args.gnn_num_layers, out_channels=dataset.num_classes, ) print("# GNN Params:", get_n_params(gnn)) # =========================== Build LM Model ============================== model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes, lm_use_lora=lm_use_lora, device=device) lm = model.lm print("# LM Params:", get_n_params(lm)) gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr) lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr) def load_model(em_phase): print(f'Move {em_phase} model from cpu memory') if em_phase == 'lm': model.lm = model.lm.to(device, non_blocking=True) optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr) if em_phase == 'gnn': model.gnn = model.gnn.to(device, non_blocking=True) optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr) return optimizer # ================================= Run GLEM ============================== preds_filename = 'lm_pretrain' preds_dir = f'{out_dir}preds/{dataset_name}/' gnn_test_acc = 0.0 lm_test_acc = 0.0 # =============================== GLEM pretraining ======================== pretrain_phase = 'lm' if em_order == 'lm': pretrain_phase = 'gnn' pretrain_start_time = time.time() # pretraining pretrain_loader = graph_pretrain_loader test_loader = subgraph_loader pretrain_num_epochs = gnn_epochs pretrain_opt = gnn_opt if pretrain_phase == 'gnn': model.gnn = model.gnn.to(device) print('pretraining gnn to generate pseudo labels') if train_with_ext_pred: pretrain_loader = graph_train_loader preds_filename = 'gnn_pretrain' elif pretrain_phase == 'lm': model.lm = model.lm.to(device) print('pretraining lm to generate pseudo labels') pretrain_num_epochs = lm_epochs pretrain_loader = text_pretrain_loader test_loader = text_test_loader pretrain_opt = lm_opt if train_with_ext_pred: pretrain_loader = text_train_loader preds_filename = 'lm_pretrain' early_stopping = 0 best_val_acc = 0.0 for epoch in range(1, pretrain_num_epochs + 1): acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt, ext_pseudo_labels, epoch, pretrain_augmented, verbose) if epoch >= 5 or epoch == pretrain_num_epochs: pretrain_preds = model.inference(pretrain_phase, test_loader, verbose=verbose) train_acc, val_acc, _ = evaluate(pretrain_preds, ['train', 'valid']) print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}') if val_acc <= best_val_acc: early_stopping += 1 if early_stopping > patience: print(f'Pretrain Early stopped by Epoch: {epoch}') break else: best_val_acc = val_acc preds = model.inference(pretrain_phase, test_loader, verbose=verbose) train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) if pretrain_phase == 'gnn': gnn_test_acc = max(gnn_test_acc, test_acc) model.gnn = model.gnn.to('cpu', non_blocking=True) else: lm_test_acc = max(lm_test_acc, test_acc) model.lm = model.lm.to('cpu', non_blocking=True) torch.cuda.empty_cache() pretrain_phase_time = time.time() - pretrain_start_time print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s') os.makedirs(osp.dirname(preds_dir), exist_ok=True) torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt')) print( f'Saved predictions to {osp.join(preds_dir, f"{preds_filename}.pt")}') train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') # EM iterations em_phase = em_order """ We run E-step(LM training) and M-Step(GNN training) alternatively in each em iterations, so the total number of iterations is num_em_iter * 2 and we switch the em_phase at end of each iteration in following loop """ gnn_val_acc = lm_val_acc = 0.0 for em_it in range(1, num_em_iters * 2 + 1): pseudo_labels = preds.argmax(dim=-1) best_val_acc = 0.0 print(f'EM iteration: {em_it}, EM phase: {em_phase}') optimizer = load_model(em_phase) num_epochs = lm_epochs train_loader = text_train_loader test_loader = text_test_loader early_stopping = 0 if em_phase == 'gnn': train_loader = graph_train_loader num_epochs = gnn_epochs test_loader = subgraph_loader for epoch in range(1, num_epochs + 1): acc, loss = model.train(em_phase, train_loader, optimizer, pseudo_labels, epoch, True, verbose) if epoch >= 5 or epoch == num_epochs: cur_preds = model.inference(em_phase, test_loader, verbose=verbose) train_acc, val_acc, _ = evaluate(cur_preds, ['train', 'valid']) print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f},') if val_acc <= best_val_acc: early_stopping += 1 if early_stopping > patience: print(f'''Early stopped by Epoch: {epoch}, \ Best acc: {best_val_acc}''') break else: best_val_acc = val_acc preds = model.inference(em_phase, test_loader, verbose=verbose) if em_phase == 'gnn': gnn_val_acc = max(gnn_val_acc, best_val_acc) model.gnn = model.gnn.to('cpu', non_blocking=True) em_phase = 'lm' else: lm_val_acc = max(lm_val_acc, best_val_acc) model.lm = model.lm.to('cpu', non_blocking=True) em_phase = 'gnn' torch.cuda.empty_cache() print(f'Best GNN validation acc: {gnn_val_acc},' f'LM validation acc: {lm_val_acc}') print('============================') if gnn_val_acc > lm_val_acc: em_phase = 'gnn' model.gnn = model.gnn.to(device, non_blocking=True) test_loader = subgraph_loader else: em_phase = 'lm' model.lm = model.lm.to(device, non_blocking=True) test_loader = text_test_loader test_preds = model.inference(em_phase, test_loader, verbose=verbose) train_acc, val_acc, test_acc = evaluate(test_preds, ['train', 'valid', 'test']) final_test_acc = max(gnn_test_acc, max(lm_test_acc, test_acc)) print(f'Best test acc: {final_test_acc}, model: {em_phase}') end_time = time.time() running_time = (end_time - start_time) / 3600 print(f'Total running time: {running_time:.2f} hours') if __name__ == '__main__': available_gb = psutil.virtual_memory().available / (1024**3) if available_gb < 80: print(f" WARNING: This test may require more RAM than available.\n" f" Estimated RAM needed: ~80 GB\n" f" Detected available RAM: {available_gb:.2f} GB\n" " If the program crashes or is killed, consider upgrading " "system memory.") parser = argparse.ArgumentParser(description='GLEM Example:') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--num_runs', type=int, default=10, help='number of runs') parser.add_argument('--num_em_iters', type=int, default=1, help='number of iterations') parser.add_argument("--dataset", type=str, default='products', help='arxiv or products') parser.add_argument( "--text_type", type=str, default='raw_text', help="type of text, support raw_text, llm_explanation," "all for arxiv and raw_text for products") parser.add_argument("--pl_ratio", type=float, default=0.5, help="pseudo labels ratio") parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', help='huggingface model repo id') parser.add_argument( '--gnn_model', type=str, default='SAGE', help='gnn model for node classification,' 'options: SAGE, GAT, GCN') parser.add_argument('--gnn_hidden_channels', type=int, default=256) parser.add_argument('--gnn_num_layers', type=int, default=3) parser.add_argument('--gat_heads', type=int, default=4, help='Number of multi-head-attentions for GAT ') parser.add_argument('--lm_batch_size', type=int, default=256) parser.add_argument('--gnn_batch_size', type=int, default=1024) parser.add_argument( '--external_pred_path', type=str, default=None, help="Other model's output logits during the " "pretraining phase or simply concatenate it with" "node features as augmented data for gnn") parser.add_argument('--alpha', type=float, default=0.5, help='pseudo label weight in E-step') parser.add_argument('--beta', type=float, default=0.5, help='pseudo label weight in M-step') parser.add_argument('--lm_epochs', type=int, default=10) parser.add_argument('--gnn_epochs', type=int, default=50) parser.add_argument('--gnn_lr', type=float, default=0.002) parser.add_argument('--lm_lr', type=float, default=0.001) parser.add_argument('--patience', type=int, default=3, help='Patience for early stopping') parser.add_argument('--verbose', action='store_true', help='show progress bar during training or not') parser.add_argument('--em_order', type=str, default='lm', help='decide train LM first or GNN first') parser.add_argument('--lm_use_lora', action='store_true', help='use Lora to fine-tune model or not') parser.add_argument( '--token_on_disk', action='store_true', help='save token on disk and load token from disk' 'for reducing duplicated tokenizing') parser.add_argument('--out_dir', type=str, default='output/', help='output directory') parser.add_argument( '--train_without_ext_pred', action='store_true', help='train glem without using additional pseudo labels ' 'for augmenting data only available for ogbn-products') args = parser.parse_args() print(args) main(args) ================================================ FILE: examples/llm/molecule_gpt.py ================================================ """This example implements the MoleculeGPT model (https://ai4d3.github.io/papers/34.pdf) using PyG. """ import argparse import math import os.path as osp import time import torch from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm from torch_geometric import seed_everything from torch_geometric.datasets import InstructMolDataset, MoleculeGPTDataset from torch_geometric.llm.models import LLM, MoleculeGPT, SentenceTransformer from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv def save_params_dict(model, save_path): state_dict = model.state_dict() param_grad_dict = { k: v.requires_grad for (k, v) in model.named_parameters() } for k in list(state_dict.keys()): if k in param_grad_dict.keys() and not param_grad_dict[k]: del state_dict[k] # Delete parameters that do not require gradient torch.save(state_dict, save_path) @torch.no_grad() def eval(model, data_loader): model.eval() loss = 0 for batch in data_loader: batch_loss = model(batch.x, batch.edge_index, batch.batch, batch.edge_attr, batch.smiles, batch.instruction, batch.y) loss += batch_loss.item() / len(data_loader) return loss def train( dataset_name: str, num_epochs: int, lr: float, batch_size: int, checkpointing: bool, ): def adjust_learning_rate(param_group, LR, epoch): # Decay the learning rate with half-cycle cosine after warmup min_lr = 5e-6 warmup_epochs = 1 if epoch < warmup_epochs: lr = LR else: lr = min_lr + (LR - min_lr) * 0.5 * ( 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / (num_epochs - warmup_epochs))) param_group['lr'] = lr return lr def get_clippable_params(params): return [ p for p in params if isinstance(p, torch.Tensor) and not hasattr(p, '_spec') ] start_time = time.time() # Load dataset ================================================ path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', '..', 'data', dataset_name) if dataset_name == 'MoleculeGPT': dataset = MoleculeGPTDataset(path) elif dataset_name == 'InstructMol': dataset = InstructMolDataset(path) train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset)) train_dataset = dataset[:train_size] val_dataset = dataset[train_size:train_size + val_size] test_dataset = dataset[train_size + val_size:] seed_everything(42) train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, pin_memory=True, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=False, pin_memory=True, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=False, pin_memory=True, shuffle=False) # Create model =============================================== llm = LLM( # model_name='lmsys/vicuna-7b-v1.5', model_name='Qwen/Qwen3-0.6B', num_params=1, dtype=torch.bfloat16, sys_prompt='You are an agent, answer my questions.', ) graph_encoder = GINEConv( nn=torch.nn.Sequential( torch.nn.Linear(6, 768), torch.nn.ReLU(), torch.nn.Linear(768, 768), ), train_eps=True, edge_dim=4, ) smiles_encoder = SentenceTransformer( model_name='DeepChem/ChemBERTa-77M-MTR', pooling_strategy='last_hidden_state', ) model = MoleculeGPT( llm=llm, graph_encoder=graph_encoder, smiles_encoder=smiles_encoder, ) # Train and eval ============================================ params = [p for _, p in model.named_parameters() if p.requires_grad] optimizer = torch.optim.AdamW([ { 'params': params, 'lr': lr, 'weight_decay': 0.05, }, ], betas=(0.9, 0.95)) grad_steps = 2 best_epoch = 0 best_val_loss = float('inf') for epoch in range(num_epochs): # Train model.train() epoch_loss = 0 if epoch == 0: print(f"Total Preparation Time: {time.time() - start_time:2f}s") start_time = time.time() print("Training beginning...") epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' loader = tqdm(train_loader, desc=epoch_str) for step, batch in enumerate(loader): optimizer.zero_grad() loss = model(batch.x, batch.edge_index, batch.batch, batch.edge_attr, batch.smiles, batch.instruction, batch.y) loss.backward() clip_grad_norm_( get_clippable_params(optimizer.param_groups[0]['params']), 0.1) if (step + 1) % grad_steps == 0: adjust_learning_rate(optimizer.param_groups[0], lr, step / len(train_loader) + epoch) optimizer.step() epoch_loss += loss.detach().item() if (step + 1) % grad_steps == 0: lr = optimizer.param_groups[0]['lr'] train_loss = epoch_loss / len(train_loader) # Eval val_loss = eval(model, val_loader) print( f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501 ) if checkpointing and val_loss < best_val_loss: best_val_loss = val_loss best_epoch = epoch save_params_dict( model, f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 ) torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() print(f"Total Training Time: {time.time() - start_time:2f}s") # Test test_loss = eval(model, test_loader) print(f'Test loss: {test_loss:4f}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--dataset_name", type=str, default='MoleculeGPT', choices=['MoleculeGPT', 'InstructMol'], help='Support MoleculeGPT and InstructMol') parser.add_argument('--epochs', type=int, default=3) parser.add_argument('--lr', type=float, default=1e-5) parser.add_argument('--batch_size', type=int, default=2) parser.add_argument('--checkpointing', type=bool, default=True) args = parser.parse_args() start_time = time.time() train( args.dataset_name, args.epochs, args.lr, args.batch_size, args.checkpointing, ) print(f'Total Time: {time.time() - start_time:2f}s') ================================================ FILE: examples/llm/protein_mpnn.py ================================================ """This example implements the ProteinMPNN model (https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1) using PyG. """ import argparse import time import numpy as np import psutil import torch from torch_geometric import seed_everything from torch_geometric.datasets import ProteinMPNNDataset from torch_geometric.llm.models import ProteinMPNN from torch_geometric.loader import DataLoader def loss_smoothed(y, logits, mask, weight=0.1): """Negative log probabilities.""" y_onehot = torch.nn.functional.one_hot(y, 21).float() # Label smoothing y_onehot = y_onehot + weight / float(y_onehot.size(-1)) y_onehot = y_onehot / y_onehot.sum(-1, keepdim=True) loss = -(y_onehot * logits).sum(-1) loss_av = torch.sum(loss * mask) / 2000.0 return loss, loss_av def loss_nll(y, logits, mask): """Negative log probabilities.""" criterion = torch.nn.NLLLoss(reduction='none') loss = criterion(logits.contiguous().view(-1, logits.size(-1)), y.contiguous().view(-1)).view(y.size()) y_argmaxed = torch.argmax(logits, -1) # [B, L] true_false = (y == y_argmaxed).float() loss_av = torch.sum(loss * mask) / torch.sum(mask) return loss, loss_av, true_false class NoamOpt: """Optim wrapper that implements rate.""" def __init__(self, model_size, factor, warmup, optimizer, step): self.optimizer = optimizer self._step = step self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 @property def param_groups(self): """Return param_groups.""" return self.optimizer.param_groups def step(self): """Update parameters and rate.""" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p['lr'] = rate self._rate = rate self.optimizer.step() def rate(self, step=None): """Implement learning rate above.""" if step is None: step = self._step return self.factor * (self.model_size**(-0.5) * min(step**(-0.5), step * self.warmup**(-1.5))) def zero_grad(self): self.optimizer.zero_grad() def train(model, optimizer, data_loader, device, scaler): model.train() train_sum = 0.0 train_acc = 0.0 train_weights = 0.0 for batch in data_loader: optimizer.zero_grad() batch = batch.to(device) mask_for_loss = batch.mask * batch.chain_mask_all y = batch.chain_seq_label if torch.cuda.is_available() and args.mixed_precision: with torch.amp.autocast('cuda'): logits = model(batch.x, batch.chain_seq_label, batch.mask, batch.chain_mask_all, batch.residue_idx, batch.chain_encoding_all, batch.batch) _, loss = loss_smoothed(y, logits, mask_for_loss) scaler.scale(loss).backward() if args.gradient_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm) scaler.step(optimizer) scaler.update() else: logits = model(batch.x, batch.chain_seq_label, batch.mask, batch.chain_mask_all, batch.residue_idx, batch.chain_encoding_all, batch.batch) _, loss = loss_smoothed(y, logits, mask_for_loss) loss.backward() if args.gradient_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm) optimizer.step() loss, _, true_false = loss_nll(y, logits, mask_for_loss) train_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy() train_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy() train_weights += torch.sum(mask_for_loss).cpu().data.numpy() train_loss = train_sum / train_weights train_accuracy = train_acc / train_weights train_perplexity = np.exp(train_loss) return train_perplexity, train_accuracy @torch.no_grad() def eval(model, data_loader, device): model.eval() valid_sum = 0. valid_weights = 0. valid_acc = 0. for batch in data_loader: batch = batch.to(device) logits = model(batch.x, batch.chain_seq_label, batch.mask, batch.chain_mask_all, batch.residue_idx, batch.chain_encoding_all, batch.batch) mask_for_loss = batch.mask * batch.chain_mask_all y = batch.chain_seq_label loss, _, true_false = loss_nll(y, logits, mask_for_loss) valid_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy() valid_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy() valid_weights += torch.sum(mask_for_loss).cpu().data.numpy() valid_loss = valid_sum / valid_weights valid_accuracy = valid_acc / valid_weights valid_perplexity = np.exp(valid_loss) return valid_perplexity, valid_accuracy def main(args): wall_clock_start = time.perf_counter() seed_everything(123) scaler = torch.amp.GradScaler() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if args.size == 'large' and psutil.virtual_memory().total < 64.1 * 1024**3: print('Warning: may not have enough RAM to run this example.') print('Consider upgrading RAM if an error occurs.') print('Estimated RAM Needed: ~64.1GB.') train_dataset = ProteinMPNNDataset( root=args.data_path, size=args.size, split='train', rescut=args.rescut, max_length=args.max_protein_length, ) valid_dataset = ProteinMPNNDataset( root=args.data_path, size=args.size, split='valid', rescut=args.rescut, max_length=args.max_protein_length, ) test_dataset = ProteinMPNNDataset( root=args.data_path, size=args.size, split='test', rescut=args.rescut, max_length=args.max_protein_length, ) train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=6) valid_loader = DataLoader(valid_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=6) test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=6) model = ProteinMPNN( hidden_dim=args.hidden_dim, num_encoder_layers=args.num_encoder_layers, num_decoder_layers=args.num_decoder_layers, num_neighbors=args.num_neighbors, dropout=args.dropout, augment_eps=args.backbone_noise, num_positional_embedding=16, ).to(device) total_step = 0 optimizer = NoamOpt( model_size=args.hidden_dim, factor=2, warmup=4000, optimizer=torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9), step=total_step) times = [] for e in range(args.num_epochs): start = time.perf_counter() train_perplexity, train_accuracy = train(model, optimizer, train_loader, device, scaler) valid_perplexity, valid_accuracy = eval(model, valid_loader, device) print( f'epoch: {e:03d}, step: {total_step}, ' f'train: {train_perplexity:.3f}, valid: {valid_perplexity:.3f}, ' f'train_acc: {train_accuracy:.3f}, valid_acc: {valid_accuracy:.3f}' ) times.append(time.perf_counter() - start) print(f'Average Epoch Time: {torch.tensor(times).mean():.4f}s') print(f'Median Epoch Time: {torch.tensor(times).median():.4f}s') print(f'Total Program Runtime: ' f'{time.perf_counter() - wall_clock_start:.4f}s') # Test test_perplexity, test_accuracy = eval(model, test_loader, device) print(f'test: {test_perplexity:.3f}, test_acc: {test_accuracy:.3f}') if __name__ == '__main__': parser = argparse.ArgumentParser() # dataset config parser.add_argument('--data_path', type=str, default='data/ProteinMPNN', help='path for loading training data') parser.add_argument( '--size', type=str, default='small', choices=['small', 'large'], help='Use of "small (229.4 MB)" or "large (64.1 GB)" dataset') parser.add_argument('--max_protein_length', type=int, default=10000, help='maximum length of the protein complext') parser.add_argument('--rescut', type=float, default=3.5, help='PDB resolution cutoff') # training config parser.add_argument('--num_epochs', type=int, default=50, help='number of epochs to train for') parser.add_argument('--train_batch_size', type=int, default=4, help='number of tokens for one train batch') parser.add_argument('--eval_batch_size', type=int, default=8, help='number of tokens for one valid or test batch') parser.add_argument( '--gradient_norm', type=float, default=-1.0, help='clip gradient norm, set to negative to omit clipping') parser.add_argument('--mixed_precision', type=bool, default=True, help='train with mixed precision') # model config parser.add_argument('--hidden_dim', type=int, default=128, help='hidden model dimension') parser.add_argument('--num_encoder_layers', type=int, default=3, help='number of encoder layers') parser.add_argument('--num_decoder_layers', type=int, default=3, help='number of decoder layers') parser.add_argument('--num_neighbors', type=int, default=30, help='number of neighbors for the sparse graph') parser.add_argument('--num_rbf', type=int, default=16, help='number of radial basis functions') parser.add_argument('--dropout', type=float, default=0.1, help='dropout level; 0.0 means no dropout') parser.add_argument( '--backbone_noise', type=float, default=0.2, help='amount of noise added to backbone during training') args = parser.parse_args() main(args) ================================================ FILE: examples/llm/txt2kg_rag.py ================================================ import argparse import gc import json import os import random import re import sys from datetime import datetime from glob import glob from itertools import chain from pathlib import Path import yaml try: import wandb wandb_available = True except ImportError: wandb_available = False import torch from g_retriever import ( adjust_learning_rate, get_loss, inference_step, load_params_dict, save_params_dict, ) from huggingface_hub import hf_hub_download from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm from torch_geometric import seed_everything from torch_geometric.llm import RAGQueryLoader from torch_geometric.llm.models import ( LLM, TXT2KG, GRetriever, LLMJudge, SentenceTransformer, ) from torch_geometric.llm.models.txt2kg import _chunk_text from torch_geometric.llm.utils.backend_utils import ( create_graph_from_triples, create_remote_backend_from_graph_data, make_pcst_filter, preprocess_triplet, ) from torch_geometric.llm.utils.feature_store import KNNRAGFeatureStore from torch_geometric.llm.utils.graph_store import NeighborSamplingRAGGraphStore from torch_geometric.llm.utils.vectorrag import DocumentRetriever from torch_geometric.loader import DataLoader from torch_geometric.nn import GAT, SGFormer # Define constants for better readability NV_NIM_MODEL_DEFAULT = "nvidia/llama-3.1-nemotron-ultra-253b-v1" LLM_GENERATOR_NAME_DEFAULT = "meta-llama/Meta-Llama-3.1-8B-Instruct" ENCODER_MODEL_NAME_DEFAULT = "Alibaba-NLP/gte-modernbert-base" KG_CHUNK_SIZE_DEFAULT = 512 GNN_HID_CHANNELS_DEFAULT = 1024 GNN_LAYERS_DEFAULT = 4 LR_DEFAULT = 1e-5 EPOCHS_DEFAULT = 2 BATCH_SIZE_DEFAULT = 1 EVAL_BATCH_SIZE_DEFAULT = 2 LLM_GEN_MODE_DEFAULT = "full" DEFAULT_ENDPOINT_URL = "https://integrate.api.nvidia.com/v1" max_chars_in_train_answer = 128 def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--gnn_model', type=str, default="GAT", choices=["GAT", "SGFormer"], help="The GNN model to use. Default is GAT.") parser.add_argument('--NV_NIM_MODEL', type=str, default=NV_NIM_MODEL_DEFAULT, help="The NIM LLM to use for TXT2KG for LLMJudge") parser.add_argument('--NV_NIM_KEY', type=str, help="NVIDIA API key") parser.add_argument( '--ENDPOINT_URL', type=str, default=DEFAULT_ENDPOINT_URL, help="The URL hosting your model, \ in case you are not using the public NIM.") parser.add_argument( '--kg_chunk_size', type=int, default=KG_CHUNK_SIZE_DEFAULT, help="When splitting context documents for txt2kg,\ the maximum number of characters per chunk.") parser.add_argument('--gnn_hidden_channels', type=int, default=GNN_HID_CHANNELS_DEFAULT, help="Hidden channels for GNN") parser.add_argument('--num_gnn_layers', type=int, default=GNN_LAYERS_DEFAULT, help="Number of GNN layers") parser.add_argument('--lr', type=float, default=LR_DEFAULT, help="Learning rate") parser.add_argument('--epochs', type=int, default=EPOCHS_DEFAULT, help="Number of epochs") parser.add_argument('--batch_size', type=int, default=BATCH_SIZE_DEFAULT, help="Batch size") parser.add_argument('--eval_batch_size', type=int, default=EVAL_BATCH_SIZE_DEFAULT, help="Evaluation batch size") parser.add_argument('--llm_generator_name', type=str, default=LLM_GENERATOR_NAME_DEFAULT, help="The LLM to use for Generation") parser.add_argument( '--llm_generator_mode', type=str, default=LLM_GEN_MODE_DEFAULT, choices=["frozen", "lora", "full"], help="Whether to freeze the Generator LLM,\ use LORA, or fully finetune") parser.add_argument('--dont_save_model', action="store_true", help="Whether to skip model saving.") parser.add_argument('--log_steps', type=int, default=30, help="Log to wandb every N steps") parser.add_argument('--wandb_project', type=str, default="techqa", help="Weights & Biases project name") parser.add_argument('--wandb', action="store_true", help="Enable wandb logging") parser.add_argument( '--num_gpus', type=int, default=None, help="Number of GPUs to use. If not specified," "will determine automatically based on model size.") parser.add_argument('--regenerate_dataset', action="store_true", help="Regenerate the dataset") parser.add_argument( '--doc_parsing_mode', type=str, default=None, choices=["paragraph", "file"], help="How to parse documents: 'paragraph' splits " "files by paragraphs, 'file' treats each file as" "one document. " "This will override any value set in the config file.") parser.add_argument( '--k_for_docs', type=int, default=None, help="Number of docs to retrieve for each question. " "This will override any value set in the config file.") parser.add_argument( '--doc_chunk_size', type=int, default=None, help="The chunk size to use VectorRAG (document retrieval). " "This will override any value set in the config file.") parser.add_argument( '--dataset', type=str, default="techqa", help="Dataset folder name, " "should contain corpus and train.json files." "extracted triples, processed dataset, " "document retriever, and model checkpoints " "will be saved in the dataset folder") parser.add_argument( '--skip_graph_rag', action="store_true", help="Skip the graph RAG step. " "Used to compare the performance of Vector+Graph RAG vs Vector RAG.") parser.add_argument( '--use_x_percent_corpus', default=100.0, type=float, help="Debug flag that allows user to only use a random percentage " "of available knowledge base corpus for RAG") args = parser.parse_args() assert args.NV_NIM_KEY, "NVIDIA API key is required for TXT2KG and eval" assert args.use_x_percent_corpus <= 100 and \ args.use_x_percent_corpus > 0, "Please provide a value in (0,100]" if args.skip_graph_rag: print("Skipping graph RAG step, setting GNN layers to 0...") args.num_gnn_layers = 0 config_path = os.path.join(args.dataset, "config.yaml") if os.path.exists(config_path): print(f"Loading config from {config_path}...") with open(config_path) as config_file: config = yaml.safe_load(config_file) if config is not None: # Use a loop to check and apply config values for each parameter config_params = [ 'doc_parsing_mode', 'doc_chunk_size', 'k_for_docs' ] for param in config_params: if param in config and getattr(args, param) is None: setattr(args, param, config[param]) print(f"Using config value for {param}: {config[param]}") else: print("Skipping config loading...") if args.dataset == "techqa": if args.doc_chunk_size is None: args.doc_chunk_size = 1024 if args.k_for_docs is None: args.k_for_docs = 14 assert args.doc_chunk_size is not None, "doc_chunk_size has not been set" assert args.k_for_docs is not None, "k_for_docs has not been set" return args sys_prompt = ( "You are an expert assistant that can answer " "any question from its knowledge, given a knowledge graph embedding and " "it's textualized context. Just give the answer, without explanation.") prompt_template = """ [QUESTION] {question} [END_QUESTION] [RETRIEVED_CONTEXTS] {context} [END_RETRIEVED_CONTEXTS] """ def _process_and_chunk_text(text, chunk_size, doc_parsing_mode): full_chunks = [] """ Some corpora of docs are grouped into chunked files, typically by paragraph. Only split into individual documents if multiple paragraphs are detected. """ if doc_parsing_mode == "paragraph": paragraphs = re.split(r'\n{2,}', text) else: # doc_parsing_mode == 'file' or doc_parsing_mode is None paragraphs = [text] for paragraph in paragraphs: if chunk_size is not None: chunks = _chunk_text(paragraph, chunk_size) else: # defaults to 512 in _chunk_text chunks = _chunk_text(paragraph) full_chunks.extend(chunks) return full_chunks def get_data(args): # need a JSON dict of Questions and answers, see below for how its used json_path = Path(args.dataset) / "train.json" corpus_path = Path(args.dataset) / "corpus" # techqa specified but neither corpus or train.json exists if "techqa" in args.dataset.lower() and not (json_path.exists() or corpus_path.exists()): print("Could not find Q&A pairs and/or knowledge base corpus") print("Would you like to download the TechQA dataset for demo?") user_input = input("Y/N: ") if user_input.lower() == "y" or user_input.lower() == "yes": print("Downloading data...") # downloads zip_path = hf_hub_download( repo_id="nvidia/TechQA-RAG-Eval", repo_type="dataset", filename="corpus.zip", ) json_path = hf_hub_download( repo_id="nvidia/TechQA-RAG-Eval", repo_type="dataset", filename="train.json", ) # move to working dir if not os.path.exists(args.dataset): os.mkdir(args.dataset) import zipfile with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(args.dataset) import shutil shutil.copy(json_path, os.path.join(args.dataset, "train.json")) elif user_input.lower() == "n" or user_input.lower() == "no": sys.exit("No selected, no data to work with... exiting.") else: sys.exit("Invalid user input, exiting.") with open(os.path.join(args.dataset, "train.json")) as file: json_obj = json.load(file) text_contexts = [] # Read corpus data to create the KG and for document retrieval (RAG). # Prefer *.json files, fall back to txt files. # TODO: add support for additional corpus file formats: PDF, CSV, XML, # HTML, possibly others. # corpus folder is simply a folder with context documents in it. file_paths = glob(os.path.join(args.dataset, "corpus", "*.json")) if len(file_paths) > 0: for file_path in file_paths: with open(file_path, "r+") as f: data = json.load(f) doc_type = data[0]["document_type"] if doc_type != "text": raise ValueError(f"Bad extraction for {file_path}, expecting " f"text only but got {doc_type}") text_contexts.extend( _process_and_chunk_text(data[0]["metadata"]["content"], args.doc_chunk_size, args.doc_parsing_mode)) else: for file_path in glob(os.path.join(args.dataset, "corpus", "*")): with open(file_path, "r+") as f: text_context = f.read() text_contexts.extend( _process_and_chunk_text(text_context, args.doc_chunk_size, args.doc_parsing_mode)) if args.use_x_percent_corpus < 100: random.shuffle(text_contexts) text_contexts = text_contexts[ 0:int(len(text_contexts) * args.use_x_percent_corpus / 100.0)] return json_obj, text_contexts def index_kg(args, context_docs): kg_maker = TXT2KG(NVIDIA_NIM_MODEL=args.NV_NIM_MODEL, NVIDIA_API_KEY=args.NV_NIM_KEY, ENDPOINT_URL=args.ENDPOINT_URL, chunk_size=args.kg_chunk_size) print( "Note that if the TXT2KG process is too slow for you're liking using " "the public NIM, consider deploying yourself using local_lm flag of " "TXT2KG or using https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct " # noqa "to deploy to a private endpoint, which you can pass to this script " "w/ --ENDPOINT_URL flag.") print( "Guide for deploying NIM: https://developer.nvidia.com/blog/a-simple-guide-to-deploying-generative-ai-with-nvidia-nim/" # noqa ) total_tqdm_count = len(context_docs) initial_tqdm_count = 0 checkpoint_file = list(Path(args.dataset).glob("*--*--checkpoint_kg.pt")) if len(checkpoint_file) > 1: raise RuntimeError("Error: more than one checkpoint file found") if len(checkpoint_file) == 1: print("Restoring KG from checkpoint") checkpoint_file = checkpoint_file[0] checkpoint_model_name = checkpoint_file.name.split('--')[0] # check if triples generation are using the correct model if args.NV_NIM_MODEL.split('/')[-1] != checkpoint_model_name: raise RuntimeError( "Error: stored triples were generated using a different model") saved_relevant_triples = torch.load(checkpoint_file, weights_only=False) kg_maker.relevant_triples = saved_relevant_triples kg_maker.doc_id_counter = len(saved_relevant_triples) initial_tqdm_count = kg_maker.doc_id_counter context_docs = context_docs[kg_maker.doc_id_counter:] chkpt_interval = 10 chkpt_count = 0 for context_doc in tqdm(context_docs, total=total_tqdm_count, initial=initial_tqdm_count, desc="Extracting KG triples"): kg_maker.add_doc_2_KG(txt=context_doc) chkpt_count += 1 if chkpt_count == chkpt_interval: chkpt_count = 0 path = args.dataset + "/{m}--{t}--checkpoint_kg.pt" model = kg_maker.NIM_MODEL.split( '/')[-1] if not kg_maker.local_LM else "local" path = path.format(m=model, t=datetime.now().strftime("%Y%m%d_%H%M%S")) torch.save(kg_maker.relevant_triples, path) relevant_triples = kg_maker.relevant_triples triples = list( chain.from_iterable(triple_set for triple_set in relevant_triples.values())) triples = [preprocess_triplet(triplet) for triplet in triples] triples = list(dict.fromkeys(triples)) raw_triples_path = args.dataset + "/{m}--{t}--raw_triples.pt" model_name = kg_maker.NIM_MODEL.split( '/')[-1] if not kg_maker.local_LM else "local" torch.save( triples, raw_triples_path.format(m=model_name, t=datetime.now().strftime("%Y%m%d_%H%M%S"))) for old_checkpoint_file in Path( args.dataset).glob("*--*--checkpoint_kg.pt"): os.remove(old_checkpoint_file) return triples def update_data_lists(args, data_lists): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # creating the embedding model sent_trans_batch_size = 256 model = SentenceTransformer( model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval() model_kwargs = { "output_device": device, "batch_size": int(sent_trans_batch_size / 4), } doc_retriever_path = os.path.join(args.dataset, "document_retriever.pt") if os.path.exists(doc_retriever_path): print("Loading document retriever from checkpoint...") vector_retriever = DocumentRetriever.load(doc_retriever_path, model=model.encode, model_kwargs=model_kwargs) if args.k_for_docs != vector_retriever.k_for_docs: vector_retriever.k_for_docs = args.k_for_docs else: return data_lists else: raise ValueError("Document retriever not found") print("k_for_docs changed, updating data lists...") total_points = sum(len(data_list) for data_list in data_lists.values()) progress_bar = tqdm(total=total_points, desc="Updating text contexts") for data_list in data_lists.values(): for data_point in data_list: q = data_point["question"] data_point["text_context"] = vector_retriever.query(q) progress_bar.update(1) progress_bar.close() vector_retriever.save(doc_retriever_path) del vector_retriever gc.collect() torch.cuda.empty_cache() dataset_name = os.path.basename(args.dataset) dataset_path = os.path.join(args.dataset, f"{dataset_name}.pt") torch.save(data_lists, dataset_path) return data_lists def make_dataset(args): qa_pairs, context_docs = get_data(args) print("Number of Docs in our VectorDB =", len(context_docs)) data_lists = {"train": [], "validation": [], "test": []} triples = [] raw_triples_file = list(Path(args.dataset).glob("*--*--raw_triples.pt")) if len(raw_triples_file) > 1: raise RuntimeError("Error: multiple raw_triples files found") if len(raw_triples_file) == 1: raw_triples_file = raw_triples_file[0] stored_model_name = raw_triples_file.name.split('--')[0] if args.NV_NIM_MODEL.split('/')[-1] != stored_model_name: raise RuntimeError( "Error: stored triples were generated using a different model") print(f" -> Saved triples generated with: {stored_model_name}") triples = torch.load(raw_triples_file) else: triples = index_kg(args, context_docs) print("Number of triples in our GraphDB =", len(triples)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # creating the embedding model sent_trans_batch_size = 256 model = SentenceTransformer( model_name=ENCODER_MODEL_NAME_DEFAULT).to(device) print("Creating the graph data from raw triples...") # create the graph data from raw triples graph_data = create_graph_from_triples( triples=triples, embedding_model=model.encode, embedding_method_kwargs={ "batch_size": min(len(triples), sent_trans_batch_size), "verbose": True }) print("Creating the graph and feature stores...") # creating the graph and feature stores fs, gs = create_remote_backend_from_graph_data( graph_data=graph_data, path="backend", graph_db=NeighborSamplingRAGGraphStore, feature_db=KNNRAGFeatureStore).load() """ NOTE: these retriever hyperparams are very important. Tuning may be needed for custom data... """ model_kwargs = { "output_device": device, "batch_size": int(sent_trans_batch_size / 4), "verbose": True } doc_retriever_path = os.path.join(args.dataset, "document_retriever.pt") if os.path.exists(doc_retriever_path): print("Loading document retriever from checkpoint...") vector_retriever = DocumentRetriever.load(doc_retriever_path, model=model.encode, model_kwargs=model_kwargs) if args.k_for_docs != vector_retriever.k_for_docs: vector_retriever.k_for_docs = args.k_for_docs else: print("Creating document retriever...") vector_retriever = DocumentRetriever(context_docs, k_for_docs=args.k_for_docs, model=model.encode, model_kwargs=model_kwargs) vector_retriever.save(doc_retriever_path) subgraph_filter = make_pcst_filter( triples, model, topk=5, # nodes topk_e=5, # edges cost_e=.5, # edge cost num_clusters=10) # num clusters # number of neighbors for each seed node selected by KNN fanout = 100 # number of hops for neighborsampling num_hops = 2 query_loader_config = { "k_nodes": 1024, # k for Graph KNN "num_neighbors": [fanout] * num_hops, # number of sampled neighbors "encoder_model": model, } # GraphDB retrieval done with KNN+NeighborSampling+PCST # PCST = Prize Collecting Steiner Tree # VectorDB retrieval just vanilla vector RAG print("Now to retrieve context for each query from " "our Vector and Graph DBs...") query_loader = RAGQueryLoader(graph_data=(fs, gs), subgraph_filter=subgraph_filter, vector_retriever=vector_retriever, config=query_loader_config) # pre-process the dataset total_data_list = [] extracted_triple_sizes = [] global max_chars_in_train_answer for data_point in tqdm(qa_pairs, desc="Building un-split dataset"): if data_point["is_impossible"]: continue QA_pair = (data_point["question"], data_point["answer"]) q = QA_pair[0] max_chars_in_train_answer = max(len(QA_pair[1]), max_chars_in_train_answer) # (TODO) make this batch queries for retrieving w/ CuVS+CuGraph subgraph = query_loader.query(q) subgraph.label = QA_pair[1] total_data_list.append(subgraph) extracted_triple_sizes.append(len(subgraph.triples)) random.shuffle(total_data_list) # stats print("Min # of Retrieved Triples =", min(extracted_triple_sizes)) print("Max # of Retrieved Triples =", max(extracted_triple_sizes)) print("Average # of Retrieved Triples =", sum(extracted_triple_sizes) / len(extracted_triple_sizes)) # 60:20:20 split data_lists["train"] = total_data_list[:int(.6 * len(total_data_list))] data_lists["validation"] = total_data_list[int(.6 * len(total_data_list) ):int(.8 * len(total_data_list))] data_lists["test"] = total_data_list[int(.8 * len(total_data_list)):] dataset_name = os.path.basename(args.dataset) dataset_path = os.path.join(args.dataset, f"{dataset_name}.pt") torch.save((data_lists, max_chars_in_train_answer), dataset_path) del model gc.collect() torch.cuda.empty_cache() return data_lists def train(args, train_loader, val_loader): if args.wandb: wandb.init(project=args.wandb_project, name=f"run_{datetime.now().strftime('%Y-%m-%d_%H:%M')}", config=vars(args)) hidden_channels = args.gnn_hidden_channels num_gnn_layers = args.num_gnn_layers if args.num_gnn_layers > 0: if args.gnn_model == "GAT": gnn = GAT(in_channels=768, hidden_channels=hidden_channels, out_channels=1024, num_layers=num_gnn_layers, heads=4) elif args.gnn_model == "SGFormer": gnn = SGFormer(in_channels=768, hidden_channels=hidden_channels, out_channels=1024, trans_num_heads=1, trans_dropout=0.5, gnn_num_layers=num_gnn_layers, gnn_dropout=0.5) else: raise ValueError(f"Invalid GNN model: {args.gnn_model}") else: gnn = None if args.llm_generator_mode == "full": llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt, n_gpus=args.num_gpus) elif args.llm_generator_mode == "lora": llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt, dtype=torch.float32, n_gpus=args.num_gpus) else: # frozen llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt, dtype=torch.float32, n_gpus=args.num_gpus).eval() for _, p in llm.named_parameters(): p.requires_grad = False model = GRetriever(llm=llm, gnn=gnn, use_lora=args.llm_generator_mode == "lora") save_name = os.path.join(args.dataset, "model.pt") if args.llm_generator_mode == "frozen" and args.num_gnn_layers == 0: if not args.dont_save_model: save_params_dict(model, save_path=save_name) return model if os.path.exists(save_name) and not args.regenerate_dataset: print("Re-using saved G-retriever model for testing...") model = load_params_dict(model, save_name) else: params = [p for _, p in model.named_parameters() if p.requires_grad] lr = args.lr optimizer = torch.optim.AdamW([{ 'params': params, 'lr': lr, 'weight_decay': 0.05 }], betas=(0.9, 0.95)) num_oom_errors = 0 for epoch in range(args.epochs): model.train() epoch_loss = 0 epoch_str = f'Epoch: {epoch + 1}|{args.epochs}' loader = tqdm(train_loader, desc=epoch_str) for step, batch in enumerate(loader): new_qs = [] for i, q in enumerate(batch["question"]): # insert VectorRAG context new_qs.append( prompt_template.format( question=q, context="\n".join(batch.text_context[i]))) batch.question = new_qs if args.skip_graph_rag: batch.desc = "" optimizer.zero_grad() try: loss = get_loss(model, batch) loss.backward() clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) if (step + 1) % 2 == 0: adjust_learning_rate(optimizer.param_groups[0], lr, step / len(train_loader) + epoch, args.epochs) optimizer.step() epoch_loss += float(loss.detach()) if args.wandb and (step + 1) % args.log_steps == 0: wandb.log({ "train/loss": float(loss.detach()), "train/lr": optimizer.param_groups[0]['lr'], }) if (step + 1) % 2 == 0: lr = optimizer.param_groups[0]['lr'] except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() print("Sequence length of last batch: ", model.seq_length_stats[-1]) # TODO: Implement CPU fallback (WIP) num_oom_errors += 1 print("Sequence length stats: ") print("seq_len avg: ", sum(model.seq_length_stats) / len(model.seq_length_stats)) print("seq_len min: ", min(model.seq_length_stats)) print("seq_len max: ", max(model.seq_length_stats)) print("Percent of OOM errors: ", num_oom_errors / len(train_loader)) train_loss = epoch_loss / len(train_loader) print(epoch_str + f', Train Loss: {train_loss:4f}') # Eval Step val_loss = 0 model.eval() with torch.no_grad(): for batch in val_loader: new_qs = [] for i, q in enumerate(batch["question"]): # insert VectorRAG context new_qs.append( prompt_template.format( question=q, context="\n".join(batch.text_context[i]))) batch.question = new_qs if args.skip_graph_rag: batch.desc = "" loss = get_loss(model, batch) val_loss += loss.item() val_loss = val_loss / len(val_loader) print(epoch_str + f", Val Loss: {val_loss:4f}") if args.wandb: wandb.log({ "val/loss": val_loss, "train/epoch_loss": train_loss, "epoch": epoch + 1 }) if args.wandb: wandb.finish() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() model.eval() if not args.dont_save_model: save_params_dict(model, save_path=save_name) return model def test(model, test_loader, args): llm_judge = LLMJudge(args.NV_NIM_MODEL, args.NV_NIM_KEY, args.ENDPOINT_URL) def eval(question: str, pred: str, correct_answer: str): # calculate the score based on pred and correct answer return llm_judge.score(question, pred, correct_answer) scores = [] eval_tuples = [] for test_batch in tqdm(test_loader, desc="Testing"): new_qs = [] raw_qs = test_batch["question"] for i, q in enumerate(test_batch["question"]): # insert VectorRAG context new_qs.append( prompt_template.format( question=q, context="\n".join(test_batch.text_context[i]))) test_batch.question = new_qs if args.skip_graph_rag: test_batch.desc = "" preds = (inference_step(model, test_batch, max_out_tokens=max_chars_in_train_answer / 2)) for question, pred, label in zip(raw_qs, preds, test_batch.label): eval_tuples.append((question, pred, label)) for question, pred, label in tqdm(eval_tuples, desc="Eval"): scores.append(eval(question, pred, label)) avg_scores = sum(scores) / len(scores) print("Avg marlin accuracy=", avg_scores) print("*" * 5 + "NOTE" + "*" * 5) print("Marlin Accuracy is Estimated by LLM as a Judge!") print("Improvement of this estimation process is WIP...") if __name__ == '__main__': # for reproducibility seed_everything(50) args = parse_args() if args.wandb and not wandb_available: print("Error: wandb package not found but --wandb flag was used.") print("Please install wandb and rerun the script.") sys.exit(1) # Need to sanitize sensitive keys saved_NIM_KEY = args.NV_NIM_KEY args.NV_NIM_KEY = "********" print(f"Starting {args.dataset} training with args: ", args) args.NV_NIM_KEY = saved_NIM_KEY dataset_name = os.path.basename(args.dataset) dataset_path = os.path.join(args.dataset, f"{dataset_name}.pt") if os.path.exists(dataset_path) and not args.regenerate_dataset: print(f"Re-using Saved {dataset_name} KG-RAG Dataset...") data_lists, max_chars_in_train_answer = torch.load( dataset_path, weights_only=False) doc_retriever_path = os.path.join(args.dataset, "document_retriever.pt") if os.path.exists(doc_retriever_path): print("Updating data lists with document retriever...") data_lists = update_data_lists(args, data_lists) else: data_lists = make_dataset(args) batch_size = args.batch_size eval_batch_size = args.eval_batch_size train_loader = DataLoader(data_lists["train"], batch_size=batch_size, drop_last=True, pin_memory=True, shuffle=True) val_loader = DataLoader(data_lists["validation"], batch_size=eval_batch_size, drop_last=False, pin_memory=True, shuffle=False) test_loader = DataLoader(data_lists["test"], batch_size=eval_batch_size, drop_last=False, pin_memory=True, shuffle=False) model = train(args, train_loader, val_loader) test(model, test_loader, args) ================================================ FILE: examples/lpformer.py ================================================ import random from argparse import ArgumentParser from collections import defaultdict import numpy as np import torch from ogb.linkproppred import Evaluator, PygLinkPropPredDataset from torch.utils.data import DataLoader from tqdm import tqdm from torch_geometric.nn.models import LPFormer parser = ArgumentParser() parser.add_argument('--data_name', type=str, default='ogbl-ppa') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--runs', help="# random seeds to run over", type=int, default=5) parser.add_argument('--batch_size', type=int, default=32768) parser.add_argument('--hidden_channels', type=int, default=64) parser.add_argument('--gnn_layers', type=int, default=3) parser.add_argument('--dropout', help="Applies to GNN and Transformer", type=float, default=0.1) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--eps', help="PPR precision", type=float, default=5e-5) parser.add_argument('--thresholds', help="List of cn, 1-hop, >1-hop (in that order)", nargs="+", default=[0, 1e-4, 1e-2]) args = parser.parse_args() device = torch.device(args.device) dataset = PygLinkPropPredDataset(name=args.data_name) data = dataset[0].to(device) data.edge_index = data.edge_index.to(device) if hasattr(data, 'x') and data.x is not None: data.x = data.x.to(device).to(torch.float) split_edge = dataset.get_edge_split() split_data = { "train_pos": split_edge['train']['edge'].to(device), "valid_pos": split_edge['valid']['edge'].to(device), "valid_neg": split_edge['valid']['edge_neg'].to(device), "test_pos": split_edge['test']['edge'].to(device), "test_neg": split_edge['test']['edge_neg'].to(device) } if hasattr(data, 'edge_weight') and data.edge_weight is not None: edge_weight = data.edge_weight.to(torch.float) data.edge_weight = data.edge_weight.view(-1).to(torch.float) else: edge_weight = torch.ones(data.edge_index.size(1)).to(device).float() # Convert edge_index to SparseTensor for efficiency # adj_prop = SparseTensor.from_edge_index( # data.edge_index, edge_weight.squeeze(-1), # [data.num_nodes, data.num_nodes]).to(device) adj_prop = torch.sparse_coo_tensor(data.edge_index, edge_weight.squeeze(-1), [data.num_nodes, data.num_nodes]).to(device) evaluator_hit = Evaluator(name=args.data_name) model = LPFormer(data.x.size(-1), args.hidden_channels, num_gnn_layers=args.gnn_layers, ppr_thresholds=args.thresholds, gnn_dropout=args.dropout, transformer_dropout=args.dropout, gcn_cache=True).to(device) # Get PPR matrix in sparse format ppr_matrix = model.calc_sparse_ppr(data.edge_index, data.num_nodes, eps=args.eps).to(device) def train_epoch(): model.train() train_pos = split_data['train_pos'].to(device) adjt_mask = torch.ones(train_pos.size(0), dtype=torch.bool, device=device) total_loss = total_examples = 0 d = DataLoader(range(train_pos.size(0)), args.batch_size, shuffle=True) for perm in tqdm(d, "Epoch"): edges = train_pos[perm].t() # Mask positive input samples - Common strategy during training adjt_mask[perm] = 0 edge2keep = train_pos[adjt_mask, :].t() # masked_adj_prop = SparseTensor.from_edge_index( # edge2keep.t(), sparse_sizes=(data['num_nodes'], # data['num_nodes'])).to_device(device) # masked_adj_prop = masked_adj_prop.to_symmetric() # Ensure symmetric edge2keep = torch.cat((edge2keep, edge2keep[[1, 0]]), dim=1) masked_adj_prop = torch.sparse_coo_tensor( edge2keep, torch.ones(edge2keep.size(1)).to(device), (data['num_nodes'], data['num_nodes'])).to(device) # For next batch adjt_mask[perm] = 1 pos_out = model(edges, data.x, masked_adj_prop, ppr_matrix) pos_loss = -torch.log(torch.sigmoid(pos_out) + 1e-6).mean() # Trivial random sampling neg_edges = torch.randint(0, data['num_nodes'], (edges.size(0), edges.size(1)), dtype=torch.long, device=edges.device) neg_out = model(neg_edges, data.x, adj_prop, ppr_matrix) neg_loss = -torch.log(1 - torch.sigmoid(neg_out) + 1e-6).mean() loss = pos_loss + neg_loss loss.backward() optimizer.step() optimizer.zero_grad() num_examples = pos_out.size(0) total_loss += loss.item() * num_examples total_examples += num_examples return total_loss / total_examples @torch.no_grad() def test(): # NOTE: Eval for ogbl-citation2 is different # See `train.py` in https://github.com/HarryShomer/LPFormer/ for more # Also see there for how to eval under the HeaRT setting # HeaRT = https://arxiv.org/abs/2306.10453 model.eval() all_preds = defaultdict(list) for split_key, split_vals in split_data.items(): if "train" not in split_key: preds = [] for perm in DataLoader(range(split_vals.size(0)), args.batch_size): edges = split_vals[perm].t() perm_logits = model(edges, data.x, adj_prop, ppr_matrix) preds += [torch.sigmoid(perm_logits).cpu()] all_preds[split_key] = torch.cat(preds, dim=0) val_hits = evaluator_hit.eval({ 'y_pred_pos': all_preds['valid_pos'], 'y_pred_neg': all_preds['valid_neg'] })[f'hits@{evaluator_hit.K}'] test_hits = evaluator_hit.eval({ 'y_pred_pos': all_preds['test_pos'], 'y_pred_neg': all_preds['test_neg'] })[f'hits@{evaluator_hit.K}'] return val_hits, test_hits def set_seeds(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # Train over args.runs seeds and average results # Best result for reach run chosen via validation val_perf_runs = [] test_perf_runs = [] for run in range(1, args.runs + 1): print("=" * 75) print(f"RUNNING run={run}") print("=" * 75) set_seeds(run) model.reset_parameters() optimizer = torch.optim.Adam(list(model.parameters()), lr=args.lr) best_valid = 0 best_valid_test = 0 for epoch in range(1, 1 + args.epochs): loss = train_epoch() print(f"Epoch {epoch} Loss: {loss:.4f}\n") if epoch % 5 == 0: print("Evaluating model...\n", flush=True) eval_val, eval_test = test() print(f"Valid Hits@{evaluator_hit.K} = {eval_val}") print(f"Test Hits@{evaluator_hit.K} = {eval_test}") if eval_val > best_valid: best_valid = eval_val best_valid_test = eval_test print( f"\nBest Performance:\n Valid={best_valid}\n Test={best_valid_test}") val_perf_runs.append(best_valid) test_perf_runs.append(best_valid_test) if args.runs > 1: print("\n\n") print(f"Results over {args.runs} runs:") print(f" Valid = {np.mean(val_perf_runs)} +/- {np.std(val_perf_runs)}") print(f" Test = {np.mean(test_perf_runs)} +/- {np.std(test_perf_runs)}") ================================================ FILE: examples/mem_pool.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from torch.nn import BatchNorm1d, LeakyReLU, Linear from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import DeepGCNLayer, GATConv, MemPooling path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TUD') dataset = TUDataset(path, name="PROTEINS_full", use_node_attr=True) dataset._data.x = dataset._data.x[:, :-3] # only use non-binary features. dataset = dataset.shuffle() n = (len(dataset)) // 10 test_dataset = dataset[:n] val_dataset = dataset[n:2 * n] train_dataset = dataset[2 * n:] test_loader = DataLoader(test_dataset, batch_size=20) val_loader = DataLoader(val_dataset, batch_size=20) train_loader = DataLoader(train_dataset, batch_size=20) class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, dropout): super().__init__() self.dropout = dropout self.lin = Linear(in_channels, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(2): conv = GATConv(hidden_channels, hidden_channels, dropout=dropout) norm = BatchNorm1d(hidden_channels) act = LeakyReLU() self.convs.append( DeepGCNLayer(conv, norm, act, block='res+', dropout=dropout)) self.mem1 = MemPooling(hidden_channels, 80, heads=5, num_clusters=10) self.mem2 = MemPooling(80, out_channels, heads=5, num_clusters=1) def forward(self, x, edge_index, batch): x = self.lin(x) for conv in self.convs: x = conv(x, edge_index) x, S1 = self.mem1(x, batch) x = F.leaky_relu(x) x = F.dropout(x, p=self.dropout) x, S2 = self.mem2(x) return ( F.log_softmax(x.squeeze(1), dim=-1), MemPooling.kl_loss(S1) + MemPooling.kl_loss(S2), ) device = 'cuda' if torch.cuda.is_available() else 'cpu' model = Net(dataset.num_features, 32, dataset.num_classes, dropout=0.1) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=4e-5) def train(): model.train() model.mem1.k.requires_grad = False model.mem2.k.requires_grad = False for data in train_loader: optimizer.zero_grad() data = data.to(device) out = model(data.x, data.edge_index, data.batch)[0] loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() kl_loss = 0. model.mem1.k.requires_grad = True model.mem2.k.requires_grad = True optimizer.zero_grad() for data in train_loader: data = data.to(device) kl_loss += model(data.x, data.edge_index, data.batch)[1] kl_loss /= len(train_loader.dataset) kl_loss.backward() optimizer.step() @torch.no_grad() def test(loader): model.eval() total_correct = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch)[0] total_correct += int((out.argmax(dim=-1) == data.y).sum()) return total_correct / len(loader.dataset) times = [] patience = start_patience = 250 test_acc = best_val_acc = 0. for epoch in range(1, 2001): start = time.time() train() val_acc = test(val_loader) if epoch % 500 == 0: optimizer.param_groups[0]['lr'] *= 0.5 if best_val_acc < val_acc: patience = start_patience best_val_acc = val_acc test_acc = test(test_loader) else: patience -= 1 print(f'Epoch {epoch:02d}, Val: {val_acc:.3f}, Test: {test_acc:.3f}') if patience <= 0: break times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/mixhop.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.nn import BatchNorm, Linear, MixHopConv if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora') data = dataset[0] class MixHop(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = MixHopConv(dataset.num_features, 60, powers=[0, 1, 2]) self.norm1 = BatchNorm(3 * 60) self.conv2 = MixHopConv(3 * 60, 60, powers=[0, 1, 2]) self.norm2 = BatchNorm(3 * 60) self.conv3 = MixHopConv(3 * 60, 60, powers=[0, 1, 2]) self.norm3 = BatchNorm(3 * 60) self.lin = Linear(3 * 60, dataset.num_classes) def forward(self, x, edge_index): x = F.dropout(x, p=0.7, training=self.training) x = self.conv1(x, edge_index) x = self.norm1(x) x = F.dropout(x, p=0.9, training=self.training) x = self.conv2(x, edge_index) x = self.norm2(x) x = F.dropout(x, p=0.9, training=self.training) x = self.conv3(x, edge_index) x = self.norm3(x) x = F.dropout(x, p=0.9, training=self.training) return self.lin(x) model, data = MixHop().to(device), data.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=0.005) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.01) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() scheduler.step() return float(loss) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index).argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs best_val_acc = test_acc = 0 for epoch in range(1, 101): loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/mnist_graclus.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import MNISTSuperpixels from torch_geometric.loader import DataLoader from torch_geometric.nn import ( SplineConv, global_mean_pool, graclus, max_pool, max_pool_x, ) from torch_geometric.typing import WITH_SPLINE, WITH_TORCH_CLUSTER from torch_geometric.utils import normalized_cut if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") if not WITH_SPLINE: quit("This example requires 'pyg-lib>=0.6.0'") path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') transform = T.Cartesian(cat=False) train_dataset = MNISTSuperpixels(path, True, transform=transform) test_dataset = MNISTSuperpixels(path, False, transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64) d = train_dataset def normalized_cut_2d(edge_index, pos): row, col = edge_index edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SplineConv(d.num_features, 32, dim=2, kernel_size=5) self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5) self.fc1 = torch.nn.Linear(64, 128) self.fc2 = torch.nn.Linear(128, d.num_classes) def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = max_pool(cluster, data, transform=transform) data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = max_pool_x(cluster, data.x, data.batch) x = global_mean_pool(x, batch) x = F.elu(self.fc1(x)) x = F.dropout(x, training=self.training) return F.log_softmax(self.fc2(x), dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(epoch): model.train() if epoch == 16: for param_group in optimizer.param_groups: param_group['lr'] = 0.001 if epoch == 26: for param_group in optimizer.param_groups: param_group['lr'] = 0.0001 for data in train_loader: data = data.to(device) optimizer.zero_grad() F.nll_loss(model(data), data.y).backward() optimizer.step() def test(): model.eval() correct = 0 for data in test_loader: data = data.to(device) pred = model(data).max(1)[1] correct += pred.eq(data.y).sum().item() return correct / len(test_dataset) for epoch in range(1, 31): train(epoch) test_acc = test() print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}') ================================================ FILE: examples/mnist_nn_conv.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear, ReLU, Sequential import torch_geometric.transforms as T from torch_geometric.datasets import MNISTSuperpixels from torch_geometric.loader import DataLoader from torch_geometric.nn import ( NNConv, global_mean_pool, graclus, max_pool, max_pool_x, ) from torch_geometric.typing import WITH_TORCH_CLUSTER from torch_geometric.utils import normalized_cut if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') transform = T.Cartesian(cat=False) train_dataset = MNISTSuperpixels(path, True, transform=transform) test_dataset = MNISTSuperpixels(path, False, transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) d = train_dataset def normalized_cut_2d(edge_index, pos): row, col = edge_index edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) class Net(torch.nn.Module): def __init__(self): super().__init__() nn1 = Sequential( Linear(2, 25), ReLU(), Linear(25, d.num_features * 32), ) self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean') nn2 = Sequential( Linear(2, 25), ReLU(), Linear(25, 32 * 64), ) self.conv2 = NNConv(32, 64, nn2, aggr='mean') self.fc1 = torch.nn.Linear(64, 128) self.fc2 = torch.nn.Linear(128, d.num_classes) def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = max_pool(cluster, data, transform=transform) data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = max_pool_x(cluster, data.x, data.batch) x = global_mean_pool(x, batch) x = F.elu(self.fc1(x)) x = F.dropout(x, training=self.training) return F.log_softmax(self.fc2(x), dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(epoch): model.train() if epoch == 16: for param_group in optimizer.param_groups: param_group['lr'] = 0.001 if epoch == 26: for param_group in optimizer.param_groups: param_group['lr'] = 0.0001 for data in train_loader: data = data.to(device) optimizer.zero_grad() F.nll_loss(model(data), data.y).backward() optimizer.step() def test(): model.eval() correct = 0 for data in test_loader: data = data.to(device) pred = model(data).max(1)[1] correct += pred.eq(data.y).sum().item() return correct / len(test_dataset) for epoch in range(1, 31): train(epoch) test_acc = test() print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}') ================================================ FILE: examples/mnist_voxel_grid.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import MNISTSuperpixels from torch_geometric.loader import DataLoader from torch_geometric.nn import SplineConv, max_pool, max_pool_x, voxel_grid from torch_geometric.typing import WITH_SPLINE if not WITH_SPLINE: quit("This example requires 'pyg-lib>=0.6.0'") path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') transform = T.Cartesian(cat=False) train_dataset = MNISTSuperpixels(path, True, transform=transform) test_dataset = MNISTSuperpixels(path, False, transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64) d = train_dataset class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SplineConv(d.num_features, 32, dim=2, kernel_size=5) self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5) self.conv3 = SplineConv(64, 64, dim=2, kernel_size=5) self.fc1 = torch.nn.Linear(4 * 64, 128) self.fc2 = torch.nn.Linear(128, d.num_classes) def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) cluster = voxel_grid(data.pos, batch=data.batch, size=5, start=0, end=28) data.edge_attr = None data = max_pool(cluster, data, transform=transform) data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) cluster = voxel_grid(data.pos, batch=data.batch, size=7, start=0, end=28) data.edge_attr = None data = max_pool(cluster, data, transform=transform) data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) cluster = voxel_grid(data.pos, batch=data.batch, size=14, start=0, end=27.99) x, _ = max_pool_x(cluster, data.x, data.batch, size=4) x = x.view(-1, self.fc1.weight.size(1)) x = F.elu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(epoch): model.train() if epoch == 6: for param_group in optimizer.param_groups: param_group['lr'] = 0.001 if epoch == 16: for param_group in optimizer.param_groups: param_group['lr'] = 0.0001 for data in train_loader: data = data.to(device) optimizer.zero_grad() F.nll_loss(model(data), data.y).backward() optimizer.step() def test(): model.eval() correct = 0 for data in test_loader: data = data.to(device) pred = model(data).max(1)[1] correct += pred.eq(data.y).sum().item() return correct / len(test_dataset) for epoch in range(1, 21): train(epoch) test_acc = test() print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}') ================================================ FILE: examples/multi_gpu/README.md ================================================ # Examples for Distributed Training ## Examples with NVIDIA GPUs ### Examples with cuGraph For the best performance with NVIDIA GPUs, we recommend using **cuGraph**. Refer to [our installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html#accelerating-pyg-with-nvidia-cugraph-gnn) for setup instructions and to the [cuGraph-PyG examples](https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples) for ready-to-run training scripts covering single-node, multi-node, and link-prediction workloads. ### Examples with Pure PyTorch | Example | Scalability | Description | | ---------------------------------------------------------------------------------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------ | | [`distributed_batching.py`](./distributed_batching.py) | single-node | Graph-level prediction on many small graphs (ogbg-molhiv) using `DataLoader` with `DistributedSampler`. | | [`distributed_sampling.py`](./distributed_sampling.py) | single-node | Node-level classification on a single large graph (Reddit) using `NeighborLoader` for multi-hop subgraph sampling. | | [`distributed_sampling_multinode.py`](./distributed_sampling_multinode.py) | multi-node | Training GNNs on a homogeneous graph with neighbor sampling on multiple nodes. | | [`distributed_sampling_multinode.sbatch`](./distributed_sampling_multinode.sbatch) | multi-node | Submitting a training job to a Slurm cluster using [`distributed_sampling_multinode.py`](./distributed_sampling_multinode.py). | | [`papers100m_gcn.py`](./papers100m_gcn.py) | single-node | Training GNNs on the `ogbn-papers100M` homogeneous graph w/ ~1.6B edges. | | [`papers100m_gcn_multinode.py`](./papers100m_gcn_multinode.py) | multi-node | Training GNNs on a homogeneous graph on multiple nodes. | | [`pcqm4m_ogb.py`](./pcqm4m_ogb.py) | single-node | Training GNNs for a graph-level regression task. | | [`mag240m_graphsage.py`](./mag240m_graphsage.py) | single-node | Training GNNs on a large heterogeneous graph. | | [`taobao.py`](./taobao.py) | single-node | Training link prediction GNNs on a heterogeneous graph. | | [`model_parallel.py`](./model_parallel.py) | single-node | Model parallelism by manually placing layers on each GPU. | ## Examples with Intel GPUs (XPUs) | Example | Scalability | Description | | -------------------------------------------------------------- | ---------------------- | ------------------------------------------------------------ | | [`distributed_sampling_xpu.py`](./distributed_sampling_xpu.py) | single-node, multi-gpu | Training GNNs on a homogeneous graph with neighbor sampling. | ================================================ FILE: examples/multi_gpu/distributed_batching.py ================================================ import os import os.path as osp import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F from ogb.graphproppred import Evaluator, PygGraphPropPredDataset from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from torch.nn import BatchNorm1d as BatchNorm from torch.nn import Linear, ReLU, Sequential from torch.nn.parallel import DistributedDataParallel from torch.utils.data.distributed import DistributedSampler from torch_sparse import SparseTensor import torch_geometric.transforms as T from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv, global_mean_pool class GIN(torch.nn.Module): def __init__( self, hidden_channels: int, out_channels: int, num_layers: int = 3, dropout: float = 0.5, ) -> None: super().__init__() self.dropout = dropout self.atom_encoder = AtomEncoder(hidden_channels) self.bond_encoder = BondEncoder(hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), BatchNorm(hidden_channels), ReLU(), ) self.convs.append(GINEConv(nn, train_eps=True)) self.lin = Linear(hidden_channels, out_channels) def forward( self, x: torch.Tensor, adj_t: SparseTensor, batch: torch.Tensor, ) -> torch.Tensor: x = self.atom_encoder(x) edge_attr = adj_t.coo()[2] adj_t = adj_t.set_value(self.bond_encoder(edge_attr), layout='coo') for conv in self.convs: x = conv(x, adj_t) x = F.dropout(x, p=self.dropout, training=self.training) x = global_mean_pool(x, batch) x = self.lin(x) return x def run(rank: int, world_size: int, dataset_name: str, root: str) -> None: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) dataset = PygGraphPropPredDataset( dataset_name, root=root, pre_transform=T.ToSparseTensor(attr='edge_attr'), ) split_idx = dataset.get_idx_split() evaluator = Evaluator(dataset_name) train_dataset = dataset[split_idx['train']] train_loader = DataLoader( train_dataset, batch_size=128, sampler=DistributedSampler( train_dataset, shuffle=True, drop_last=True, ), ) torch.manual_seed(12345) model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank) model = DistributedDataParallel(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = torch.nn.BCEWithLogitsLoss() if rank == 0: val_loader = DataLoader(dataset[split_idx['valid']], batch_size=256) test_loader = DataLoader(dataset[split_idx['test']], batch_size=256) for epoch in range(1, 51): model.train() train_loader.sampler.set_epoch(epoch) total_loss = torch.zeros(2, device=rank) for data in train_loader: data = data.to(rank) logits = model(data.x, data.adj_t, data.batch) loss = criterion(logits, data.y.to(torch.float)) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): total_loss[0] += loss * logits.size(0) total_loss[1] += data.num_graphs dist.all_reduce(total_loss, op=dist.ReduceOp.AVG) train_loss = total_loss[0] / total_loss[1] if rank == 0: # We evaluate on a single GPU for now. model.eval() y_pred, y_true = [], [] for data in val_loader: data = data.to(rank) with torch.no_grad(): y_pred.append(model.module(data.x, data.adj_t, data.batch)) y_true.append(data.y) val_rocauc = evaluator.eval({ 'y_pred': torch.cat(y_pred, dim=0), 'y_true': torch.cat(y_true, dim=0), })['rocauc'] y_pred, y_true = [], [] for data in test_loader: data = data.to(rank) with torch.no_grad(): y_pred.append(model.module(data.x, data.adj_t, data.batch)) y_true.append(data.y) test_rocauc = evaluator.eval({ 'y_pred': torch.cat(y_pred, dim=0), 'y_true': torch.cat(y_true, dim=0), })['rocauc'] print(f'Epoch: {epoch:03d}, ' f'Loss: {train_loss:.4f}, ' f'Val: {val_rocauc:.4f}, ' f'Test: {test_rocauc:.4f}') dist.barrier() dist.destroy_process_group() if __name__ == '__main__': dataset_name = 'ogbg-molhiv' root = osp.join( osp.dirname(__file__), '..', '..', 'data', 'OGB', ) # Download and process the dataset on main process. PygGraphPropPredDataset( dataset_name, root, pre_transform=T.ToSparseTensor(attr='edge_attr'), ) world_size = torch.cuda.device_count() print('Let\'s use', world_size, 'GPUs!') args = (world_size, dataset_name, root) mp.spawn(run, args=args, nprocs=world_size, join=True) ================================================ FILE: examples/multi_gpu/distributed_sampling.py ================================================ import os import os.path as osp from math import ceil import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F from torch import Tensor from torch.nn.parallel import DistributedDataParallel from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader from torch_geometric.nn import SAGEConv class SAGE(torch.nn.Module): def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int = 2, ) -> None: super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) for _ in range(num_layers - 2): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.convs.append(SAGEConv(hidden_channels, out_channels)) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for i, conv in enumerate(self.convs): x = conv(x, edge_index) if i < len(self.convs) - 1: x = x.relu() x = F.dropout(x, p=0.5, training=self.training) return x @torch.no_grad() def test( loader: NeighborLoader, model: DistributedDataParallel, rank: int, ) -> Tensor: model.eval() total_correct = torch.tensor(0, dtype=torch.long, device=rank) total_examples = 0 for batch in loader: out = model(batch.x, batch.edge_index.to(rank)) pred = out[:batch.batch_size].argmax(dim=-1) y = batch.y[:batch.batch_size].to(rank) total_correct += (pred == y).sum() total_examples += batch.batch_size return total_correct / total_examples def run(rank: int, world_size: int, dataset: Reddit) -> None: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) data = dataset[0] data = data.to(rank, 'x', 'y') # Move to device for faster feature fetch. # Split indices into `world_size` many chunks: train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) train_idx = train_idx.split(ceil(train_idx.size(0) / world_size))[rank] val_idx = data.val_mask.nonzero(as_tuple=False).view(-1) val_idx = val_idx.split(ceil(val_idx.size(0) / world_size))[rank] test_idx = data.val_mask.nonzero(as_tuple=False).view(-1) test_idx = test_idx.split(ceil(test_idx.size(0) / world_size))[rank] kwargs = dict( data=data, batch_size=1024, num_neighbors=[25, 10], drop_last=True, num_workers=4, persistent_workers=True, ) train_loader = NeighborLoader( input_nodes=train_idx, shuffle=True, **kwargs, ) val_loader = NeighborLoader( input_nodes=val_idx, shuffle=False, **kwargs, ) test_loader = NeighborLoader( input_nodes=test_idx, shuffle=False, **kwargs, ) torch.manual_seed(12345) model = SAGE(dataset.num_features, 256, dataset.num_classes).to(rank) model = DistributedDataParallel(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(1, 21): model.train() for batch in tqdm( train_loader, desc=f'Epoch {epoch:02d}', disable=rank != 0, ): out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size] loss = F.cross_entropy(out, batch.y[:batch.batch_size]) loss.backward() optimizer.step() optimizer.zero_grad() if rank == 0: print(f'Epoch {epoch:02d}: Train loss: {loss:.4f}') if epoch % 5 == 0: train_acc = test(train_loader, model, rank) val_acc = test(val_loader, model, rank) test_acc = test(test_loader, model, rank) if world_size > 1: dist.all_reduce(train_acc, op=dist.ReduceOp.AVG) dist.all_reduce(val_acc, op=dist.ReduceOp.AVG) dist.all_reduce(test_acc, op=dist.ReduceOp.AVG) if rank == 0: print(f'Train acc: {train_acc:.4f}, ' f'Val acc: {val_acc:.4f}, ' f'Test acc: {test_acc:.4f}') dist.destroy_process_group() if __name__ == '__main__': path = osp.join( osp.dirname(__file__), '..', '..', 'data', 'Reddit', ) dataset = Reddit(path) world_size = torch.cuda.device_count() print("Let's use", world_size, "GPUs!") mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True) ================================================ FILE: examples/multi_gpu/distributed_sampling_multinode.py ================================================ import copy import os from math import ceil import torch import torch.distributed as dist import torch.nn.functional as F from torch import Tensor from torch.nn.parallel import DistributedDataParallel from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader from torch_geometric.nn import SAGEConv class SAGE(torch.nn.Module): def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int = 2, ): super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) for _ in range(num_layers - 2): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.convs.append(SAGEConv(hidden_channels, out_channels)) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for i, conv in enumerate(self.convs): x = conv(x, edge_index) if i < len(self.convs) - 1: x = x.relu_() x = F.dropout(x, p=0.5, training=self.training) return x @torch.no_grad() def inference( self, x_all: Tensor, device: torch.device, subgraph_loader: NeighborLoader, ) -> Tensor: pbar = tqdm(total=len(subgraph_loader) * len(self.convs)) pbar.set_description('Evaluating') # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch: for i, conv in enumerate(self.convs): xs = [] for batch in subgraph_loader: x = x_all[batch.node_id.to(x_all.device)].to(device) x = conv(x, batch.edge_index.to(device)) x = x[:batch.batch_size] if i < len(self.convs) - 1: x = x.relu_() xs.append(x.cpu()) pbar.update(1) x_all = torch.cat(xs, dim=0) pbar.close() return x_all def run(world_size: int, rank: int, local_rank: int): # Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`. # Make sure, those are set! dist.init_process_group('nccl', world_size=world_size, rank=rank) # Download and unzip only with one process ... if rank == 0: dataset = Reddit('data/Reddit') dist.barrier() # ... and then read from all the other processes: if rank != 0: dataset = Reddit('data/Reddit') dist.barrier() data = dataset[0] # Move to device for faster feature fetch. data = data.to(local_rank, 'x', 'y') # Split training indices into `world_size` many chunks: train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) train_idx = train_idx.split(ceil(train_idx.size(0) / world_size))[rank] kwargs = dict(batch_size=1024, num_workers=4, persistent_workers=True) train_loader = NeighborLoader( data, input_nodes=train_idx, num_neighbors=[25, 10], shuffle=True, drop_last=True, **kwargs, ) if rank == 0: # Create single-hop evaluation neighbor loader: subgraph_loader = NeighborLoader( copy.copy(data), num_neighbors=[-1], shuffle=False, **kwargs, ) # No need to maintain these features during evaluation: del subgraph_loader.data.x, subgraph_loader.data.y # Add global node index information: subgraph_loader.data.node_id = torch.arange(data.num_nodes) torch.manual_seed(12345) model = SAGE(dataset.num_features, 256, dataset.num_classes).to(local_rank) model = DistributedDataParallel(model, device_ids=[local_rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(1, 21): model.train() for batch in train_loader: optimizer.zero_grad() out = model(batch.x, batch.edge_index.to(local_rank))[:batch.batch_size] loss = F.cross_entropy(out, batch.y[:batch.batch_size]) loss.backward() optimizer.step() dist.barrier() if rank == 0: print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') if rank == 0 and epoch % 5 == 0: # We evaluate on a single GPU for now model.eval() with torch.no_grad(): out = model.module.inference( data.x, local_rank, subgraph_loader, ) res = out.argmax(dim=-1) == data.y.to(out.device) acc1 = int(res[data.train_mask].sum()) / int(data.train_mask.sum()) acc2 = int(res[data.val_mask].sum()) / int(data.val_mask.sum()) acc3 = int(res[data.test_mask].sum()) / int(data.test_mask.sum()) print(f'Train: {acc1:.4f}, Val: {acc2:.4f}, Test: {acc3:.4f}') dist.barrier() dist.destroy_process_group() if __name__ == '__main__': # Get the world size from the WORLD_SIZE variable or directly from SLURM: world_size = int( os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS'))) # Likewise for RANK and LOCAL_RANK: rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID'))) local_rank = int( os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID'))) run(world_size, rank, local_rank) ================================================ FILE: examples/multi_gpu/distributed_sampling_multinode.sbatch ================================================ #!/bin/bash #SBATCH --job-name=pyg-multinode-tutorial # identifier for the job listings #SBATCH --output=pyg-multinode.log # outputfile #SBATCH --partition=gpucloud # ADJUST this to your system #SBATCH -N 2 # number of nodes you want to use #SBATCH --ntasks=4 # number of processes to be run #SBATCH --gpus-per-task=1 # every process wants one GPU! #SBATCH --gpu-bind=none # NCCL can't deal with task-binding... ## Now you can add more stuff for your convenience #SBATCH --cpus-per-task=8 # make sure more cpu-cores are available to each process to spawn workers (default=1 and this is a hard limit) #SBATCH --mem=100G # total number of memory available per node (tensorflow need(ed) at least per GPU) #SBATCH --export=ALL # use your shell environment (PATHs, ...) # Thanks for shell-ideas to https://github.com/PrincetonUniversity/multi_gpu_training export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) echo "MASTER_ADDR:MASTER_PORT="${MASTER_ADDR}:${MASTER_PORT} echo "###########################################################################" echo "We recommend you set up your environment here (conda/spack/pip/modulefiles)" echo "then remove --export=ALL (allows running the sbatch from any shell" echo "###########################################################################" # use --output=0 so that only the first task logs to the file! srun --output=0 python distributed_sampling_multinode.py ================================================ FILE: examples/multi_gpu/distributed_sampling_xpu.py ================================================ """Distributed GAT training, targeting XPU devices. PVC has 2 tiles, each reports itself as a separate device. DDP approach allows us to employ both tiles. Additional requirements: IPEX (intel_extension_for_pytorch) oneCCL (oneccl_bindings_for_pytorch) We need to import both these modules, as they extend torch module with XPU/oneCCL related functionality. Run with: mpirun -np 2 python distributed_sampling_xpu.py """ import copy import os import os.path as osp from typing import Any, Tuple, Union import intel_extension_for_pytorch # noqa import oneccl_bindings_for_pytorch # noqa import torch import torch.distributed as dist import torch.nn.functional as F from ogb.nodeproppred import Evaluator, PygNodePropPredDataset from torch import Tensor from torch.nn import Linear as Lin from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm from torch_geometric.loader import NeighborLoader from torch_geometric.nn import GATConv from torch_geometric.profile import get_stats_summary, profileit class GAT(torch.nn.Module): def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, heads: int, ): super().__init__() self.num_layers = num_layers self.convs = torch.nn.ModuleList() self.convs.append(GATConv(dataset.num_features, hidden_channels, heads)) for _ in range(num_layers - 2): self.convs.append( GATConv(heads * hidden_channels, hidden_channels, heads)) self.convs.append( GATConv(heads * hidden_channels, out_channels, heads, concat=False)) self.skips = torch.nn.ModuleList() self.skips.append(Lin(dataset.num_features, hidden_channels * heads)) for _ in range(num_layers - 2): self.skips.append( Lin(hidden_channels * heads, hidden_channels * heads)) self.skips.append(Lin(hidden_channels * heads, out_channels)) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for i, (conv, skip) in enumerate(zip(self.convs, self.skips)): x = conv(x, edge_index) + skip(x) if i != self.num_layers - 1: x = F.elu(x) x = F.dropout(x, p=0.5, training=self.training) return x def inference( self, x_all: Tensor, device: Union[str, torch.device], subgraph_loader: NeighborLoader, ) -> Tensor: pbar = tqdm(total=x_all.size(0) * self.num_layers) pbar.set_description("Evaluating") # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch. for i in range(self.num_layers): xs = [] for batch in subgraph_loader: x = x_all[batch.n_id].to(device) edge_index = batch.edge_index.to(device) x = self.convs[i](x, edge_index) + self.skips[i](x) x = x[:batch.batch_size] if i != self.num_layers - 1: x = F.elu(x) xs.append(x.cpu()) pbar.update(batch.batch_size) x_all = torch.cat(xs, dim=0) pbar.close() return x_all @profileit('xpu') def train_step(model: Any, optimizer: Any, x: Tensor, edge_index: Tensor, y: Tensor, bs: int) -> float: optimizer.zero_grad() out = model(x, edge_index)[:bs] loss = F.cross_entropy(out, y[:bs].squeeze()) loss.backward() optimizer.step() return float(loss) def run(rank: int, world_size: int, dataset: PygNodePropPredDataset): device = f"xpu:{rank}" split_idx = dataset.get_idx_split() split_idx["train"] = (split_idx["train"].split( split_idx["train"].size(0) // world_size, dim=0)[rank].clone()) data = dataset[0].to(device, "x", "y") kwargs = dict(batch_size=1024, num_workers=0, pin_memory=True) train_loader = NeighborLoader(data, input_nodes=split_idx["train"], num_neighbors=[10, 10, 5], **kwargs) if rank == 0: subgraph_loader = NeighborLoader(copy.copy(data), num_neighbors=[-1], **kwargs) evaluator = Evaluator(name="ogbn-products") torch.manual_seed(12345) model = GAT(dataset.num_features, 128, dataset.num_classes, num_layers=3, heads=4).to(device) model = DDP(model, device_ids=[device]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(1, 21): epoch_stats = [] model.train() for batch in train_loader: batch = batch.to(device) loss, stats = train_step(model, optimizer, batch.x, batch.edge_index, batch.y, batch.batch_size) epoch_stats.append(stats) dist.barrier() if rank == 0: print(f"Epoch: {epoch:02d}, Loss: {loss:.4f}") print(f"Epoch: {epoch:02d}, Rank: {rank}, " f"Stats: {get_stats_summary(epoch_stats)}") if rank == 0 and epoch % 5 == 0: # Evaluation on a single GPU model.eval() with torch.no_grad(): out = model.module.inference(data.x, device, subgraph_loader) y_true = data.y.to(out.device) y_pred = out.argmax(dim=-1, keepdim=True) train_acc = evaluator.eval({ "y_true": y_true[split_idx["train"]], "y_pred": y_pred[split_idx["train"]], })["acc"] val_acc = evaluator.eval({ "y_true": y_true[split_idx["valid"]], "y_pred": y_pred[split_idx["valid"]], })["acc"] test_acc = evaluator.eval({ "y_true": y_true[split_idx["test"]], "y_pred": y_pred[split_idx["test"]], })["acc"] print(f"Train: {train_acc:.4f}, Val: {val_acc:.4f}, " f"Test: {test_acc:.4f}") dist.barrier() dist.destroy_process_group() def get_dist_params() -> Tuple[int, int, str]: master_addr = "127.0.0.1" master_port = "29500" os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port mpi_rank = int(os.environ.get("PMI_RANK", -1)) mpi_world_size = int(os.environ.get("PMI_SIZE", -1)) rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0) world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get( "WORLD_SIZE", 1)) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) init_method = f"tcp://{master_addr}:{master_port}" return rank, world_size, init_method if __name__ == "__main__": rank, world_size, init_method = get_dist_params() dist.init_process_group(backend="ccl", init_method=init_method, world_size=world_size, rank=rank) path = osp.join(osp.dirname(osp.realpath(__file__)), "../../data", "ogbn-products") dataset = PygNodePropPredDataset("ogbn-products", path) run(rank, world_size, dataset) ================================================ FILE: examples/multi_gpu/mag240m_graphsage.py ================================================ import argparse import os import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F from ogb.lsc import MAG240MDataset from torch.nn.parallel import DistributedDataParallel from torchmetrics import Accuracy from tqdm import tqdm from torch_geometric.loader import NeighborLoader from torch_geometric.nn import BatchNorm, HeteroConv, SAGEConv def common_step(batch, model): batch_size = batch['paper'].batch_size x_dict = model(batch.x_dict, batch.edge_index_dict) y_hat = x_dict['paper'][:batch_size] y = batch['paper'].y[:batch_size].to(torch.long) return y_hat, y def training_step(batch, acc, model): y_hat, y = common_step(batch, model) train_loss = F.cross_entropy(y_hat, y) acc(y_hat, y) return train_loss def validation_step(batch, acc, model): y_hat, y = common_step(batch, model) acc(y_hat, y) class HeteroSAGEConv(torch.nn.Module): def __init__(self, in_channels, out_channels, dropout, node_types, edge_types, is_output_layer=False): super().__init__() self.conv = HeteroConv({ edge_type: SAGEConv(in_channels, out_channels) for edge_type in edge_types }) if not is_output_layer: self.dropout = torch.nn.Dropout(dropout) self.norm_dict = torch.nn.ModuleDict({ node_type: BatchNorm(out_channels) for node_type in node_types }) self.is_output_layer = is_output_layer def forward(self, x_dict, edge_index_dict): x_dict = self.conv(x_dict, edge_index_dict) if not self.is_output_layer: for node_type, x in x_dict.items(): x = self.dropout(x.relu()) x = self.norm_dict[node_type](x) x_dict[node_type] = x return x_dict class HeteroGraphSAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, num_layers, out_channels, dropout, node_types, edge_types): super().__init__() self.convs = torch.nn.ModuleList() for i in range(num_layers): # Since authors and institution do not come with features, we learn # them via the GNN. However, this also means we need to exclude # them as source types in the first two iterations: if i == 0: edge_types_of_layer = [ edge_type for edge_type in edge_types if edge_type[0] == 'paper' ] elif i == 1: edge_types_of_layer = [ edge_type for edge_type in edge_types if edge_type[0] != 'institution' ] else: edge_types_of_layer = edge_types conv = HeteroSAGEConv( in_channels if i == 0 else hidden_channels, out_channels if i == num_layers - 1 else hidden_channels, dropout=dropout, node_types=node_types, edge_types=edge_types_of_layer, is_output_layer=i == num_layers - 1, ) self.convs.append(conv) def forward(self, x_dict, edge_index_dict): for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) return x_dict def run( rank, data, num_devices, num_epochs, num_steps_per_epoch, log_every_n_steps, batch_size, num_neighbors, hidden_channels, dropout, num_val_steps, lr, ): if num_devices > 1: if rank == 0: print("Setting up distributed...") os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=num_devices) acc = Accuracy(task='multiclass', num_classes=data.num_classes) model = HeteroGraphSAGE( in_channels=-1, hidden_channels=hidden_channels, num_layers=len(num_neighbors), out_channels=data.num_classes, dropout=dropout, node_types=data.node_types, edge_types=data.edge_types, ) train_idx = data['paper'].train_mask.nonzero(as_tuple=False).view(-1) val_idx = data['paper'].val_mask.nonzero(as_tuple=False).view(-1) if num_devices > 1: # Split indices into `num_devices` many chunks: train_idx = train_idx.split(train_idx.size(0) // num_devices)[rank] val_idx = val_idx.split(val_idx.size(0) // num_devices)[rank] # Delete unused tensors to not sample: del data['paper'].train_mask del data['paper'].val_mask del data['paper'].test_mask del data['paper'].year kwargs = dict( batch_size=batch_size, num_workers=16, persistent_workers=True, num_neighbors=num_neighbors, drop_last=True, ) train_loader = NeighborLoader( data, input_nodes=('paper', train_idx), shuffle=True, **kwargs, ) val_loader = NeighborLoader(data, input_nodes=('paper', val_idx), **kwargs) if num_devices > 0: model = model.to(rank) acc = acc.to(rank) if num_devices > 1: model = DistributedDataParallel(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters(), lr=lr) for epoch in range(1, num_epochs + 1): model.train() for i, batch in enumerate(tqdm(train_loader)): if num_steps_per_epoch >= 0 and i >= num_steps_per_epoch: break if num_devices > 0: batch = batch.to(rank, 'x', 'y', 'edge_index') # Features loaded in as 16 bits, train in 32 bits: batch['paper'].x = batch['paper'].x.to(torch.float32) optimizer.zero_grad() loss = training_step(batch, acc, model) loss.backward() optimizer.step() if i % log_every_n_steps == 0: if rank == 0: print(f"Epoch: {epoch:02d}, Step: {i:d}, " f"Loss: {loss:.4f}, " f"Train Acc: {acc.compute():.4f}") if num_devices > 1: dist.barrier() if rank == 0: print(f"Epoch: {epoch:02d}, Loss: {loss:.4f}, " f"Train Acc :{acc.compute():.4f}") acc.reset() model.eval() with torch.no_grad(): for i, batch in enumerate(tqdm(val_loader)): if num_val_steps >= 0 and i >= num_val_steps: break if num_devices > 0: batch = batch.to(rank, 'x', 'y', 'edge_index') batch['paper'].x = batch['paper'].x.to(torch.float32) validation_step(batch, acc, model) if num_devices > 1: dist.barrier() if rank == 0: print(f"Val Acc: {acc.compute():.4f}") acc.reset() model.eval() if num_devices > 1: dist.destroy_process_group() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--hidden_channels", type=int, default=1024) parser.add_argument("--batch_size", type=int, default=1024) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--num_epochs", type=int, default=20) parser.add_argument("--num_steps_per_epoch", type=int, default=-1) parser.add_argument("--log_every_n_steps", type=int, default=100) parser.add_argument("--num_val_steps", type=int, default=-1, help=50) parser.add_argument("--num_neighbors", type=str, default="25-15") parser.add_argument("--num_devices", type=int, default=1) args = parser.parse_args() args.num_neighbors = [int(i) for i in args.num_neighbors.split('-')] import warnings warnings.simplefilter("ignore") if not torch.cuda.is_available(): args.num_devices = 0 elif args.num_devices > torch.cuda.device_count(): args.num_devices = torch.cuda.device_count() dataset = MAG240MDataset() data = dataset.to_pyg_hetero_data() if args.num_devices > 1: print("Let's use", args.num_devices, "GPUs!") from torch.multiprocessing.spawn import ProcessExitedException try: mp.spawn( run, args=( data, args.num_devices, args.num_epochs, args.num_steps_per_epoch, args.log_every_n_steps, args.batch_size, args.num_neighbors, args.hidden_channels, args.dropout, args.num_val_steps, args.lr, ), nprocs=args.num_devices, join=True, ) except ProcessExitedException as e: print("torch.multiprocessing.spawn.ProcessExitedException:", e) print("Exceptions/SIGBUS/Errors may be caused by a lack of RAM") else: run( 0, data, args.num_devices, args.num_epochs, args.num_steps_per_epoch, args.log_every_n_steps, args.batch_size, args.num_neighbors, args.hidden_channels, args.dropout, args.num_val_steps, args.lr, ) ================================================ FILE: examples/multi_gpu/model_parallel.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv from torch_geometric.transforms import NormalizeFeatures if torch.cuda.device_count() < 2: quit('This example requires multiple GPUs') path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', '..', 'data', 'Planetoid') dataset = Planetoid(root=path, name='Cora', transform=NormalizeFeatures()) data = dataset[0].to('cuda:0') class GCN(torch.nn.Module): def __init__(self, in_channels, out_channels, device1, device2): super().__init__() self.device1 = device1 self.device2 = device2 self.conv1 = GCNConv(in_channels, 16).to(device1) self.conv2 = GCNConv(16, out_channels).to(device2) def forward(self, x, edge_index): x = F.dropout(x, p=0.5, training=self.training) x = self.conv1(x, edge_index).relu() # Move data to the second device: x, edge_index = x.to(self.device2), edge_index.to(self.device2) x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return x model = GCN( dataset.num_features, dataset.num_classes, device1='cuda:0', device2='cuda:1', ) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index).to('cuda:0') loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() out = model(data.x, data.edge_index) pred = out.argmax(dim=-1).to('cuda:0') accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs best_val_acc = test_acc = 0 times = [] for epoch in range(1, 201): loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/multi_gpu/papers100m_gcn.py ================================================ import argparse import os import tempfile import time import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F from ogb.nodeproppred import PygNodePropPredDataset from torch.nn.parallel import DistributedDataParallel from torchmetrics import Accuracy import torch_geometric from torch_geometric.loader import NeighborLoader def get_num_workers(world_size): num_work = None if hasattr(os, "sched_getaffinity"): try: num_work = len(os.sched_getaffinity(0)) / (2 * world_size) except Exception: pass if num_work is None: num_work = os.cpu_count() / (2 * world_size) return int(num_work) def run_train(rank, data, world_size, model, epochs, batch_size, fan_out, split_idx, num_classes, wall_clock_start, tempdir=None, num_layers=3): # init pytorch worker os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) if world_size > 1: split_idx['train'] = split_idx['train'].split( split_idx['train'].size(0) // world_size, dim=0)[rank].clone() split_idx['valid'] = split_idx['valid'].split( split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() split_idx['test'] = split_idx['test'].split( split_idx['test'].size(0) // world_size, dim=0)[rank].clone() model = model.to(rank) model = DistributedDataParallel(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) kwargs = dict( num_neighbors=[fan_out] * num_layers, batch_size=batch_size, ) num_work = get_num_workers(world_size) train_loader = NeighborLoader(data, input_nodes=split_idx['train'], num_workers=num_work, shuffle=True, drop_last=True, **kwargs) val_loader = NeighborLoader(data, input_nodes=split_idx['valid'], num_workers=num_work, **kwargs) test_loader = NeighborLoader(data, input_nodes=split_idx['test'], num_workers=num_work, **kwargs) eval_steps = 1000 warmup_steps = 20 acc = Accuracy(task="multiclass", num_classes=num_classes).to(rank) dist.barrier() torch.cuda.synchronize() if rank == 0: prep_time = round(time.perf_counter() - wall_clock_start, 2) print("Total time before training begins (prep_time) =", prep_time, "seconds") print("Beginning training...") for epoch in range(epochs): for i, batch in enumerate(train_loader): if i == warmup_steps: torch.cuda.synchronize() start = time.time() batch = batch.to(rank) batch_size = batch.num_sampled_nodes[0] batch.y = batch.y.to(torch.long) optimizer.zero_grad() out = model(batch.x, batch.edge_index) loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size]) loss.backward() optimizer.step() if rank == 0 and i % 10 == 0: print("Epoch: " + str(epoch) + ", Iteration: " + str(i) + ", Loss: " + str(loss)) nb = i + 1.0 dist.barrier() torch.cuda.synchronize() if rank == 0: print("Average Training Iteration Time:", (time.time() - start) / (nb - warmup_steps), "s/iter") with torch.no_grad(): for i, batch in enumerate(val_loader): if i >= eval_steps: break batch = batch.to(rank) batch_size = batch.num_sampled_nodes[0] batch.y = batch.y.to(torch.long) out = model(batch.x, batch.edge_index) acc_i = acc( # noqa out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) acc_sum = acc.compute() if rank == 0: print(f"Validation Accuracy: {acc_sum * 100.0:.4f}%", ) dist.barrier() acc.reset() with torch.no_grad(): for batch in test_loader: batch = batch.to(rank) batch_size = batch.num_sampled_nodes[0] batch.y = batch.y.to(torch.long) out = model(batch.x, batch.edge_index) acc_i = acc( # noqa out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) acc_sum = acc.compute() if rank == 0: print(f"Test Accuracy: {acc_sum * 100.0:.4f}%", ) dist.barrier() acc.reset() if rank == 0: total_time = round(time.perf_counter() - wall_clock_start, 2) print("Total Program Runtime (total_time) =", total_time, "seconds") print("total_time - prep_time =", total_time - prep_time, "seconds") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--batch_size', type=int, default=1024) parser.add_argument('--fan_out', type=int, default=30) parser.add_argument( "--use_gat_conv", action='store_true', help="Whether or not to use GATConv. (Defaults to using GCNConv)", ) parser.add_argument( "--n_gat_conv_heads", type=int, default=4, help="If using GATConv, number of attention heads to use", ) parser.add_argument( "--n_devices", type=int, default=-1, help="1-8 to use that many GPUs. Defaults to all available GPUs") args = parser.parse_args() wall_clock_start = time.perf_counter() if args.n_devices == -1: world_size = torch.cuda.device_count() else: world_size = args.n_devices import psutil gb_ram_needed = 190 + 200 * world_size if (psutil.virtual_memory().total / (1024**3)) < gb_ram_needed: print("Warning: may not have enough RAM to use this many GPUs.") print("Consider upgrading RAM or using less GPUs if an error occurs.") print("Estimated RAM Needed: ~" + str(gb_ram_needed)) print('Let\'s use', world_size, 'GPUs!') dataset = PygNodePropPredDataset(name='ogbn-papers100M', root='/datasets/ogb_datasets') split_idx = dataset.get_idx_split() data = dataset[0] data.y = data.y.reshape(-1) if args.use_gat_conv: model = torch_geometric.nn.models.GAT(dataset.num_features, args.hidden_channels, args.num_layers, dataset.num_classes, heads=args.n_gat_conv_heads) else: model = torch_geometric.nn.models.GCN( dataset.num_features, args.hidden_channels, args.num_layers, dataset.num_classes, ) print("Data =", data) with tempfile.TemporaryDirectory() as tempdir: if world_size > 1: mp.spawn( run_train, args=(data, world_size, model, args.epochs, args.batch_size, args.fan_out, split_idx, dataset.num_classes, wall_clock_start, tempdir, args.num_layers), nprocs=world_size, join=True) else: run_train(0, data, world_size, model, args.epochs, args.batch_size, args.fan_out, split_idx, dataset.num_classes, wall_clock_start, tempdir, args.num_layers) ================================================ FILE: examples/multi_gpu/papers100m_gcn_multinode.py ================================================ """Multi-node multi-GPU example on ogbn-papers100m. Example way to run using srun: srun -l -N --ntasks-per-node= \ --container-name=cont --container-image= \ --container-mounts=/ogb-papers100m/:/workspace/dataset python3 path_to_script.py """ import os import time from typing import Optional import torch import torch.distributed as dist import torch.nn.functional as F from ogb.nodeproppred import PygNodePropPredDataset from torch.nn.parallel import DistributedDataParallel from torchmetrics import Accuracy from torch_geometric.loader import NeighborLoader from torch_geometric.nn import GCN def get_num_workers() -> int: num_workers = None if hasattr(os, "sched_getaffinity"): try: num_workers = len(os.sched_getaffinity(0)) // 2 except Exception: pass if num_workers is None: num_workers = os.cpu_count() // 2 return num_workers def run(world_size, data, split_idx, model, acc, wall_clock_start): local_id = int(os.environ['LOCAL_RANK']) rank = torch.distributed.get_rank() torch.cuda.set_device(local_id) device = torch.device(local_id) if rank == 0: print(f'Using {nprocs} GPUs...') split_idx['train'] = split_idx['train'].split( split_idx['train'].size(0) // world_size, dim=0)[rank].clone() split_idx['valid'] = split_idx['valid'].split( split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() split_idx['test'] = split_idx['test'].split( split_idx['test'].size(0) // world_size, dim=0)[rank].clone() model = DistributedDataParallel(model.to(device), device_ids=[local_id]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) kwargs = dict( data=data, batch_size=1024, num_workers=get_num_workers(), num_neighbors=[30, 30], ) train_loader = NeighborLoader( input_nodes=split_idx['train'], shuffle=True, drop_last=True, **kwargs, ) val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) val_steps = 1000 warmup_steps = 100 acc = acc.to(device) dist.barrier() torch.cuda.synchronize() if rank == 0: prep_time = round(time.perf_counter() - wall_clock_start, 2) print("Total time before training begins (prep_time)=", prep_time, "seconds") print("Beginning training...") for epoch in range(1, 21): model.train() for i, batch in enumerate(train_loader): if i == warmup_steps: torch.cuda.synchronize() start = time.time() batch = batch.to(device) optimizer.zero_grad() y = batch.y[:batch.batch_size].view(-1).to(torch.long) out = model(batch.x, batch.edge_index)[:batch.batch_size] loss = F.cross_entropy(out, y) loss.backward() optimizer.step() if rank == 0 and i % 10 == 0: print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') dist.barrier() torch.cuda.synchronize() if rank == 0: sec_per_iter = (time.time() - start) / (i + 1 - warmup_steps) print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter") @torch.no_grad() def test(loader: NeighborLoader, num_steps: Optional[int] = None): model.eval() for j, batch in enumerate(loader): if num_steps is not None and j >= num_steps: break batch = batch.to(device) out = model(batch.x, batch.edge_index)[:batch.batch_size] y = batch.y[:batch.batch_size].view(-1).to(torch.long) acc(out, y) acc_sum = acc.compute() return acc_sum eval_acc = test(val_loader, num_steps=val_steps) if rank == 0: print(f"Val Accuracy: {eval_acc:.4f}%", ) acc.reset() dist.barrier() test_acc = test(test_loader) if rank == 0: print(f"Test Accuracy: {test_acc:.4f}%", ) dist.barrier() acc.reset() torch.cuda.synchronize() if rank == 0: total_time = round(time.perf_counter() - wall_clock_start, 2) print("Total Program Runtime (total_time) =", total_time, "seconds") print("total_time - prep_time =", total_time - prep_time, "seconds") if __name__ == '__main__': wall_clock_start = time.perf_counter() # Setup multi-node: torch.distributed.init_process_group("nccl") nprocs = dist.get_world_size() assert dist.is_initialized(), "Distributed cluster not initialized" dataset = PygNodePropPredDataset(name='ogbn-papers100M') split_idx = dataset.get_idx_split() model = GCN(dataset.num_features, 256, 2, dataset.num_classes) acc = Accuracy(task="multiclass", num_classes=dataset.num_classes) data = dataset[0] data.y = data.y.reshape(-1) run(nprocs, data, split_idx, model, acc, wall_clock_start) ================================================ FILE: examples/multi_gpu/pcqm4m_ogb.py ================================================ # Code adapted from OGB. # https://github.com/snap-stanford/ogb/tree/master/examples/lsc/pcqm4m-v2 import argparse import math import os import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F import torch.optim as optim from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm from torch_geometric.data import Data from torch_geometric.datasets import PCQM4Mv2 from torch_geometric.io import fs from torch_geometric.loader import DataLoader from torch_geometric.nn import ( GlobalAttention, MessagePassing, Set2Set, global_add_pool, global_max_pool, global_mean_pool, ) from torch_geometric.utils import degree try: from ogb.lsc import PCQM4Mv2Evaluator, PygPCQM4Mv2Dataset except ImportError as e: raise ImportError( "`PygPCQM4Mv2Dataset` requires rdkit (`pip install rdkit`)") from e from ogb.utils import smiles2graph def ogb_from_smiles_wrapper(smiles, *args, **kwargs): """Returns `torch_geometric.data.Data` object from smiles while `ogb.utils.smiles2graph` returns a dict of np arrays. """ data_dict = smiles2graph(smiles, *args, **kwargs) return Data( x=torch.from_numpy(data_dict['node_feat']), edge_index=torch.from_numpy(data_dict['edge_index']), edge_attr=torch.from_numpy(data_dict['edge_feat']), smiles=smiles, ) class GINConv(MessagePassing): def __init__(self, emb_dim): r"""GINConv. Args: emb_dim (int): node embedding dimensionality """ super().__init__(aggr="add") self.mlp = torch.nn.Sequential( torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim), ) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim=emb_dim) def forward(self, x, edge_index, edge_attr): edge_embedding = self.bond_encoder(edge_attr) return self.mlp( (1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) def message(self, x_j, edge_attr): return F.relu(x_j + edge_attr) def update(self, aggr_out): return aggr_out class GCNConv(MessagePassing): def __init__(self, emb_dim): super().__init__(aggr='add') self.linear = torch.nn.Linear(emb_dim, emb_dim) self.root_emb = torch.nn.Embedding(1, emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim) def forward(self, x, edge_index, edge_attr): x = self.linear(x) edge_embedding = self.bond_encoder(edge_attr) row, col = edge_index deg = degree(row, x.size(0), dtype=x.dtype) + 1 deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] return self.propagate( edge_index, x=x, edge_attr=edge_embedding, norm=norm ) + F.relu(x + self.root_emb.weight) * 1. / deg.view(-1, 1) def message(self, x_j, edge_attr, norm): return norm.view(-1, 1) * F.relu(x_j + edge_attr) def update(self, aggr_out): return aggr_out class GNNNode(torch.nn.Module): def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'): r"""GNN Node. Args: emb_dim (int): node embedding dimensionality. num_layers (int): number of GNN message passing layers. residual (bool): whether to add residual connection. drop_ratio (float): dropout ratio. JK (str): "last" or "sum" to choose JK concat strat. residual (bool): Whether or not to add the residual gnn_type (str): Type of GNN to use. """ super().__init__() if num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.num_layers = num_layers self.drop_ratio = drop_ratio self.JK = JK self.residual = residual self.atom_encoder = AtomEncoder(emb_dim) self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for _ in range(num_layers): if gnn_type == 'gin': self.convs.append(GINConv(emb_dim)) elif gnn_type == 'gcn': self.convs.append(GCNConv(emb_dim)) else: raise ValueError(f'Undefined GNN type called {gnn_type}') self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) def forward(self, batched_data): x = batched_data.x edge_index = batched_data.edge_index edge_attr = batched_data.edge_attr # compute input node embedding h_list = [self.atom_encoder(x)] for layer in range(self.num_layers): h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layers - 1: # remove relu for the last layer h = F.dropout(h, self.drop_ratio, training=self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) if self.residual: h += h_list[layer] h_list.append(h) # Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layers + 1): node_representation += h_list[layer] return node_representation class GNNNodeVirtualNode(torch.nn.Module): """Outputs node representations.""" def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'): super().__init__() if num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.num_layers = num_layers self.drop_ratio = drop_ratio self.JK = JK self.residual = residual self.atom_encoder = AtomEncoder(emb_dim) # set the initial virtual node embedding to 0. self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() self.mlp_virtualnode_list = torch.nn.ModuleList() for _ in range(num_layers): if gnn_type == 'gin': self.convs.append(GINConv(emb_dim)) elif gnn_type == 'gcn': self.convs.append(GCNConv(emb_dim)) else: raise ValueError('Undefined GNN type called {gnn_type}') self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) for _ in range(num_layers - 1): self.mlp_virtualnode_list.append( torch.nn.Sequential( torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), )) def forward(self, batched_data): x = batched_data.x edge_index = batched_data.edge_index edge_attr = batched_data.edge_attr batch = batched_data.batch # virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding( torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to( edge_index.device)) h_list = [self.atom_encoder(x)] for layer in range(self.num_layers): # add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] # Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layers - 1: # remove relu for the last layer h = F.dropout(h, self.drop_ratio, training=self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) if self.residual: h = h + h_list[layer] h_list.append(h) # update the virtual nodes if layer < self.num_layers - 1: # add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool( h_list[layer], batch) + virtualnode_embedding # transform virtual nodes using MLP if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout( self.mlp_virtualnode_list[layer] (virtualnode_embedding_temp), self.drop_ratio, training=self.training) else: virtualnode_embedding = F.dropout( self.mlp_virtualnode_list[layer]( virtualnode_embedding_temp), self.drop_ratio, training=self.training) # Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layers + 1): node_representation += h_list[layer] return node_representation class GNN(torch.nn.Module): def __init__( self, num_tasks=1, num_layers=5, emb_dim=300, gnn_type='gin', virtual_node=True, residual=False, drop_ratio=0, JK="last", graph_pooling="sum", ): r"""GNN. Args: num_tasks (int): number of labels to be predicted num_layers (int): number of gnn layers. emb_dim (int): embedding dim to use. gnn_type (str): Type of GNN to use. virtual_node (bool): whether to add virtual node or not. residual (bool): Whether or not to add the residual drop_ratio (float): dropout ratio. JK (str): "last" or "sum" to choose JK concat strat. graph_pooling (str): Graph pooling strat to use. """ super().__init__() if num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.num_layers = num_layers self.drop_ratio = drop_ratio self.JK = JK self.emb_dim = emb_dim self.num_tasks = num_tasks self.graph_pooling = graph_pooling if virtual_node: self.gnn_node = GNNNodeVirtualNode( num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual, gnn_type=gnn_type, ) else: self.gnn_node = GNNNode( num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual, gnn_type=gnn_type, ) # Pooling function to generate whole-graph embeddings if self.graph_pooling == "sum": self.pool = global_add_pool elif self.graph_pooling == "mean": self.pool = global_mean_pool elif self.graph_pooling == "max": self.pool = global_max_pool elif self.graph_pooling == "attention": self.pool = GlobalAttention(gate_nn=torch.nn.Sequential( torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, 1), )) elif self.graph_pooling == "set2set": self.pool = Set2Set(emb_dim, processing_steps=2) else: raise ValueError("Invalid graph pooling type.") if graph_pooling == "set2set": self.graph_pred_linear = torch.nn.Linear(2 * emb_dim, num_tasks) else: self.graph_pred_linear = torch.nn.Linear(emb_dim, num_tasks) def forward(self, batched_data): h_node = self.gnn_node(batched_data) h_graph = self.pool(h_node, batched_data.batch) output = self.graph_pred_linear(h_graph) if self.training: return output else: # At inference time, we clamp the value between 0 and 20 return torch.clamp(output, min=0, max=20) def train(model, rank, device, loader, optimizer): model.train() reg_criterion = torch.nn.L1Loss() loss_accum = 0.0 for step, batch in enumerate( # noqa: B007 tqdm(loader, desc="Training", disable=(rank > 0))): batch = batch.to(device) pred = model(batch).view(-1, ) optimizer.zero_grad() loss = reg_criterion(pred, batch.y) loss.backward() optimizer.step() loss_accum += loss.detach().cpu().item() return loss_accum / (step + 1) def eval(model, device, loader, evaluator): model.eval() y_true = [] y_pred = [] for batch in tqdm(loader, desc="Evaluating"): batch = batch.to(device) with torch.no_grad(): pred = model(batch).view(-1, ) y_true.append(batch.y.view(pred.shape).detach().cpu()) y_pred.append(pred.detach().cpu()) y_true = torch.cat(y_true, dim=0) y_pred = torch.cat(y_pred, dim=0) input_dict = {"y_true": y_true, "y_pred": y_pred} return evaluator.eval(input_dict)["mae"] def test(model, device, loader): model.eval() y_pred = [] for batch in tqdm(loader, desc="Testing"): batch = batch.to(device) with torch.no_grad(): pred = model(batch).view(-1, ) y_pred.append(pred.detach().cpu()) y_pred = torch.cat(y_pred, dim=0) return y_pred def run(rank, dataset, args): num_devices = args.num_devices device = torch.device( "cuda:" + str(rank)) if num_devices > 0 else torch.device("cpu") if num_devices > 1: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("nccl", rank=rank, world_size=num_devices) if args.on_disk_dataset: train_idx = torch.arange(len(dataset.indices())) else: split_idx = dataset.get_idx_split() train_idx = split_idx["train"] if num_devices > 1: num_splits = math.ceil(train_idx.size(0) / num_devices) train_idx = train_idx.split(num_splits)[rank] if args.train_subset: subset_ratio = 0.1 n = len(train_idx) subset_idx = torch.randperm(n)[:int(subset_ratio * n)] train_dataset = dataset[train_idx[subset_idx]] else: train_dataset = dataset[train_idx] train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, ) if rank == 0: if args.on_disk_dataset: valid_dataset = PCQM4Mv2(root='on_disk_dataset/', split="val", from_smiles_func=ogb_from_smiles_wrapper) test_dev_dataset = PCQM4Mv2( root='on_disk_dataset/', split="test", from_smiles_func=ogb_from_smiles_wrapper) test_challenge_dataset = PCQM4Mv2( root='on_disk_dataset/', split="holdout", from_smiles_func=ogb_from_smiles_wrapper) else: valid_dataset = dataset[split_idx["valid"]] test_dev_dataset = dataset[split_idx["test-dev"]] test_challenge_dataset = dataset[split_idx["test-challenge"]] valid_loader = DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) if args.save_test_dir != '': testdev_loader = DataLoader( test_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) testchallenge_loader = DataLoader( test_challenge_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) if args.checkpoint_dir != '': os.makedirs(args.checkpoint_dir, exist_ok=True) evaluator = PCQM4Mv2Evaluator() gnn_type, virtual_node = args.gnn.split('-') model = GNN( gnn_type=gnn_type, virtual_node=virtual_node, num_layers=args.num_layers, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, graph_pooling=args.graph_pooling, ) if num_devices > 0: model = model.to(rank) if num_devices > 1: model = DistributedDataParallel(model, device_ids=[rank]) optimizer = optim.Adam(model.parameters(), lr=0.001) if args.log_dir != '': writer = SummaryWriter(log_dir=args.log_dir) best_valid_mae = 1000 if args.train_subset: scheduler = StepLR(optimizer, step_size=300, gamma=0.25) args.epochs = 1000 else: scheduler = StepLR(optimizer, step_size=30, gamma=0.25) current_epoch = 1 checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pt') if os.path.isfile(checkpoint_path): checkpoint = fs.torch_load(checkpoint_path) current_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) best_valid_mae = checkpoint['best_val_mae'] print(f"Found checkpoint, resume training at epoch {current_epoch}") for epoch in range(current_epoch, args.epochs + 1): train_mae = train(model, rank, device, train_loader, optimizer) if num_devices > 1: dist.barrier() if rank == 0: valid_mae = eval( model.module if isinstance(model, DistributedDataParallel) else model, device, valid_loader, evaluator) print(f"Epoch {epoch:02d}, " f"Train MAE: {train_mae:.4f}, " f"Val MAE: {valid_mae:.4f}") if args.log_dir != '': writer.add_scalar('valid/mae', valid_mae, epoch) writer.add_scalar('train/mae', train_mae, epoch) if valid_mae < best_valid_mae: best_valid_mae = valid_mae if args.checkpoint_dir != '': checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae, } torch.save(checkpoint, checkpoint_path) if args.save_test_dir != '': test_model = model.module if isinstance( model, DistributedDataParallel) else model testdev_pred = test(test_model, device, testdev_loader) evaluator.save_test_submission( {'y_pred': testdev_pred.cpu().detach().numpy()}, args.save_test_dir, mode='test-dev', ) testchallenge_pred = test(test_model, device, testchallenge_loader) evaluator.save_test_submission( {'y_pred': testchallenge_pred.cpu().detach().numpy()}, args.save_test_dir, mode='test-challenge', ) print(f'Best validation MAE so far: {best_valid_mae}') if num_devices > 1: dist.barrier() scheduler.step() if rank == 0 and args.log_dir != '': writer.close() if __name__ == "__main__": parser = argparse.ArgumentParser( description='GNN baselines on pcqm4m with Pytorch Geometrics', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--gnn', type=str, default='gin-virtual', choices=['gin', 'gin-virtual', 'gcn', 'gcn-virtual'], help='GNN architecture') parser.add_argument('--graph_pooling', type=str, default='sum', help='graph pooling strategy mean or sum') parser.add_argument('--drop_ratio', type=float, default=0, help='dropout ratio') parser.add_argument('--num_layers', type=int, default=5, help='number of GNN message passing layers') parser.add_argument('--emb_dim', type=int, default=600, help='dimensionality of hidden units in GNNs') parser.add_argument('--train_subset', action='store_true') parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train') parser.add_argument('--num_workers', type=int, default=0, help='number of workers') parser.add_argument('--log_dir', type=str, default="", help='tensorboard log directory') parser.add_argument('--checkpoint_dir', type=str, default='', help='directory to save checkpoint') parser.add_argument('--save_test_dir', type=str, default='', help='directory to save test submission file') parser.add_argument('--num_devices', type=int, default='1', help="Number of GPUs, if 0 runs on the CPU") parser.add_argument('--on_disk_dataset', action='store_true') args = parser.parse_args() available_gpus = torch.cuda.device_count() if torch.cuda.is_available( ) else 0 if args.num_devices > available_gpus: if available_gpus == 0: print("No GPUs available, running w/ CPU...") else: raise ValueError(f"Cannot train with {args.num_devices} GPUs: " f"available GPUs count {available_gpus}") # automatic dataloading and splitting if args.on_disk_dataset: dataset = PCQM4Mv2(root='on_disk_dataset/', split='train', from_smiles_func=ogb_from_smiles_wrapper) else: dataset = PygPCQM4Mv2Dataset(root='dataset/') if args.num_devices > 1: mp.spawn(run, args=(dataset, args), nprocs=args.num_devices, join=True) else: run(0, dataset, args) ================================================ FILE: examples/multi_gpu/taobao.py ================================================ # An Multi GPU implementation of unsupervised bipartite GraphSAGE # using the Alibaba Taobao dataset. import argparse import os import os.path as osp import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F import tqdm from sklearn.metrics import roc_auc_score from torch.nn import Embedding, Linear from torch.nn.parallel import DistributedDataParallel import torch_geometric.transforms as T from torch_geometric.datasets import Taobao from torch_geometric.loader import LinkNeighborLoader from torch_geometric.nn import SAGEConv from torch_geometric.utils.convert import to_scipy_sparse_matrix class ItemGNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv(-1, hidden_channels) self.conv2 = SAGEConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return self.lin(x) class UserGNNEncoder(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), hidden_channels) self.conv3 = SAGEConv((-1, -1), hidden_channels) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): item_x = self.conv1( x_dict['item'], edge_index_dict[('item', 'to', 'item')], ).relu() user_x = self.conv2( (x_dict['item'], x_dict['user']), edge_index_dict[('item', 'rev_to', 'user')], ).relu() user_x = self.conv3( (item_x, user_x), edge_index_dict[('item', 'rev_to', 'user')], ).relu() return self.lin(user_x) class EdgeDecoder(torch.nn.Module): def __init__(self, hidden_channels): super().__init__() self.lin1 = Linear(2 * hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, 1) def forward(self, z_src, z_dst, edge_label_index): row, col = edge_label_index z = torch.cat([z_src[row], z_dst[col]], dim=-1) z = self.lin1(z).relu() z = self.lin2(z) return z.view(-1) class Model(torch.nn.Module): def __init__(self, num_users, num_items, hidden_channels, out_channels): super().__init__() self.user_emb = Embedding(num_users, hidden_channels) self.item_emb = Embedding(num_items, hidden_channels) self.item_encoder = ItemGNNEncoder(hidden_channels, out_channels) self.user_encoder = UserGNNEncoder(hidden_channels, out_channels) self.decoder = EdgeDecoder(out_channels) def forward(self, x_dict, edge_index_dict, edge_label_index): z_dict = {} x_dict['user'] = self.user_emb(x_dict['user']) x_dict['item'] = self.item_emb(x_dict['item']) z_dict['item'] = self.item_encoder( x_dict['item'], edge_index_dict[('item', 'to', 'item')], ) z_dict['user'] = self.user_encoder(x_dict, edge_index_dict) return self.decoder(z_dict['user'], z_dict['item'], edge_label_index) def run_train(rank, data, train_data, val_data, test_data, args, world_size): if rank == 0: print("Setting up Data Loaders...") train_edge_label_idx = train_data[('user', 'to', 'item')].edge_label_index train_edge_label_idx = train_edge_label_idx.split( train_edge_label_idx.size(1) // world_size, dim=1)[rank].clone() train_loader = LinkNeighborLoader( data=train_data, num_neighbors=[8, 4], edge_label_index=(('user', 'to', 'item'), train_edge_label_idx), neg_sampling='binary', batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, ) val_loader = LinkNeighborLoader( data=val_data, num_neighbors=[8, 4], edge_label_index=( ('user', 'to', 'item'), val_data[('user', 'to', 'item')].edge_label_index, ), edge_label=val_data[('user', 'to', 'item')].edge_label, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) test_loader = LinkNeighborLoader( data=test_data, num_neighbors=[8, 4], edge_label_index=( ('user', 'to', 'item'), test_data[('user', 'to', 'item')].edge_label_index, ), edge_label=test_data[('user', 'to', 'item')].edge_label, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) def train(): model.train() total_loss = total_examples = 0 for batch in tqdm.tqdm(train_loader, disable=rank != 0): batch = batch.to(rank) optimizer.zero_grad() pred = model( batch.x_dict, batch.edge_index_dict, batch['user', 'item'].edge_label_index, ) loss = F.binary_cross_entropy_with_logits( pred, batch['user', 'item'].edge_label) loss.backward() optimizer.step() total_loss += float(loss) total_examples += pred.numel() return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() preds, targets = [], [] for batch in tqdm.tqdm(loader, disable=rank != 0): batch = batch.to(rank) pred = model( batch.x_dict, batch.edge_index_dict, batch['user', 'item'].edge_label_index, ).sigmoid().view(-1).cpu() target = batch['user', 'item'].edge_label.long().cpu() preds.append(pred) targets.append(target) pred = torch.cat(preds, dim=0).numpy() target = torch.cat(targets, dim=0).numpy() return roc_auc_score(target, pred) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) model = Model( num_users=data['user'].num_nodes, num_items=data['item'].num_nodes, hidden_channels=64, out_channels=64, ).to(rank) # Initialize lazy modules for batch in train_loader: batch = batch.to(rank) _ = model( batch.x_dict, batch.edge_index_dict, batch['user', 'item'].edge_label_index, ) break model = DistributedDataParallel(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) best_val_auc = 0 for epoch in range(1, args.epochs): print("Train") loss = train() if rank == 0: print("Val") val_auc = test(val_loader) best_val_auc = max(best_val_auc, val_auc) if rank == 0: print( f'Epoch: {epoch:02d}, Loss: {loss:4f}, Val AUC: {val_auc:.4f}') if rank == 0: print("Test") test_auc = test(test_loader) print(f'Total {args.epochs:02d} epochs: Final Loss: {loss:4f}, ' f'Best Val AUC: {best_val_auc:.4f}, ' f'Test AUC: {test_auc:.4f}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--num_workers', type=int, default=16, help="Number of workers per dataloader") parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--epochs', type=int, default=21) parser.add_argument('--batch_size', type=int, default=2048) parser.add_argument( '--dataset_root_dir', type=str, default=osp.join(osp.dirname(osp.realpath(__file__)), '../../data/Taobao')) args = parser.parse_args() def pre_transform(data): # Compute sparsified item<>item relationships through users: print('Computing item<>item relationships...') mat = to_scipy_sparse_matrix(data['user', 'item'].edge_index).tocsr() mat = mat[:data['user'].num_nodes, :data['item'].num_nodes] comat = mat.T @ mat comat.setdiag(0) comat = comat >= 3. comat = comat.tocoo() row = torch.from_numpy(comat.row).to(torch.long) col = torch.from_numpy(comat.col).to(torch.long) data['item', 'item'].edge_index = torch.stack([row, col], dim=0) return data dataset = Taobao(args.dataset_root_dir, pre_transform=pre_transform) data = dataset[0] data['user'].x = torch.arange(0, data['user'].num_nodes) data['item'].x = torch.arange(0, data['item'].num_nodes) # Only consider user<>item relationships for simplicity: del data['category'] del data['item', 'category'] del data['user', 'item'].time del data['user', 'item'].behavior # Add a reverse ('item', 'rev_to', 'user') relation for message passing: data = T.ToUndirected()(data) # Perform a link-level split into training, validation, and test edges: print('Computing data splits...') train_data, val_data, test_data = T.RandomLinkSplit( num_val=0.1, num_test=0.1, neg_sampling_ratio=1.0, add_negative_train_samples=False, edge_types=[('user', 'to', 'item')], rev_edge_types=[('item', 'rev_to', 'user')], )(data) print('Done!') world_size = torch.cuda.device_count() print('Let\'s use', world_size, 'GPUs!') mp.spawn(run_train, args=(data, train_data, val_data, test_data, args, world_size), nprocs=world_size, join=True) ================================================ FILE: examples/mutag_gin.py ================================================ import argparse import os.path as osp import time import torch import torch.nn.functional as F from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.logging import init_wandb, log from torch_geometric.nn import MLP, GINConv, global_add_pool parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='MUTAG') parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--hidden_channels', type=int, default=32) parser.add_argument('--num_layers', type=int, default=5) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--wandb', action='store_true', help='Track experiment') args = parser.parse_args() if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # MPS is currently slower than CPU due to missing int64 min/max ops device = torch.device('cpu') else: device = torch.device('cpu') init_wandb( name=f'GIN-{args.dataset}', batch_size=args.batch_size, lr=args.lr, epochs=args.epochs, hidden_channels=args.hidden_channels, num_layers=args.num_layers, device=device, ) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU') dataset = TUDataset(path, name=args.dataset).shuffle() train_loader = DataLoader(dataset[:0.9], args.batch_size, shuffle=True) test_loader = DataLoader(dataset[0.9:], args.batch_size) class GIN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): mlp = MLP([in_channels, hidden_channels, hidden_channels]) self.convs.append(GINConv(nn=mlp, train_eps=False)) in_channels = hidden_channels self.mlp = MLP([hidden_channels, hidden_channels, out_channels], norm=None, dropout=0.5) def forward(self, x, edge_index, batch): for conv in self.convs: x = conv(x, edge_index).relu() x = global_add_pool(x, batch) return self.mlp(x) model = GIN( in_channels=dataset.num_features, hidden_channels=args.hidden_channels, out_channels=dataset.num_classes, num_layers=args.num_layers, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = F.cross_entropy(out, data.y) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() total_correct = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch) pred = out.argmax(dim=-1) total_correct += int((pred == data.y).sum()) return total_correct / len(loader.dataset) times = [] for epoch in range(1, args.epochs + 1): start = time.time() loss = train() train_acc = test(train_loader) test_acc = test(test_loader) log(Epoch=epoch, Loss=loss, Train=train_acc, Test=test_acc) times.append(time.time() - start) print(f'Median time per epoch: {torch.tensor(times).median():.4f}s') ================================================ FILE: examples/node2vec.py ================================================ import os.path as osp import sys import matplotlib.pyplot as plt import torch from sklearn.manifold import TSNE from torch_geometric.datasets import Planetoid from torch_geometric.nn import Node2Vec path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora') data = dataset[0] device = 'cuda' if torch.cuda.is_available() else 'cpu' model = Node2Vec( data.edge_index, embedding_dim=128, walk_length=20, context_size=10, walks_per_node=10, num_negative_samples=1, p=1.0, q=1.0, sparse=True, ).to(device) num_workers = 4 if sys.platform == 'linux' else 0 loader = model.loader(batch_size=128, shuffle=True, num_workers=num_workers) optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01) def train(): model.train() total_loss = 0 for pos_rw, neg_rw in loader: optimizer.zero_grad() loss = model.loss(pos_rw.to(device), neg_rw.to(device)) loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(loader) @torch.no_grad() def test(): model.eval() z = model() acc = model.test( train_z=z[data.train_mask], train_y=data.y[data.train_mask], test_z=z[data.test_mask], test_y=data.y[data.test_mask], max_iter=150, ) return acc for epoch in range(1, 101): loss = train() acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Acc: {acc:.4f}') @torch.no_grad() def plot_points(colors): model.eval() z = model().cpu().numpy() z = TSNE(n_components=2).fit_transform(z) y = data.y.cpu().numpy() plt.figure(figsize=(8, 8)) for i in range(dataset.num_classes): plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i]) plt.axis('off') plt.show() colors = [ '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700' ] plot_points(colors) ================================================ FILE: examples/ogbn_proteins_deepgcn.py ================================================ import torch import torch.nn.functional as F from ogb.nodeproppred import Evaluator, PygNodePropPredDataset from torch.nn import LayerNorm, Linear, ReLU from tqdm import tqdm from torch_geometric.loader import RandomNodeLoader from torch_geometric.nn import DeepGCNLayer, GENConv from torch_geometric.utils import scatter dataset = PygNodePropPredDataset('ogbn-proteins', root='../data') splitted_idx = dataset.get_idx_split() data = dataset[0] data.node_species = None data.y = data.y.to(torch.float) # Initialize features of nodes by aggregating edge features. row, col = data.edge_index data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum') # Set split indices to masks. for split in ['train', 'valid', 'test']: mask = torch.zeros(data.num_nodes, dtype=torch.bool) mask[splitted_idx[split]] = True data[f'{split}_mask'] = mask train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True, num_workers=5) test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5) class DeeperGCN(torch.nn.Module): def __init__(self, hidden_channels, num_layers): super().__init__() self.node_encoder = Linear(data.x.size(-1), hidden_channels) self.edge_encoder = Linear(data.edge_attr.size(-1), hidden_channels) self.layers = torch.nn.ModuleList() for i in range(1, num_layers + 1): conv = GENConv(hidden_channels, hidden_channels, aggr='softmax', t=1.0, learn_t=True, num_layers=2, norm='layer') norm = LayerNorm(hidden_channels, elementwise_affine=True) act = ReLU(inplace=True) layer = DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1, ckpt_grad=i % 3) self.layers.append(layer) self.lin = Linear(hidden_channels, data.y.size(-1)) def forward(self, x, edge_index, edge_attr): x = self.node_encoder(x) edge_attr = self.edge_encoder(edge_attr) x = self.layers[0].conv(x, edge_index, edge_attr) for layer in self.layers[1:]: x = layer(x, edge_index, edge_attr) x = self.layers[0].act(self.layers[0].norm(x)) x = F.dropout(x, p=0.1, training=self.training) return self.lin(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = DeeperGCN(hidden_channels=64, num_layers=28).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) criterion = torch.nn.BCEWithLogitsLoss() evaluator = Evaluator('ogbn-proteins') def train(epoch): model.train() pbar = tqdm(total=len(train_loader)) pbar.set_description(f'Training epoch: {epoch:04d}') total_loss = total_examples = 0 for data in train_loader: optimizer.zero_grad() data = data.to(device) out = model(data.x, data.edge_index, data.edge_attr) loss = criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() total_loss += float(loss) * int(data.train_mask.sum()) total_examples += int(data.train_mask.sum()) pbar.update(1) pbar.close() return total_loss / total_examples @torch.no_grad() def test(): model.eval() y_true = {'train': [], 'valid': [], 'test': []} y_pred = {'train': [], 'valid': [], 'test': []} pbar = tqdm(total=len(test_loader)) pbar.set_description(f'Evaluating epoch: {epoch:04d}') for data in test_loader: data = data.to(device) out = model(data.x, data.edge_index, data.edge_attr) for split in y_true.keys(): mask = data[f'{split}_mask'] y_true[split].append(data.y[mask].cpu()) y_pred[split].append(out[mask].cpu()) pbar.update(1) pbar.close() train_rocauc = evaluator.eval({ 'y_true': torch.cat(y_true['train'], dim=0), 'y_pred': torch.cat(y_pred['train'], dim=0), })['rocauc'] valid_rocauc = evaluator.eval({ 'y_true': torch.cat(y_true['valid'], dim=0), 'y_pred': torch.cat(y_pred['valid'], dim=0), })['rocauc'] test_rocauc = evaluator.eval({ 'y_true': torch.cat(y_true['test'], dim=0), 'y_pred': torch.cat(y_pred['test'], dim=0), })['rocauc'] return train_rocauc, valid_rocauc, test_rocauc for epoch in range(1, 1001): loss = train(epoch) train_rocauc, valid_rocauc, test_rocauc = test() print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, ' f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}') ================================================ FILE: examples/ogbn_train.py ================================================ import argparse import os.path as osp import time import psutil import torch import torch.nn.functional as F from ogb.nodeproppred import PygNodePropPredDataset from torch import Tensor from tqdm import tqdm from torch_geometric import seed_everything from torch_geometric.loader import NeighborLoader from torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer from torch_geometric.utils import ( add_self_loops, remove_self_loops, to_undirected, ) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( '--dataset', type=str, default='ogbn-arxiv', choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'], help='Dataset name.', ) parser.add_argument( '--dataset_dir', type=str, default='./data', help='Root directory of dataset.', ) parser.add_argument( "--model", type=str.lower, default='SGFormer', choices=['sage', 'gat', 'sgformer', 'polynormer'], help="Model used for training", ) parser.add_argument('-e', '--epochs', type=int, default=50) parser.add_argument('-le', '--local_epochs', type=int, default=50, help='warmup epochs for polynormer') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--num_heads', type=int, default=1, help='number of heads for GAT or Graph Transformer model.') parser.add_argument('-b', '--batch_size', type=int, default=1024) parser.add_argument('--num_workers', type=int, default=12) parser.add_argument('--fan_out', type=int, default=10, help='number of neighbors in each layer') parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--lr', type=float, default=0.003) parser.add_argument('--wd', type=float, default=0.0) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument( '--use_directed_graph', action='store_true', help='Whether or not to use directed graph', ) parser.add_argument( '--add_self_loop', action='store_true', help='Whether or not to add self loop', ) args = parser.parse_args() wall_clock_start = time.perf_counter() if (args.dataset == 'ogbn-papers100M' and (psutil.virtual_memory().total / (1024**3)) < 390): print('Warning: may not have enough RAM to run this example.') print('Consider upgrading RAM if an error occurs.') print('Estimated RAM Needed: ~390GB.') if args.model == 'polynormer' and args.num_layers != 7: print("The original polynormer paper recommends 7 layers, you have chosen", args.num_layers, "which may effect results. See for details") print(f'Training {args.dataset} with {args.model} model.') seed_everything(123) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') num_epochs = args.epochs if args.model == 'polynormer': num_epochs += args.local_epochs num_layers = args.num_layers num_workers = args.num_workers num_hidden_channels = args.hidden_channels batch_size = args.batch_size root = osp.join(args.dataset_dir, args.dataset) print('The root is: ', root) dataset = PygNodePropPredDataset(name=args.dataset, root=root) split_idx = dataset.get_idx_split() data = dataset[0] if not args.use_directed_graph: data.edge_index = to_undirected(data.edge_index, reduce='mean') if args.add_self_loop: data.edge_index, _ = remove_self_loops(data.edge_index) data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.num_nodes) data.to(device, 'x', 'y') def get_loader(input_nodes: dict[str, Tensor]) -> NeighborLoader: return NeighborLoader( data, input_nodes=input_nodes, num_neighbors=[args.fan_out] * num_layers, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=num_workers > 0, disjoint=args.model in ['sgformer', 'polynormer'], ) train_loader = get_loader(split_idx['train']) val_loader = get_loader(split_idx['valid']) test_loader = get_loader(split_idx['test']) def train(epoch: int) -> tuple[Tensor, float]: model.train() pbar = tqdm(total=split_idx['train'].size(0)) pbar.set_description(f'Epoch {epoch:02d}') total_loss = total_correct = 0 for batch in train_loader: optimizer.zero_grad() if args.model in ['sgformer', 'polynormer']: if args.model == 'polynormer' and epoch == args.local_epochs: print('start global attention') model._global = True out = model(batch.x, batch.edge_index.to(device), batch.batch.to(device))[:batch.batch_size] else: out = model(batch.x, batch.edge_index.to(device))[:batch.batch_size] y = batch.y[:batch.batch_size].squeeze().to(torch.long) loss = F.cross_entropy(out, y) loss.backward() optimizer.step() total_loss += float(loss) total_correct += int(out.argmax(dim=-1).eq(y).sum()) pbar.update(batch.batch_size) pbar.close() loss = total_loss / len(train_loader) approx_acc = total_correct / split_idx['train'].size(0) return loss, approx_acc @torch.no_grad() def test(loader: NeighborLoader) -> float: model.eval() total_correct = total_examples = 0 for batch in loader: batch = batch.to(device) batch_size = batch.num_sampled_nodes[0] if args.model in ['sgformer', 'polynormer']: out = model(batch.x, batch.edge_index, batch.batch)[:batch.batch_size] else: out = model(batch.x, batch.edge_index)[:batch_size] pred = out.argmax(dim=-1) y = batch.y[:batch_size].view(-1).to(torch.long) total_correct += int((pred == y).sum()) total_examples += y.size(0) return total_correct / total_examples def get_model(model_name: str) -> torch.nn.Module: if model_name == 'gat': model = GAT( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, num_layers=num_layers, out_channels=dataset.num_classes, dropout=args.dropout, heads=args.num_heads, ) elif model_name == 'sage': model = GraphSAGE( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, num_layers=num_layers, out_channels=dataset.num_classes, dropout=args.dropout, ) elif model_name == 'sgformer': model = SGFormer( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, out_channels=dataset.num_classes, trans_num_heads=args.num_heads, trans_dropout=args.dropout, gnn_num_layers=num_layers, gnn_dropout=args.dropout, ) elif model_name == 'polynormer': model = Polynormer( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, out_channels=dataset.num_classes, local_layers=num_layers, ) else: raise ValueError(f'Unsupported model type: {model_name}') return model model = get_model(args.model).to(device) model.reset_parameters() optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.wd, ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5) print(f'Total time before training begins took ' f'{time.perf_counter() - wall_clock_start:.4f}s') print('Training...') times = [] train_times = [] inference_times = [] best_val = 0. for epoch in range(1, num_epochs + 1): train_start = time.perf_counter() loss, _ = train(epoch) train_times.append(time.perf_counter() - train_start) inference_start = time.perf_counter() val_acc = test(val_loader) inference_times.append(time.perf_counter() - inference_start) print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, ', f'Train Time: {train_times[-1]:.4f}s') print(f'Val: {val_acc * 100.0:.2f}%,') if val_acc > best_val: best_val = val_acc times.append(time.perf_counter() - train_start) for param_group in optimizer.param_groups: print('lr:') print(param_group['lr']) scheduler.step(val_acc) print(f'Average Epoch Time on training: ' f'{torch.tensor(train_times).mean():.4f}s') print(f'Average Epoch Time on inference: ' f'{torch.tensor(inference_times).mean():.4f}s') print(f'Average Epoch Time: {torch.tensor(times).mean():.4f}s') print(f'Median Epoch Time: {torch.tensor(times).median():.4f}s') print(f'Best Validation Accuracy: {100.0 * best_val:.2f}%') print('Testing...') test_final_acc = test(test_loader) print(f'Test Accuracy: {100.0 * test_final_acc:.2f}%') print(f'Total Program Runtime: ' f'{time.perf_counter() - wall_clock_start:.4f}s') ================================================ FILE: examples/ogc.py ================================================ # The OGC method from the "From Cluster Assumption to Graph Convolution: # Graph-based Semi-Supervised Learning Revisited" paper. # ArXiv: https://arxiv.org/abs/2309.13599 # Datasets CiteSeer Cora PubMed # Acc 0.774 0.869 0.837 # Time 3.76 1.53 2.92 import argparse import os.path as osp import time import warnings import torch import torch.nn.functional as F from torch import Tensor import torch_geometric.transforms as T from torch_geometric.data import Data from torch_geometric.datasets import Planetoid from torch_geometric.utils import one_hot warnings.filterwarnings('ignore', '.*Sparse CSR tensor support.*') decline = 0.9 # decline rate eta_sup = 0.001 # learning rate for supervised loss eta_W = 0.5 # learning rate for updating W beta = 0.1 # moving probability that a node moves to neighbors max_sim_tol = 0.995 # max label prediction similarity between iterations max_patience = 2 # tolerance for consecutive similar test predictions parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='Cora') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') transform = T.Compose([ T.NormalizeFeatures(), T.GCNNorm(), T.ToSparseTensor(layout=torch.sparse_csr), ]) dataset = Planetoid(path, name=args.dataset, transform=transform) data = dataset[0].to(device) y_one_hot = one_hot(data.y, dataset.num_classes) data.trainval_mask = data.train_mask | data.val_mask # LIM track, else use trainval_mask to construct S S = torch.diag(data.train_mask).float().to_sparse() I_N = torch.eye(data.num_nodes).to_sparse(layout=torch.sparse_csr).to(device) # Lazy random walk (also known as lazy graph convolution): lazy_adj = beta * data.adj_t + (1 - beta) * I_N class LinearNeuralNetwork(torch.nn.Module): def __init__(self, num_features: int, num_classes: int, bias: bool = True): super().__init__() self.W = torch.nn.Linear(num_features, num_classes, bias=bias) def forward(self, x: Tensor) -> Tensor: return self.W(x) @torch.no_grad() def test(self, U: Tensor, y_one_hot: Tensor, data: Data): self.eval() out = self(U) loss = F.mse_loss( out[data.trainval_mask], y_one_hot[data.trainval_mask], ) accs = [] pred = out.argmax(dim=-1) for _, mask in data('trainval_mask', 'test_mask'): accs.append(float((pred[mask] == data.y[mask]).sum() / mask.sum())) return float(loss), accs[0], accs[1], pred def update_W(self, U: Tensor, y_one_hot: Tensor, data: Data): optimizer = torch.optim.SGD(self.parameters(), lr=eta_W) self.train() optimizer.zero_grad() pred = self(U) loss = F.mse_loss(pred[data.trainval_mask], y_one_hot[ data.trainval_mask, ], reduction='sum') loss.backward() optimizer.step() return self(U).data, self.W.weight.data model = LinearNeuralNetwork( num_features=dataset.num_features, num_classes=dataset.num_classes, bias=False, ).to(device) def update_U(U: Tensor, y_one_hot: Tensor, pred: Tensor, W: Tensor): global eta_sup # Update the smoothness loss via LGC: U = lazy_adj @ U # Update the supervised loss via SEB: dU_sup = 2 * (S @ (-y_one_hot + pred)) @ W U = U - eta_sup * dU_sup eta_sup = eta_sup * decline return U def ogc() -> float: U = data.x _, _, last_acc, last_pred = model.test(U, y_one_hot, data) patience = 0 for i in range(1, 65): # Updating W by training a simple linear neural network: pred, W = model.update_W(U, y_one_hot, data) # Updating U by LGC and SEB jointly: U = update_U(U, y_one_hot, pred, W) loss, trainval_acc, test_acc, pred = model.test(U, y_one_hot, data) print(f'Epoch: {i:02d}, Loss: {loss:.4f}, ' f'Train+Val Acc: {trainval_acc:.4f} Test Acc {test_acc:.4f}') sim_rate = float((pred == last_pred).sum()) / pred.size(0) if (sim_rate > max_sim_tol): patience += 1 if (patience > max_patience): break last_acc, last_pred = test_acc, pred return last_acc start_time = time.time() test_acc = ogc() print(f'Test Accuracy: {test_acc:.4f}') print(f'Total Time: {time.time() - start_time:.4f}s') ================================================ FILE: examples/pmlp.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import PMLP device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures()) data = dataset[0].to(device) model = PMLP( in_channels=dataset.num_features, hidden_channels=16, out_channels=dataset.num_classes, num_layers=2, dropout=0.5, norm=False, ).to(device) optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4, lr=0.01) def train(): model.train() optimizer.zero_grad() out = model(data.x) # MLP during training. loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() out = model(data.x, data.edge_index) pred = out.argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) return accs best_val_acc = final_test_acc = 0 for epoch in range(1, 201): loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') ================================================ FILE: examples/pna.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau import torch_geometric from torch_geometric.datasets import ZINC from torch_geometric.loader import DataLoader from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool from torch_geometric.utils import degree path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC') train_dataset = ZINC(path, subset=True, split='train') val_dataset = ZINC(path, subset=True, split='val') test_dataset = ZINC(path, subset=True, split='test') train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=128) test_loader = DataLoader(test_dataset, batch_size=128) # Compute the maximum in-degree in the training data. max_degree = -1 for data in train_dataset: d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) max_degree = max(max_degree, int(d.max())) # Compute the in-degree histogram tensor deg = torch.zeros(max_degree + 1, dtype=torch.long) for data in train_dataset: d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel()) class Net(torch.nn.Module): def __init__(self): super().__init__() self.node_emb = Embedding(21, 75) self.edge_emb = Embedding(4, 50) aggregators = ['mean', 'min', 'max', 'std'] scalers = ['identity', 'amplification', 'attenuation'] self.convs = ModuleList() self.batch_norms = ModuleList() for _ in range(4): conv = PNAConv(in_channels=75, out_channels=75, aggregators=aggregators, scalers=scalers, deg=deg, edge_dim=50, towers=5, pre_layers=1, post_layers=1, divide_input=False) self.convs.append(conv) self.batch_norms.append(BatchNorm(75)) self.mlp = Sequential(Linear(75, 50), ReLU(), Linear(50, 25), ReLU(), Linear(25, 1)) def forward(self, x, edge_index, edge_attr, batch): x = self.node_emb(x.squeeze()) edge_attr = self.edge_emb(edge_attr) for conv, batch_norm in zip(self.convs, self.batch_norms): x = F.relu(batch_norm(conv(x, edge_index, edge_attr))) x = global_add_pool(x, batch) return self.mlp(x) if torch.cuda.is_available(): device = torch.device('cuda') elif torch_geometric.is_xpu_available(): device = torch.device('xpu') else: device = torch.device('cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=0.00001) def train(epoch): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.edge_attr, data.batch) loss = (out.squeeze() - data.y).abs().mean() loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() total_error = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.edge_attr, data.batch) total_error += (out.squeeze() - data.y).abs().sum().item() return total_error / len(loader.dataset) for epoch in range(1, 301): loss = train(epoch) val_mae = test(val_loader) test_mae = test(test_loader) scheduler.step(val_mae) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, ' f'Test: {test_mae:.4f}') ================================================ FILE: examples/point_transformer_classification.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear as Lin import torch_geometric.transforms as T from torch_geometric.datasets import ModelNet from torch_geometric.loader import DataLoader from torch_geometric.nn import ( MLP, PointTransformerConv, fps, global_mean_pool, knn, knn_graph, ) from torch_geometric.typing import WITH_TORCH_CLUSTER from torch_geometric.utils import scatter if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data/ModelNet10') pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) train_dataset = ModelNet(path, '10', True, transform, pre_transform) test_dataset = ModelNet(path, '10', False, transform, pre_transform) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) class TransformerBlock(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.lin_in = Lin(in_channels, in_channels) self.lin_out = Lin(out_channels, out_channels) self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False) self.attn_nn = MLP([out_channels, 64, out_channels], norm=None, plain_last=False) self.transformer = PointTransformerConv(in_channels, out_channels, pos_nn=self.pos_nn, attn_nn=self.attn_nn) def forward(self, x, pos, edge_index): x = self.lin_in(x).relu() x = self.transformer(x, pos, edge_index) x = self.lin_out(x).relu() return x class TransitionDown(torch.nn.Module): """Samples the input point cloud by a ratio percentage to reduce cardinality and uses an mlp to augment features dimensionnality. """ def __init__(self, in_channels, out_channels, ratio=0.25, k=16): super().__init__() self.k = k self.ratio = ratio self.mlp = MLP([in_channels, out_channels], plain_last=False) def forward(self, x, pos, batch): # FPS sampling id_clusters = fps(pos, ratio=self.ratio, batch=batch) # compute for each cluster the k nearest points sub_batch = batch[id_clusters] if batch is not None else None # beware of self loop id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch, batch_y=sub_batch) # transformation of features through a simple MLP x = self.mlp(x) # Max pool onto each cluster the features from knn in points x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0, dim_size=id_clusters.size(0), reduce='max') # keep only the clusters and their max-pooled features sub_pos, out = pos[id_clusters], x_out return out, sub_pos, sub_batch class Net(torch.nn.Module): def __init__(self, in_channels, out_channels, dim_model, k=16): super().__init__() self.k = k # dummy feature is created if there is none given in_channels = max(in_channels, 1) # first block self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False) self.transformer_input = TransformerBlock(in_channels=dim_model[0], out_channels=dim_model[0]) # backbone layers self.transformers_down = torch.nn.ModuleList() self.transition_down = torch.nn.ModuleList() for i in range(len(dim_model) - 1): # Add Transition Down block followed by a Transformer block self.transition_down.append( TransitionDown(in_channels=dim_model[i], out_channels=dim_model[i + 1], k=self.k)) self.transformers_down.append( TransformerBlock(in_channels=dim_model[i + 1], out_channels=dim_model[i + 1])) # class score computation self.mlp_output = MLP([dim_model[-1], 64, out_channels], norm=None) def forward(self, x, pos, batch=None): # add dummy features in case there is none if x is None: x = torch.ones((pos.shape[0], 1), device=pos.get_device()) # first block x = self.mlp_input(x) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformer_input(x, pos, edge_index) # backbone for i in range(len(self.transformers_down)): x, pos, batch = self.transition_down[i](x, pos, batch=batch) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformers_down[i](x, pos, edge_index) # GlobalAveragePooling x = global_mean_pool(x, batch) # Class score out = self.mlp_output(x) return F.log_softmax(out, dim=-1) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.pos, data.batch) loss = F.nll_loss(out, data.y) loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_dataset) @torch.no_grad() def test(loader): model.eval() correct = 0 for data in loader: data = data.to(device) pred = model(data.x, data.pos, data.batch).max(dim=1)[1] correct += pred.eq(data.y).sum().item() return correct / len(loader.dataset) if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(0, train_dataset.num_classes, dim_model=[32, 64, 128, 256, 512], k=16).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) for epoch in range(1, 201): loss = train() test_acc = test(test_loader) print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Test: {test_acc:.4f}') scheduler.step() ================================================ FILE: examples/point_transformer_segmentation.py ================================================ import os.path as osp import torch import torch.nn.functional as F from point_transformer_classification import TransformerBlock, TransitionDown from torchmetrics.functional import jaccard_index import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, knn_graph, knn_interpolate from torch_geometric.typing import WITH_TORCH_CLUSTER from torch_geometric.utils import scatter if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') transform = T.Compose([ T.RandomJitter(0.01), T.RandomRotate(15, axis=0), T.RandomRotate(15, axis=1), T.RandomRotate(15, axis=2), ]) pre_transform = T.NormalizeScale() train_dataset = ShapeNet(path, category, split='trainval', transform=transform, pre_transform=pre_transform) test_dataset = ShapeNet(path, category, split='test', pre_transform=pre_transform) train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False) class TransitionUp(torch.nn.Module): """Reduce features dimensionality and interpolate back to higher resolution and cardinality. """ def __init__(self, in_channels, out_channels): super().__init__() self.mlp_sub = MLP([in_channels, out_channels], plain_last=False) self.mlp = MLP([out_channels, out_channels], plain_last=False) def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None): # transform low-res features and reduce the number of features x_sub = self.mlp_sub(x_sub) # interpolate low-res feats to high-res points x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3, batch_x=batch_sub, batch_y=batch) x = self.mlp(x) + x_interpolated return x class Net(torch.nn.Module): def __init__(self, in_channels, out_channels, dim_model, k=16): super().__init__() self.k = k # dummy feature is created if there is none given in_channels = max(in_channels, 1) # first block self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False) self.transformer_input = TransformerBlock( in_channels=dim_model[0], out_channels=dim_model[0], ) # backbone layers self.transformers_up = torch.nn.ModuleList() self.transformers_down = torch.nn.ModuleList() self.transition_up = torch.nn.ModuleList() self.transition_down = torch.nn.ModuleList() for i in range(0, len(dim_model) - 1): # Add Transition Down block followed by a Point Transformer block self.transition_down.append( TransitionDown(in_channels=dim_model[i], out_channels=dim_model[i + 1], k=self.k)) self.transformers_down.append( TransformerBlock(in_channels=dim_model[i + 1], out_channels=dim_model[i + 1])) # Add Transition Up block followed by Point Transformer block self.transition_up.append( TransitionUp(in_channels=dim_model[i + 1], out_channels=dim_model[i])) self.transformers_up.append( TransformerBlock(in_channels=dim_model[i], out_channels=dim_model[i])) # summit layers self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], norm=None, plain_last=False) self.transformer_summit = TransformerBlock( in_channels=dim_model[-1], out_channels=dim_model[-1], ) # class score computation self.mlp_output = MLP([dim_model[0], 64, out_channels], norm=None) def forward(self, x, pos, batch=None): # add dummy features in case there is none if x is None: x = torch.ones((pos.shape[0], 1)).to(pos.get_device()) out_x = [] out_pos = [] out_batch = [] # first block x = self.mlp_input(x) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformer_input(x, pos, edge_index) # save outputs for skipping connections out_x.append(x) out_pos.append(pos) out_batch.append(batch) # backbone down : #reduce cardinality and augment dimensionnality for i in range(len(self.transformers_down)): x, pos, batch = self.transition_down[i](x, pos, batch=batch) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformers_down[i](x, pos, edge_index) out_x.append(x) out_pos.append(pos) out_batch.append(batch) # summit x = self.mlp_summit(x) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformer_summit(x, pos, edge_index) # backbone up : augment cardinality and reduce dimensionnality n = len(self.transformers_down) for i in range(n): x = self.transition_up[-i - 1](x=out_x[-i - 2], x_sub=x, pos=out_pos[-i - 2], pos_sub=out_pos[-i - 1], batch_sub=out_batch[-i - 1], batch=out_batch[-i - 2]) edge_index = knn_graph(out_pos[-i - 2], k=self.k, batch=out_batch[-i - 2]) x = self.transformers_up[-i - 1](x, out_pos[-i - 2], edge_index) # Class score out = self.mlp_output(x) return F.log_softmax(out, dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(3, train_dataset.num_classes, dim_model=[32, 64, 128, 256, 512], k=16).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) def train(): model.train() total_loss = correct_nodes = total_nodes = 0 for i, data in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() out = model(data.x, data.pos, data.batch) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() total_loss += loss.item() correct_nodes += out.argmax(dim=1).eq(data.y).sum().item() total_nodes += data.num_nodes if (i + 1) % 10 == 0: print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} ' f'Train Acc: {correct_nodes / total_nodes:.4f}') total_loss = correct_nodes = total_nodes = 0 def test(loader): model.eval() ious, categories = [], [] y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) outs = model(data.x, data.pos, data.batch) sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() for out, y, category in zip(outs.split(sizes), data.y.split(sizes), data.category.tolist()): category = list(ShapeNet.seg_classes.keys())[category] part = ShapeNet.seg_classes[category] part = torch.tensor(part, device=device) y_map[part] = torch.arange(part.size(0), device=device) iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y], num_classes=part.size(0), absent_score=1.0) ious.append(iou) categories.append(data.category) iou = torch.tensor(ious, device=device) category = torch.cat(categories, dim=0) mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. return float(mean_iou.mean()) # Global IoU. for epoch in range(1, 100): train() iou = test(test_loader) print(f'Epoch: {epoch:03d}, Test IoU: {iou:.4f}') scheduler.step() ================================================ FILE: examples/pointnet2_classification.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import ModelNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius from torch_geometric.typing import WITH_TORCH_CLUSTER if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") class SAModule(torch.nn.Module): def __init__(self, ratio, r, nn): super().__init__() self.ratio = ratio self.r = r self.conv = PointNetConv(nn, add_self_loops=False) def forward(self, x, pos, batch): idx = fps(pos, batch, ratio=self.ratio) row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) x_dst = None if x is None else x[idx] x = self.conv((x, x_dst), (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x, pos, batch class GlobalSAModule(torch.nn.Module): def __init__(self, nn): super().__init__() self.nn = nn def forward(self, x, pos, batch): x = self.nn(torch.cat([x, pos], dim=1)) x = global_max_pool(x, batch) pos = pos.new_zeros((x.size(0), 3)) batch = torch.arange(x.size(0), device=batch.device) return x, pos, batch class Net(torch.nn.Module): def __init__(self): super().__init__() # Input channels account for both `pos` and node features. self.sa1_module = SAModule(0.5, 0.2, MLP([3, 64, 64, 128])) self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256])) self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024])) self.mlp = MLP([1024, 512, 256, 10], dropout=0.5, norm=None) def forward(self, data): sa0_out = (data.x, data.pos, data.batch) sa1_out = self.sa1_module(*sa0_out) sa2_out = self.sa2_module(*sa1_out) sa3_out = self.sa3_module(*sa2_out) x, pos, batch = sa3_out return self.mlp(x).log_softmax(dim=-1) def train(epoch): model.train() for data in train_loader: data = data.to(device) optimizer.zero_grad() loss = F.nll_loss(model(data), data.y) loss.backward() optimizer.step() def test(loader): model.eval() correct = 0 for data in loader: data = data.to(device) with torch.no_grad(): pred = model(data).max(1)[1] correct += pred.eq(data.y).sum().item() return correct / len(loader.dataset) if __name__ == '__main__': path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data/ModelNet10') pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) train_dataset = ModelNet(path, '10', True, transform, pre_transform) test_dataset = ModelNet(path, '10', False, transform, pre_transform) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=6) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=6) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(1, 201): train(epoch) test_acc = test(test_loader) print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}') ================================================ FILE: examples/pointnet2_segmentation.py ================================================ import os.path as osp import torch import torch.nn.functional as F from pointnet2_classification import GlobalSAModule, SAModule from torchmetrics.functional import jaccard_index import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, knn_interpolate from torch_geometric.typing import WITH_TORCH_CLUSTER from torch_geometric.utils import scatter if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') transform = T.Compose([ T.RandomJitter(0.01), T.RandomRotate(15, axis=0), T.RandomRotate(15, axis=1), T.RandomRotate(15, axis=2) ]) pre_transform = T.NormalizeScale() train_dataset = ShapeNet(path, category, split='trainval', transform=transform, pre_transform=pre_transform) test_dataset = ShapeNet(path, category, split='test', pre_transform=pre_transform) train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=6) test_loader = DataLoader(test_dataset, batch_size=12, shuffle=False, num_workers=6) class FPModule(torch.nn.Module): def __init__(self, k, nn): super().__init__() self.k = k self.nn = nn def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip): x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.nn(x) return x, pos_skip, batch_skip class Net(torch.nn.Module): def __init__(self, num_classes): super().__init__() # Input channels account for both `pos` and node features. self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128])) self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256])) self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024])) self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256])) self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128])) self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128])) self.mlp = MLP([128, 128, 128, num_classes], dropout=0.5, norm=None) self.lin1 = torch.nn.Linear(128, 128) self.lin2 = torch.nn.Linear(128, 128) self.lin3 = torch.nn.Linear(128, num_classes) def forward(self, data): sa0_out = (data.x, data.pos, data.batch) sa1_out = self.sa1_module(*sa0_out) sa2_out = self.sa2_module(*sa1_out) sa3_out = self.sa3_module(*sa2_out) fp3_out = self.fp3_module(*sa3_out, *sa2_out) fp2_out = self.fp2_module(*fp3_out, *sa1_out) x, _, _ = self.fp1_module(*fp2_out, *sa0_out) return self.mlp(x).log_softmax(dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(train_dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() total_loss = correct_nodes = total_nodes = 0 for i, data in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() out = model(data) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() total_loss += loss.item() correct_nodes += out.argmax(dim=1).eq(data.y).sum().item() total_nodes += data.num_nodes if (i + 1) % 10 == 0: print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} ' f'Train Acc: {correct_nodes / total_nodes:.4f}') total_loss = correct_nodes = total_nodes = 0 @torch.no_grad() def test(loader): model.eval() ious, categories = [], [] y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) outs = model(data) sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() for out, y, category in zip(outs.split(sizes), data.y.split(sizes), data.category.tolist()): category = list(ShapeNet.seg_classes.keys())[category] part = ShapeNet.seg_classes[category] part = torch.tensor(part, device=device) y_map[part] = torch.arange(part.size(0), device=device) iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y], num_classes=part.size(0), absent_score=1.0) ious.append(iou) categories.append(data.category) iou = torch.tensor(ious, device=device) category = torch.cat(categories, dim=0) mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. return float(mean_iou.mean()) # Global IoU. for epoch in range(1, 31): train() iou = test(test_loader) print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}') ================================================ FILE: examples/ppi.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from sklearn.metrics import f1_score from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader from torch_geometric.nn import GATConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI') train_dataset = PPI(path, split='train') val_dataset = PPI(path, split='val') test_dataset = PPI(path, split='test') train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GATConv(train_dataset.num_features, 256, heads=4, residual=True) self.conv2 = GATConv(4 * 256, 256, heads=4, residual=True) self.conv3 = GATConv(4 * 256, train_dataset.num_classes, heads=6, concat=False, residual=True) def forward(self, x, edge_index): x = F.elu(self.conv1(x, edge_index)) x = F.elu(self.conv2(x, edge_index)) x = self.conv3(x, edge_index) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) loss_op = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.005) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() loss = loss_op(model(data.x, data.edge_index), data.y) total_loss += loss.item() * data.num_graphs loss.backward() optimizer.step() return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() ys, preds = [], [] for data in loader: ys.append(data.y) out = model(data.x.to(device), data.edge_index.to(device)) preds.append((out > 0).float().cpu()) y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 times = [] for epoch in range(1, 101): start = time.time() loss = train() val_f1 = test(val_loader) test_f1 = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, ' f'Test: {test_f1:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/proteins_diff_pool.py ================================================ import os.path as osp import time from math import ceil import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import TUDataset from torch_geometric.loader import DenseDataLoader from torch_geometric.nn import DenseSAGEConv, dense_diff_pool max_nodes = 150 path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS_dense') dataset = TUDataset( path, name='PROTEINS', transform=T.ToDense(max_nodes), pre_filter=lambda data: data.num_nodes <= max_nodes, ) dataset = dataset.shuffle() n = (len(dataset) + 9) // 10 test_dataset = dataset[:n] val_dataset = dataset[n:2 * n] train_dataset = dataset[2 * n:] test_loader = DenseDataLoader(test_dataset, batch_size=20) val_loader = DenseDataLoader(val_dataset, batch_size=20) train_loader = DenseDataLoader(train_dataset, batch_size=20) class GNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, normalize=False, lin=True): super().__init__() self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize) self.bn1 = torch.nn.BatchNorm1d(hidden_channels) self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize) self.bn2 = torch.nn.BatchNorm1d(hidden_channels) self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize) self.bn3 = torch.nn.BatchNorm1d(out_channels) if lin is True: self.lin = torch.nn.Linear(2 * hidden_channels + out_channels, out_channels) else: self.lin = None def bn(self, i, x): batch_size, num_nodes, num_channels = x.size() x = x.view(-1, num_channels) x = getattr(self, f'bn{i}')(x) x = x.view(batch_size, num_nodes, num_channels) return x def forward(self, x, adj, mask=None): batch_size, num_nodes, in_channels = x.size() x0 = x x1 = self.bn(1, self.conv1(x0, adj, mask).relu()) x2 = self.bn(2, self.conv2(x1, adj, mask).relu()) x3 = self.bn(3, self.conv3(x2, adj, mask).relu()) x = torch.cat([x1, x2, x3], dim=-1) if self.lin is not None: x = self.lin(x).relu() return x class Net(torch.nn.Module): def __init__(self): super().__init__() num_nodes = ceil(0.25 * max_nodes) self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes) self.gnn1_embed = GNN(dataset.num_features, 64, 64, lin=False) num_nodes = ceil(0.25 * num_nodes) self.gnn2_pool = GNN(3 * 64, 64, num_nodes) self.gnn2_embed = GNN(3 * 64, 64, 64, lin=False) self.gnn3_embed = GNN(3 * 64, 64, 64, lin=False) self.lin1 = torch.nn.Linear(3 * 64, 64) self.lin2 = torch.nn.Linear(64, dataset.num_classes) def forward(self, x, adj, mask=None): s = self.gnn1_pool(x, adj, mask) x = self.gnn1_embed(x, adj, mask) x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask) s = self.gnn2_pool(x, adj) x = self.gnn2_embed(x, adj) x, adj, l2, e2 = dense_diff_pool(x, adj, s) x = self.gnn3_embed(x, adj) x = x.mean(dim=1) x = self.lin1(x).relu() x = self.lin2(x) return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2 if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(epoch): model.train() loss_all = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() output, _, _ = model(data.x, data.adj, data.mask) loss = F.nll_loss(output, data.y.view(-1)) loss.backward() loss_all += data.y.size(0) * float(loss) optimizer.step() return loss_all / len(train_dataset) @torch.no_grad() def test(loader): model.eval() correct = 0 for data in loader: data = data.to(device) pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1] correct += int(pred.eq(data.y.view(-1)).sum()) return correct / len(loader.dataset) best_val_acc = test_acc = 0 times = [] for epoch in range(1, 151): start = time.time() train_loss = train(epoch) val_acc = test(val_loader) if val_acc > best_val_acc: test_acc = test(test_loader) best_val_acc = val_acc print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, ' f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/proteins_dmon_pool.py ================================================ import os.path as osp import time from math import ceil import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import DenseGraphConv, DMoNPooling, GCNConv from torch_geometric.utils import to_dense_adj, to_dense_batch path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS') dataset = TUDataset(path, name='PROTEINS').shuffle() avg_num_nodes = int(dataset._data.x.size(0) / len(dataset)) n = (len(dataset) + 9) // 10 test_dataset = dataset[:n] val_dataset = dataset[n:2 * n] train_dataset = dataset[2 * n:] test_loader = DataLoader(test_dataset, batch_size=20) val_loader = DataLoader(val_dataset, batch_size=20) train_loader = DataLoader(train_dataset, batch_size=20) class Net(torch.nn.Module): def __init__(self, in_channels, out_channels, hidden_channels=32): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) num_nodes = ceil(0.5 * avg_num_nodes) self.pool1 = DMoNPooling([hidden_channels, hidden_channels], num_nodes) self.conv2 = DenseGraphConv(hidden_channels, hidden_channels) num_nodes = ceil(0.5 * num_nodes) self.pool2 = DMoNPooling([hidden_channels, hidden_channels], num_nodes) self.conv3 = DenseGraphConv(hidden_channels, hidden_channels) self.lin1 = Linear(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) def forward(self, x, edge_index, batch): x = self.conv1(x, edge_index).relu() x, mask = to_dense_batch(x, batch) adj = to_dense_adj(edge_index, batch) _, x, adj, sp1, _, c1 = self.pool1(x, adj, mask) x = self.conv2(x, adj).relu() _, x, adj, sp2, _, c2 = self.pool2(x, adj) x = self.conv3(x, adj) x = x.mean(dim=1) x = self.lin1(x).relu() x = self.lin2(x) return F.log_softmax(x, dim=-1), sp1 + sp2 + c1 + c2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(dataset.num_features, dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(train_loader): model.train() loss_all = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out, tot_loss = model(data.x, data.edge_index, data.batch) loss = F.nll_loss(out, data.y.view(-1)) + tot_loss loss.backward() loss_all += data.y.size(0) * float(loss.detach()) optimizer.step() return loss_all / len(train_dataset) @torch.no_grad() def test(loader): model.eval() correct = 0 loss_all = 0 for data in loader: data = data.to(device) pred, tot_loss = model(data.x, data.edge_index, data.batch) loss = F.nll_loss(pred, data.y.view(-1)) + tot_loss loss_all += data.y.size(0) * float(loss.detach()) correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum()) return loss_all / len(loader.dataset), correct / len(loader.dataset) times = [] for epoch in range(1, 101): start = time.time() train_loss = train(train_loader) _, train_acc = test(train_loader) val_loss, val_acc = test(val_loader) test_loss, test_acc = test(test_loader) print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, ' f'Train Acc: {train_acc:.3f}, Val Loss: {val_loss:.3f}, ' f'Val Acc: {val_acc:.3f}, Test Loss: {test_loss:.3f}, ' f'Test Acc: {test_acc:.3f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/proteins_gmt.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, GraphMultisetTransformer path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS') dataset = TUDataset(path, name='PROTEINS').shuffle() n = (len(dataset) + 9) // 10 train_dataset = dataset[2 * n:] val_dataset = dataset[n:2 * n] test_dataset = dataset[:n] train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=128) test_loader = DataLoader(test_dataset, batch_size=128) class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 32) self.conv2 = GCNConv(32, 32) self.conv3 = GCNConv(32, 32) self.pool = GraphMultisetTransformer(96, k=10, heads=4) self.lin1 = Linear(96, 16) self.lin2 = Linear(16, dataset.num_classes) def forward(self, x0, edge_index, batch): x1 = self.conv1(x0, edge_index).relu() x2 = self.conv2(x1, edge_index).relu() x3 = self.conv3(x2, edge_index).relu() x = torch.cat([x1, x2, x3], dim=-1) x = self.pool(x, batch) x = self.lin1(x).relu() x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = F.cross_entropy(out, data.y) loss.backward() total_loss += data.num_graphs * float(loss.detach()) optimizer.step() return total_loss / len(train_dataset) @torch.no_grad() def test(loader): model.eval() total_correct = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch) total_correct += int((out.argmax(dim=-1) == data.y).sum()) return total_correct / len(loader.dataset) times = [] for epoch in range(1, 201): start = time.time() train_loss = train() val_acc = test(val_loader) test_acc = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, ' f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/proteins_mincut_pool.py ================================================ import os.path as osp import time from math import ceil import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import DenseGraphConv, GCNConv, dense_mincut_pool from torch_geometric.utils import to_dense_adj, to_dense_batch path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS') dataset = TUDataset(path, name='PROTEINS').shuffle() avg_num_nodes = int(dataset._data.x.size(0) / len(dataset)) n = (len(dataset) + 9) // 10 test_dataset = dataset[:n] val_dataset = dataset[n:2 * n] train_dataset = dataset[2 * n:] test_loader = DataLoader(test_dataset, batch_size=20) val_loader = DataLoader(val_dataset, batch_size=20) train_loader = DataLoader(train_dataset, batch_size=20) class Net(torch.nn.Module): def __init__(self, in_channels, out_channels, hidden_channels=32): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) num_nodes = ceil(0.5 * avg_num_nodes) self.pool1 = Linear(hidden_channels, num_nodes) self.conv2 = DenseGraphConv(hidden_channels, hidden_channels) num_nodes = ceil(0.5 * num_nodes) self.pool2 = Linear(hidden_channels, num_nodes) self.conv3 = DenseGraphConv(hidden_channels, hidden_channels) self.lin1 = Linear(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) def forward(self, x, edge_index, batch): x = self.conv1(x, edge_index).relu() x, mask = to_dense_batch(x, batch) adj = to_dense_adj(edge_index, batch) s = self.pool1(x) x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask) x = self.conv2(x, adj).relu() s = self.pool2(x) x, adj, mc2, o2 = dense_mincut_pool(x, adj, s) x = self.conv3(x, adj) x = x.mean(dim=1) x = self.lin1(x).relu() x = self.lin2(x) return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(dataset.num_features, dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4) def train(epoch): model.train() loss_all = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch) loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss loss.backward() loss_all += data.y.size(0) * float(loss) optimizer.step() return loss_all / len(train_dataset) @torch.no_grad() def test(loader): model.eval() correct = 0 loss_all = 0 for data in loader: data = data.to(device) pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch) loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss loss_all += data.y.size(0) * float(loss) correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum()) return loss_all / len(loader.dataset), correct / len(loader.dataset) times = [] best_val_acc = test_acc = 0 best_val_loss = float('inf') patience = start_patience = 50 for epoch in range(1, 15001): start = time.time() train_loss = train(epoch) _, train_acc = test(train_loader) val_loss, val_acc = test(val_loader) if val_loss < best_val_loss: test_loss, test_acc = test(test_loader) best_val_acc = val_acc patience = start_patience else: patience -= 1 if patience == 0: break print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, ' f'Train Acc: {train_acc:.3f}, Val Loss: {val_loss:.3f}, ' f'Val Acc: {val_acc:.3f}, Test Loss: {test_loss:.3f}, ' f'Test Acc: {test_acc:.3f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/proteins_topk_pool.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GraphConv, TopKPooling from torch_geometric.nn import global_max_pool as gmp from torch_geometric.nn import global_mean_pool as gap path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS') dataset = TUDataset(path, name='PROTEINS') dataset = dataset.shuffle() n = len(dataset) // 10 test_dataset = dataset[:n] train_dataset = dataset[n:] test_loader = DataLoader(test_dataset, batch_size=60) train_loader = DataLoader(train_dataset, batch_size=60) class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GraphConv(dataset.num_features, 128) self.pool1 = TopKPooling(128, ratio=0.8) self.conv2 = GraphConv(128, 128) self.pool2 = TopKPooling(128, ratio=0.8) self.conv3 = GraphConv(128, 128) self.pool3 = TopKPooling(128, ratio=0.8) self.lin1 = torch.nn.Linear(256, 128) self.lin2 = torch.nn.Linear(128, 64) self.lin3 = torch.nn.Linear(64, dataset.num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = x1 + x2 + x3 x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.log_softmax(self.lin3(x), dim=-1) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) def train(epoch): model.train() loss_all = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, data.y) loss.backward() loss_all += data.num_graphs * loss.item() optimizer.step() return loss_all / len(train_dataset) def test(loader): model.eval() correct = 0 for data in loader: data = data.to(device) pred = model(data).max(dim=1)[1] correct += pred.eq(data.y).sum().item() return correct / len(loader.dataset) for epoch in range(1, 201): loss = train(epoch) train_acc = test(train_loader) test_acc = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {loss:.5f}, Train Acc: {train_acc:.5f}, ' f'Test Acc: {test_acc:.5f}') ================================================ FILE: examples/pytorch_ignite/README.md ================================================ # Examples for PyTorch Ignite This directory provides examples showcasing the integration of PyG with [PyTorch Ingite](https://pytorch.org/ignite/index.html). | Example | Description | | -------------------- | ---------------------------------------------------------------- | | [`gin.py`](./gin.py) | Demonstrates how to implement the GIN model using PyTorch Ignite | ================================================ FILE: examples/pytorch_ignite/gin.py ================================================ import os.path as osp import ignite import ignite.contrib.handlers.tensorboard_logger import ignite.contrib.handlers.tqdm_logger import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric import seed_everything from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GIN, MLP, global_add_pool class Model(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, hidden_channels: int = 64, num_layers: int = 3, dropout: float = 0.5): super().__init__() self.gnn = GIN(in_channels, hidden_channels, num_layers, dropout=dropout, jk='cat') self.classifier = MLP([hidden_channels, hidden_channels, out_channels], norm="batch_norm", dropout=dropout) def forward(self, data): x = self.gnn(data.x, data.edge_index) x = global_add_pool(x, data.batch) x = self.classifier(x) return x def main(): seed_everything(42) root = osp.join('data', 'TUDataset') dataset = TUDataset(root, 'IMDB-BINARY', pre_transform=T.OneHotDegree(135)) dataset = dataset.shuffle() test_dataset = dataset[:len(dataset) // 10] val_dataset = dataset[len(dataset) // 10:2 * len(dataset) // 10] train_dataset = dataset[2 * len(dataset) // 10:] train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=64, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=64, pin_memory=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model(dataset.num_node_features, dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) metrics = {'acc': ignite.metrics.Accuracy()} def prepare_batch_fn(batch, device, non_blocking): return (batch.to(device, non_blocking=non_blocking), batch.y.to(device, non_blocking=non_blocking)) trainer = ignite.engine.create_supervised_trainer( model=model, optimizer=optimizer, loss_fn=F.cross_entropy, device=device, prepare_batch=prepare_batch_fn, output_transform=lambda x, y, y_pred, loss: loss.item(), amp_mode='amp', ) # Progress bar for each epoch: pbar = ignite.contrib.handlers.tqdm_logger.ProgressBar() pbar.attach(trainer, output_transform=lambda x: {'loss': x}) def log_metrics(evaluator, loader, tag): def logger(trainer): evaluator.run(loader) print(f'{tag:10} Epoch: {trainer.state.epoch:02d}, ' f'Acc: {evaluator.state.metrics["acc"]:.4f}') return logger train_evaluator = ignite.engine.create_supervised_evaluator( model=model, metrics=metrics, device=device, prepare_batch=prepare_batch_fn, output_transform=lambda x, y, y_pred: (y_pred, y), amp_mode='amp', ) trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=1))(log_metrics( train_evaluator, train_loader, 'Training')) val_evaluator = ignite.engine.create_supervised_evaluator( model=model, metrics=metrics, device=device, prepare_batch=prepare_batch_fn, output_transform=lambda x, y, y_pred: (y_pred, y), amp_mode='amp', ) trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=1))(log_metrics( val_evaluator, val_loader, 'Validation')) test_evaluator = ignite.engine.create_supervised_evaluator( model=model, metrics=metrics, device=device, prepare_batch=prepare_batch_fn, output_transform=lambda x, y, y_pred: (y_pred, y), amp_mode='amp', ) trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=1))(log_metrics( test_evaluator, test_loader, 'Test')) # Save checkpoint of the model based on Accuracy on the validation set: checkpoint_handler = ignite.handlers.Checkpoint( {'model': model}, 'runs/gin', n_saved=2, score_name=list(metrics.keys())[0], filename_pattern='best-{global_step}-{score_name}-{score}.pt', global_step_transform=ignite.handlers.global_step_from_engine(trainer), ) val_evaluator.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED, checkpoint_handler) # Create a tensorboard logger to write logs: tb_logger = ignite.contrib.handlers.tensorboard_logger.TensorboardLogger( log_dir=osp.join('runs/example', 'tb_logs')) tb_logger.attach_output_handler( trainer, event_name=ignite.engine.Events.ITERATION_COMPLETED, tag='training', output_transform=lambda loss: {'loss_iteration': loss}) tb_logger.attach_output_handler( trainer, event_name=ignite.engine.Events.EPOCH_COMPLETED, tag='training', output_transform=lambda loss: {'loss_epoch': loss}) tb_logger.attach_output_handler( train_evaluator, event_name=ignite.engine.Events.EPOCH_COMPLETED, tag='training', metric_names='all', global_step_transform=ignite.handlers.global_step_from_engine(trainer), ) tb_logger.attach_output_handler( val_evaluator, event_name=ignite.engine.Events.EPOCH_COMPLETED, tag='validation', metric_names='all', global_step_transform=ignite.handlers.global_step_from_engine(trainer), ) tb_logger.attach_output_handler( test_evaluator, event_name=ignite.engine.Events.EPOCH_COMPLETED, tag='test', metric_names='all', global_step_transform=ignite.handlers.global_step_from_engine(trainer), ) tb_logger.close() trainer.run(train_loader, max_epochs=50) if __name__ == '__main__': main() ================================================ FILE: examples/pytorch_lightning/README.md ================================================ # Examples for PyTorch Lightning This directory provides examples showcasing the integration of PyG with [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning). | Example | Description | | ------------------------------------------ | ------------------------------------------------------------------------------------ | | [`graph_sage.py`](./graph_sage.py) | Combines PyG and PyTorch Lightning for node classification via the `GraphSAGE` model | | [`gin.py`](./gin.py) | Combines PyG and PyTorch Lightning for graph classification via the `GIN` model | | [`relational_gnn.py`](./relational_gnn.py) | Combines PyG and PyTorch Lightning for heterogeneous node classification | ================================================ FILE: examples/pytorch_lightning/gin.py ================================================ import os.path as osp import pytorch_lightning as pl import torch import torch.nn.functional as F from torchmetrics import Accuracy import torch_geometric.transforms as T from torch_geometric.data.lightning import LightningDataset from torch_geometric.datasets import TUDataset from torch_geometric.nn import GIN, MLP, global_add_pool class Model(pl.LightningModule): def __init__(self, in_channels: int, out_channels: int, hidden_channels: int = 64, num_layers: int = 3, dropout: float = 0.5): super().__init__() self.gnn = GIN(in_channels, hidden_channels, num_layers, dropout=dropout, jk='cat') self.classifier = MLP([hidden_channels, hidden_channels, out_channels], norm="batch_norm", dropout=dropout) self.train_acc = Accuracy(task='multiclass', num_classes=out_channels) self.val_acc = Accuracy(task='multiclass', num_classes=out_channels) self.test_acc = Accuracy(task='multiclass', num_classes=out_channels) def forward(self, x, edge_index, batch): x = self.gnn(x, edge_index) x = global_add_pool(x, batch) x = self.classifier(x) return x def training_step(self, data, batch_idx): y_hat = self(data.x, data.edge_index, data.batch) loss = F.cross_entropy(y_hat, data.y) self.train_acc(y_hat.softmax(dim=-1), data.y) self.log('train_acc', self.train_acc, prog_bar=True, on_step=False, on_epoch=True) return loss def validation_step(self, data, batch_idx): y_hat = self(data.x, data.edge_index, data.batch) self.val_acc(y_hat.softmax(dim=-1), data.y) self.log('val_acc', self.val_acc, prog_bar=True, on_step=False, on_epoch=True) def test_step(self, data, batch_idx): y_hat = self(data.x, data.edge_index, data.batch) self.test_acc(y_hat.softmax(dim=-1), data.y) self.log('test_acc', self.test_acc, prog_bar=True, on_step=False, on_epoch=True) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) if __name__ == '__main__': root = osp.join('data', 'TUDataset') dataset = TUDataset(root, 'IMDB-BINARY', pre_transform=T.OneHotDegree(135)) dataset = dataset.shuffle() test_dataset = dataset[:len(dataset) // 10] val_dataset = dataset[len(dataset) // 10:2 * len(dataset) // 10] train_dataset = dataset[2 * len(dataset) // 10:] datamodule = LightningDataset(train_dataset, val_dataset, test_dataset, batch_size=64, num_workers=4) model = Model(dataset.num_node_features, dataset.num_classes) devices = torch.cuda.device_count() strategy = pl.strategies.DDPStrategy(accelerator='gpu') checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1, mode='max') trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=50, log_every_n_steps=5, callbacks=[checkpoint]) trainer.fit(model, datamodule) trainer.test(ckpt_path='best', datamodule=datamodule) ================================================ FILE: examples/pytorch_lightning/graph_sage.py ================================================ import os.path as osp import pytorch_lightning as pl import torch import torch.nn.functional as F from torch.nn import BatchNorm1d from torchmetrics import Accuracy from torch_geometric.data.lightning import LightningNodeData from torch_geometric.datasets import Reddit from torch_geometric.nn import GraphSAGE class Model(pl.LightningModule): def __init__(self, in_channels: int, out_channels: int, hidden_channels: int = 256, num_layers: int = 2, dropout: float = 0.5): super().__init__() self.gnn = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels, dropout=dropout, norm=BatchNorm1d(hidden_channels)) self.train_acc = Accuracy(task='multiclass', num_classes=out_channels) self.val_acc = Accuracy(task='multiclass', num_classes=out_channels) self.test_acc = Accuracy(task='multiclass', num_classes=out_channels) def forward(self, x, edge_index): return self.gnn(x, edge_index) def training_step(self, data, batch_idx): y_hat = self(data.x, data.edge_index)[:data.batch_size] y = data.y[:data.batch_size] loss = F.cross_entropy(y_hat, y) self.train_acc(y_hat.softmax(dim=-1), y) self.log('train_acc', self.train_acc, prog_bar=True, on_step=False, on_epoch=True) return loss def validation_step(self, data, batch_idx): y_hat = self(data.x, data.edge_index)[:data.batch_size] y = data.y[:data.batch_size] self.val_acc(y_hat.softmax(dim=-1), y) self.log('val_acc', self.val_acc, prog_bar=True, on_step=False, on_epoch=True) def test_step(self, data, batch_idx): y_hat = self(data.x, data.edge_index)[:data.batch_size] y = data.y[:data.batch_size] self.test_acc(y_hat.softmax(dim=-1), y) self.log('test_acc', self.test_acc, prog_bar=True, on_step=False, on_epoch=True) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) if __name__ == '__main__': dataset = Reddit(osp.join('data', 'Reddit')) data = dataset[0] datamodule = LightningNodeData( data, input_train_nodes=data.train_mask, input_val_nodes=data.val_mask, input_test_nodes=data.test_mask, loader='neighbor', num_neighbors=[25, 10], batch_size=1024, num_workers=8, ) model = Model(dataset.num_node_features, dataset.num_classes) strategy = pl.strategies.SingleDeviceStrategy('cuda:0') checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1, mode='max') trainer = pl.Trainer(strategy=strategy, devices=1, max_epochs=20, callbacks=[checkpoint]) trainer.fit(model, datamodule) trainer.test(ckpt_path='best', datamodule=datamodule) ================================================ FILE: examples/pytorch_lightning/relational_gnn.py ================================================ import os.path as osp from typing import Dict, List, Tuple import pytorch_lightning as pl import torch import torch.nn.functional as F from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from torch import Tensor from torchmetrics import Accuracy import torch_geometric.transforms as T from torch_geometric.data import Batch from torch_geometric.data.lightning import LightningNodeData from torch_geometric.datasets import OGB_MAG from torch_geometric.nn import Linear, SAGEConv, to_hetero from torch_geometric.typing import EdgeType, NodeType class GNN(torch.nn.Module): def __init__(self, hidden_channels: int, out_channels: int, dropout: float): super().__init__() self.dropout = torch.nn.Dropout(p=dropout) self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), hidden_channels) self.lin = Linear(-1, out_channels) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.conv1(x, edge_index).relu() x = self.dropout(x) x = self.conv2(x, edge_index).relu() x = self.dropout(x) return self.lin(x) class RelationalGNN(LightningModule): def __init__( self, metadata: Tuple[List[NodeType], List[EdgeType]], hidden_channels: int, out_channels: int, dropout: float, ): super().__init__() self.save_hyperparameters() model = GNN(hidden_channels, out_channels, dropout) # Convert the homogeneous GNN model to a heterogeneous variant in # which distinct parameters are learned for each node and edge type. self.model = to_hetero(model, metadata, aggr='sum') self.train_acc = Accuracy(task='multiclass', num_classes=out_channels) self.val_acc = Accuracy(task='multiclass', num_classes=out_channels) self.test_acc = Accuracy(task='multiclass', num_classes=out_channels) def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Tensor], ) -> Dict[NodeType, Tensor]: return self.model(x_dict, edge_index_dict) def common_step(self, batch: Batch) -> Tuple[Tensor, Tensor]: batch_size = batch['paper'].batch_size y_hat = self(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size] y = batch['paper'].y[:batch_size] return y_hat, y def training_step(self, batch: Batch, batch_idx: int) -> Tensor: y_hat, y = self.common_step(batch) loss = F.cross_entropy(y_hat, y) self.train_acc(y_hat.softmax(dim=-1), y) self.log('train_acc', self.train_acc, prog_bar=True, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Batch, batch_idx: int): y_hat, y = self.common_step(batch) self.val_acc(y_hat.softmax(dim=-1), y) self.log('val_acc', self.val_acc, prog_bar=True, on_step=False, on_epoch=True) def test_step(self, batch: Batch, batch_idx: int): y_hat, y = self.common_step(batch) self.test_acc(y_hat.softmax(dim=-1), y) self.log('test_acc', self.test_acc, prog_bar=True, on_step=False, on_epoch=True) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) def main(): dataset = OGB_MAG(osp.join('data', 'OGB'), preprocess='metapath2vec', transform=T.ToUndirected(merge=False)) data = dataset[0] datamodule = LightningNodeData( data, input_train_nodes=('paper', data['paper'].train_mask), input_val_nodes=('paper', data['paper'].val_mask), input_test_nodes=('paper', data['paper'].test_mask), loader='neighbor', num_neighbors=[10, 10], batch_size=1024, num_workers=8, ) model = RelationalGNN(data.metadata(), hidden_channels=64, out_channels=349, dropout=0.0) with torch.no_grad(): # Run a dummy forward pass to initialize lazy model loader = datamodule.train_dataloader() batch = next(iter(loader)) model.common_step(batch) strategy = pl.strategies.SingleDeviceStrategy('cuda:0') checkpoint = ModelCheckpoint(monitor='val_acc', save_top_k=1, mode='max') trainer = Trainer(strategy=strategy, devices=1, max_epochs=20, log_every_n_steps=5, callbacks=[checkpoint]) trainer.fit(model, datamodule) trainer.test(ckpt_path='best', datamodule=datamodule) if __name__ == "__main__": main() ================================================ FILE: examples/qm9_nn_conv.py ================================================ import copy import os.path as osp import torch import torch.nn.functional as F from torch.nn import GRU, Linear, ReLU, Sequential import torch_geometric.transforms as T from torch_geometric.datasets import QM9 from torch_geometric.loader import DataLoader from torch_geometric.nn import NNConv, Set2Set from torch_geometric.utils import remove_self_loops target = 0 dim = 64 class MyTransform: def __call__(self, data): data = copy.copy(data) data.y = data.y[:, target] # Specify target. return data class Complete: def __call__(self, data): data = copy.copy(data) device = data.edge_index.device row = torch.arange(data.num_nodes, dtype=torch.long, device=device) col = torch.arange(data.num_nodes, dtype=torch.long, device=device) row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1) col = col.repeat(data.num_nodes) edge_index = torch.stack([row, col], dim=0) edge_attr = None if data.edge_attr is not None: idx = data.edge_index[0] * data.num_nodes + data.edge_index[1] size = list(data.edge_attr.size()) size[0] = data.num_nodes * data.num_nodes edge_attr = data.edge_attr.new_zeros(size) edge_attr[idx] = data.edge_attr edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) data.edge_attr = edge_attr data.edge_index = edge_index return data path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9') transform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)]) dataset = QM9(path, transform=transform).shuffle() # Normalize targets to mean = 0 and std = 1. mean = dataset.data.y.mean(dim=0, keepdim=True) std = dataset.data.y.std(dim=0, keepdim=True) dataset.data.y = (dataset.data.y - mean) / std mean, std = mean[:, target].item(), std[:, target].item() # Split datasets. test_dataset = dataset[:10000] val_dataset = dataset[10000:20000] train_dataset = dataset[20000:] test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False) val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) class Net(torch.nn.Module): def __init__(self): super().__init__() self.lin0 = torch.nn.Linear(dataset.num_features, dim) nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim)) self.conv = NNConv(dim, dim, nn, aggr='mean') self.gru = GRU(dim, dim) self.set2set = Set2Set(dim, processing_steps=3) self.lin1 = torch.nn.Linear(2 * dim, dim) self.lin2 = torch.nn.Linear(dim, 1) def forward(self, data): out = F.relu(self.lin0(data.x)) h = out.unsqueeze(0) for _ in range(3): m = F.relu(self.conv(out, data.edge_index, data.edge_attr)) out, h = self.gru(m.unsqueeze(0), h) out = out.squeeze(0) out = self.set2set(out, data.batch) out = F.relu(self.lin1(out)) out = self.lin2(out) return out.view(-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) def train(epoch): model.train() loss_all = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() loss = F.mse_loss(model(data), data.y) loss.backward() loss_all += loss.item() * data.num_graphs optimizer.step() return loss_all / len(train_loader.dataset) def test(loader): model.eval() error = 0 for data in loader: data = data.to(device) error += (model(data) * std - data.y * std).abs().sum().item() # MAE return error / len(loader.dataset) best_val_error = None for epoch in range(1, 301): lr = scheduler.optimizer.param_groups[0]['lr'] loss = train(epoch) val_error = test(val_loader) scheduler.step(val_error) if best_val_error is None or val_error <= best_val_error: test_error = test(test_loader) best_val_error = val_error print(f'Epoch: {epoch:03d}, LR: {lr:7f}, Loss: {loss:.7f}, ' f'Val MAE: {val_error:.7f}, Test MAE: {test_error:.7f}') ================================================ FILE: examples/qm9_pretrained_dimenet.py ================================================ import argparse import os.path as osp import torch from torch_geometric.datasets import QM9 from torch_geometric.loader import DataLoader from torch_geometric.nn import DimeNet, DimeNetPlusPlus parser = argparse.ArgumentParser() parser.add_argument('--use_dimenet_plus_plus', action='store_true') args = parser.parse_args() Model = DimeNetPlusPlus if args.use_dimenet_plus_plus else DimeNet path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9') dataset = QM9(path) # DimeNet uses the atomization energy for targets U0, U, H, and G, i.e.: # 7 -> 12, 8 -> 13, 9 -> 14, 10 -> 15 idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 12, 13, 14, 15, 11]) dataset.data.y = dataset.data.y[:, idx] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for target in range(12): # Skip target \delta\epsilon, since it can be computed via # \epsilon_{LUMO} - \epsilon_{HOMO}: if target == 4: continue model, datasets = Model.from_qm9_pretrained(path, dataset, target) train_dataset, val_dataset, test_dataset = datasets model = model.to(device) loader = DataLoader(test_dataset, batch_size=256) maes = [] for data in loader: data = data.to(device) with torch.no_grad(): pred = model(data.z, data.pos, data.batch) mae = (pred.view(-1) - data.y[:, target]).abs() maes.append(mae) mae = torch.cat(maes, dim=0) # Report meV instead of eV: mae = 1000 * mae if target in [2, 3, 4, 6, 7, 8, 9, 10] else mae print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}') ================================================ FILE: examples/qm9_pretrained_schnet.py ================================================ import argparse import os.path as osp import torch from tqdm import tqdm from torch_geometric.datasets import QM9 from torch_geometric.loader import DataLoader from torch_geometric.nn import SchNet parser = argparse.ArgumentParser() parser.add_argument('--cutoff', type=float, default=10.0, help='Cutoff distance for interatomic interactions') args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9') dataset = QM9(path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for target in range(12): model, datasets = SchNet.from_qm9_pretrained(path, dataset, target) train_dataset, val_dataset, test_dataset = datasets model = model.to(device) loader = DataLoader(test_dataset, batch_size=256) maes = [] for data in tqdm(loader): data = data.to(device) with torch.no_grad(): pred = model(data.z, data.pos, data.batch) mae = (pred.view(-1) - data.y[:, target]).abs() maes.append(mae) mae = torch.cat(maes, dim=0) # Report meV instead of eV. mae = 1000 * mae if target in [2, 3, 4, 6, 7, 8, 9, 10] else mae print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}') ================================================ FILE: examples/quiver/README.md ================================================ # Using Quiver for PyG Examples **[Quiver](https://github.com/quiver-team/torch-quiver)** is a **GPU-optimized distributed library** for PyG. It can speed up graph sampling and feature aggregation through GPU when running PyG examples. ## Installation Assuming you have installed PyTorch and PyG, you can install Quiver as follows: ```bash pip install torch-quiver>=0.1.1 ``` ## Usage The API and design documentation of Quiver can be found [here](https://github.com/quiver-team/torch-quiver). ## Examples We provide several examples to showcase the usage of Quiver within PyG: ### Single-GPU Training The single-GPU example leverages Quiver's ability of **(i)** GPU-based graph sampling and feature aggregation, and **(ii)** GNN data caching algorithm (which cache hot data in GPU memory) while enabling fast access to CPU data using a Quiver shared tensor implementation: ```bash python single_gpu_quiver.py ``` ### Multi-GPU Training The multi-GPU example further leverages Quiver's ability of **(i)** distributing sampling and feature aggregation to multiple GPUs, and **(ii)** using multi-GPU memories to cache and replicate hot GNN data: ```bash python multi_gpu_quiver.py ``` ### Distributed Training A Quiver-based distributed PyG example is coming soon. ================================================ FILE: examples/quiver/multi_gpu_quiver.py ================================================ # This script shows how to use Quiver in an existing PyG example: # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py import os from math import ceil import quiver import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborSampler from torch_geometric.nn import SAGEConv class SAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): super().__init__() self.num_layers = num_layers self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) for _ in range(self.num_layers - 2): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.convs.append(SAGEConv(hidden_channels, out_channels)) def forward(self, x, adjs): for i, (edge_index, _, size) in enumerate(adjs): x_target = x[:size[1]] # Target nodes are always placed first. x = self.convs[i]((x, x_target), edge_index) if i != self.num_layers - 1: x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) return x.log_softmax(dim=-1) @torch.no_grad() def inference(self, x_all, device, subgraph_loader): pbar = tqdm(total=x_all.size(0) * self.num_layers) pbar.set_description('Evaluating') for i in range(self.num_layers): xs = [] for batch_size, n_id, adj in subgraph_loader: edge_index, _, size = adj.to(device) x = x_all[n_id].to(device) x_target = x[:size[1]] x = self.convs[i]((x, x_target), edge_index) if i != self.num_layers - 1: x = F.relu(x) xs.append(x.cpu()) pbar.update(batch_size) x_all = torch.cat(xs, dim=0) pbar.close() return x_all def run(rank, world_size, dataset, quiver_feature, quiver_sampler): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) torch.cuda.set_device(rank) data = dataset[0] train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) train_idx = train_idx.split(ceil(train_idx.size(0) / world_size))[rank] train_loader = torch.utils.data.DataLoader(train_idx, batch_size=1024, shuffle=True, num_workers=0) if rank == 0: subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], batch_size=2048, shuffle=False, num_workers=6) torch.manual_seed(12345) model = SAGE(dataset.num_features, 256, dataset.num_classes).to(rank) model = DistributedDataParallel(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) y = data.y.to(rank) for epoch in range(1, 21): model.train() for seeds in train_loader: n_id, batch_size, adjs = quiver_sampler.sample(seeds) adjs = [adj.to(rank) for adj in adjs] optimizer.zero_grad() out = model(quiver_feature[n_id].to(rank), adjs) loss = F.nll_loss(out, y[n_id[:batch_size]]) loss.backward() optimizer.step() dist.barrier() if rank == 0: print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}') if rank == 0 and epoch % 5 == 0: # We evaluate on a single GPU for now model.eval() with torch.no_grad(): out = model.module.inference(quiver_feature, rank, subgraph_loader) res = out.argmax(dim=-1) == data.y acc1 = int(res[data.train_mask].sum()) / int(data.train_mask.sum()) acc2 = int(res[data.val_mask].sum()) / int(data.val_mask.sum()) acc3 = int(res[data.test_mask].sum()) / int(data.test_mask.sum()) print(f'Train: {acc1:.4f}, Val: {acc2:.4f}, Test: {acc3:.4f}') dist.barrier() dist.destroy_process_group() if __name__ == '__main__': dataset = Reddit('../../data/Reddit') data = dataset[0] world_size = torch.cuda.device_count() print('Let\'s use', world_size, 'GPUs!') ######################################################################## # The below code enable Quiver for PyG. # Please refer to: https://torch-quiver.readthedocs.io/en/latest/api/ for # how to configure the CSRTopo, Sampler and Feature of Quiver. ######################################################################## csr_topo = quiver.CSRTopo(data.edge_index) # Quiver quiver_sampler = quiver.pyg.GraphSageSampler(csr_topo, [25, 10], 0, mode='GPU') # Quiver quiver_feature = quiver.Feature(rank=0, device_list=list(range(world_size)), device_cache_size="2G", cache_policy="device_replicate", csr_topo=csr_topo) # Quiver quiver_feature.from_cpu_tensor(data.x) # Quiver mp.spawn(run, args=(world_size, dataset, quiver_feature, quiver_sampler), nprocs=world_size, join=True) ================================================ FILE: examples/quiver/single_gpu_quiver.py ================================================ # This script shows how to use Quiver in an existing PyG example: # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py import os.path as osp import quiver import torch import torch.nn.functional as F from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborSampler from torch_geometric.nn import SAGEConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit') dataset = Reddit(path) data = dataset[0] train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) ################################ # Step 1: Using Quiver's sampler ################################ train_loader = torch.utils.data.DataLoader(train_idx, batch_size=1024, shuffle=True, drop_last=True) # Quiver ######################################################################## # The below code enable Quiver for PyG. # Please refer to: https://torch-quiver.readthedocs.io/en/latest/api/ for # how to configure the CSRTopo, Sampler and Feature of Quiver. ######################################################################## csr_topo = quiver.CSRTopo(data.edge_index) # Quiver quiver_sampler = quiver.pyg.GraphSageSampler(csr_topo, sizes=[25, 10], device=0, mode='GPU') # Quiver subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], batch_size=1024, shuffle=False, num_workers=12) class SAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.num_layers = 2 self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) self.convs.append(SAGEConv(hidden_channels, out_channels)) def forward(self, x, adjs): # `train_loader` computes the k-hop neighborhood of a batch of nodes, # and returns, for each layer, a bipartite graph object, holding the # bipartite edges `edge_index`, the index `e_id` of the original edges, # and the size/shape `size` of the bipartite graph. # Target nodes are also included in the source nodes so that one can # easily apply skip-connections or add self-loops. for i, (edge_index, _, size) in enumerate(adjs): x_target = x[:size[1]] # Target nodes are always placed first. x = self.convs[i]((x, x_target), edge_index) if i != self.num_layers - 1: x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) return x.log_softmax(dim=-1) def inference(self, x_all): pbar = tqdm(total=x_all.size(0) * self.num_layers) pbar.set_description('Evaluating') # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch. for i in range(self.num_layers): xs = [] for batch_size, n_id, adj in subgraph_loader: edge_index, _, size = adj.to(device) x = x_all[n_id].to(device) x_target = x[:size[1]] x = self.convs[i]((x, x_target), edge_index) if i != self.num_layers - 1: x = F.relu(x) xs.append(x.cpu()) pbar.update(batch_size) x_all = torch.cat(xs, dim=0) pbar.close() return x_all device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = SAGE(dataset.num_features, 256, dataset.num_classes) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) ################################ # Step 2: Using Quiver's Feature ################################ x = quiver.Feature(rank=0, device_list=[0], device_cache_size="4G", cache_policy="device_replicate", csr_topo=csr_topo) # Quiver x.from_cpu_tensor(data.x) # Quiver y = data.y.squeeze().to(device) def train(epoch): model.train() pbar = tqdm(total=int(data.train_mask.sum())) pbar.set_description(f'Epoch {epoch:02d}') total_loss = total_correct = 0 ############################################ # Step 3: Training the PyG Model with Quiver ############################################ # for batch_size, n_id, adjs in train_loader: # Original PyG Code for seeds in train_loader: # Quiver n_id, batch_size, adjs = quiver_sampler.sample(seeds) # Quiver # `adjs` holds a list of `(edge_index, e_id, size)` tuples. adjs = [adj.to(device) for adj in adjs] optimizer.zero_grad() out = model(x[n_id], adjs) loss = F.nll_loss(out, y[n_id[:batch_size]]) loss.backward() optimizer.step() total_loss += float(loss) total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum()) pbar.update(batch_size) pbar.close() loss = total_loss / len(train_loader) approx_acc = total_correct / int(data.train_mask.sum()) return loss, approx_acc @torch.no_grad() def test(): model.eval() out = model.inference(x) y_true = y.cpu().unsqueeze(-1) y_pred = out.argmax(dim=-1, keepdim=True) results = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())] return results for epoch in range(1, 11): loss, acc = train(epoch) print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}') train_acc, val_acc, test_acc = test() print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') ================================================ FILE: examples/randlanet_classification.py ================================================ """An adaptation of RandLA-Net to the classification task, which was not addressed in the `"RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds" `_ paper. """ import os.path as osp import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Linear from tqdm import tqdm import torch_geometric.transforms as T from torch_geometric.datasets import ModelNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP from torch_geometric.nn.aggr import MaxAggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.pool import knn_graph from torch_geometric.nn.pool.decimation import decimation_indices from torch_geometric.typing import WITH_TORCH_CLUSTER from torch_geometric.utils import softmax if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") # Default activation and batch norm parameters used by RandLA-Net: lrelu02_kwargs = {'negative_slope': 0.2} bn099_kwargs = {'momentum': 0.01, 'eps': 1e-6} class SharedMLP(MLP): """SharedMLP following RandLA-Net paper.""" def __init__(self, *args, **kwargs): # BN + Act always active even at last layer. kwargs['plain_last'] = False # LeakyRelu with 0.2 slope by default. kwargs['act'] = kwargs.get('act', 'LeakyReLU') kwargs['act_kwargs'] = kwargs.get('act_kwargs', lrelu02_kwargs) # BatchNorm with 1 - 0.99 = 0.01 momentum # and 1e-6 eps by default (tensorflow momentum != pytorch momentum) kwargs['norm_kwargs'] = kwargs.get('norm_kwargs', bn099_kwargs) super().__init__(*args, **kwargs) class LocalFeatureAggregation(MessagePassing): """Positional encoding of points in a neighborhood.""" def __init__(self, channels): super().__init__(aggr='add') self.mlp_encoder = SharedMLP([10, channels // 2]) self.mlp_attention = SharedMLP([channels, channels], bias=False, act=None, norm=None) self.mlp_post_attention = SharedMLP([channels, channels]) def forward(self, edge_index, x, pos): out = self.propagate(edge_index, x=x, pos=pos) # N, d_out out = self.mlp_post_attention(out) # N, d_out return out def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor, index: Tensor) -> Tensor: """Local Spatial Encoding (locSE) and attentive pooling of features. Args: x_j (Tensor): neighboors features (K,d) pos_i (Tensor): centroid position (repeated) (K,3) pos_j (Tensor): neighboors positions (K,3) index (Tensor): index of centroid positions (e.g. [0,...,0,1,...,1,...,N,...,N]) Returns: (Tensor): locSE weighted by feature attention scores. """ # Encode local neighborhood structural information pos_diff = pos_j - pos_i distance = torch.sqrt((pos_diff * pos_diff).sum(1, keepdim=True)) relative_infos = torch.cat([pos_i, pos_j, pos_diff, distance], dim=1) # N * K, d local_spatial_encoding = self.mlp_encoder(relative_infos) # N * K, d local_features = torch.cat([x_j, local_spatial_encoding], dim=1) # N * K, 2d # Attention will weight the different features of x # along the neighborhood dimension. att_features = self.mlp_attention(local_features) # N * K, d_out att_scores = softmax(att_features, index=index) # N * K, d_out return att_scores * local_features # N * K, d_out class DilatedResidualBlock(torch.nn.Module): def __init__( self, num_neighbors, d_in: int, d_out: int, ): super().__init__() self.num_neighbors = num_neighbors self.d_in = d_in self.d_out = d_out # MLP on input self.mlp1 = SharedMLP([d_in, d_out // 8]) # MLP on input, and the result is summed with the output of mlp2 self.shortcut = SharedMLP([d_in, d_out], act=None) # MLP on output self.mlp2 = SharedMLP([d_out // 2, d_out], act=None) self.lfa1 = LocalFeatureAggregation(d_out // 4) self.lfa2 = LocalFeatureAggregation(d_out // 2) self.lrelu = torch.nn.LeakyReLU(**lrelu02_kwargs) def forward(self, x, pos, batch): edge_index = knn_graph(pos, self.num_neighbors, batch=batch, loop=True) shortcut_of_x = self.shortcut(x) # N, d_out x = self.mlp1(x) # N, d_out//8 x = self.lfa1(edge_index, x, pos) # N, d_out//2 x = self.lfa2(edge_index, x, pos) # N, d_out//2 x = self.mlp2(x) # N, d_out x = self.lrelu(x + shortcut_of_x) # N, d_out return x, pos, batch def decimate(tensors, ptr: Tensor, decimation_factor: int): """Decimates each element of the given tuple of tensors.""" idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) tensors_decim = tuple(tensor[idx_decim] for tensor in tensors) return tensors_decim, ptr_decim class Net(torch.nn.Module): def __init__( self, num_features, num_classes, decimation: int = 4, num_neighboors: int = 16, return_logits: bool = False, ): super().__init__() self.decimation = decimation # An option to return logits instead of log probabilities: self.return_logits = return_logits self.fc0 = Linear(in_features=num_features, out_features=8) # 2 DilatedResidualBlock converges better than 4 on ModelNet. self.block1 = DilatedResidualBlock(num_neighboors, 8, 32) self.block2 = DilatedResidualBlock(num_neighboors, 32, 128) self.mlp1 = SharedMLP([128, 128]) self.max_agg = MaxAggregation() self.mlp_classif = SharedMLP([128, 32], dropout=[0.5]) self.fc_classif = Linear(32, num_classes) def forward(self, x, pos, batch, ptr): x = x if x is not None else pos b1 = self.block1(self.fc0(x), pos, batch) b1_decimated, ptr1 = decimate(b1, ptr, self.decimation) b2 = self.block2(*b1_decimated) b2_decimated, _ = decimate(b2, ptr1, self.decimation) x = self.mlp1(b2_decimated[0]) x = self.max_agg(x, b2_decimated[2]) x = self.mlp_classif(x) logits = self.fc_classif(x) return logits if self.return_logits else logits.log_softmax(dim=-1) def train(epoch): model.train() total_loss = 0 for data in tqdm(train_loader): data = data.to(device) optimizer.zero_grad() out = model(data.x, data.pos, data.batch, data.ptr) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() total_loss += data.num_graphs * float(loss) return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() correct = 0 for data in loader: data = data.to(device) out = model(data.x, data.pos, data.batch, data.ptr) correct += int((out.argmax(dim=-1) == data.y).sum()) return correct / len(loader.dataset) if __name__ == '__main__': path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', 'data/ModelNet10') pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) train_dataset = ModelNet(path, '10', True, transform, pre_transform) test_dataset = ModelNet(path, '10', False, transform, pre_transform) train_loader = DataLoader(train_dataset, 32, shuffle=True, num_workers=6) test_loader = DataLoader(test_dataset, 32, shuffle=False, num_workers=6) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(3, train_dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) for epoch in range(1, 201): loss = train(epoch) test_acc = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test: {test_acc:.4f}') scheduler.step() ================================================ FILE: examples/randlanet_segmentation.py ================================================ """An implementation of RandLA-Net based on the `"RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds" `_ paper. """ import os.path as osp import torch import torch.nn.functional as F from randlanet_classification import DilatedResidualBlock, SharedMLP, decimate from torch.nn import Linear from torchmetrics.functional import jaccard_index from tqdm import tqdm import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn import knn_interpolate from torch_geometric.typing import WITH_TORCH_CLUSTER from torch_geometric.utils import scatter if not WITH_TORCH_CLUSTER: quit("This example requires 'torch-cluster'") category = 'Airplane' # Pass in `None` to train on all categories. category_num_classes = 4 # 4 for Airplane - see ShapeNet for details path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') transform = T.Compose([ T.RandomJitter(0.01), T.RandomRotate(15, axis=0), T.RandomRotate(15, axis=1), T.RandomRotate(15, axis=2), ]) pre_transform = T.NormalizeScale() train_dataset = ShapeNet( path, category, split='trainval', transform=transform, pre_transform=pre_transform, ) test_dataset = ShapeNet( path, category, split='test', pre_transform=pre_transform, ) train_loader = DataLoader(train_dataset, 12, shuffle=True, num_workers=6) test_loader = DataLoader(test_dataset, 12, shuffle=False, num_workers=6) class FPModule(torch.nn.Module): """Upsampling with a skip connection.""" def __init__(self, k, nn): super().__init__() self.k = k self.nn = nn def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip): x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) x = torch.cat([x, x_skip], dim=1) x = self.nn(x) return x, pos_skip, batch_skip class Net(torch.nn.Module): def __init__( self, num_features: int, num_classes: int, decimation: int = 4, num_neighbors: int = 16, return_logits: bool = False, ): super().__init__() self.decimation = decimation # An option to return logits instead of log probabilities: self.return_logits = return_logits # Authors use 8, which is a bottleneck # for the final MLP, and also when num_classes>8 # or num_features>8. d_bottleneck = max(32, num_classes, num_features) self.fc0 = Linear(num_features, d_bottleneck) self.block1 = DilatedResidualBlock(num_neighbors, d_bottleneck, 32) self.block2 = DilatedResidualBlock(num_neighbors, 32, 128) self.block3 = DilatedResidualBlock(num_neighbors, 128, 256) self.block4 = DilatedResidualBlock(num_neighbors, 256, 512) self.mlp_summit = SharedMLP([512, 512]) self.fp4 = FPModule(1, SharedMLP([512 + 256, 256])) self.fp3 = FPModule(1, SharedMLP([256 + 128, 128])) self.fp2 = FPModule(1, SharedMLP([128 + 32, 32])) self.fp1 = FPModule(1, SharedMLP([32 + 32, d_bottleneck])) self.mlp_classif = SharedMLP([d_bottleneck, 64, 32], dropout=[0.0, 0.5]) self.fc_classif = Linear(32, num_classes) def forward(self, x, pos, batch, ptr): x = x if x is not None else pos b1_out = self.block1(self.fc0(x), pos, batch) b1_out_decimated, ptr1 = decimate(b1_out, ptr, self.decimation) b2_out = self.block2(*b1_out_decimated) b2_out_decimated, ptr2 = decimate(b2_out, ptr1, self.decimation) b3_out = self.block3(*b2_out_decimated) b3_out_decimated, ptr3 = decimate(b3_out, ptr2, self.decimation) b4_out = self.block4(*b3_out_decimated) b4_out_decimated, _ = decimate(b4_out, ptr3, self.decimation) mlp_out = ( self.mlp_summit(b4_out_decimated[0]), b4_out_decimated[1], b4_out_decimated[2], ) fp4_out = self.fp4(*mlp_out, *b3_out_decimated) fp3_out = self.fp3(*fp4_out, *b2_out_decimated) fp2_out = self.fp2(*fp3_out, *b1_out_decimated) fp1_out = self.fp1(*fp2_out, *b1_out) x = self.mlp_classif(fp1_out[0]) logits = self.fc_classif(x) if self.return_logits: return logits probas = logits.log_softmax(dim=-1) return probas device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(3, category_num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = correct_nodes = total_nodes = 0 for i, data in tqdm(enumerate(train_loader)): data = data.to(device) optimizer.zero_grad() out = model(data.x, data.pos, data.batch, data.ptr) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() total_loss += loss.item() correct_nodes += out.argmax(dim=1).eq(data.y).sum().item() total_nodes += data.num_nodes if (i + 1) % 10 == 0: print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} ' f'Train Acc: {correct_nodes / total_nodes:.4f}') total_loss = correct_nodes = total_nodes = 0 @torch.no_grad() def test(loader): model.eval() ious, categories = [], [] y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) outs = model(data.x, data.pos, data.batch, data.ptr) sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() for out, y, category in zip(outs.split(sizes), data.y.split(sizes), data.category.tolist()): category = list(ShapeNet.seg_classes.keys())[category] part = ShapeNet.seg_classes[category] part = torch.tensor(part, device=device) y_map[part] = torch.arange(part.size(0), device=device) iou = jaccard_index( out[:, part].argmax(dim=-1), y_map[y], num_classes=part.size(0), absent_score=1.0, ) ious.append(iou) categories.append(data.category) iou = torch.tensor(ious, device=device) category = torch.cat(categories, dim=0) mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. return float(mean_iou.mean()) # Global IoU. for epoch in range(1, 31): train() iou = test(test_loader) print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}') ================================================ FILE: examples/rdl.py ================================================ """This example demonstrates how to train a Relational Deep Learning model using RelBench. Please refer to: 1. https://arxiv.org/abs/2407.20060 for RelBench, and 2. https://github.com/snap-stanford/relbench for reproducing the results reported on the RelBench paper. """ import argparse import math import operator import os from typing import Any, Dict, List, NamedTuple, Optional, Tuple import numpy as np import pandas as pd import torch import torch_frame from relbench.base import EntityTask, Table, TaskType from relbench.datasets import get_dataset, get_dataset_names from relbench.modeling.graph import make_pkey_fkey_graph from relbench.modeling.utils import get_stype_proposal from relbench.tasks import get_task, get_task_names from sentence_transformers import SentenceTransformer from torch import Tensor from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.data.stats import StatType from torch_frame.nn.models import ResNet from tqdm import tqdm from torch_geometric.data import HeteroData from torch_geometric.loader import NeighborLoader from torch_geometric.nn import ( MLP, HeteroConv, LayerNorm, PositionalEncoding, SAGEConv, ) from torch_geometric.seed import seed_everything from torch_geometric.typing import EdgeType, NodeType class GloveTextEmbedding: """GloveTextEmbedding based on SentenceTransformer.""" def __init__(self, device: Optional[torch.device] = None) -> None: self.model = SentenceTransformer( "sentence-transformers/average_word_embeddings_glove.6B.300d", device=device, ) def __call__(self, sentences: List[str]) -> Tensor: return torch.from_numpy(self.model.encode(sentences)) class HeteroEncoder(torch.nn.Module): r"""HeteroEncoder based on PyTorch Frame implemented with ResNet. A heterogeneous encoder that processes different node types using PyTorch Frame models. For each node type, it creates a separate encoder model that processes the node features according to their data types (categorical, numerical, etc). Args: channels: The output channels for each node type. num_layers: The number of layers for the ResNet. col_names_dict: A dictionary mapping from node type to column names dictionary compatible with PyTorch Frame. stats_dict: A dictionary containing statistics for each column in each node type. Used for feature normalization and encoding. """ def __init__( self, channels: int, num_layers: int, col_names_dict: Dict[NodeType, Dict[torch_frame.stype, List[str]]], stats_dict: Dict[NodeType, Dict[str, Dict[StatType, Any]]], ) -> None: super().__init__() self.encoders = torch.nn.ModuleDict() for node_type in col_names_dict.keys(): stype_encoder_dict = { torch_frame.categorical: torch_frame.nn.EmbeddingEncoder(), torch_frame.numerical: torch_frame.nn.LinearEncoder(), torch_frame.multicategorical: torch_frame.nn.MultiCategoricalEmbeddingEncoder(), torch_frame.embedding: torch_frame.nn.LinearEmbeddingEncoder(), torch_frame.timestamp: torch_frame.nn.TimestampEncoder() } torch_frame_model = ResNet( channels=channels, num_layers=num_layers, out_channels=channels, col_stats=stats_dict[node_type], col_names_dict=col_names_dict[node_type], stype_encoder_dict=stype_encoder_dict, ) self.encoders[node_type] = torch_frame_model def reset_parameters(self) -> None: """Reset the parameters of all encoder models.""" for encoder in self.encoders.values(): encoder.reset_parameters() def forward( self, tf_dict: Dict[NodeType, torch_frame.TensorFrame], ) -> Dict[NodeType, Tensor]: """Forward pass of the heterogeneous encoder. Args: tf_dict: A dictionary mapping node types to their corresponding TensorFrame objects containing the node features. Returns: Dictionary mapping node types to their encoded representations. Each tensor has shape ``[num_nodes, channels]``. """ return { node_type: self.encoders[node_type](tf) for node_type, tf in tf_dict.items() } class HeteroTemporalEncoder(torch.nn.Module): """HeteroTemporalEncoder class that uses PositionalEncoding to encode temporal information for heterogeneous graphs. This encoder computes relative time embeddings between a seed time and node timestamps, converting the time differences from seconds to days. It applies positional encoding followed by a linear transformation for each node type. Args: node_types: List of node types in the heterogeneous graph channels: Number of channels/dimensions for the encoded embeddings Example: >>> encoder = HeteroTemporalEncoder(['user', 'item'], channels=64) >>> seed_time = torch.tensor([1000]) # Reference timestamp >>> time_dict = {'user': torch.tensor([800, 900]), >>> 'item': torch.tensor([700, 850])} >>> batch_dict = {'user': torch.tensor([0, 0]), >>> 'item': torch.tensor([0, 0])} >>> out_dict = encoder(seed_time, time_dict, batch_dict) >>> out_dict['user'].shape torch.Size([2, 64]) """ def __init__(self, node_types: List[NodeType], channels: int) -> None: super().__init__() self.encoder_dict = torch.nn.ModuleDict({ node_type: PositionalEncoding(channels) for node_type in node_types }) self.lin_dict = torch.nn.ModuleDict({ node_type: torch.nn.Linear(channels, channels) for node_type in node_types }) def reset_parameters(self) -> None: """Reset the parameters of all encoders and linear layers.""" for encoder in self.encoder_dict.values(): encoder.reset_parameters() for lin in self.lin_dict.values(): lin.reset_parameters() def forward( self, seed_time: Tensor, time_dict: Dict[NodeType, Tensor], batch_dict: Dict[NodeType, Tensor], ) -> Dict[NodeType, Tensor]: """Forward pass of the temporal encoder. Args: seed_time: Reference timestamps for computing relative times time_dict: Dictionary mapping node types to their timestamps batch_dict: Dictionary mapping node types to batch assignments Returns: Dictionary mapping node types to their temporal embeddings """ out_dict: Dict[NodeType, Tensor] = {} for node_type, time in time_dict.items(): rel_time = seed_time[batch_dict[node_type]] - time rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days. x = self.encoder_dict[node_type](rel_time) x = self.lin_dict[node_type](x) out_dict[node_type] = x return out_dict class HeteroGraphSAGE(torch.nn.Module): """Heterogeneous GraphSAGE model with layer normalization. This model implements a heterogeneous version of GraphSAGE that operates on multiple node and edge types. Each layer consists of a heterogeneous graph convolution followed by layer normalization and ReLU activation. Args: node_types: List of node types in the graph edge_types: List of edge types in the graph channels: Number of channels/features aggr: Node aggregation scheme. num_layers: Number of graph convolution layers. Example: >>> model = HeteroGraphSAGE( >>> node_types=['user', 'item'], >>> edge_types=[('user', 'rates', 'item')], >>> channels=64) >>> out_dict = model(x_dict, edge_index_dict) """ def __init__( self, node_types: List[NodeType], edge_types: List[EdgeType], channels: int, aggr: str = "mean", num_layers: int = 2, ) -> None: super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv( { edge_type: SAGEConv( (channels, channels), channels, aggr=aggr) for edge_type in edge_types }, aggr="sum", ) self.convs.append(conv) self.norms = torch.nn.ModuleList() for _ in range(num_layers): norm_dict = torch.nn.ModuleDict() for node_type in node_types: norm_dict[node_type] = LayerNorm(channels, mode="node") self.norms.append(norm_dict) def reset_parameters(self) -> None: """Reset the parameters of all convolution and normalization layers.""" for conv in self.convs: conv.reset_parameters() for norm_dict in self.norms: for norm in norm_dict.values(): norm.reset_parameters() def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[NodeType, Tensor], ) -> Dict[NodeType, Tensor]: """Forward pass of the heterogeneous GraphSAGE model. Args: x_dict: Node feature dictionary edge_index_dict: Edge index dictionary Returns: Updated node features after message passing """ for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): x_dict = conv(x_dict, edge_index_dict) x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} x_dict = {key: x.relu() for key, x in x_dict.items()} return x_dict class Model(torch.nn.Module): """A heterogeneous graph neural network model for temporal graph learning. This model consists of: 1. A heterogeneous feature encoder for node attributes 2. A temporal encoder for handling time information 3. A heterogeneous GraphSAGE model for message passing 4. An MLP head for final predictions Args: node_types: List of node types in the graph edge_types: List of edge types in the graph col_names_dict: Dictionary mapping node types to their column names and types temporal_node_types: List of node types with temporal features col_stats_dict: Statistics of node features num_layers: Number of GNN layers channels: Hidden dimension size out_channels: Output dimension size aggr: Aggregation method for GNN norm: Normalization method for MLP """ def __init__( self, node_types: List[NodeType], edge_types: List[EdgeType], col_names_dict: Dict[NodeType, Dict[torch_frame.stype, List[str]]], temporal_node_types: List[NodeType], col_stats_dict: Dict[NodeType, Dict[str, Dict[StatType, Any]]], num_layers: int, channels: int, out_channels: int, aggr: str, norm: str, ) -> None: super().__init__() self.encoder = HeteroEncoder( channels=channels, num_layers=num_layers, col_names_dict=col_names_dict, stats_dict=col_stats_dict, ) self.temporal_encoder = HeteroTemporalEncoder( node_types=temporal_node_types, channels=channels, ) self.gnn = HeteroGraphSAGE( node_types=node_types, edge_types=edge_types, channels=channels, aggr=aggr, num_layers=num_layers, ) self.head = MLP( channels, out_channels=out_channels, norm=norm, num_layers=1, ) self.reset_parameters() def reset_parameters(self) -> None: """Reset the parameters of all model components.""" self.encoder.reset_parameters() self.temporal_encoder.reset_parameters() self.gnn.reset_parameters() self.head.reset_parameters() def forward( self, batch: HeteroData, entity_table: NodeType, ) -> Tensor: """Forward pass of the model. Steps: 1. Get seed time from entity table 2. Encode node features using HeteroEncoder 3. Encode temporal features using HeteroTemporalEncoder 4. Add temporal embeddings to node features 5. Apply graph neural network (HeteroGraphSAGE) 6. Apply final MLP head to target node embeddings Args: batch: Batch of heterogeneous graph data entity_table: The target node type for prediction Returns: Tensor: Predictions for nodes in the entity table """ seed_time = batch[entity_table].seed_time x_dict = self.encoder(batch.tf_dict) rel_time_dict = self.temporal_encoder( seed_time, batch.time_dict, batch.batch_dict, ) for node_type, rel_time in rel_time_dict.items(): x_dict[node_type] = x_dict[node_type] + rel_time x_dict = self.gnn(x_dict, batch.edge_index_dict) return self.head(x_dict[entity_table][:seed_time.size(0)]) class AttachTargetTransform: r"""Attach the target label to the heterogeneous mini-batch. The batch consists of disjoins subgraphs loaded via temporal sampling. The same input node can occur multiple times with different timestamps, and thus different subgraphs and labels. Hence labels cannot be stored in the graph object directly, and must be attached to the batch after the batch is created. """ def __init__(self, entity: str, target: Tensor) -> None: self.entity = entity self.target = target def __call__(self, batch: HeteroData) -> HeteroData: batch[self.entity].y = self.target[batch[self.entity].input_id] return batch class TrainingTableInput(NamedTuple): r"""Training table input for node prediction tasks. A container for organizing input data needed for node-level predictions. Attributes: nodes: Tuple of (node_type, indices_tensor) containing the node type identifier and Tensor of node IDs to predict on. time: Optional Tensor of timestamps for temporal sampling. Shape matches node indices. None if task is not temporal. target: Optional Tensor of ground truth labels/values. Shape matches node indices. None during inference. transform: Optional transform that attaches target labels to batches during training. Needed for temporal sampling where nodes can appear multiple times with different labels. """ nodes: Tuple[NodeType, Tensor] time: Optional[Tensor] target: Optional[Tensor] transform: Optional[AttachTargetTransform] def get_task_type_params( task: EntityTask) -> Tuple[int, torch.nn.Module, str, bool]: r"""Get task-specific optimization parameters based on task type. Args: task: Task specification containing task type. Returns: Tuple containing: - out_channels: Number of output channels - loss_fn: Loss function - tune_metric: Metric to optimize - higher_is_better: Whether higher metric values are better """ if task.task_type == TaskType.REGRESSION: out_channels = 1 loss_fn = torch.nn.L1Loss() tune_metric = "mae" higher_is_better = False elif task.task_type == TaskType.BINARY_CLASSIFICATION: out_channels = 1 loss_fn = torch.nn.BCEWithLogitsLoss() tune_metric = "roc_auc" higher_is_better = True else: raise ValueError(f"Unsupported task type: {task.task_type}") return out_channels, loss_fn, tune_metric, higher_is_better def to_unix_time(ser: pd.Series) -> np.ndarray: r"""Convert a pandas Timestamp series to UNIX timestamp in seconds. Args: ser: Input pandas Series containing datetime values. Returns: Array of UNIX timestamps in seconds. """ assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")] unix_time = ser.astype("int64").values if ser.dtype == np.dtype("datetime64[ns]"): unix_time //= 10**9 return unix_time def get_train_table_input( split_table: Table, task: EntityTask, ) -> TrainingTableInput: r"""Get the training table input for node prediction. Processes a table split and task to create a TrainingTableInput object containing: 1. Node indices for the target entity type 2. Optional timestamps for temporal sampling 3. Optional target labels/values for training 4. Optional transform to attach labels during batch loading Args: split_table: Table containing node IDs, optional timestamps, and optional target values to predict. task: Task specification containing entity table name, entity column name, target column name, etc. Returns: Container with processed node indices, timestamps, target values and transform needed for training/inference. """ nodes = torch.from_numpy( split_table.df[task.entity_col].astype(int).values) time: Optional[Tensor] = None if split_table.time_col is not None: time = torch.from_numpy( to_unix_time(split_table.df[split_table.time_col])) target: Optional[Tensor] = None transform: Optional[AttachTargetTransform] = None if task.target_col in split_table.df: target = torch.from_numpy( split_table.df[task.target_col].values.astype(float)) transform = AttachTargetTransform(task.entity_table, target) return TrainingTableInput( nodes=(task.entity_table, nodes), time=time, target=target, transform=transform, ) def train( model: Model, train_loader: NeighborLoader, task: EntityTask, optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module, device: torch.device, ) -> float: model.train() loss_accum = torch.zeros(1, device=device).squeeze_() count_accum = 0 for batch in tqdm(train_loader): batch = batch.to(device) optimizer.zero_grad() pred = model(batch, task.entity_table) pred = pred.view(-1) if pred.size(1) == 1 else pred # Get the target column name from the task loss = loss_fn(pred, batch[task.entity_table].y.float()) loss.backward() optimizer.step() loss *= pred.size(0) loss_accum += loss count_accum += pred.size(0) return loss_accum.item() / count_accum @torch.no_grad() def test( test_loader: NeighborLoader, model: Model, task: EntityTask, device: torch.device, ) -> np.ndarray: model.eval() pred_list = [] for batch in tqdm(test_loader): batch = batch.to(device) pred = model(batch, task.entity_table) pred = pred.view(-1) if pred.size(1) == 1 else pred pred_list.append(pred.detach().cpu()) return torch.cat(pred_list, dim=0).numpy() def main(): seed_everything(42) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--dataset", type=str, default="rel-f1", choices=get_dataset_names()) parser.add_argument( "--task", type=str, default=None, help="See available tasks at https://relbench.stanford.edu/") parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--temporal_strategy", type=str, default="uniform", choices=["uniform", "last"]) parser.add_argument("--num_neighbors", type=list, default=[128, 128]) parser.add_argument("--channels", type=int, default=128) parser.add_argument("--aggr", type=str, default="sum") parser.add_argument("--norm", type=str, default="batch_norm") parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--lr", type=float, default=0.005) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) print("Loading dataset and task...") assert args.task in get_task_names(args.dataset), ( f"Invalid --task '{args.task}' for --dataset '{args.dataset}'. " f"Available tasks: {get_task_names(args.dataset)}") dataset = get_dataset(name=args.dataset, download=True) task = get_task( dataset_name=args.dataset, task_name=args.task, download=True, ) print(f"Task type: {task.task_type}") print(f"Target column: '{task.target_col}'") print(f"Entity table: '{task.entity_table}'") print("Getting column to stype dictionary...") db = dataset.get_db() col_to_stype_dict = get_stype_proposal(db) print("Column to stype dictionary: ", col_to_stype_dict) print("Defining text embedder...") text_embedder_cfg = TextEmbedderConfig( text_embedder=GloveTextEmbedding(device=device), batch_size=256, ) # Transform the dataset into a HeteroData object with torch_frame features # See also: # https://github.com/snap-stanford/relbench/blob/v1.1.0/relbench/modeling/graph.py#L20-L111 # noqa: E501 print("Transforming dataset into HeteroData object...") data, col_stats_dict = make_pkey_fkey_graph( db, col_to_stype_dict=col_to_stype_dict, # specified column types text_embedder_cfg=text_embedder_cfg, # our chosen text encoder cache_dir=os.path.join( # store materialized graph for convenience "./data", f"{args.dataset}_{args.task}_materialized_cache", ), ) print("Preparing data loaders...") loader_dict = {} num_neighbors_dict = { edge_type: args.num_neighbors for edge_type in data.edge_types } for split in ["train", "val", "test"]: table = task.get_table(split) print(f"Creating '{split}' dataloader with columns: " f"{list(table.df.columns)}") table_input = get_train_table_input(split_table=table, task=task) loader_dict[split] = NeighborLoader( data=data, num_neighbors=num_neighbors_dict, input_nodes=table_input.nodes, input_time=table_input.time, time_attr="time", transform=table_input.transform, batch_size=args.batch_size, temporal_strategy=args.temporal_strategy, shuffle=split == "train", num_workers=4, persistent_workers=True, ) print("Getting task-specific parameters...") out_channels, loss_fn, tune_metric, higher_is_better = \ get_task_type_params(task) print("out_channels: ", out_channels) print("loss_fn: ", loss_fn) print("tune_metric: ", tune_metric) print("higher_is_better: ", higher_is_better) print("Initializing the model...") col_names_dict = { node_type: data[node_type].tf.col_names_dict for node_type in data.node_types } temporal_node_types = [ node_type for node_type in data.node_types if "time" in data[node_type] ] model = Model( node_types=data.node_types, # Include all node types edge_types=data.edge_types, # Include all edge types col_names_dict=col_names_dict, col_stats_dict=col_stats_dict, temporal_node_types=temporal_node_types, num_layers=len(args.num_neighbors), channels=args.channels, out_channels=out_channels, aggr=args.aggr, norm=args.norm, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) print("Training the model...") best_val_metric = -math.inf if higher_is_better else math.inf for epoch in range(1, args.epochs + 1): train_loss = train( model=model, train_loader=loader_dict["train"], task=task, optimizer=optimizer, loss_fn=loss_fn, device=device, ) val_pred = test( test_loader=loader_dict["val"], model=model, task=task, device=device, ) val_metrics = task.evaluate(val_pred, task.get_table("val")) print( f"Epoch: {epoch:02d}, " f"train_loss: {train_loss:.4f}, " f"{', '.join([f'val_{k}: {v:.4f}' for k, v in val_metrics.items()])}" # noqa: E501 ) is_better_op = operator.gt if higher_is_better else operator.lt if is_better_op(val_metrics[tune_metric], best_val_metric): best_val_metric = val_metrics[tune_metric] torch.save(model.state_dict(), "best_model.pt") print("Testing the best model...") model.load_state_dict(torch.load("best_model.pt")) test_pred = test( test_loader=loader_dict["test"], model=model, task=task, device=device, ) test_metrics = task.evaluate(test_pred) print( f"{', '.join([f'test_{k}: {v:.4f}' for k, v in test_metrics.items()])}" ) if __name__ == "__main__": main() ================================================ FILE: examples/rect.py ================================================ import argparse import copy import os.path as osp import torch from sklearn.linear_model import LogisticRegression import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import RECT_L # RECT focuses on the zero-shot, i.e. completely-imbalanced label setting: # For this, we first remove "unseen" classes from the training set and train a # RECT (or more specifically its supervised part RECT-L) model in the zero-shot # label scenario. Lastly, we train a simple classifier to evaluate the final # performance of the embeddings based on the original labels. # Datasets Citeseer Cora Pubmed # Unseen Classes [1, 2, 5] [3, 4] [1, 2, 3] [3, 4, 6] [2] # RECT-L 66.30 68.20 74.60 71.20 75.30 # GCN 51.80 55.70 55.80 57.10 59.80 # NodeFeats 61.40 61.40 57.50 57.50 73.10 parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='Cora', choices=['Cora', 'CiteSeer', 'PubMed']) parser.add_argument('--unseen-classes', type=int, nargs='*', default=[1, 2, 3]) args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '../data/Planetoid') train_mask_original = Planetoid(path, args.dataset)[0].train_mask.clone() transform = T.Compose([ T.NormalizeFeatures(), T.SVDFeatureReduction(200), T.GDC(), ]) dataset = Planetoid(path, args.dataset, transform=transform) data = dataset[0] zs_data = T.RemoveTrainingClasses(args.unseen_classes)(copy.copy(data)) model = RECT_L(200, 200, normalize=False, dropout=0.0) zs_data.y = model.get_semantic_labels(zs_data.x, zs_data.y, zs_data.train_mask) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model, zs_data = model.to(device), zs_data.to(device) criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) model.train() for epoch in range(1, 201): optimizer.zero_grad() out = model(zs_data.x, zs_data.edge_index, zs_data.edge_attr) loss = criterion(out[zs_data.train_mask], zs_data.y) loss.backward() optimizer.step() print(f'Epoch {epoch:03d}, Loss {loss:.4f}') model.eval() with torch.no_grad(): h = model.embed(zs_data.x, zs_data.edge_index, zs_data.edge_attr).cpu() reg = LogisticRegression() reg.fit(h[data.train_mask].numpy(), data.y[data.train_mask].numpy()) test_acc = reg.score(h[data.test_mask].numpy(), data.y[data.test_mask].numpy()) print(f'Test Acc: {test_acc:.4f}') ================================================ FILE: examples/reddit.py ================================================ import copy import os.path as osp import time import torch import torch.nn.functional as F from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader from torch_geometric.nn import SAGEConv device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit') dataset = Reddit(path) # Already send node features/labels to GPU for faster access during sampling: data = dataset[0].to(device, 'x', 'y') kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True} train_loader = NeighborLoader(data, input_nodes=data.train_mask, num_neighbors=[25, 10], shuffle=True, **kwargs) subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None, num_neighbors=[-1], shuffle=False, **kwargs) # No need to maintain these features during evaluation: del subgraph_loader.data.x, subgraph_loader.data.y # Add global node index information. subgraph_loader.data.num_nodes = data.num_nodes subgraph_loader.data.n_id = torch.arange(data.num_nodes) class SAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) self.convs.append(SAGEConv(hidden_channels, out_channels)) def forward(self, x, edge_index): for i, conv in enumerate(self.convs): x = conv(x, edge_index) if i < len(self.convs) - 1: x = x.relu_() x = F.dropout(x, p=0.5, training=self.training) return x @torch.no_grad() def inference(self, x_all, subgraph_loader): pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs)) pbar.set_description('Evaluating') # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch: for i, conv in enumerate(self.convs): xs = [] for batch in subgraph_loader: x = x_all[batch.n_id.to(x_all.device)].to(device) x = conv(x, batch.edge_index.to(device)) if i < len(self.convs) - 1: x = x.relu_() xs.append(x[:batch.batch_size].cpu()) pbar.update(batch.batch_size) x_all = torch.cat(xs, dim=0) pbar.close() return x_all model = SAGE(dataset.num_features, 256, dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(epoch): model.train() pbar = tqdm(total=int(len(train_loader.dataset))) pbar.set_description(f'Epoch {epoch:02d}') total_loss = total_correct = total_examples = 0 for batch in train_loader: optimizer.zero_grad() y = batch.y[:batch.batch_size] y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size] loss = F.cross_entropy(y_hat, y) loss.backward() optimizer.step() total_loss += float(loss) * batch.batch_size total_correct += int((y_hat.argmax(dim=-1) == y).sum()) total_examples += batch.batch_size pbar.update(batch.batch_size) pbar.close() return total_loss / total_examples, total_correct / total_examples @torch.no_grad() def test(): model.eval() y_hat = model.inference(data.x, subgraph_loader).argmax(dim=-1) y = data.y.to(y_hat.device) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: accs.append(int((y_hat[mask] == y[mask]).sum()) / int(mask.sum())) return accs times = [] for epoch in range(1, 11): start = time.time() loss, acc = train(epoch) print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}') train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:02d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/renet.py ================================================ import argparse import os.path as osp import time import torch import torch.nn.functional as F from torch_geometric.datasets import GDELT, ICEWS18 from torch_geometric.loader import DataLoader from torch_geometric.nn.models.re_net import RENet parser = argparse.ArgumentParser() parser.add_argument( '--dataset', type=str, default='GDELT', choices=['ICEWS18', 'GDELT'], ) parser.add_argument('--seq_len', type=int, default=10) args = parser.parse_args() # Load the dataset and precompute history objects: path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', 'data', args.dataset) pre_transform = RENet.pre_transform(args.seq_len) if args.dataset == 'ICEWS18': train_dataset = ICEWS18(path, pre_transform=pre_transform) test_dataset = ICEWS18(path, split='test', pre_transform=pre_transform) elif args.dataset == 'GDELT': train_dataset = GDELT(path, pre_transform=pre_transform) test_dataset = GDELT(path, split='test', pre_transform=pre_transform) # Create dataloader for training and test dataset. train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, follow_batch=['h_sub', 'h_obj'], num_workers=6) test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, follow_batch=['h_sub', 'h_obj'], num_workers=6) # Initialize model and optimizer. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = RENet( train_dataset.num_nodes, train_dataset.num_rels, hidden_channels=200, seq_len=args.seq_len, dropout=0.5, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00001) def train(): model.train() # Train model via multi-class classification against the corresponding # object and subject entities. for data in train_loader: data = data.to(device) optimizer.zero_grad() log_prob_obj, log_prob_sub = model(data) loss_obj = F.nll_loss(log_prob_obj, data.obj) loss_sub = F.nll_loss(log_prob_sub, data.sub) loss = loss_obj + loss_sub loss.backward() optimizer.step() def test(loader): model.eval() # Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10. result = torch.tensor([0, 0, 0, 0], dtype=torch.float) for data in loader: data = data.to(device) with torch.no_grad(): log_prob_obj, log_prob_sub = model(data) result += model.test(log_prob_obj, data.obj) * data.obj.size(0) result += model.test(log_prob_sub, data.sub) * data.sub.size(0) result = result / (2 * len(loader.dataset)) return result.tolist() times = [] for epoch in range(1, 21): start = time.time() train() mrr, hits1, hits3, hits10 = test(test_loader) print(f'Epoch: {epoch:02d}, MRR: {mrr:.4f}, Hits@1: {hits1:.4f}, ' f'Hits@3: {hits3:.4f}, Hits@10: {hits10:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/rev_gnn.py ================================================ # Peak GPU memory usage is around 1.57 G # | RevGNN Models | Test Acc | Val Acc | # |-------------------------|-----------------|-----------------| # | 112 layers 160 channels | 0.8307 ± 0.0030 | 0.9290 ± 0.0007 | # | 7 layers 160 channels | 0.8276 ± 0.0027 | 0.9272 ± 0.0006 | import os.path as osp import time import torch import torch.nn.functional as F from ogb.nodeproppred import Evaluator, PygNodePropPredDataset from torch.nn import LayerNorm, Linear from tqdm import tqdm import torch_geometric.transforms as T from torch_geometric.loader import RandomNodeLoader from torch_geometric.nn import GroupAddRev, SAGEConv from torch_geometric.utils import index_to_mask class GNNBlock(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.norm = LayerNorm(in_channels, elementwise_affine=True) self.conv = SAGEConv(in_channels, out_channels) def reset_parameters(self): self.norm.reset_parameters() self.conv.reset_parameters() def forward(self, x, edge_index, dropout_mask=None): x = self.norm(x).relu() if self.training and dropout_mask is not None: x = x * dropout_mask return self.conv(x, edge_index) class RevGNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_groups=2): super().__init__() self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.norm = LayerNorm(hidden_channels, elementwise_affine=True) assert hidden_channels % num_groups == 0 self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = GNNBlock( hidden_channels // num_groups, hidden_channels // num_groups, ) self.convs.append(GroupAddRev(conv, num_groups=num_groups)) def reset_parameters(self): self.lin1.reset_parameters() self.lin2.reset_parameters() self.norm.reset_parameters() for conv in self.convs: conv.reset_parameters() def forward(self, x, edge_index): x = self.lin1(x) # Generate a dropout mask which will be shared across GNN blocks: mask = None if self.training and self.dropout > 0: mask = torch.zeros_like(x).bernoulli_(1 - self.dropout) mask = mask.requires_grad_(False) mask = mask / (1 - self.dropout) for conv in self.convs: x = conv(x, edge_index, mask) x = self.norm(x).relu() x = F.dropout(x, p=self.dropout, training=self.training) return self.lin2(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') transform = T.Compose([T.ToDevice(device), T.ToSparseTensor()]) root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products') dataset = PygNodePropPredDataset('ogbn-products', root, transform=T.AddSelfLoops()) evaluator = Evaluator(name='ogbn-products') data = dataset[0] split_idx = dataset.get_idx_split() for split in ['train', 'valid', 'test']: data[f'{split}_mask'] = index_to_mask(split_idx[split], data.y.shape[0]) train_loader = RandomNodeLoader(data, num_parts=10, shuffle=True, num_workers=5) # Increase the num_parts of the test loader if you cannot fit # the full batch graph into your GPU: test_loader = RandomNodeLoader(data, num_parts=1, num_workers=5) model = RevGNN( in_channels=dataset.num_features, hidden_channels=160, out_channels=dataset.num_classes, num_layers=7, # You can try 1000 layers for fun dropout=0.5, num_groups=2, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.003) def train(epoch): model.train() pbar = tqdm(total=len(train_loader)) pbar.set_description(f'Training epoch: {epoch:03d}') total_loss = total_examples = 0 for data in train_loader: optimizer.zero_grad() # Memory-efficient aggregations: data = transform(data) out = model(data.x, data.adj_t)[data.train_mask] loss = F.cross_entropy(out, data.y[data.train_mask].view(-1)) loss.backward() optimizer.step() total_loss += float(loss) * int(data.train_mask.sum()) total_examples += int(data.train_mask.sum()) pbar.update(1) pbar.close() return total_loss / total_examples @torch.no_grad() def test(epoch): model.eval() y_true = {"train": [], "valid": [], "test": []} y_pred = {"train": [], "valid": [], "test": []} pbar = tqdm(total=len(test_loader)) pbar.set_description(f'Evaluating epoch: {epoch:03d}') for data in test_loader: # Memory-efficient aggregations data = transform(data) out = model(data.x, data.adj_t).argmax(dim=-1, keepdim=True) for split in ['train', 'valid', 'test']: mask = data[f'{split}_mask'] y_true[split].append(data.y[mask].cpu()) y_pred[split].append(out[mask].cpu()) pbar.update(1) pbar.close() train_acc = evaluator.eval({ 'y_true': torch.cat(y_true['train'], dim=0), 'y_pred': torch.cat(y_pred['train'], dim=0), })['acc'] valid_acc = evaluator.eval({ 'y_true': torch.cat(y_true['valid'], dim=0), 'y_pred': torch.cat(y_pred['valid'], dim=0), })['acc'] test_acc = evaluator.eval({ 'y_true': torch.cat(y_true['test'], dim=0), 'y_pred': torch.cat(y_pred['test'], dim=0), })['acc'] return train_acc, valid_acc, test_acc times = [] best_val = 0.0 final_train = 0.0 final_test = 0.0 for epoch in range(1, 1001): start = time.time() loss = train(epoch) train_acc, val_acc, test_acc = test(epoch) if val_acc > best_val: best_val = val_acc final_train = train_acc final_test = test_acc print(f'Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') times.append(time.time() - start) print(f'Final Train: {final_train:.4f}, Best Val: {best_val:.4f}, ' f'Final Test: {final_test:.4f}') print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/rgat.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F from torch_geometric.datasets import Entities from torch_geometric.nn import RGATConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities') dataset = Entities(path, 'AIFB') data = dataset[0] data.x = torch.randn(data.num_nodes, 16) class RGAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_relations): super().__init__() self.conv1 = RGATConv(in_channels, hidden_channels, num_relations) self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations) self.lin = torch.nn.Linear(hidden_channels, out_channels) def forward(self, x, edge_index, edge_type): x = self.conv1(x, edge_index, edge_type).relu() x = self.conv2(x, edge_index, edge_type).relu() x = self.lin(x) return F.log_softmax(x, dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = data.to(device) model = RGAT(16, 16, dataset.num_classes, dataset.num_relations).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index, data.edge_type) loss = F.nll_loss(out[data.train_idx], data.train_y) loss.backward() optimizer.step() return float(loss.detach()) @torch.no_grad() def test(): model.eval() pred = model(data.x, data.edge_index, data.edge_type).argmax(dim=-1) train_acc = float((pred[data.train_idx] == data.train_y).float().mean()) test_acc = float((pred[data.test_idx] == data.test_y).float().mean()) return train_acc, test_acc times = [] for epoch in range(1, 51): start = time.time() loss = train() train_acc, test_acc = test() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} ' f'Test: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/rgcn.py ================================================ import argparse import os.path as osp import time import torch import torch.nn.functional as F from torch_geometric.datasets import Entities from torch_geometric.nn import FastRGCNConv, RGCNConv from torch_geometric.utils import k_hop_subgraph parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='AIFB', choices=['AIFB', 'MUTAG', 'BGS', 'AM']) args = parser.parse_args() # Trade memory consumption for faster computation. if args.dataset in ['AIFB', 'MUTAG']: Conv = FastRGCNConv else: Conv = RGCNConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities') dataset = Entities(path, args.dataset) data = dataset[0] # BGS and AM graphs are too big to process them in a full-batch fashion. # Since our model does only make use of a rather small receptive field, we # filter the graph to only contain the nodes that are at most 2-hop neighbors # away from any training/test node. node_idx = torch.cat([data.train_idx, data.test_idx], dim=0) node_idx, edge_index, mapping, edge_mask = k_hop_subgraph( node_idx, 2, data.edge_index, relabel_nodes=True) data.num_nodes = node_idx.size(0) data.edge_index = edge_index data.edge_type = data.edge_type[edge_mask] data.train_idx = mapping[:data.train_idx.size(0)] data.test_idx = mapping[data.train_idx.size(0):] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = Conv(data.num_nodes, 16, dataset.num_relations, num_bases=30) self.conv2 = Conv(16, dataset.num_classes, dataset.num_relations, num_bases=30) def forward(self, edge_index, edge_type): x = F.relu(self.conv1(None, edge_index, edge_type)) x = self.conv2(x, edge_index, edge_type) return F.log_softmax(x, dim=1) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') device = torch.device('cpu') if args.dataset == 'AM' else device model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) def train(): model.train() optimizer.zero_grad() out = model(data.edge_index, data.edge_type) loss = F.nll_loss(out[data.train_idx], data.train_y) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() pred = model(data.edge_index, data.edge_type).argmax(dim=-1) train_acc = float((pred[data.train_idx] == data.train_y).float().mean()) test_acc = float((pred[data.test_idx] == data.test_y).float().mean()) return train_acc, test_acc times = [] for epoch in range(1, 51): start = time.time() loss = train() train_acc, test_acc = test() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} ' f'Test: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/rgcn_link_pred.py ================================================ """" Implements the link prediction task on the FB15k237 datasets according to the `"Modeling Relational Data with Graph Convolutional Networks" `_ paper. Caution: This script is executed in a full-batch fashion, and therefore needs to run on CPU (following the experimental setup in the official paper). """ import os.path as osp import time import torch import torch.nn.functional as F from torch.nn import Parameter from tqdm import tqdm from torch_geometric.datasets import RelLinkPredDataset from torch_geometric.nn import GAE, RGCNConv device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'RLPD') dataset = RelLinkPredDataset(path, 'FB15k-237') data = dataset[0].to(device) class RGCNEncoder(torch.nn.Module): def __init__(self, num_nodes, hidden_channels, num_relations): super().__init__() self.node_emb = Parameter(torch.empty(num_nodes, hidden_channels)) self.conv1 = RGCNConv(hidden_channels, hidden_channels, num_relations, num_blocks=5) self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations, num_blocks=5) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.node_emb) self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, edge_index, edge_type): x = self.node_emb x = self.conv1(x, edge_index, edge_type).relu_() x = F.dropout(x, p=0.2, training=self.training) x = self.conv2(x, edge_index, edge_type) return x class DistMultDecoder(torch.nn.Module): def __init__(self, num_relations, hidden_channels): super().__init__() self.rel_emb = Parameter(torch.empty(num_relations, hidden_channels)) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.rel_emb) def forward(self, z, edge_index, edge_type): z_src, z_dst = z[edge_index[0]], z[edge_index[1]] rel = self.rel_emb[edge_type] return torch.sum(z_src * rel * z_dst, dim=1) model = GAE( RGCNEncoder(data.num_nodes, 500, dataset.num_relations), DistMultDecoder(dataset.num_relations // 2, 500), ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def negative_sampling(edge_index, num_nodes): # Sample edges by corrupting either the subject or the object of each edge. mask_1 = torch.rand(edge_index.size(1)) < 0.5 mask_2 = ~mask_1 neg_edge_index = edge_index.clone() neg_edge_index[0, mask_1] = torch.randint(num_nodes, (mask_1.sum(), ), device=neg_edge_index.device) neg_edge_index[1, mask_2] = torch.randint(num_nodes, (mask_2.sum(), ), device=neg_edge_index.device) return neg_edge_index def train(): model.train() optimizer.zero_grad() z = model.encode(data.edge_index, data.edge_type) pos_out = model.decode(z, data.train_edge_index, data.train_edge_type) neg_edge_index = negative_sampling(data.train_edge_index, data.num_nodes) neg_out = model.decode(z, neg_edge_index, data.train_edge_type) out = torch.cat([pos_out, neg_out]) gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)]) cross_entropy_loss = F.binary_cross_entropy_with_logits(out, gt) reg_loss = z.pow(2).mean() + model.decoder.rel_emb.pow(2).mean() loss = cross_entropy_loss + 1e-2 * reg_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() z = model.encode(data.edge_index, data.edge_type) valid_mrr = compute_mrr(z, data.valid_edge_index, data.valid_edge_type) test_mrr = compute_mrr(z, data.test_edge_index, data.test_edge_type) return valid_mrr, test_mrr @torch.no_grad() def compute_rank(ranks): # fair ranking prediction as the average # of optimistic and pessimistic ranking true = ranks[0] optimistic = (ranks > true).sum() + 1 pessimistic = (ranks >= true).sum() return (optimistic + pessimistic).float() * 0.5 @torch.no_grad() def compute_mrr(z, edge_index, edge_type): ranks = [] for i in tqdm(range(edge_type.numel())): (src, dst), rel = edge_index[:, i], edge_type[i] # Try all nodes as tails, but delete true triplets: tail_mask = torch.ones(data.num_nodes, dtype=torch.bool) for (heads, tails), types in [ (data.train_edge_index, data.train_edge_type), (data.valid_edge_index, data.valid_edge_type), (data.test_edge_index, data.test_edge_type), ]: tail_mask[tails[(heads == src) & (types == rel)]] = False tail = torch.arange(data.num_nodes)[tail_mask] tail = torch.cat([torch.tensor([dst]), tail]) head = torch.full_like(tail, fill_value=src) eval_edge_index = torch.stack([head, tail], dim=0) eval_edge_type = torch.full_like(tail, fill_value=rel) out = model.decode(z, eval_edge_index, eval_edge_type) rank = compute_rank(out) ranks.append(rank) # Try all nodes as heads, but delete true triplets: head_mask = torch.ones(data.num_nodes, dtype=torch.bool) for (heads, tails), types in [ (data.train_edge_index, data.train_edge_type), (data.valid_edge_index, data.valid_edge_type), (data.test_edge_index, data.test_edge_type), ]: head_mask[heads[(tails == dst) & (types == rel)]] = False head = torch.arange(data.num_nodes)[head_mask] head = torch.cat([torch.tensor([src]), head]) tail = torch.full_like(head, fill_value=dst) eval_edge_index = torch.stack([head, tail], dim=0) eval_edge_type = torch.full_like(head, fill_value=rel) out = model.decode(z, eval_edge_index, eval_edge_type) rank = compute_rank(out) ranks.append(rank) return (1. / torch.tensor(ranks, dtype=torch.float)).mean() times = [] for epoch in range(1, 10001): start = time.time() loss = train() print(f'Epoch: {epoch:05d}, Loss: {loss:.4f}') if (epoch % 500) == 0: valid_mrr, test_mrr = test() print(f'Val MRR: {valid_mrr:.4f}, Test MRR: {test_mrr:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/seal_link_pred.py ================================================ import math import os.path as osp import time from itertools import chain import numpy as np import torch import torch.nn.functional as F from scipy.sparse.csgraph import shortest_path from sklearn.metrics import roc_auc_score from torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList from torch_geometric.data import Data, InMemoryDataset from torch_geometric.datasets import Planetoid from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, GCNConv, SortAggregation from torch_geometric.transforms import RandomLinkSplit from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix class SEALDataset(InMemoryDataset): def __init__(self, dataset, num_hops, split='train'): self._data = dataset[0] self.num_hops = num_hops super().__init__(dataset.root) index = ['train', 'val', 'test'].index(split) self.load(self.processed_paths[index]) @property def processed_file_names(self): return ['SEAL_train_data.pt', 'SEAL_val_data.pt', 'SEAL_test_data.pt'] def process(self): transform = RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True) train_data, val_data, test_data = transform(self._data) self._max_z = 0 # Collect a list of subgraphs for training, validation and testing: train_pos_data_list = self.extract_enclosing_subgraphs( train_data.edge_index, train_data.pos_edge_label_index, 1) train_neg_data_list = self.extract_enclosing_subgraphs( train_data.edge_index, train_data.neg_edge_label_index, 0) val_pos_data_list = self.extract_enclosing_subgraphs( val_data.edge_index, val_data.pos_edge_label_index, 1) val_neg_data_list = self.extract_enclosing_subgraphs( val_data.edge_index, val_data.neg_edge_label_index, 0) test_pos_data_list = self.extract_enclosing_subgraphs( test_data.edge_index, test_data.pos_edge_label_index, 1) test_neg_data_list = self.extract_enclosing_subgraphs( test_data.edge_index, test_data.neg_edge_label_index, 0) # Convert node labeling to one-hot features. for data in chain(train_pos_data_list, train_neg_data_list, val_pos_data_list, val_neg_data_list, test_pos_data_list, test_neg_data_list): # We solely learn links from structure, dropping any node features: data.x = F.one_hot(data.z, self._max_z + 1).to(torch.float) train_data_list = train_pos_data_list + train_neg_data_list self.save(train_data_list, self.processed_paths[0]) val_data_list = val_pos_data_list + val_neg_data_list self.save(val_data_list, self.processed_paths[1]) test_data_list = test_pos_data_list + test_neg_data_list self.save(test_data_list, self.processed_paths[2]) def extract_enclosing_subgraphs(self, edge_index, edge_label_index, y): data_list = [] for src, dst in edge_label_index.t().tolist(): sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph( [src, dst], self.num_hops, edge_index, relabel_nodes=True) src, dst = mapping.tolist() # Remove target link from the subgraph. mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst) mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src) sub_edge_index = sub_edge_index[:, mask1 & mask2] # Calculate node labeling. z = self.drnl_node_labeling(sub_edge_index, src, dst, num_nodes=sub_nodes.size(0)) data = Data(x=self._data.x[sub_nodes], z=z, edge_index=sub_edge_index, y=y) data_list.append(data) return data_list def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None): # Double-radius node labeling (DRNL). src, dst = (dst, src) if src > dst else (src, dst) adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr() idx = list(range(src)) + list(range(src + 1, adj.shape[0])) adj_wo_src = adj[idx, :][:, idx] idx = list(range(dst)) + list(range(dst + 1, adj.shape[0])) adj_wo_dst = adj[idx, :][:, idx] dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) dist2src = np.insert(dist2src, dst, 0, axis=0) dist2src = torch.from_numpy(dist2src) dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1) dist2dst = np.insert(dist2dst, src, 0, axis=0) dist2dst = torch.from_numpy(dist2dst) dist = dist2src + dist2dst dist_over_2, dist_mod_2 = dist // 2, dist % 2 z = 1 + torch.min(dist2src, dist2dst) z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) z[src] = 1. z[dst] = 1. z[torch.isnan(z)] = 0. self._max_z = max(int(z.max()), self._max_z) return z.to(torch.long) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora') train_dataset = SEALDataset(dataset, num_hops=2, split='train') val_dataset = SEALDataset(dataset, num_hops=2, split='val') test_dataset = SEALDataset(dataset, num_hops=2, split='test') train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32) test_loader = DataLoader(test_dataset, batch_size=32) class DGCNN(torch.nn.Module): def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6): super().__init__() if k < 1: # Transform percentile to number. num_nodes = sorted([data.num_nodes for data in train_dataset]) k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1] k = int(max(10, k)) self.convs = ModuleList() self.convs.append(GNN(train_dataset.num_features, hidden_channels)) for _ in range(0, num_layers - 1): self.convs.append(GNN(hidden_channels, hidden_channels)) self.convs.append(GNN(hidden_channels, 1)) conv1d_channels = [16, 32] total_latent_dim = hidden_channels * num_layers + 1 conv1d_kws = [total_latent_dim, 5] self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]) self.pool = SortAggregation(k) self.maxpool1d = MaxPool1d(2, 2) self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1) dense_dim = int((k - 2) / 2 + 1) dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, norm=None) def forward(self, x, edge_index, batch): xs = [x] for conv in self.convs: xs += [conv(xs[-1], edge_index).tanh()] x = torch.cat(xs[1:], dim=-1) # Global pooling. x = self.pool(x, batch) x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] x = self.conv1(x).relu() x = self.maxpool1d(x) x = self.conv2(x).relu() x = x.view(x.size(0), -1) # [num_graphs, dense_dim] return self.mlp(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = DGCNN(hidden_channels=32, num_layers=3).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001) criterion = BCEWithLogitsLoss() def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = criterion(out.view(-1), data.y.to(torch.float)) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs return total_loss / len(train_dataset) @torch.no_grad() def test(loader): model.eval() y_pred, y_true = [], [] for data in loader: data = data.to(device) logits = model(data.x, data.edge_index, data.batch) y_pred.append(logits.view(-1).cpu()) y_true.append(data.y.view(-1).cpu().to(torch.float)) return roc_auc_score(torch.cat(y_true), torch.cat(y_pred)) times = [] best_val_auc = test_auc = 0 for epoch in range(1, 51): start = time.time() loss = train() val_auc = test(val_loader) if val_auc > best_val_auc: best_val_auc = val_auc test_auc = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ' f'Test: {test_auc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/sgc.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.nn import SGConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SGConv( in_channels=dataset.num_features, out_channels=dataset.num_classes, K=2, cached=True, ) def forward(self): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) return F.log_softmax(x, dim=1) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=0.005) def train(): model.train() optimizer.zero_grad() F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() @torch.no_grad() def test(): model.eval() out, accs = model(), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs best_val_acc = test_acc = 0 for epoch in range(1, 101): train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/shadow.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Flickr from torch_geometric.loader import ShaDowKHopSampler from torch_geometric.nn import SAGEConv, global_mean_pool path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr') dataset = Flickr(path) data = dataset[0] kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True} train_loader = ShaDowKHopSampler(data, depth=2, num_neighbors=5, node_idx=data.train_mask, **kwargs) val_loader = ShaDowKHopSampler(data, depth=2, num_neighbors=5, node_idx=data.val_mask, **kwargs) test_loader = ShaDowKHopSampler(data, depth=2, num_neighbors=5, node_idx=data.test_mask, **kwargs) class GNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv(in_channels, hidden_channels) self.conv2 = SAGEConv(hidden_channels, hidden_channels) self.conv3 = SAGEConv(hidden_channels, hidden_channels) self.lin = torch.nn.Linear(2 * hidden_channels, out_channels) def forward(self, x, edge_index, batch, root_n_id): x = self.conv1(x, edge_index).relu() x = F.dropout(x, p=0.3) x = self.conv2(x, edge_index).relu() x = F.dropout(x, p=0.3, training=self.training) x = self.conv3(x, edge_index).relu() x = F.dropout(x, p=0.3, training=self.training) # We merge both central node embeddings and subgraph embeddings: x = torch.cat([x[root_n_id], global_mean_pool(x, batch)], dim=-1) x = self.lin(x) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GNN(dataset.num_features, 256, dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() total_loss = total_examples = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch, data.root_n_id) loss = F.cross_entropy(out, data.y) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs total_examples += data.num_graphs return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() total_correct = total_examples = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch, data.root_n_id) total_correct += int((out.argmax(dim=-1) == data.y).sum()) total_examples += data.num_graphs return total_correct / total_examples for epoch in range(1, 51): loss = train() val_acc = test(val_loader) test_acc = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, ', f'Val: {val_acc:.4f} Test: {test_acc:.4f}') ================================================ FILE: examples/sign.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear from torch.utils.data import DataLoader import torch_geometric.transforms as T from torch_geometric.datasets import Flickr K = 2 path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr') transform = T.Compose([T.NormalizeFeatures(), T.SIGN(K)]) dataset = Flickr(path, transform=transform) data = dataset[0] train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) val_idx = data.val_mask.nonzero(as_tuple=False).view(-1) test_idx = data.test_mask.nonzero(as_tuple=False).view(-1) train_loader = DataLoader(train_idx, batch_size=16 * 1024, shuffle=True) val_loader = DataLoader(val_idx, batch_size=32 * 1024) test_loader = DataLoader(test_idx, batch_size=32 * 1024) class Net(torch.nn.Module): def __init__(self): super().__init__() self.lins = torch.nn.ModuleList() for _ in range(K + 1): self.lins.append(Linear(dataset.num_node_features, 1024)) self.lin = Linear((K + 1) * 1024, dataset.num_classes) def forward(self, xs): hs = [] for x, lin in zip(xs, self.lins): h = lin(x).relu() h = F.dropout(h, p=0.5, training=self.training) hs.append(h) h = torch.cat(hs, dim=-1) h = self.lin(h) return h.log_softmax(dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(): model.train() total_loss = total_examples = 0 for idx in train_loader: xs = [data.x[idx].to(device)] xs += [data[f'x{i}'][idx].to(device) for i in range(1, K + 1)] y = data.y[idx].to(device) optimizer.zero_grad() out = model(xs) loss = F.nll_loss(out, y) loss.backward() optimizer.step() total_loss += float(loss) * idx.numel() total_examples += idx.numel() return total_loss / total_examples @torch.no_grad() def test(loader): model.eval() total_correct = total_examples = 0 for idx in loader: xs = [data.x[idx].to(device)] xs += [data[f'x{i}'][idx].to(device) for i in range(1, K + 1)] y = data.y[idx].to(device) out = model(xs) total_correct += int((out.argmax(dim=-1) == y).sum()) total_examples += idx.numel() return total_correct / total_examples for epoch in range(1, 201): loss = train() train_acc = test(train_loader) val_acc = test(val_loader) test_acc = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/signed_gcn.py ================================================ import os.path as osp import torch from torch_geometric.datasets import BitcoinOTC from torch_geometric.nn import SignedGCN name = 'BitcoinOTC-1' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) dataset = BitcoinOTC(path, edge_window_size=1) # Generate dataset. pos_edge_indices, neg_edge_indices = [], [] for data in dataset: pos_edge_indices.append(data.edge_index[:, data.edge_attr > 0]) neg_edge_indices.append(data.edge_index[:, data.edge_attr < 0]) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') pos_edge_index = torch.cat(pos_edge_indices, dim=1).to(device) neg_edge_index = torch.cat(neg_edge_indices, dim=1).to(device) # Build and train model. model = SignedGCN(64, 64, num_layers=2, lamb=5).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) train_pos_edge_index, test_pos_edge_index = model.split_edges(pos_edge_index) train_neg_edge_index, test_neg_edge_index = model.split_edges(neg_edge_index) x = model.create_spectral_features(train_pos_edge_index, train_neg_edge_index) def train(): model.train() optimizer.zero_grad() z = model(x, train_pos_edge_index, train_neg_edge_index) loss = model.loss(z, train_pos_edge_index, train_neg_edge_index) loss.backward() optimizer.step() return loss.item() def test(): model.eval() with torch.no_grad(): z = model(x, train_pos_edge_index, train_neg_edge_index) return model.test(z, test_pos_edge_index, test_neg_edge_index) for epoch in range(101): loss = train() auc, f1 = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, AUC: {auc:.4f}, ' f'F1: {f1:.4f}') ================================================ FILE: examples/super_gat.py ================================================ import os.path as osp import time import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import SuperGATConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SuperGATConv(dataset.num_features, 8, heads=8, dropout=0.6, attention_type='MX', edge_sample_ratio=0.8, is_undirected=True) self.conv2 = SuperGATConv(8 * 8, dataset.num_classes, heads=8, concat=False, dropout=0.6, attention_type='MX', edge_sample_ratio=0.8, is_undirected=True) def forward(self, x, edge_index): x = F.dropout(x, p=0.6, training=self.training) x = F.elu(self.conv1(x, edge_index)) att_loss = self.conv1.get_attention_loss() x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x, data.edge_index) att_loss += self.conv2.get_attention_loss() return F.log_softmax(x, dim=-1), att_loss device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) def train(data): model.train() optimizer.zero_grad() out, att_loss = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss += 4.0 * att_loss loss.backward() optimizer.step() @torch.no_grad() def test(data): model.eval() out, accs = model(data.x, data.edge_index)[0], [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs times = [] for epoch in range(1, 501): start = time.time() train(data) train_acc, val_acc, test_acc = test(data) print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') times.append(time.time() - start) print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") ================================================ FILE: examples/tagcn.py ================================================ import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import TAGConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = TAGConv(dataset.num_features, 16) self.conv2 = TAGConv(16, dataset.num_classes) def forward(self): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() @torch.no_grad() def test(): model.eval() out, accs = model(), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs best_val_acc = test_acc = 0 for epoch in range(1, 201): train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/tensorboard_logging.py ================================================ import os.path as osp import torch import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) data = dataset[0] class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index, None)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index, None) return F.log_softmax(x, dim=1) if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') model, data = Net().to(device), data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return loss.item() @torch.no_grad() def test(): model.eval() out, accs = model(data.x, data.edge_index), [] for _, mask in data('train_mask', 'val_mask', 'test_mask'): pred = out[mask].argmax(1) acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() accs.append(acc) return accs model(data.x, data.edge_index) writer = SummaryWriter() writer.add_graph(model, [data.x, data.edge_index]) best_val_acc = test_acc = 0 for epoch in range(1, 201): train_loss = train() train_acc, val_acc, tmp_test_acc = test() if val_acc > best_val_acc: best_val_acc = val_acc test_acc = tmp_test_acc print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, ' f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('Accuracy/train', train_acc, epoch) writer.add_scalar('Accuracy/val', val_acc, epoch) writer.add_scalar('Accuracy/test', test_acc, epoch) ================================================ FILE: examples/tgn.py ================================================ # This code achieves a performance of around 96.60%. However, it is not # directly comparable to the results reported by the TGN paper since a # slightly different evaluation setup is used here. # In particular, predictions in the same batch are made in parallel, i.e. # predictions for interactions later in the batch have no access to any # information whatsoever about previous interactions in the same batch. # On the contrary, when sampling node neighborhoods for interactions later in # the batch, the TGN paper code has access to previous interactions in the # batch. # While both approaches are correct, together with the authors of the paper we # decided to present this version here as it is more realistic and a better # test bed for future methods. import os.path as osp import torch from sklearn.metrics import average_precision_score, roc_auc_score from torch.nn import Linear from torch_geometric.datasets import JODIEDataset from torch_geometric.loader import TemporalDataLoader from torch_geometric.nn import TGNMemory, TransformerConv from torch_geometric.nn.models.tgn import ( IdentityMessage, LastAggregator, LastNeighborLoader, ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE') dataset = JODIEDataset(path, name='wikipedia') data = dataset[0] # For small datasets, we can put the whole dataset on GPU and thus avoid # expensive memory transfer costs for mini-batches: data = data.to(device) train_data, val_data, test_data = data.train_val_test_split( val_ratio=0.15, test_ratio=0.15) train_loader = TemporalDataLoader( train_data, batch_size=200, neg_sampling_ratio=1.0, ) val_loader = TemporalDataLoader( val_data, batch_size=200, neg_sampling_ratio=1.0, ) test_loader = TemporalDataLoader( test_data, batch_size=200, neg_sampling_ratio=1.0, ) neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device) class GraphAttentionEmbedding(torch.nn.Module): def __init__(self, in_channels, out_channels, msg_dim, time_enc): super().__init__() self.time_enc = time_enc edge_dim = msg_dim + time_enc.out_channels self.conv = TransformerConv(in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim) def forward(self, x, last_update, edge_index, t, msg): rel_t = last_update[edge_index[0]] - t rel_t_enc = self.time_enc(rel_t.to(x.dtype)) edge_attr = torch.cat([rel_t_enc, msg], dim=-1) return self.conv(x, edge_index, edge_attr) class LinkPredictor(torch.nn.Module): def __init__(self, in_channels): super().__init__() self.lin_src = Linear(in_channels, in_channels) self.lin_dst = Linear(in_channels, in_channels) self.lin_final = Linear(in_channels, 1) def forward(self, z_src, z_dst): h = self.lin_src(z_src) + self.lin_dst(z_dst) h = h.relu() return self.lin_final(h) memory_dim = time_dim = embedding_dim = 100 memory = TGNMemory( data.num_nodes, data.msg.size(-1), memory_dim, time_dim, message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim), aggregator_module=LastAggregator(), ).to(device) gnn = GraphAttentionEmbedding( in_channels=memory_dim, out_channels=embedding_dim, msg_dim=data.msg.size(-1), time_enc=memory.time_enc, ).to(device) link_pred = LinkPredictor(in_channels=embedding_dim).to(device) optimizer = torch.optim.Adam( set(memory.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.0001) criterion = torch.nn.BCEWithLogitsLoss() # Helper vector to map global node indices to local ones. assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) def train(): memory.train() gnn.train() link_pred.train() memory.reset_state() # Start with a fresh memory. neighbor_loader.reset_state() # Start with an empty graph. total_loss = 0 for batch in train_loader: optimizer.zero_grad() batch = batch.to(device) n_id, edge_index, e_id = neighbor_loader(batch.n_id) assoc[n_id] = torch.arange(n_id.size(0), device=device) # Get updated memory of all nodes involved in the computation. z, last_update = memory(n_id) z = gnn(z, last_update, edge_index, data.t[e_id].to(device), data.msg[e_id].to(device)) pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]]) neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]]) loss = criterion(pos_out, torch.ones_like(pos_out)) loss += criterion(neg_out, torch.zeros_like(neg_out)) # Update memory and neighbor loader with ground-truth state. memory.update_state(batch.src, batch.dst, batch.t, batch.msg) neighbor_loader.insert(batch.src, batch.dst) loss.backward() optimizer.step() memory.detach() total_loss += float(loss) * batch.num_events return total_loss / train_data.num_events @torch.no_grad() def test(loader): memory.eval() gnn.eval() link_pred.eval() torch.manual_seed(12345) # Ensure deterministic sampling across epochs. aps, aucs = [], [] for batch in loader: batch = batch.to(device) n_id, edge_index, e_id = neighbor_loader(batch.n_id) assoc[n_id] = torch.arange(n_id.size(0), device=device) z, last_update = memory(n_id) z = gnn(z, last_update, edge_index, data.t[e_id].to(device), data.msg[e_id].to(device)) pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]]) neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]]) y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu() y_true = torch.cat( [torch.ones(pos_out.size(0)), torch.zeros(neg_out.size(0))], dim=0) aps.append(average_precision_score(y_true, y_pred)) aucs.append(roc_auc_score(y_true, y_pred)) memory.update_state(batch.src, batch.dst, batch.t, batch.msg) neighbor_loader.insert(batch.src, batch.dst) return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean()) for epoch in range(1, 51): loss = train() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') val_ap, val_auc = test(val_loader) test_ap, test_auc = test(test_loader) print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}') print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}') ================================================ FILE: examples/triangles_sag_pool.py ================================================ import copy import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq import torch_geometric.transforms as T from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, GINConv, SAGPooling, global_max_pool from torch_geometric.utils import scatter class HandleNodeAttention: def __call__(self, data): data = copy.copy(data) data.attn = torch.softmax(data.x, dim=0).flatten() data.x = None return data transform = T.Compose([HandleNodeAttention(), T.OneHotDegree(max_degree=14)]) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TRIANGLES') dataset = TUDataset(path, name='TRIANGLES', use_node_attr=True, transform=transform) train_loader = DataLoader(dataset[:30000], batch_size=60, shuffle=True) val_loader = DataLoader(dataset[30000:35000], batch_size=60) test_loader = DataLoader(dataset[35000:], batch_size=60) class Net(torch.nn.Module): def __init__(self, in_channels): super().__init__() self.conv1 = GINConv(Seq(Lin(in_channels, 64), ReLU(), Lin(64, 64))) self.pool1 = SAGPooling(64, min_score=0.001, GNN=GCNConv) self.conv2 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64))) self.pool2 = SAGPooling(64, min_score=0.001, GNN=GCNConv) self.conv3 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64))) self.lin = torch.nn.Linear(64, 1) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, perm, score = self.pool1( x, edge_index, None, batch) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, perm, score = self.pool2( x, edge_index, None, batch) ratio = x.size(0) / data.x.size(0) x = F.relu(self.conv3(x, edge_index)) x = global_max_pool(x, batch) x = self.lin(x).view(-1) attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm], reduction='none') attn_loss = scatter(attn_loss, batch, reduce='mean') return x, attn_loss, ratio device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(dataset.num_features).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out, attn_loss, _ = model(data) loss = ((out - data.y).pow(2) + 100 * attn_loss).mean() loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_loader.dataset) def test(loader): model.eval() corrects, total_ratio = [], 0 for data in loader: data = data.to(device) out, _, ratio = model(data) pred = out.round().to(torch.long) corrects.append(pred.eq(data.y.to(torch.long))) total_ratio += ratio return torch.cat(corrects, dim=0), total_ratio / len(loader) for epoch in range(1, 301): loss = train() train_correct, train_ratio = test(train_loader) val_correct, val_ratio = test(val_loader) test_correct, test_ratio = test(test_loader) train_acc = train_correct.sum().item() / train_correct.size(0) val_acc = val_correct.sum().item() / val_correct.size(0) test_acc1 = test_correct[:5000].sum().item() / 5000 test_acc2 = test_correct[5000:].sum().item() / 5000 print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.3f}, ' f'Val: {val_acc:.3f}, Test Orig: {test_acc1:.3f}, ' f'Test Large: {test_acc2:.3f}, Train/Val/Test Ratio=' f'{train_ratio:.3f}/{val_ratio:.3f}/{test_ratio:.3f}') ================================================ FILE: examples/unimp_arxiv.py ================================================ import os.path as osp import torch import torch.nn.functional as F from ogb.nodeproppred import PygNodePropPredDataset import torch_geometric.transforms as T from torch_geometric.nn import MaskLabel, TransformerConv from torch_geometric.utils import index_to_mask root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB') dataset = PygNodePropPredDataset('ogbn-arxiv', root, T.ToUndirected()) class UniMP(torch.nn.Module): def __init__(self, in_channels, num_classes, hidden_channels, num_layers, heads, dropout=0.3): super().__init__() self.label_emb = MaskLabel(num_classes, in_channels) self.convs = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() for i in range(1, num_layers + 1): if i < num_layers: out_channels = hidden_channels // heads concat = True else: out_channels = num_classes concat = False conv = TransformerConv(in_channels, out_channels, heads, concat=concat, beta=True, dropout=dropout) self.convs.append(conv) in_channels = hidden_channels if i < num_layers: self.norms.append(torch.nn.LayerNorm(hidden_channels)) def forward(self, x, y, edge_index, label_mask): x = self.label_emb(x, y, label_mask) for conv, norm in zip(self.convs, self.norms): x = norm(conv(x, edge_index)).relu() return self.convs[-1](x, edge_index) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = dataset[0].to(device) data.y = data.y.view(-1) model = UniMP(dataset.num_features, dataset.num_classes, hidden_channels=64, num_layers=3, heads=2).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005) split_idx = dataset.get_idx_split() train_mask = index_to_mask(split_idx['train'], size=data.num_nodes) val_mask = index_to_mask(split_idx['valid'], size=data.num_nodes) test_mask = index_to_mask(split_idx['test'], size=data.num_nodes) def train(label_rate=0.65): # How many labels to use for propagation. model.train() propagation_mask = MaskLabel.ratio_mask(train_mask, ratio=label_rate) supervision_mask = train_mask ^ propagation_mask optimizer.zero_grad() out = model(data.x, data.y, data.edge_index, propagation_mask) loss = F.cross_entropy(out[supervision_mask], data.y[supervision_mask]) loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(): model.eval() propagation_mask = train_mask out = model(data.x, data.y, data.edge_index, propagation_mask) pred = out[val_mask].argmax(dim=-1) val_acc = int((pred == data.y[val_mask]).sum()) / pred.size(0) propagation_mask = train_mask | val_mask out = model(data.x, data.y, data.edge_index, propagation_mask) pred = out[test_mask].argmax(dim=-1) test_acc = int((pred == data.y[test_mask]).sum()) / pred.size(0) return val_acc, test_acc for epoch in range(1, 501): loss = train() val_acc, test_acc = test() print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') ================================================ FILE: examples/upfd.py ================================================ import argparse import os.path as osp import torch import torch.nn.functional as F from torch.nn import Linear from torch_geometric.datasets import UPFD from torch_geometric.loader import DataLoader from torch_geometric.nn import GATConv, GCNConv, SAGEConv, global_max_pool from torch_geometric.transforms import ToUndirected parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='politifact', choices=['politifact', 'gossipcop']) parser.add_argument('--feature', type=str, default='spacy', choices=['profile', 'spacy', 'bert', 'content']) parser.add_argument('--model', type=str, default='GCN', choices=['GCN', 'GAT', 'SAGE']) args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'UPFD') train_dataset = UPFD(path, args.dataset, args.feature, 'train', ToUndirected()) val_dataset = UPFD(path, args.dataset, args.feature, 'val', ToUndirected()) test_dataset = UPFD(path, args.dataset, args.feature, 'test', ToUndirected()) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False) class Net(torch.nn.Module): def __init__(self, model, in_channels, hidden_channels, out_channels, concat=False): super().__init__() self.concat = concat if model == 'GCN': self.conv1 = GCNConv(in_channels, hidden_channels) elif model == 'SAGE': self.conv1 = SAGEConv(in_channels, hidden_channels) elif model == 'GAT': self.conv1 = GATConv(in_channels, hidden_channels) if self.concat: self.lin0 = Linear(in_channels, hidden_channels) self.lin1 = Linear(2 * hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) def forward(self, x, edge_index, batch): h = self.conv1(x, edge_index).relu() h = global_max_pool(h, batch) if self.concat: # Get the root node (tweet) features of each graph: root = (batch[1:] - batch[:-1]).nonzero(as_tuple=False).view(-1) root = torch.cat([root.new_zeros(1), root + 1], dim=0) news = x[root] news = self.lin0(news).relu() h = self.lin1(torch.cat([news, h], dim=-1)).relu() h = self.lin2(h) return h.log_softmax(dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net(args.model, train_dataset.num_features, 128, train_dataset.num_classes, concat=True).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) def train(): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() total_correct = total_examples = 0 for data in loader: data = data.to(device) pred = model(data.x, data.edge_index, data.batch).argmax(dim=-1) total_correct += int((pred == data.y).sum()) total_examples += data.num_graphs return total_correct / total_examples for epoch in range(1, 61): loss = train() train_acc = test(train_loader) val_acc = test(val_loader) test_acc = test(test_loader) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') ================================================ FILE: examples/wl_kernel.py ================================================ import argparse import os.path as osp import warnings import torch from sklearn.exceptions import ConvergenceWarning from sklearn.metrics import accuracy_score from sklearn.svm import LinearSVC from torch_geometric.data import Batch from torch_geometric.datasets import TUDataset from torch_geometric.nn import WLConv warnings.filterwarnings('ignore', category=ConvergenceWarning) parser = argparse.ArgumentParser() parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() torch.manual_seed(42) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU') dataset = TUDataset(path, name='ENZYMES') data = Batch.from_data_list(dataset) class WL(torch.nn.Module): def __init__(self, num_layers): super().__init__() self.convs = torch.nn.ModuleList([WLConv() for _ in range(num_layers)]) def forward(self, x, edge_index, batch=None): hists = [] for conv in self.convs: x = conv(x, edge_index) hists.append(conv.histogram(x, batch, norm=True)) return hists wl = WL(num_layers=5) hists = wl(data.x, data.edge_index, data.batch) test_accs = torch.empty(args.runs, dtype=torch.float) for run in range(1, args.runs + 1): perm = torch.randperm(data.num_graphs) val_index = perm[:data.num_graphs // 10] test_index = perm[data.num_graphs // 10:data.num_graphs // 5] train_index = perm[data.num_graphs // 5:] best_val_acc = 0 for hist in hists: train_hist, train_y = hist[train_index], data.y[train_index] val_hist, val_y = hist[val_index], data.y[val_index] test_hist, test_y = hist[test_index], data.y[test_index] for C in [10**3, 10**2, 10**1, 10**0, 10**-1, 10**-2, 10**-3]: model = LinearSVC(C=C, tol=0.01, dual=True) model.fit(train_hist, train_y) val_acc = accuracy_score(val_y, model.predict(val_hist)) if val_acc > best_val_acc: best_val_acc = val_acc test_acc = accuracy_score(test_y, model.predict(test_hist)) test_accs[run - 1] = test_acc print(f'Run: {run:02d}, Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') print(f'Final Test Performance: {test_accs.mean():.4f}±{test_accs.std():.4f}') ================================================ FILE: graphgym/agg_batch.py ================================================ import argparse from torch_geometric.graphgym.utils.agg_runs import agg_batch def parse_args(): """Parses the arguments.""" parser = argparse.ArgumentParser( description='Train a classification model') parser.add_argument('--dir', dest='dir', help='Dir for batch of results', required=True, type=str) parser.add_argument('--metric', dest='metric', help='metric to select best epoch', required=False, type=str, default='auto') return parser.parse_args() args = parse_args() agg_batch(args.dir, args.metric) ================================================ FILE: graphgym/configs/example.yaml ================================================ # The recommended basic settings for GNN out_dir: results dataset: format: PyG name: Cora task: node task_type: classification transductive: true split: [0.8, 0.2] transform: none train: batch_size: 32 eval_period: 20 ckpt_period: 100 model: type: gnn loss_fun: cross_entropy edge_decoding: dot graph_pooling: add gnn: layers_pre_mp: 1 layers_mp: 2 layers_post_mp: 1 dim_inner: 256 layer_type: generalconv stage_type: stack batchnorm: true act: prelu dropout: 0.0 agg: add normalize_adj: false optim: optimizer: adam base_lr: 0.01 max_epoch: 400 ================================================ FILE: graphgym/configs/pyg/example_graph.yaml ================================================ out_dir: results dataset: format: OGB name: ogbg-molhiv task: graph task_type: classification node_encoder: true node_encoder_name: Atom edge_encoder: true edge_encoder_name: Bond train: batch_size: 128 eval_period: 1 ckpt_period: 100 sampler: full_batch model: type: gnn loss_fun: cross_entropy edge_decoding: dot graph_pooling: add gnn: layers_pre_mp: 1 layers_mp: 2 layers_post_mp: 1 dim_inner: 300 layer_type: generalconv stage_type: stack batchnorm: true act: prelu dropout: 0.0 agg: mean normalize_adj: false optim: optimizer: adam base_lr: 0.01 max_epoch: 100 ================================================ FILE: graphgym/configs/pyg/example_link.yaml ================================================ out_dir: results dataset: format: OGB name: ogbl-collab task: link_pred task_type: classification node_encoder: false node_encoder_name: Atom edge_encoder: false edge_encoder_name: Bond train: batch_size: 128 eval_period: 1 ckpt_period: 100 sampler: full_batch model: type: gnn loss_fun: cross_entropy edge_decoding: dot graph_pooling: add gnn: layers_pre_mp: 1 layers_mp: 2 layers_post_mp: 1 dim_inner: 300 layer_type: gcnconv stage_type: stack batchnorm: true act: prelu dropout: 0.0 agg: mean normalize_adj: false optim: optimizer: adam base_lr: 0.01 max_epoch: 100 ================================================ FILE: graphgym/configs/pyg/example_node.yaml ================================================ out_dir: results dataset: format: PyG name: Cora task: node task_type: classification node_encoder: false node_encoder_name: Atom edge_encoder: false edge_encoder_name: Bond train: batch_size: 128 eval_period: 1 ckpt_period: 100 sampler: full_batch model: type: gnn loss_fun: cross_entropy edge_decoding: dot graph_pooling: add gnn: layers_pre_mp: 0 layers_mp: 2 layers_post_mp: 1 dim_inner: 16 layer_type: gcnconv stage_type: stack batchnorm: false act: prelu dropout: 0.1 agg: mean normalize_adj: false optim: optimizer: adam base_lr: 0.01 max_epoch: 200 ================================================ FILE: graphgym/configs_gen.py ================================================ import argparse import copy import csv import os.path as osp import random import numpy as np import yaml from torch_geometric.graphgym.utils.comp_budget import match_baseline_cfg from torch_geometric.graphgym.utils.io import ( makedirs_rm_exist, string_to_python, ) random.seed(123) def parse_args(): """Parses the arguments.""" parser = argparse.ArgumentParser() parser.add_argument('--config', dest='config', help='the base configuration file used for edit', default=None, type=str) parser.add_argument('--grid', dest='grid', help='configuration file for grid search', required=True, type=str) parser.add_argument('--sample_alias', dest='sample_alias', help='configuration file for sample alias', default=None, required=False, type=str) parser.add_argument('--sample_num', dest='sample_num', help='Number of random samples in the space', default=10, type=int) parser.add_argument('--out_dir', dest='out_dir', help='output directory for generated config files', default='configs', type=str) parser.add_argument( '--config_budget', dest='config_budget', help='the base configuration file used for matching computation', default=None, type=str) return parser.parse_args() def get_fname(string): if string is not None: return string.split('/')[-1].split('.')[0] else: return 'default' def grid2list(grid): list_in = [[]] for grid_temp in grid: list_out = [] for val in grid_temp: for list_temp in list_in: list_out.append(list_temp + [val]) list_in = list_out return list_in def lists_distance(l1, l2): assert len(l1) == len(l2) dist = 0 for i in range(len(l1)): if l1[i] != l2[i]: dist += 1 return dist def grid2list_sample(grid, sample=10): configs = [] while len(configs) < sample: config = [] for grid_temp in grid: config.append(random.choice(grid_temp)) if config not in configs: configs.append(config) return configs def load_config(fname): if fname is not None: with open(fname) as f: return yaml.load(f, Loader=yaml.FullLoader) else: return {} def load_search_file(fname): with open(fname) as f: out_raw = csv.reader(f, delimiter=' ') outs = [] out = [] for row in out_raw: if '#' in row: continue elif len(row) > 0: assert len(row) == 3, \ 'Exact 1 space between each grid argument file' \ 'And no spaces within each argument is allowed' out.append(row) else: if len(out) > 0: outs.append(out) out = [] if len(out) > 0: outs.append(out) return outs def load_alias_file(fname): with open(fname) as f: file = csv.reader(f, delimiter=' ') for line in file: # noqa: B007 break return line def exclude_list_id(list, id): return [list[i] for i in range(len(list)) if i != id] def gen_grid(args, config, config_budget=None): if config_budget is None: config_budget = {} task_name = f'{get_fname(args.config)}_grid_{get_fname(args.grid)}' fname_start = get_fname(args.config) out_dir = f'{args.out_dir}/{task_name}' makedirs_rm_exist(out_dir) config['out_dir'] = osp.join(config['out_dir'], task_name) outs = load_search_file(args.grid) for i, out in enumerate(outs): vars_label = [row[0].split('.') for row in out] vars_alias = [row[1] for row in out] vars_value = grid2list([string_to_python(row[2]) for row in out]) if i == 0: print(f'Variable label: {vars_label}') print(f'Variable alias: {vars_alias}') for vars in vars_value: config_out = config.copy() fname_out = fname_start for id, var in enumerate(vars): if len(vars_label[id]) == 1: config_out[vars_label[id][0]] = var elif len(vars_label[id]) == 2: if vars_label[id][0] in config_out: # if key1 exist config_out[vars_label[id][0]][vars_label[id][1]] = var else: config_out[vars_label[id][0]] = { vars_label[id][1]: var } else: raise ValueError('Only 2-level config files are supported') var_repr = str(var).strip("[]").strip("''") # noqa: B005 fname_out += f'-{vars_alias[id]}={var_repr}' if len(config_budget) > 0: config_out = match_baseline_cfg(config_out, config_budget) with open(f'{out_dir}/{fname_out}.yaml', 'w') as f: yaml.dump(config_out, f, default_flow_style=False) print(f'{len(vars_value)} configurations saved to: {out_dir}') def gen_grid_sample(args, config, config_budget=None, compare_alias_list=None): if config_budget is None: config_budget = {} if compare_alias_list is None: compare_alias_list = [] task_name = f'{get_fname(args.config)}_grid_{get_fname(args.grid)}' fname_start = get_fname(args.config) out_dir = f'{args.out_dir}/{task_name}' makedirs_rm_exist(out_dir) config['out_dir'] = osp.join(config['out_dir'], task_name) outs = load_search_file(args.grid) counts = [] for out in outs: vars_grid = [string_to_python(row[2]) for row in out] count = 1 for var in vars_grid: count *= len(var) counts.append(count) counts = np.array(counts) print('Total size of each chunk of experiment space:', counts) counts = counts / np.sum(counts) counts = np.round(counts * args.sample_num) counts[0] += args.sample_num - np.sum(counts) print('Total sample size of each chunk of experiment space:', counts) for i, out in enumerate(outs): vars_label = [row[0].split('.') for row in out] vars_alias = [row[1] for row in out] if i == 0: print(f'Variable label: {vars_label}') print(f'Variable alias: {vars_alias}') vars_grid = [string_to_python(row[2]) for row in out] for alias in compare_alias_list: alias_id = vars_alias.index(alias) vars_grid_select = copy.deepcopy(vars_grid[alias_id]) vars_grid[alias_id] = [vars_grid[alias_id][0]] vars_value = grid2list_sample(vars_grid, counts[i]) vars_value_new = [] for vars in vars_value: for grid in vars_grid_select: vars[alias_id] = grid vars_value_new.append(copy.deepcopy(vars)) vars_value = vars_value_new vars_grid[alias_id] = vars_grid_select for vars in vars_value: config_out = config.copy() fname_out = fname_start + f'-sample={vars_alias[alias_id]}' for id, var in enumerate(vars): if len(vars_label[id]) == 1: config_out[vars_label[id][0]] = var elif len(vars_label[id]) == 2: if vars_label[id][0] in config_out: # if key1 exist config_out[vars_label[id][0]][vars_label[id] [1]] = var else: config_out[vars_label[id][0]] = { vars_label[id][1]: var } else: raise ValueError( 'Only 2-level config files are supported') var_repr = str(var).strip("[]").strip("''") # noqa: B005 fname_out += f'-{vars_alias[id]}={var_repr}' if len(config_budget) > 0: config_out = match_baseline_cfg(config_out, config_budget, verbose=False) with open(f'{out_dir}/{fname_out}.yaml', "w") as f: yaml.dump(config_out, f, default_flow_style=False) print(f'Chunk {i + 1}/{len(outs)}: ' f'Perturbing design dimension {alias}, ' f'{len(vars_value)} configurations saved to: {out_dir}') args = parse_args() config = load_config(args.config) config_budget = load_config(args.config_budget) if args.sample_alias is None: gen_grid(args, config, config_budget) else: alias_list = load_alias_file(args.sample_alias) gen_grid_sample(args, config, config_budget, alias_list) ================================================ FILE: graphgym/custom_graphgym/__init__.py ================================================ from .act import * # noqa from .config import * # noqa from .encoder import * # noqa from .head import * # noqa from .layer import * # noqa from .loader import * # noqa from .loss import * # noqa from .network import * # noqa from .optimizer import * # noqa from .pooling import * # noqa from .stage import * # noqa from .train import * # noqa from .transform import * # noqa ================================================ FILE: graphgym/custom_graphgym/act/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/act/example.py ================================================ from functools import partial import torch import torch.nn as nn from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_act class SWISH(nn.Module): def __init__(self, inplace=False): super().__init__() self.inplace = inplace def forward(self, x): if self.inplace: x.mul_(torch.sigmoid(x)) return x else: return x * torch.sigmoid(x) register_act('swish', partial(SWISH, inplace=cfg.mem.inplace)) register_act('lrelu_03', partial(nn.LeakyReLU, 0.3, inplace=cfg.mem.inplace)) ================================================ FILE: graphgym/custom_graphgym/config/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/config/example.py ================================================ from yacs.config import CfgNode as CN from torch_geometric.graphgym.register import register_config @register_config('example') def set_cfg_example(cfg): r"""This function sets the default config value for customized options :return: customized configuration use by the experiment. """ # ----------------------------------------------------------------------- # # Customized options # ----------------------------------------------------------------------- # # example argument cfg.example_arg = 'example' # example argument group cfg.example_group = CN() # then argument can be specified within the group cfg.example_group.example_arg = 'example' ================================================ FILE: graphgym/custom_graphgym/encoder/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/encoder/example.py ================================================ import torch from ogb.utils.features import get_bond_feature_dims from torch_geometric.graphgym.register import ( register_edge_encoder, register_node_encoder, ) @register_node_encoder('example') class ExampleNodeEncoder(torch.nn.Module): """Provides an encoder for integer node features. Args: num_classes (int): The number of classes for the embedding mapping to learn. """ def __init__(self, emb_dim, num_classes=None): super().__init__() self.encoder = torch.nn.Embedding(num_classes, emb_dim) torch.nn.init.xavier_uniform_(self.encoder.weight.data) def forward(self, batch): # Encode just the first dimension if more exist batch.x = self.encoder(batch.x[:, 0]) return batch @register_edge_encoder('example') class ExampleEdgeEncoder(torch.nn.Module): def __init__(self, emb_dim): super().__init__() self.bond_embedding_list = torch.nn.ModuleList() full_bond_feature_dims = get_bond_feature_dims() for dim in full_bond_feature_dims: emb = torch.nn.Embedding(dim, emb_dim) torch.nn.init.xavier_uniform_(emb.weight.data) self.bond_embedding_list.append(emb) def forward(self, batch): bond_embedding = 0 for i in range(batch.edge_feature.shape[1]): bond_embedding += \ self.bond_embedding_list[i](batch.edge_attr[:, i]) batch.edge_attr = bond_embedding return batch ================================================ FILE: graphgym/custom_graphgym/head/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/head/example.py ================================================ import torch.nn as nn from torch_geometric.graphgym.register import register_head @register_head('head') class ExampleNodeHead(nn.Module): r"""Head of GNN for node prediction.""" def __init__(self, dim_in, dim_out): super().__init__() self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True) def _apply_index(self, batch): if batch.node_label_index.shape[0] == batch.node_label.shape[0]: return batch.x[batch.node_label_index], batch.node_label else: return batch.x[batch.node_label_index], \ batch.node_label[batch.node_label_index] def forward(self, batch): batch = self.layer_post_mp(batch) pred, label = self._apply_index(batch) return pred, label ================================================ FILE: graphgym/custom_graphgym/layer/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/layer/example.py ================================================ import torch import torch.nn as nn from torch.nn import Parameter from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_layer from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import glorot, zeros # Note: A registered GNN layer should take 'batch' as input # and 'batch' as output # Example 1: Directly define a GraphGym format Conv # take 'batch' as input and 'batch' as output @register_layer('exampleconv1') class ExampleConv1(MessagePassing): r"""Example GNN layer.""" def __init__(self, in_channels, out_channels, bias=True, **kwargs): super().__init__(aggr=cfg.gnn.agg, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.weight = Parameter(torch.empty(in_channels, out_channels)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): glorot(self.weight) zeros(self.bias) def forward(self, batch): x, edge_index = batch.x, batch.edge_index x = torch.matmul(x, self.weight) batch.x = self.propagate(edge_index, x=x) return batch def message(self, x_j): return x_j def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out # Example 2: First define a PyG format Conv layer # Then wrap it to become GraphGym format class ExampleConv2Layer(MessagePassing): r"""Example GNN layer.""" def __init__(self, in_channels, out_channels, bias=True, **kwargs): super().__init__(aggr=cfg.gnn.agg, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.weight = Parameter(torch.empty(in_channels, out_channels)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): glorot(self.weight) zeros(self.bias) def forward(self, x, edge_index): x = torch.matmul(x, self.weight) return self.propagate(edge_index, x=x) def message(self, x_j): return x_j def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out @register_layer('exampleconv2') class ExampleConv2(nn.Module): def __init__(self, dim_in, dim_out, bias=False, **kwargs): super().__init__() self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch ================================================ FILE: graphgym/custom_graphgym/loader/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/loader/example.py ================================================ from torch_geometric.datasets import QM7b from torch_geometric.graphgym.register import register_loader @register_loader('example') def load_dataset_example(format, name, dataset_dir): dataset_dir = f'{dataset_dir}/{name}' if format == 'PyG': if name == 'QM7b': dataset_raw = QM7b(dataset_dir) return dataset_raw ================================================ FILE: graphgym/custom_graphgym/loss/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/loss/example.py ================================================ import torch from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_loss @register_loss('smoothl1') def loss_example(pred, true): if cfg.model.loss_fun == 'smoothl1': l1_loss = torch.nn.SmoothL1Loss() loss = l1_loss(pred, true) return loss, pred ================================================ FILE: graphgym/custom_graphgym/network/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/network/example.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch_geometric.graphgym.models.head # noqa, register module import torch_geometric.graphgym.register as register import torch_geometric.nn as pyg_nn from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_network @register_network('example') class ExampleGNN(torch.nn.Module): def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'): super().__init__() conv_model = self.build_conv_model(model_type) self.convs = nn.ModuleList() self.convs.append(conv_model(dim_in, dim_in)) for _ in range(num_layers - 1): self.convs.append(conv_model(dim_in, dim_in)) GNNHead = register.head_dict[cfg.dataset.task] self.post_mp = GNNHead(dim_in=dim_in, dim_out=dim_out) def build_conv_model(self, model_type): if model_type == 'GCN': return pyg_nn.GCNConv elif model_type == 'GAT': return pyg_nn.GATConv elif model_type == "GraphSage": return pyg_nn.SAGEConv else: raise ValueError(f'Model {model_type} unavailable') def forward(self, batch): x, edge_index = batch.x, batch.edge_index for i in range(len(self.convs)): x = self.convs[i](x, edge_index) x = F.relu(x) x = F.dropout(x, p=0.1, training=self.training) batch.x = x batch = self.post_mp(batch) return batch ================================================ FILE: graphgym/custom_graphgym/optimizer/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/optimizer/example.py ================================================ from typing import Iterator from torch.nn import Parameter from torch.optim import Adagrad, Optimizer from torch.optim.lr_scheduler import ReduceLROnPlateau import torch_geometric.graphgym.register as register @register.register_optimizer('adagrad') def adagrad_optimizer(params: Iterator[Parameter], base_lr: float, weight_decay: float) -> Adagrad: return Adagrad(params, lr=base_lr, weight_decay=weight_decay) @register.register_scheduler('pleateau') def plateau_scheduler(optimizer: Optimizer, patience: int, lr_decay: float) -> ReduceLROnPlateau: return ReduceLROnPlateau(optimizer, patience=patience, factor=lr_decay) ================================================ FILE: graphgym/custom_graphgym/pooling/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/pooling/example.py ================================================ from torch_geometric.graphgym.register import register_pooling from torch_geometric.utils import scatter @register_pooling('example') def global_example_pool(x, batch, size=None): size = batch.max().item() + 1 if size is None else size return scatter(x, batch, dim=0, dim_size=size, reduce='sum') ================================================ FILE: graphgym/custom_graphgym/stage/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/stage/example.py ================================================ import torch.nn as nn import torch.nn.functional as F from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.models.layer import GeneralLayer from torch_geometric.graphgym.register import register_stage def GNNLayer(dim_in, dim_out, has_act=True): return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act) @register_stage('example') class GNNStackStage(nn.Module): r"""Simple stage that stacks GNN layers.""" def __init__(self, dim_in, dim_out, num_layers): super().__init__() for i in range(num_layers): d_in = dim_in if i == 0 else dim_out layer = GNNLayer(d_in, dim_out) self.add_module(f'layer{i}', layer) self.dim_out = dim_out def forward(self, batch): for layer in self.children(): batch = layer(batch) if cfg.gnn.l2norm: batch.x = F.normalize(batch.x, p=2, dim=-1) return batch ================================================ FILE: graphgym/custom_graphgym/train/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/custom_graphgym/train/example.py ================================================ import logging import time import torch from torch_geometric.graphgym.checkpoint import ( clean_ckpt, load_ckpt, save_ckpt, ) from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.loss import compute_loss from torch_geometric.graphgym.register import register_train from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch def train_epoch(logger, loader, model, optimizer, scheduler): model.train() time_start = time.time() for batch in loader: optimizer.zero_grad() batch.to(torch.device(cfg.device)) pred, true = model(batch) loss, pred_score = compute_loss(pred, true) loss.backward() optimizer.step() logger.update_stats(true=true.detach().cpu(), pred=pred_score.detach().cpu(), loss=loss.item(), lr=scheduler.get_last_lr()[0], time_used=time.time() - time_start, params=cfg.params) time_start = time.time() scheduler.step() def eval_epoch(logger, loader, model): model.eval() time_start = time.time() for batch in loader: batch.to(torch.device(cfg.device)) pred, true = model(batch) loss, pred_score = compute_loss(pred, true) logger.update_stats(true=true.detach().cpu(), pred=pred_score.detach().cpu(), loss=loss.item(), lr=0, time_used=time.time() - time_start, params=cfg.params) time_start = time.time() @register_train('example') def train_example(loggers, loaders, model, optimizer, scheduler): start_epoch = 0 if cfg.train.auto_resume: start_epoch = load_ckpt(model, optimizer, scheduler, cfg.train.epoch_resume) if start_epoch == cfg.optim.max_epoch: logging.info('Checkpoint found, Task already done') else: logging.info('Start from epoch %s', start_epoch) num_splits = len(loggers) for cur_epoch in range(start_epoch, cfg.optim.max_epoch): train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) loggers[0].write_epoch(cur_epoch) if is_eval_epoch(cur_epoch): for i in range(1, num_splits): eval_epoch(loggers[i], loaders[i], model) loggers[i].write_epoch(cur_epoch) if is_ckpt_epoch(cur_epoch): save_ckpt(model, optimizer, scheduler, cur_epoch) for logger in loggers: logger.close() if cfg.train.ckpt_clean: clean_ckpt() logging.info('Task done, results saved in %s', cfg.run_dir) ================================================ FILE: graphgym/custom_graphgym/transform/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: graphgym/grids/example.txt ================================================ # Format for each row: name in config.py; alias; range to search # No spaces, except between these 3 fields # Line breaks are used to union different grid search spaces # Feel free to add '#' to add comments # (1) dataset configurations dataset.format format ['PyG'] dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] dataset.task task ['graph'] dataset.transductive trans [False] # (2) The recommended GNN design space, 96 models in total gnn.layers_pre_mp l_pre [1,2] gnn.layers_mp l_mp [2,4,6,8] gnn.layers_post_mp l_post [2,3] gnn.stage_type stage ['skipsum','skipconcat'] gnn.agg agg ['add','mean','max'] ================================================ FILE: graphgym/grids/pyg/example.txt ================================================ # Format for each row: name in config.py; alias; range to search # No spaces, except between these 3 fields # Line breaks are used to union different grid search spaces # Feel free to add '#' to add comments gnn.layers_pre_mp l_pre [1,2] gnn.layers_mp l_mp [2,4,6] gnn.layers_post_mp l_post [1,2] gnn.stage_type stage ['stack','skipsum','skipconcat'] gnn.dim_inner dim [64] optim.base_lr lr [0.01] optim.max_epoch epoch [200] ================================================ FILE: graphgym/main.py ================================================ import logging import os import custom_graphgym # noqa, register custom modules import torch from torch_geometric import seed_everything from torch_geometric.graphgym.cmd_args import parse_args from torch_geometric.graphgym.config import ( cfg, dump_cfg, load_cfg, set_out_dir, set_run_dir, ) from torch_geometric.graphgym.logger import set_printing from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.train import GraphGymDataModule, train from torch_geometric.graphgym.utils.agg_runs import agg_runs from torch_geometric.graphgym.utils.comp_budget import params_count from torch_geometric.graphgym.utils.device import auto_select_device if __name__ == '__main__': # Load cmd line args args = parse_args() # Load config file load_cfg(cfg, args) set_out_dir(cfg.out_dir, args.cfg_file) # Set Pytorch environment torch.set_num_threads(cfg.num_threads) dump_cfg(cfg) # Repeat for different random seeds for _ in range(args.repeat): set_run_dir(cfg.out_dir) set_printing() # Set configurations for each run cfg.seed = cfg.seed + 1 seed_everything(cfg.seed) auto_select_device() # Set machine learning pipeline datamodule = GraphGymDataModule() model = create_model() # Print model info logging.info(model) logging.info(cfg) cfg.params = params_count(model) logging.info('Num parameters: %s', cfg.params) train(model, datamodule, logger=True) # Aggregate results from different seeds agg_runs(cfg.out_dir, cfg.metric_best) # When being launched in batch mode, mark a yaml as done if args.mark_done: os.rename(args.cfg_file, f'{args.cfg_file}_done') ================================================ FILE: graphgym/parallel.sh ================================================ CONFIG_DIR=$1 REPEAT=$2 MAX_JOBS=${3:-2} SLEEP=${4:-1} MAIN=${5:-main} ( trap 'kill 0' SIGINT CUR_JOBS=0 for CONFIG in "$CONFIG_DIR"/*.yaml; do if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then ((CUR_JOBS >= MAX_JOBS)) && wait -n python $MAIN.py --cfg $CONFIG --repeat $REPEAT --mark_done & echo $CONFIG sleep $SLEEP ((++CUR_JOBS)) fi done wait ) ================================================ FILE: graphgym/run_batch.sh ================================================ #!/usr/bin/env bash CONFIG=${CONFIG:-example_node} GRID=${GRID:-example} REPEAT=${REPEAT:-3} MAX_JOBS=${MAX_JOBS:-8} SLEEP=${SLEEP:-1} MAIN=${MAIN:-main} # generate configs (after controlling computational budget) # please remove --config_budget, if don't control computational budget python configs_gen.py --config configs/pyg/${CONFIG}.yaml \ --grid grids/pyg/${GRID}.txt \ --out_dir configs #python configs_gen.py --config configs/ChemKG/${CONFIG}.yaml --config_budget configs/ChemKG/${CONFIG}.yaml --grid grids/ChemKG/${GRID}.txt --out_dir configs # run batch of configs # Args: config_dir, num of repeats, max jobs running, sleep time bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP $MAIN # rerun missed / stopped experiments bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP $MAIN # rerun missed / stopped experiments bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP $MAIN # aggregate results for the batch python agg_batch.py --dir results/${CONFIG}_grid_${GRID} ================================================ FILE: graphgym/run_single.sh ================================================ #!/usr/bin/env bash # Test for running a single experiment. --repeat means run how many different random seeds. python main.py --cfg configs/pyg/example_node.yaml --repeat 3 # node classification python main.py --cfg configs/pyg/example_link.yaml --repeat 3 # link prediction python main.py --cfg configs/pyg/example_graph.yaml --repeat 3 # graph classification ================================================ FILE: graphgym/sample/dimensions.txt ================================================ act bn drop agg l_mp l_pre l_post stage batch lr optim epoch ================================================ FILE: graphgym/sample/dimensionsatt.txt ================================================ l_tw ================================================ FILE: pyproject.toml ================================================ [build-system] requires=["flit_core >=3.12,<4"] build-backend="flit_core.buildapi" [project] name="torch-geometric" version="2.8.0" authors=[ {name="Matthias Fey", email="matthias@pyg.org"}, ] description="Graph Neural Network Library for PyTorch" readme="README.md" requires-python=">=3.10" keywords=[ "deep-learning", "pytorch", "geometric-deep-learning", "graph-neural-networks", "graph-convolutional-networks", ] license = "MIT" license-files = ["LICENSE"] classifiers=[ "Development Status :: 5 - Production/Stable", "Programming Language :: Python", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Programming Language :: Python :: 3 :: Only", ] dependencies=[ "aiohttp", "fsspec", "jinja2", "numpy", "psutil>=5.8.0", "pyparsing", "requests", "tqdm", "xxhash", ] [project.optional-dependencies] graphgym=[ "protobuf<4.21", "pytorch-lightning", "yacs", ] modelhub=[ "huggingface_hub" ] benchmark=[ "matplotlib", "networkx", "pandas", "protobuf<4.21", "wandb", ] rag=[ "pcst_fast", "datasets", "transformers", "pandas", "sentencepiece", "accelerate", "torchmetrics", "peft", ] test=[ "onnx", "onnxruntime", "onnxscript", "pytest", "pytest-cov", ] dev=[ "ipython", "matplotlib-inline", "pre-commit", "torch_geometric[test]", ] full = [ "scipy", "scikit-learn", "ase", "captum<0.7.0", "graphviz", "h5py", "matplotlib", "networkx", "numba<0.60.0", "opt_einsum", "pandas", # See https://github.com/pgmpy/pgmpy/issues/2360. # "pgmpy", "pynndescent", "pytorch-memlab", "rdflib", "rdkit", "scikit-image", "statsmodels", "sympy", "tabulate", "torch_geometric[graphgym, modelhub]", "torchmetrics", "trimesh", ] [project.urls] homepage="https://pyg.org" documentation="https://pytorch-geometric.readthedocs.io" repository="https://github.com/pyg-team/pytorch_geometric.git" changelog="https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md" [tool.flit.module] name="torch_geometric" [tool.yapf] based_on_style = "pep8" split_before_named_assigns = false blank_line_before_nested_class_or_def = false [tool.mypy] files = ["torch_geometric"] install_types = true non_interactive = true ignore_missing_imports = true show_error_codes = true warn_redundant_casts = true warn_unused_configs = true warn_unused_ignores = true disallow_untyped_defs = true disallow_incomplete_defs = true [[tool.mypy.overrides]] ignore_errors = true module = [ "torch_geometric.data.*", "torch_geometric.sampler.*", "torch_geometric.loader.*", "torch_geometric.nn.*", "torch_geometric.explain.*", "torch_geometric.profile.*", "torch_geometric.contrib.*", "torch_geometric.graphgym.*", "torch_geometric.distributed.*", "torch_geometric.llm.*", ] [tool.isort] multi_line_output = 3 include_trailing_comma = true skip = [".gitignore", "__init__.py"] [tool.ruff] # https://docs.astral.sh/ruff/rules src = ["torch_geometric"] line-length = 80 indent-width = 4 target-version = "py310" [tool.ruff.lint] select = [ "B", # flake8-bugbear "D", # pydocstyle ] ignore = [ "B905", # TODO Don't ignore "zip with strict=False" "D100", # TODO Don't ignore "Missing docstring in public module" "D101", # TODO Don't ignore "Missing docstring in public class" "D102", # TODO Don't ignore "Missing docstring in public method" "D103", # TODO Don't ignore "Missing docstring in public function" "D104", # TODO Don't ignore "Missing docstring in public package" "D105", # Ignore "Missing docstring in magic method" "D107", # Ignore "Missing docstring in __init__" "D205", # Ignore "blank line required between summary line and description" ] [tool.ruff.format] quote-style = "single" [tool.ruff.lint.pydocstyle] convention = "google" [tool.pytest.ini_options] addopts = [ "--capture=no", "--color=yes", "-vv", ] filterwarnings = [ "ignore:distutils:DeprecationWarning", "ignore:'torch_geometric.contrib' contains experimental code:UserWarning", # Filter `torch` warnings: "ignore:The PyTorch API of nested tensors is in prototype stage:UserWarning", "ignore:scatter_reduce():UserWarning", "ignore:Sparse CSR tensor support is in beta state:UserWarning", "ignore:Sparse CSC tensor support is in beta state:UserWarning", "ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning", # Filter `torch.compile` warnings: "ignore:pkg_resources is deprecated as an API", "ignore:Deprecated call to `pkg_resources.declare_namespace", # Filter `captum` warnings: "ignore:Setting backward hooks on ReLU activations:UserWarning", "ignore:.*did not already require gradients, required_grads has been set automatically:UserWarning", # Filter `pytorch_lightning` warnings: "ignore:GPU available but not used:UserWarning", "error:.*torch_geometric.*:DeprecationWarning", # TODO(rishipuri98): Remove usage of `torch_geometric.distributed` from `torch_geometric.llm` "ignore:.*torch_geometric.distributed.*:DeprecationWarning", # Filter `torch.jit.*` deprication warnings: "ignore:.*torch.jit.*:DeprecationWarning", ] markers = [ "rag: mark test as RAG test", ] [tool.coverage.run] source = ["torch_geometric"] omit = [ "torch_geometric/distributed/*", "torch_geometric/datasets/*", "torch_geometric/data/extract.py", "torch_geometric/nn/data_parallel.py", ] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "pass", "raise NotImplementedError", "register_parameter", "torch.cuda.is_available", ] [tool.setuptools] py-modules = [] ================================================ FILE: readthedocs.yml ================================================ version: 2 sphinx: configuration: docs/source/conf.py build: os: ubuntu-24.04 tools: python: "3.10" python: install: - requirements: docs/requirements.txt - method: pip path: . formats: [] ================================================ FILE: test/conftest.py ================================================ import functools import logging import os.path as osp from typing import Callable import pytest import torch import torch_geometric.typing from torch_geometric.data import Dataset from torch_geometric.io import fs def load_dataset(root: str, name: str, *args, **kwargs) -> Dataset: r"""Returns a variety of datasets according to :obj:`name`.""" if 'karate' in name.lower(): from torch_geometric.datasets import KarateClub return KarateClub(*args, **kwargs) if name.lower() in ['cora', 'citeseer', 'pubmed']: from torch_geometric.datasets import Planetoid path = osp.join(root, 'Planetoid', name) return Planetoid(path, name, *args, **kwargs) if name in ['BZR', 'ENZYMES', 'IMDB-BINARY', 'MUTAG']: from torch_geometric.datasets import TUDataset path = osp.join(root, 'TUDataset') return TUDataset(path, name, *args, **kwargs) if name in ['ego-facebook', 'soc-Slashdot0811', 'wiki-vote']: from torch_geometric.datasets import SNAPDataset path = osp.join(root, 'SNAPDataset') return SNAPDataset(path, name, *args, **kwargs) if name.lower() in ['bashapes']: from torch_geometric.datasets import BAShapes return BAShapes(*args, **kwargs) if name in ['citationCiteseer', 'illc1850']: from torch_geometric.datasets import SuiteSparseMatrixCollection path = osp.join(root, 'SuiteSparseMatrixCollection') return SuiteSparseMatrixCollection(path, *args, name=name, **kwargs) if 'elliptic' in name.lower(): from torch_geometric.datasets import EllipticBitcoinDataset path = osp.join(root, 'EllipticBitcoinDataset') return EllipticBitcoinDataset(path, *args, **kwargs) if name.lower() in ['hetero']: from torch_geometric.testing import FakeHeteroDataset return FakeHeteroDataset(*args, **kwargs) raise ValueError(f"Cannot load dataset with name '{name}'") @pytest.fixture(scope='session') def get_dataset() -> Callable: # TODO Support memory filesystem on Windows. if torch_geometric.typing.WITH_WINDOWS: root = osp.join('/', 'tmp', 'pyg_test_datasets') else: root = 'memory://pyg_test_datasets' yield functools.partial(load_dataset, root) if fs.exists(root): fs.rm(root) @pytest.fixture def enable_extensions(): # Nothing to do. yield @pytest.fixture def disable_extensions(): def is_setting(name: str) -> bool: if not name.startswith('WITH_'): return False if name.startswith('WITH_PT') or name.startswith('WITH_WINDOWS'): return False return True settings = dir(torch_geometric.typing) settings = [key for key in settings if is_setting(key)] state = {key: getattr(torch_geometric.typing, key) for key in settings} for key in state.keys(): setattr(torch_geometric.typing, key, False) yield for key, value in state.items(): setattr(torch_geometric.typing, key, value) @pytest.fixture def without_extensions(request): request.getfixturevalue(request.param) return request.param == 'disable_extensions' @pytest.fixture(scope='function') def spawn_context(): torch.multiprocessing.set_start_method('spawn', force=True) logging.info("Setting torch.multiprocessing context to 'spawn'") ================================================ FILE: test/contrib/explain/test_pgm_explainer.py ================================================ import numpy as np import pytest import torch from torch_geometric.contrib.explain import PGMExplainer from torch_geometric.explain import Explainer from torch_geometric.explain.config import ModelConfig from torch_geometric.nn import GCNConv, global_add_pool from torch_geometric.testing import onlyLinux, withPackage class GCN(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config if model_config.mode.value == 'multiclass_classification': out_channels = 7 else: out_channels = 1 self.conv1 = GCNConv(3, 16) self.conv2 = GCNConv(16, out_channels) def forward(self, x, edge_index, edge_weight=None, batch=None, **kwargs): x = self.conv1(x, edge_index, edge_weight).relu() x = self.conv2(x, edge_index, edge_weight).relu() if self.model_config.task_level.value == 'graph': x = global_add_pool(x, batch) if self.model_config.mode.value == 'binary_classification': x = x.sigmoid() elif self.model_config.mode.value == 'multiclass_classification': x = x.log_softmax(dim=-1) return x x = torch.randn(8, 3) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) target = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2]) edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]]) @onlyLinux @withPackage('pgmpy', 'pandas') @pytest.mark.parametrize('node_idx', [2, 6]) @pytest.mark.parametrize('task_level, perturbation_mode', [ ('node', 'randint'), ('graph', 'mean'), ('graph', 'max'), ('graph', 'min'), ('graph', 'zero'), ]) def test_pgm_explainer_classification(node_idx, task_level, perturbation_mode): model_config = ModelConfig( mode='multiclass_classification', task_level=task_level, return_type='raw', ) model = GCN(model_config) logits = model(x, edge_index) target = logits.argmax(dim=1) explainer = Explainer( model=model, algorithm=PGMExplainer(feature_index=[0], perturbation_mode=perturbation_mode), explanation_type='phenomenon', node_mask_type='object', model_config=model_config, ) explanation = explainer( x=x, edge_index=edge_index, index=node_idx, target=target, ) assert 'node_mask' in explanation assert 'pgm_stats' in explanation assert explanation.node_mask.size(0) == explanation.num_nodes assert explanation.node_mask.min() >= 0 assert explanation.node_mask.max() <= 1 class DummyModel(torch.nn.Module): def forward(self, x, edge_index, **kwargs): return torch.tensor([[0.2, 0.8]], requires_grad=True) def test_batch_perturb_features_on_node(): model = DummyModel() explainer = PGMExplainer(num_samples=1) # just one sample for testing # Minimal graph data with 1 node and 2 features x = torch.randn(1, 2) edge_index = torch.tensor([[0], [0]]) # dummy self-loop indices_to_perturb = np.array([0]) # only node 0 can be perturbed # Simulate kwargs that would include prediction details kwargs = { "soft_pred": torch.tensor([0.4, 0.6]), # soft prediction of original input "pred_label": 1, "num_nodes": 1, } samples = explainer._batch_perturb_features_on_node( model=model, x=x, edge_index=edge_index, indices_to_perturb=indices_to_perturb, **kwargs) assert isinstance(samples, torch.Tensor) assert samples.shape == (1, 2) assert torch.all(samples[0] >= 0) # pred_change should be non-negative ================================================ FILE: test/contrib/nn/models/test_rbcd_attack.py ================================================ import pytest import torch from torch.nn import Linear from torch_geometric.contrib.nn import GRBCDAttack, PRBCDAttack from torch_geometric.nn import GCNConv, global_add_pool from torch_geometric.utils import to_undirected class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(3, 16) self.conv2 = GCNConv(16, 7) def forward(self, x, edge_index, edge_weight): x = self.conv1(x, edge_index, edge_weight).relu() x = self.conv2(x, edge_index, edge_weight) return x class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(3, 16) self.conv2 = GCNConv(16, 16) self.lin = Linear(16, 7) def forward(self, x, edge_index, edge_weight, batch=None): x = self.conv1(x, edge_index, edge_weight).relu() x = self.conv2(x, edge_index, edge_weight).relu() x = global_add_pool(x, batch) x = self.lin(x) return x @pytest.mark.parametrize('model', [GCN, GNN]) @pytest.mark.parametrize('budget', [1]) @pytest.mark.parametrize('loss', ['masked', 'margin', 'prob_margin', 'tanh_margin']) @pytest.mark.parametrize('is_undirected', [False, True]) @pytest.mark.parametrize('with_early_stopping', [False, True]) def test_prbcd_attack(model, budget, loss, is_undirected, with_early_stopping): attack = PRBCDAttack(model(), block_size=10_000, epochs=4, epochs_resampling=2, loss=loss, max_final_samples=2, log=False, is_undirected=is_undirected, with_early_stopping=with_early_stopping) assert str(attack) == 'PRBCDAttack()' x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7]]) if is_undirected: edge_index = to_undirected(edge_index) if model == GNN: y = torch.tensor([0]) # All nodes belong to same graph: kwargs = dict(batch=edge_index.new_zeros(x.size(0))) else: y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0]) kwargs = {} pert_edge_index, pert = attack.attack(x, edge_index, y, budget, **kwargs) m = edge_index.size(1) if budget == 1: assert pert.size() in [(2, 0), (2, 1)] if pert.size(1): if is_undirected: possible_m = [m - 2, m + 2] else: possible_m = [m - 1, m + 1] else: possible_m = [m] assert pert_edge_index.size(1) in possible_m @pytest.mark.parametrize('model', [GCN, GNN]) @pytest.mark.parametrize('budget', [1]) @pytest.mark.parametrize('is_undirected', [False, True]) def test_grbcd_attack(model, budget, is_undirected): attack = GRBCDAttack(model(), block_size=10_000, epochs=4, log=False, is_undirected=is_undirected) assert str(attack) == 'GRBCDAttack()' x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7]]) if is_undirected: edge_index = to_undirected(edge_index) if model == GNN: y = torch.tensor([0]) # All nodes belong to same graph: kwargs = dict(batch=edge_index.new_zeros(x.size(0))) else: y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0]) kwargs = {} pert_edge_index, pert = attack.attack(x, edge_index, y, budget, **kwargs) m = edge_index.size(1) if budget == 1: assert pert.size() == (2, 1) if is_undirected: possible_m = [m - 2, m + 2] else: possible_m = [m - 1, m + 1] assert pert_edge_index.size(1) in possible_m ================================================ FILE: test/data/lightning/test_datamodule.py ================================================ import math from contextlib import contextmanager import pytest import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.lightning import ( LightningDataset, LightningLinkData, LightningNodeData, ) from torch_geometric.nn import global_mean_pool from torch_geometric.sampler import BaseSampler, NeighborSampler from torch_geometric.testing import ( MyFeatureStore, MyGraphStore, get_random_edge_index, has_package, onlyCUDA, onlyFullTest, onlyNeighborSampler, onlyOnline, withPackage, ) try: from pytorch_lightning import LightningModule except ImportError: LightningModule = torch.nn.Module class LinearGraphModule(LightningModule): def __init__(self, in_channels: int, hidden_channels: int, out_channels: int): super().__init__() from torchmetrics import Accuracy self.lin1 = torch.nn.Linear(in_channels, hidden_channels) self.lin2 = torch.nn.Linear(hidden_channels, out_channels) self.train_acc = Accuracy(task='multiclass', num_classes=out_channels) self.val_acc = Accuracy(task='multiclass', num_classes=out_channels) self.test_acc = Accuracy(task='multiclass', num_classes=out_channels) def forward(self, x: Tensor, batch: Data) -> Tensor: # Basic test to ensure that the dataset is not replicated: self.trainer.datamodule.train_dataset._data.x.add_(1) x = self.lin1(x).relu() x = global_mean_pool(x, batch) x = self.lin2(x) return x def training_step(self, data: Data, batch_idx: int): y_hat = self(data.x, data.batch) loss = F.cross_entropy(y_hat, data.y) self.train_acc(y_hat.softmax(dim=-1), data.y) self.log('loss', loss, batch_size=data.num_graphs) self.log('train_acc', self.train_acc, batch_size=data.num_graphs) return loss def validation_step(self, data: Data, batch_idx: int): y_hat = self(data.x, data.batch) self.val_acc(y_hat.softmax(dim=-1), data.y) self.log('val_acc', self.val_acc, batch_size=data.num_graphs) def test_step(self, data: Data, batch_idx: int): y_hat = self(data.x, data.batch) self.test_acc(y_hat.softmax(dim=-1), data.y) self.log('test_acc', self.test_acc, batch_size=data.num_graphs) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) @onlyCUDA @onlyOnline @onlyFullTest @withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0') @pytest.mark.parametrize('strategy_type', [None, 'ddp']) def test_lightning_dataset(get_dataset, strategy_type): import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only @contextmanager def expect_rank_zero_user_warning(match: str): if rank_zero_only.rank == 0: with pytest.warns(UserWarning, match=match): yield else: yield dataset = get_dataset(name='MUTAG').shuffle() train_dataset = dataset[:50] val_dataset = dataset[50:80] test_dataset = dataset[80:90] pred_dataset = dataset[90:] devices = 1 if strategy_type is None else torch.cuda.device_count() if strategy_type == 'ddp': strategy = pl.strategies.DDPStrategy(accelerator='gpu') else: strategy = pl.strategies.SingleDeviceStrategy(device='cuda:0') model = LinearGraphModule(dataset.num_features, 64, dataset.num_classes) trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=1, log_every_n_steps=1) with pytest.warns(UserWarning, match="'shuffle=True' option is ignored"): datamodule = LightningDataset(train_dataset, val_dataset, test_dataset, pred_dataset, batch_size=5, num_workers=3, shuffle=True) assert 'shuffle' not in datamodule.kwargs old_x = train_dataset._data.x.clone() if has_package('pytorch_lightning>=2.5.0'): datamodule_repr = ('{Train dataloader: size=50}\n' '{Validation dataloader: size=30}\n' '{Test dataloader: size=10}\n' '{Predict dataloader: size=98}') else: datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), ' 'val_dataset=MUTAG(30), ' 'test_dataset=MUTAG(10), ' 'pred_dataset=MUTAG(98), batch_size=5, ' 'num_workers=3, pin_memory=True, ' 'persistent_workers=True)') assert str(datamodule) == datamodule_repr trainer.fit(model, datamodule) trainer.test(model, datamodule) new_x = train_dataset._data.x assert torch.all(new_x > old_x) # Ensure shared data. assert trainer.validate_loop._data_source.is_defined() assert trainer.test_loop._data_source.is_defined() # Test with `val_dataset=None` and `test_dataset=None`: if strategy_type is None: trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=1, log_every_n_steps=1) datamodule = LightningDataset(train_dataset, batch_size=5) if has_package('pytorch_lightning>=2.5.0'): datamodule_repr = ('{Train dataloader: size=50}\n' '{Validation dataloader: None}\n' '{Test dataloader: None}\n{' 'Predict dataloader: None}') else: datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), ' 'batch_size=5, num_workers=0, ' 'pin_memory=True, ' 'persistent_workers=False)') assert str(datamodule) == datamodule_repr with expect_rank_zero_user_warning("defined a `validation_step`"): trainer.fit(model, datamodule) assert not trainer.validate_loop._data_source.is_defined() assert not trainer.test_loop._data_source.is_defined() class LinearNodeModule(LightningModule): def __init__(self, in_channels: int, out_channels: int): super().__init__() from torchmetrics import Accuracy self.lin = torch.nn.Linear(in_channels, out_channels) self.train_acc = Accuracy(task='multiclass', num_classes=out_channels) self.val_acc = Accuracy(task='multiclass', num_classes=out_channels) self.test_acc = Accuracy(task='multiclass', num_classes=out_channels) def forward(self, x: Tensor) -> Tensor: # Basic test to ensure that the dataset is not replicated: self.trainer.datamodule.data.x.add_(1) return self.lin(x) def training_step(self, data: Data, batch_idx: int): y_hat = self(data.x)[data.train_mask] y = data.y[data.train_mask] loss = F.cross_entropy(y_hat, y) self.train_acc(y_hat.softmax(dim=-1), y) self.log('loss', loss, batch_size=y.size(0)) self.log('train_acc', self.train_acc, batch_size=y.size(0)) return loss def validation_step(self, data: Data, batch_idx: int): y_hat = self(data.x)[data.val_mask] y = data.y[data.val_mask] self.val_acc(y_hat.softmax(dim=-1), y) self.log('val_acc', self.val_acc, batch_size=y.size(0)) def test_step(self, data: Data, batch_idx: int): y_hat = self(data.x)[data.test_mask] y = data.y[data.test_mask] self.test_acc(y_hat.softmax(dim=-1), y) self.log('test_acc', self.test_acc, batch_size=y.size(0)) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) @onlyCUDA @onlyOnline @onlyFullTest @onlyNeighborSampler @withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0', 'scipy') @pytest.mark.parametrize('loader', ['full', 'neighbor']) @pytest.mark.parametrize('strategy_type', [None, 'ddp']) def test_lightning_node_data(get_dataset, strategy_type, loader): import pytorch_lightning as pl dataset = get_dataset(name='Cora') data = dataset[0] data_repr = ('Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], ' 'train_mask=[2708], val_mask=[2708], test_mask=[2708])') model = LinearNodeModule(dataset.num_features, dataset.num_classes) if strategy_type is None or loader == 'full': devices = 1 else: devices = torch.cuda.device_count() if strategy_type == 'ddp': strategy = pl.strategies.DDPStrategy(accelerator='gpu') else: strategy = pl.strategies.SingleDeviceStrategy(device='cuda:0') if loader == 'full': # Set reasonable defaults for full-batch training: batch_size = 1 num_workers = 0 else: batch_size = 32 num_workers = 3 kwargs, kwargs_repr = {}, '' if loader == 'neighbor': kwargs['num_neighbors'] = [5] kwargs_repr += 'num_neighbors=[5], ' trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=5, log_every_n_steps=1) datamodule = LightningNodeData(data, loader=loader, batch_size=batch_size, num_workers=num_workers, **kwargs) old_x = data.x.clone().cpu() flag = loader != 'full' if has_package('pytorch_lightning>=2.5.0'): datamodule_repr = ( '{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n' '{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n' '{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n' '{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}') else: datamodule_repr = (f'LightningNodeData(data={data_repr}, ' f'loader={loader}, batch_size={batch_size}, ' f'num_workers={num_workers}, {kwargs_repr}' f'pin_memory={flag}, ' f'persistent_workers={flag})') assert str(datamodule) == datamodule_repr trainer.fit(model, datamodule) trainer.test(model, datamodule) new_x = data.x.cpu() assert torch.all(new_x > old_x) # Ensure shared data. assert trainer.validate_loop._data_source.is_defined() assert trainer.test_loop._data_source.is_defined() class LinearHeteroNodeModule(LightningModule): def __init__(self, in_channels: int, out_channels: int): super().__init__() from torchmetrics import Accuracy self.lin = torch.nn.Linear(in_channels, out_channels) self.train_acc = Accuracy(task='multiclass', num_classes=out_channels) self.val_acc = Accuracy(task='multiclass', num_classes=out_channels) self.test_acc = Accuracy(task='multiclass', num_classes=out_channels) def forward(self, x: Tensor) -> Tensor: # Basic test to ensure that the dataset is not replicated: self.trainer.datamodule.data['paper'].x.add_(1) return self.lin(x) def training_step(self, data: HeteroData, batch_idx: int): y_hat = self(data['paper'].x)[data['paper'].train_mask] y = data['paper'].y[data['paper'].train_mask] loss = F.cross_entropy(y_hat, y) self.train_acc(y_hat.softmax(dim=-1), y) self.log('loss', loss, batch_size=y.size(0)) self.log('train_acc', self.train_acc, batch_size=y.size(0)) return loss def validation_step(self, data: HeteroData, batch_idx: int): y_hat = self(data['paper'].x)[data['paper'].val_mask] y = data['paper'].y[data['paper'].val_mask] self.val_acc(y_hat.softmax(dim=-1), y) self.log('val_acc', self.val_acc, batch_size=y.size(0)) def test_step(self, data: HeteroData, batch_idx: int): y_hat = self(data['paper'].x)[data['paper'].test_mask] y = data['paper'].y[data['paper'].test_mask] self.test_acc(y_hat.softmax(dim=-1), y) self.log('test_acc', self.test_acc, batch_size=y.size(0)) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) @pytest.fixture def preserve_context(): num_threads = torch.get_num_threads() yield if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() torch.set_num_threads(num_threads) @onlyCUDA @onlyFullTest @onlyNeighborSampler @withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0') def test_lightning_hetero_node_data(preserve_context, get_dataset): import pytorch_lightning as pl data = get_dataset(name='hetero')[0] model = LinearHeteroNodeModule(data['paper'].num_features, int(data['paper'].y.max()) + 1) devices = torch.cuda.device_count() strategy = pl.strategies.DDPStrategy(accelerator='gpu') trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=5, log_every_n_steps=1) datamodule = LightningNodeData(data, loader='neighbor', num_neighbors=[5], batch_size=32, num_workers=3) assert isinstance(datamodule.graph_sampler, NeighborSampler) original_x = data['paper'].x.clone() trainer.fit(model, datamodule) trainer.test(model, datamodule) assert torch.all(data['paper'].x > original_x) # Ensure shared data. assert trainer.validate_loop._data_source.is_defined() assert trainer.test_loop._data_source.is_defined() @withPackage('pytorch_lightning') def test_lightning_data_custom_sampler(): class DummySampler(BaseSampler): def sample_from_edges(self, *args, **kwargs): pass def sample_from_nodes(self, *args, **kwargs): pass data = Data(num_nodes=2, edge_index=torch.tensor([[0, 1], [1, 0]])) datamodule = LightningNodeData(data, node_sampler=DummySampler(), input_train_nodes=torch.arange(2)) assert isinstance(datamodule.graph_sampler, DummySampler) datamodule = LightningLinkData( data, link_sampler=DummySampler(), input_train_edges=torch.tensor([[0, 1], [0, 1]])) assert isinstance(datamodule.graph_sampler, DummySampler) @onlyCUDA @onlyFullTest @onlyNeighborSampler @withPackage('pytorch_lightning') def test_lightning_hetero_link_data(): torch.manual_seed(12345) data = HeteroData() data['paper'].x = torch.arange(10) data['author'].x = torch.arange(10) data['term'].x = torch.arange(10) data['paper', 'author'].edge_index = get_random_edge_index(10, 10, 10) data['author', 'paper'].edge_index = get_random_edge_index(10, 10, 10) data['paper', 'term'].edge_index = get_random_edge_index(10, 10, 10) data['author', 'term'].edge_index = get_random_edge_index(10, 10, 10) datamodule = LightningLinkData( data, input_train_edges=('author', 'paper'), input_val_edges=('paper', 'author'), input_test_edges=('paper', 'term'), input_pred_edges=('author', 'term'), loader='neighbor', num_neighbors=[5], batch_size=32, num_workers=0, ) assert isinstance(datamodule.graph_sampler, NeighborSampler) assert isinstance(datamodule.eval_graph_sampler, NeighborSampler) for batch in datamodule.train_dataloader(): assert 'edge_label_index' in batch['author', 'paper'] for batch in datamodule.val_dataloader(): assert 'edge_label_index' in batch['paper', 'author'] for batch in datamodule.test_dataloader(): assert 'edge_label_index' in batch['paper', 'term'] for batch in datamodule.predict_dataloader(): assert 'edge_label_index' in batch['author', 'term'] data['author'].time = torch.arange(data['author'].num_nodes) data['paper'].time = torch.arange(data['paper'].num_nodes) data['term'].time = torch.arange(data['term'].num_nodes) datamodule = LightningLinkData( data, input_train_edges=('author', 'paper'), input_train_time=torch.arange(data['author', 'paper'].num_edges), loader='neighbor', num_neighbors=[5], batch_size=32, num_workers=0, time_attr='time', ) for batch in datamodule.train_dataloader(): assert 'edge_label_index' in batch['author', 'paper'] assert 'edge_label_time' in batch['author', 'paper'] @onlyNeighborSampler @withPackage('pytorch_lightning') def test_lightning_hetero_link_data_custom_store(): torch.manual_seed(12345) feature_store = MyFeatureStore() graph_store = MyGraphStore() x = torch.arange(10) feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None) feature_store.put_tensor(x, group_name='author', attr_name='x', index=None) feature_store.put_tensor(x, group_name='term', attr_name='x', index=None) edge_index = get_random_edge_index(10, 10, 10) graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), edge_type=('paper', 'to', 'author'), layout='coo', size=(10, 10)) graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), edge_type=('author', 'to', 'paper'), layout='coo', size=(10, 10)) graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), edge_type=('paper', 'to', 'term'), layout='coo', size=(10, 10)) datamodule = LightningLinkData( (feature_store, graph_store), input_train_edges=('author', 'to', 'paper'), loader='neighbor', num_neighbors=[5], batch_size=32, num_workers=0, ) batch = next(iter(datamodule.train_dataloader())) assert 'edge_label_index' in batch['author', 'paper'] @onlyOnline @onlyNeighborSampler @withPackage('pytorch_lightning', 'scipy') def test_eval_loader_kwargs(get_dataset): data = get_dataset(name='Cora')[0] node_sampler = NeighborSampler(data, num_neighbors=[5]) datamodule = LightningNodeData( data, node_sampler=node_sampler, batch_size=32, eval_loader_kwargs=dict(num_neighbors=[-1], batch_size=64), ) assert datamodule.loader_kwargs['batch_size'] == 32 assert datamodule.graph_sampler.num_neighbors.values == [5] assert datamodule.eval_loader_kwargs['batch_size'] == 64 assert datamodule.eval_graph_sampler.num_neighbors.values == [-1] train_loader = datamodule.train_dataloader() assert math.ceil(int(data.train_mask.sum()) / 32) == len(train_loader) val_loader = datamodule.val_dataloader() assert math.ceil(int(data.val_mask.sum()) / 64) == len(val_loader) test_loader = datamodule.test_dataloader() assert math.ceil(int(data.test_mask.sum()) / 64) == len(test_loader) pred_loader = datamodule.predict_dataloader() assert math.ceil(data.num_nodes / 64) == len(pred_loader) ================================================ FILE: test/data/test_batch.py ================================================ import os.path as osp import numpy as np import pytest import torch import torch_geometric from torch_geometric import EdgeIndex, Index from torch_geometric.data import Batch, Data, HeteroData from torch_geometric.testing import get_random_edge_index, withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_edge_index, to_torch_sparse_tensor def test_batch_basic(): torch_geometric.set_debug(True) x = torch.tensor([1.0, 2.0, 3.0]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data1 = Data(x=x, y=1, edge_index=edge_index, string='1', array=['1', '2'], num_nodes=3) x = torch.tensor([1.0, 2.0]) edge_index = torch.tensor([[0, 1], [1, 0]]) data2 = Data(x=x, y=2, edge_index=edge_index, string='2', array=['3', '4', '5'], num_nodes=2) x = torch.tensor([1.0, 2.0, 3.0, 4.0]) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) data3 = Data(x=x, y=3, edge_index=edge_index, string='3', array=['6', '7', '8', '9'], num_nodes=4) batch = Batch.from_data_list([data1]) assert str(batch) == ('DataBatch(x=[3], edge_index=[2, 4], y=[1], ' 'string=[1], array=[1], num_nodes=3, batch=[3], ' 'ptr=[2])') assert batch.num_graphs == len(batch) == 1 assert batch.x.tolist() == [1, 2, 3] assert batch.y.tolist() == [1] assert batch.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert batch.string == ['1'] assert batch.array == [['1', '2']] assert batch.num_nodes == 3 assert batch.batch.tolist() == [0, 0, 0] assert batch.ptr.tolist() == [0, 3] batch = Batch.from_data_list([data1, data2, data3], follow_batch=['string']) assert str(batch) == ('DataBatch(x=[9], edge_index=[2, 12], y=[3], ' 'string=[3], string_batch=[3], string_ptr=[4], ' 'array=[3], num_nodes=9, batch=[9], ptr=[4])') assert batch.num_graphs == len(batch) == 3 assert batch.x.tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] assert batch.y.tolist() == [1, 2, 3] assert batch.edge_index.tolist() == [[0, 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8], [1, 0, 2, 1, 4, 3, 6, 5, 7, 6, 8, 7]] assert batch.string == ['1', '2', '3'] assert batch.string_batch.tolist() == [0, 1, 2] assert batch.string_ptr.tolist() == [0, 1, 2, 3] assert batch.array == [['1', '2'], ['3', '4', '5'], ['6', '7', '8', '9']] assert batch.num_nodes == 9 assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] assert batch.ptr.tolist() == [0, 3, 5, 9] assert str(batch[0]) == ("Data(x=[3], edge_index=[2, 4], y=[1], " "string='1', array=[2], num_nodes=3)") assert str(batch[1]) == ("Data(x=[2], edge_index=[2, 2], y=[1], " "string='2', array=[3], num_nodes=2)") assert str(batch[2]) == ("Data(x=[4], edge_index=[2, 6], y=[1], " "string='3', array=[4], num_nodes=4)") assert len(batch.index_select([1, 0])) == 2 assert len(batch.index_select(torch.tensor([1, 0]))) == 2 assert len(batch.index_select(torch.tensor([True, False]))) == 1 assert len(batch.index_select(np.array([1, 0], dtype=np.int64))) == 2 assert len(batch.index_select(np.array([True, False]))) == 1 assert len(batch[:2]) == 2 data_list = batch.to_data_list() assert len(data_list) == 3 assert len(data_list[0]) == 6 assert data_list[0].x.tolist() == [1, 2, 3] assert data_list[0].y.tolist() == [1] assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert data_list[0].string == '1' assert data_list[0].array == ['1', '2'] assert data_list[0].num_nodes == 3 assert len(data_list[1]) == 6 assert data_list[1].x.tolist() == [1, 2] assert data_list[1].y.tolist() == [2] assert data_list[1].edge_index.tolist() == [[0, 1], [1, 0]] assert data_list[1].string == '2' assert data_list[1].array == ['3', '4', '5'] assert data_list[1].num_nodes == 2 assert len(data_list[2]) == 6 assert data_list[2].x.tolist() == [1, 2, 3, 4] assert data_list[2].y.tolist() == [3] assert data_list[2].edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] assert data_list[2].string == '3' assert data_list[2].array == ['6', '7', '8', '9'] assert data_list[2].num_nodes == 4 torch_geometric.set_debug(True) def test_index(): index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True) data1 = Data(index=index1, num_nodes=3) data2 = Data(index=index2, num_nodes=4) batch = Batch.from_data_list([data1, data2]) assert len(batch) == 2 assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1])) assert batch.ptr.equal(torch.tensor([0, 3, 7])) assert isinstance(batch.index, Index) assert batch.index.equal(torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6])) assert batch.index.dim_size == 7 assert batch.index.is_sorted for i, index in enumerate([index1, index2]): data = batch[i] assert isinstance(data.index, Index) assert data.index.equal(index) assert data.index.dim_size == index.dim_size assert data.index.is_sorted == index.is_sorted def test_edge_index(): edge_index1 = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), sort_order='row', is_undirected=True, ) edge_index2 = EdgeIndex( [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]], sparse_size=(4, 4), sort_order='col', ) data1 = Data(edge_index=edge_index1) data2 = Data(edge_index=edge_index2) batch = Batch.from_data_list([data1, data2]) assert len(batch) == 2 assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1])) assert batch.ptr.equal(torch.tensor([0, 3, 7])) assert isinstance(batch.edge_index, EdgeIndex) assert batch.edge_index.equal( torch.tensor([ [0, 1, 1, 2, 4, 3, 5, 4, 6, 5], [1, 0, 2, 1, 3, 4, 4, 5, 5, 6], ])) assert batch.edge_index.sparse_size() == (7, 7) assert batch.edge_index.sort_order is None assert not batch.edge_index.is_undirected for i, edge_index in enumerate([edge_index1, edge_index2]): data = batch[i] assert isinstance(data.edge_index, EdgeIndex) assert data.edge_index.equal(edge_index) assert data.edge_index.sparse_size() == edge_index.sparse_size() assert data.edge_index.sort_order == edge_index.sort_order assert data.edge_index.is_undirected == edge_index.is_undirected @withPackage('torch_sparse') def test_batch_with_sparse_tensor(): x = SparseTensor.from_dense(torch.tensor([[1.0], [2.0], [3.0]])) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) adj = SparseTensor.from_edge_index(edge_index) data1 = Data(x=x, adj=adj) x = SparseTensor.from_dense(torch.tensor([[1.0], [2.0]])) edge_index = torch.tensor([[0, 1], [1, 0]]) adj = SparseTensor.from_edge_index(edge_index) data2 = Data(x=x, adj=adj) x = SparseTensor.from_dense(torch.tensor([[1.0], [2.0], [3.0], [4.0]])) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) adj = SparseTensor.from_edge_index(edge_index) data3 = Data(x=x, adj=adj) batch = Batch.from_data_list([data1]) assert str(batch) == ('DataBatch(x=[3, 1, nnz=3], adj=[3, 3, nnz=4], ' 'batch=[3], ptr=[2])') assert batch.num_graphs == len(batch) == 1 assert batch.x.to_dense().tolist() == [[1], [2], [3]] assert batch.adj.coo()[0].tolist() == [0, 1, 1, 2] assert batch.adj.coo()[1].tolist() == [1, 0, 2, 1] assert batch.batch.tolist() == [0, 0, 0] assert batch.ptr.tolist() == [0, 3] batch = Batch.from_data_list([data1, data2, data3]) assert str(batch) == ('DataBatch(x=[9, 1, nnz=9], adj=[9, 9, nnz=12], ' 'batch=[9], ptr=[4])') assert batch.num_graphs == len(batch) == 3 assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] assert batch.adj.coo()[0].tolist() == [0, 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8] assert batch.adj.coo()[1].tolist() == [1, 0, 2, 1, 4, 3, 6, 5, 7, 6, 8, 7] assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] assert batch.ptr.tolist() == [0, 3, 5, 9] assert str(batch[0]) == ("Data(x=[3, 1, nnz=3], adj=[3, 3, nnz=4])") assert str(batch[1]) == ("Data(x=[2, 1, nnz=2], adj=[2, 2, nnz=2])") assert str(batch[2]) == ("Data(x=[4, 1, nnz=4], adj=[4, 4, nnz=6])") data_list = batch.to_data_list() assert len(data_list) == 3 assert len(data_list[0]) == 2 assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]] assert data_list[0].adj.coo()[0].tolist() == [0, 1, 1, 2] assert data_list[0].adj.coo()[1].tolist() == [1, 0, 2, 1] assert len(data_list[1]) == 2 assert data_list[1].x.to_dense().tolist() == [[1], [2]] assert data_list[1].adj.coo()[0].tolist() == [0, 1] assert data_list[1].adj.coo()[1].tolist() == [1, 0] assert len(data_list[2]) == 2 assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]] assert data_list[2].adj.coo()[0].tolist() == [0, 1, 1, 2, 2, 3] assert data_list[2].adj.coo()[1].tolist() == [1, 0, 2, 1, 3, 2] def test_batch_with_torch_coo_tensor(): x = torch.tensor([[1.0], [2.0], [3.0]]).to_sparse_coo() data1 = Data(x=x) x = torch.tensor([[1.0], [2.0]]).to_sparse_coo() data2 = Data(x=x) x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]).to_sparse_coo() data3 = Data(x=x) batch = Batch.from_data_list([data1]) assert str(batch) == ('DataBatch(x=[3, 1], batch=[3], ptr=[2])') assert batch.num_graphs == len(batch) == 1 assert batch.x.to_dense().tolist() == [[1], [2], [3]] assert batch.batch.tolist() == [0, 0, 0] assert batch.ptr.tolist() == [0, 3] batch = Batch.from_data_list([data1, data2, data3]) assert str(batch) == ('DataBatch(x=[9, 1], batch=[9], ptr=[4])') assert batch.num_graphs == len(batch) == 3 assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] assert batch.ptr.tolist() == [0, 3, 5, 9] assert str(batch[0]) == ("Data(x=[3, 1])") assert str(batch[1]) == ("Data(x=[2, 1])") assert str(batch[2]) == ("Data(x=[4, 1])") data_list = batch.to_data_list() assert len(data_list) == 3 assert len(data_list[0]) == 1 assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]] assert len(data_list[1]) == 1 assert data_list[1].x.to_dense().tolist() == [[1], [2]] assert len(data_list[2]) == 1 assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]] def test_batching_with_new_dimension(): torch_geometric.set_debug(True) class MyData(Data): def __cat_dim__(self, key, value, *args, **kwargs): if key == 'foo': return None else: return super().__cat_dim__(key, value, *args, **kwargs) x1 = torch.tensor([1, 2, 3], dtype=torch.float) foo1 = torch.randn(4) y1 = torch.tensor(1) x2 = torch.tensor([1, 2], dtype=torch.float) foo2 = torch.randn(4) y2 = torch.tensor(2) batch = Batch.from_data_list( [MyData(x=x1, foo=foo1, y=y1), MyData(x=x2, foo=foo2, y=y2)]) assert str(batch) == ('MyDataBatch(x=[5], y=[2], foo=[2, 4], batch=[5], ' 'ptr=[3])') assert batch.num_graphs == len(batch) == 2 assert batch.x.tolist() == [1, 2, 3, 1, 2] assert batch.foo.size() == (2, 4) assert batch.foo[0].tolist() == foo1.tolist() assert batch.foo[1].tolist() == foo2.tolist() assert batch.y.tolist() == [1, 2] assert batch.batch.tolist() == [0, 0, 0, 1, 1] assert batch.ptr.tolist() == [0, 3, 5] assert batch.num_graphs == 2 data = batch[0] assert str(data) == ('MyData(x=[3], y=[1], foo=[4])') data = batch[1] assert str(data) == ('MyData(x=[2], y=[1], foo=[4])') torch_geometric.set_debug(True) def test_pickling(tmp_path): data = Data(x=torch.randn(5, 16)) batch = Batch.from_data_list([data, data, data, data]) assert id(batch._store._parent()) == id(batch) assert batch.num_nodes == 20 # filename = f'{random.randrange(sys.maxsize)}.pt' path = osp.join(tmp_path, 'batch.pt') torch.save(batch, path) assert id(batch._store._parent()) == id(batch) assert batch.num_nodes == 20 batch = torch.load(path, weights_only=False) assert id(batch._store._parent()) == id(batch) assert batch.num_nodes == 20 assert batch.__class__.__name__ == 'DataBatch' assert batch.num_graphs == len(batch) == 4 def test_recursive_batch(): data1 = Data( x={ '1': torch.randn(10, 32), '2': torch.randn(20, 48) }, edge_index=[ get_random_edge_index(30, 30, 50), get_random_edge_index(30, 30, 70) ], num_nodes=30, ) data2 = Data( x={ '1': torch.randn(20, 32), '2': torch.randn(40, 48) }, edge_index=[ get_random_edge_index(60, 60, 80), get_random_edge_index(60, 60, 90) ], num_nodes=60, ) batch = Batch.from_data_list([data1, data2]) assert batch.num_graphs == len(batch) == 2 assert batch.num_nodes == 90 assert torch.allclose(batch.x['1'], torch.cat([data1.x['1'], data2.x['1']], dim=0)) assert torch.allclose(batch.x['2'], torch.cat([data1.x['2'], data2.x['2']], dim=0)) assert (batch.edge_index[0].tolist() == torch.cat( [data1.edge_index[0], data2.edge_index[0] + 30], dim=1).tolist()) assert (batch.edge_index[1].tolist() == torch.cat( [data1.edge_index[1], data2.edge_index[1] + 30], dim=1).tolist()) assert batch.batch.size() == (90, ) assert batch.ptr.size() == (3, ) out1 = batch[0] assert len(out1) == 3 assert out1.num_nodes == 30 assert torch.allclose(out1.x['1'], data1.x['1']) assert torch.allclose(out1.x['2'], data1.x['2']) assert out1.edge_index[0].tolist(), data1.edge_index[0].tolist() assert out1.edge_index[1].tolist(), data1.edge_index[1].tolist() out2 = batch[1] assert len(out2) == 3 assert out2.num_nodes == 60 assert torch.allclose(out2.x['1'], data2.x['1']) assert torch.allclose(out2.x['2'], data2.x['2']) assert out2.edge_index[0].tolist(), data2.edge_index[0].tolist() assert out2.edge_index[1].tolist(), data2.edge_index[1].tolist() def test_batching_of_batches(): data = Data(x=torch.randn(2, 16)) batch = Batch.from_data_list([data, data]) batch = Batch.from_data_list([batch, batch]) assert batch.num_graphs == len(batch) == 2 assert batch.x[0:2].tolist() == data.x.tolist() assert batch.x[2:4].tolist() == data.x.tolist() assert batch.x[4:6].tolist() == data.x.tolist() assert batch.x[6:8].tolist() == data.x.tolist() assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] def test_hetero_batch(): e1 = ('p', 'a') e2 = ('a', 'p') data1 = HeteroData() data1['p'].x = torch.randn(100, 128) data1['a'].x = torch.randn(200, 128) data1[e1].edge_index = get_random_edge_index(100, 200, 500) data1[e1].edge_attr = torch.randn(500, 32) data1[e2].edge_index = get_random_edge_index(200, 100, 400) data1[e2].edge_attr = torch.randn(400, 32) data2 = HeteroData() data2['p'].x = torch.randn(50, 128) data2['a'].x = torch.randn(100, 128) data2[e1].edge_index = get_random_edge_index(50, 100, 300) data2[e1].edge_attr = torch.randn(300, 32) data2[e2].edge_index = get_random_edge_index(100, 50, 200) data2[e2].edge_attr = torch.randn(200, 32) batch = Batch.from_data_list([data1, data2]) assert batch.num_graphs == len(batch) == 2 assert batch.num_nodes == 450 assert torch.allclose(batch['p'].x[:100], data1['p'].x) assert torch.allclose(batch['a'].x[:200], data1['a'].x) assert torch.allclose(batch['p'].x[100:], data2['p'].x) assert torch.allclose(batch['a'].x[200:], data2['a'].x) assert (batch[e1].edge_index.tolist() == torch.cat([ data1[e1].edge_index, data2[e1].edge_index + torch.tensor([[100], [200]]) ], 1).tolist()) assert torch.allclose( batch[e1].edge_attr, torch.cat([data1[e1].edge_attr, data2[e1].edge_attr], 0)) assert (batch[e2].edge_index.tolist() == torch.cat([ data1[e2].edge_index, data2[e2].edge_index + torch.tensor([[200], [100]]) ], 1).tolist()) assert torch.allclose( batch[e2].edge_attr, torch.cat([data1[e2].edge_attr, data2[e2].edge_attr], 0)) assert batch['p'].batch.size() == (150, ) assert batch['p'].ptr.size() == (3, ) assert batch['a'].batch.size() == (300, ) assert batch['a'].ptr.size() == (3, ) out1 = batch[0] assert len(out1) == 3 assert out1.num_nodes == 300 assert torch.allclose(out1['p'].x, data1['p'].x) assert torch.allclose(out1['a'].x, data1['a'].x) assert out1[e1].edge_index.tolist() == data1[e1].edge_index.tolist() assert torch.allclose(out1[e1].edge_attr, data1[e1].edge_attr) assert out1[e2].edge_index.tolist() == data1[e2].edge_index.tolist() assert torch.allclose(out1[e2].edge_attr, data1[e2].edge_attr) out2 = batch[1] assert len(out2) == 3 assert out2.num_nodes == 150 assert torch.allclose(out2['p'].x, data2['p'].x) assert torch.allclose(out2['a'].x, data2['a'].x) assert out2[e1].edge_index.tolist() == data2[e1].edge_index.tolist() assert torch.allclose(out2[e1].edge_attr, data2[e1].edge_attr) assert out2[e2].edge_index.tolist() == data2[e2].edge_index.tolist() assert torch.allclose(out2[e2].edge_attr, data2[e2].edge_attr) def test_pair_data_batching(): class PairData(Data): def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index_s': return self.x_s.size(0) if key == 'edge_index_t': return self.x_t.size(0) return super().__inc__(key, value, *args, **kwargs) x_s = torch.randn(5, 16) edge_index_s = torch.tensor([ [0, 0, 0, 0], [1, 2, 3, 4], ]) x_t = torch.randn(4, 16) edge_index_t = torch.tensor([ [0, 0, 0], [1, 2, 3], ]) data = PairData(x_s=x_s, edge_index_s=edge_index_s, x_t=x_t, edge_index_t=edge_index_t) batch = Batch.from_data_list([data, data]) assert torch.allclose(batch.x_s, torch.cat([x_s, x_s], dim=0)) assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5], [1, 2, 3, 4, 6, 7, 8, 9]] assert torch.allclose(batch.x_t, torch.cat([x_t, x_t], dim=0)) assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4], [1, 2, 3, 5, 6, 7]] def test_batch_with_empty_list(): x = torch.randn(4, 1) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) data = Data(x=x, edge_index=edge_index, nontensor=[]) batch = Batch.from_data_list([data, data]) assert batch.nontensor == [[], []] assert batch[0].nontensor == [] assert batch[1].nontensor == [] def test_nested_follow_batch(): def tr(n, m): return torch.rand((n, m)) d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={"aa": tr(11, 3)}, x=tr(10, 5)) d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={"aa": tr(2, 3)}, x=tr(11, 5)) d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={"aa": tr(4, 3)}, x=tr(9, 5)) d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)}, x=tr(8, 5)) data_list = [d1, d2, d3, d4] batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a']) assert batch.xs[0].shape == (19, 3) assert batch.xs[1].shape == (56, 4) assert batch.xs[2].shape == (7, 2) assert batch.a['aa'].shape == (25, 3) assert len(batch.xs_batch) == 3 assert len(batch.a_batch) == 1 assert batch.xs_batch[0].tolist() == \ [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3] assert batch.xs_batch[1].tolist() == \ [0] * 11 + [1] * 14 + [2] * 15 + [3] * 16 assert batch.xs_batch[2].tolist() == \ [0] * 1 + [1] * 3 + [2] * 2 + [3] * 1 assert batch.a_batch['aa'].tolist() == \ [0] * 11 + [1] * 2 + [2] * 4 + [3] * 8 @withPackage('torch>=2.0.0') @pytest.mark.parametrize('layout', [ torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, ]) def test_torch_sparse_batch(layout): x_dense = torch.randn(3, 4) x = x_dense.to_sparse(layout=layout) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_attr = torch.rand(4) adj = to_torch_sparse_tensor(edge_index, edge_attr, layout=layout) data = Data(x=x, adj=adj) batch = Batch.from_data_list([data, data]) assert batch.x.size() == (6, 4) assert batch.x.layout in {torch.sparse_coo, torch.sparse_csr} assert torch.equal(batch.x.to_dense(), torch.cat([x_dense, x_dense], 0)) assert batch.adj.size() == (6, 6) assert batch.adj.layout == layout out = to_edge_index(batch.adj.to_sparse(layout=torch.sparse_csr)) assert torch.equal(out[0], torch.cat([edge_index, edge_index + 3], 1)) assert torch.equal(out[1], torch.cat([edge_attr, edge_attr], 0)) def test_torch_nested_batch(): from torch.nested import nested_tensor class MyData(Data): def __inc__(self, key, value, *args, **kwargs) -> int: return 2 x1 = nested_tensor([torch.randn(3), torch.randn(4)]) data1 = MyData(x=x1) assert str(data1) == 'MyData(x=[2, 4])' x2 = nested_tensor([torch.randn(3), torch.randn(4), torch.randn(5)]) data2 = MyData(x=x2) assert str(data2) == 'MyData(x=[3, 5])' batch = Batch.from_data_list([data1, data2]) assert str(batch) == 'MyDataBatch(x=[5, 5], batch=[5], ptr=[3])' expected = nested_tensor(list(x1.unbind() + (x2 + 2).unbind())) assert torch.equal( batch.x.to_padded_tensor(0.0), expected.to_padded_tensor(0.0), ) ================================================ FILE: test/data/test_data.py ================================================ import copy import pytest import torch import torch.multiprocessing as mp import torch_geometric from torch_geometric.data import Data from torch_geometric.data.storage import AttrType from torch_geometric.testing import get_random_tensor_frame, withPackage def test_data(): torch_geometric.set_debug(True) x = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float).t() edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]]) data = Data(x=x, edge_index=edge_index).to(torch.device('cpu')) data.validate(raise_on_error=True) N = data.num_nodes assert N == 3 assert data.node_attrs() == ['x'] assert data.edge_attrs() == ['edge_index'] assert data.x.tolist() == x.tolist() assert data['x'].tolist() == x.tolist() assert data.get('x').tolist() == x.tolist() assert data.get('y', 2) == 2 assert data.get('y', None) is None assert data.num_edge_types == 1 assert data.num_node_types == 1 assert next(data('x')) == ('x', x) assert sorted(data.keys()) == ['edge_index', 'x'] assert len(data) == 2 assert 'x' in data and 'edge_index' in data and 'pos' not in data data.apply_(lambda x: x.mul_(2), 'x') assert torch.allclose(data.x, x) data.requires_grad_('x') assert data.x.requires_grad is True D = data.to_dict() assert len(D) == 2 assert 'x' in D and 'edge_index' in D D = data.to_namedtuple() assert len(D) == 2 assert D.x is not None and D.edge_index is not None assert data.__cat_dim__('x', data.x) == 0 assert data.__cat_dim__('edge_index', data.edge_index) == -1 assert data.__inc__('x', data.x) == 0 assert data.__inc__('edge_index', data.edge_index) == data.num_nodes assert not data.x.is_contiguous() data.contiguous() assert data.x.is_contiguous() assert not data.is_coalesced() data = data.coalesce() assert data.is_coalesced() clone = data.clone() assert clone != data assert len(clone) == len(data) assert clone.x.data_ptr() != data.x.data_ptr() assert clone.x.tolist() == data.x.tolist() assert clone.edge_index.data_ptr() != data.edge_index.data_ptr() assert clone.edge_index.tolist() == data.edge_index.tolist() # Test `data.to_heterogeneous()`: out = data.to_heterogeneous() assert torch.allclose(data.x, out['0'].x) assert torch.allclose(data.edge_index, out['0', '0'].edge_index) data.edge_type = torch.tensor([0, 0, 1, 0]) out = data.to_heterogeneous() assert torch.allclose(data.x, out['0'].x) assert [store.num_edges for store in out.edge_stores] == [3, 1] data.edge_type = None data['x'] = x + 1 assert data.x.tolist() == (x + 1).tolist() assert str(data) == 'Data(x=[3, 2], edge_index=[2, 4])' dictionary = {'x': data.x, 'edge_index': data.edge_index} data = Data.from_dict(dictionary) assert sorted(data.keys()) == ['edge_index', 'x'] assert not data.has_isolated_nodes() assert not data.has_self_loops() assert data.is_undirected() assert not data.is_directed() assert data.num_nodes == 3 assert data.num_edges == 4 with pytest.warns(UserWarning, match='deprecated'): assert data.num_faces is None assert data.num_node_features == 2 assert data.num_features == 2 data.edge_attr = torch.randn(data.num_edges, 2) assert data.num_edge_features == 2 data.edge_attr = None data.x = None with pytest.warns(UserWarning, match='Unable to accurately infer'): assert data.num_nodes == 3 data.edge_index = None with pytest.warns(UserWarning, match='Unable to accurately infer'): assert data.num_nodes is None assert data.num_edges == 0 data.num_nodes = 4 assert data.num_nodes == 4 data = Data(x=x, attribute=x) assert len(data) == 2 assert data.x.tolist() == x.tolist() assert data.attribute.tolist() == x.tolist() face = torch.tensor([[0, 1], [1, 2], [2, 3]]) data = Data(num_nodes=4, face=face) with pytest.warns(UserWarning, match='deprecated'): assert data.num_faces == 2 assert data.num_nodes == 4 data = Data(title='test') assert str(data) == "Data(title='test')" assert data.num_node_features == 0 assert data.num_edge_features == 0 key = value = 'test_value' data[key] = value assert data[key] == value del data[value] del data[value] # Deleting unset attributes should work as well. assert data.get(key) is None assert data.get('title') == 'test' torch_geometric.set_debug(False) def test_data_attr_cache(): x = torch.randn(3, 16) edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]]) edge_attr = torch.randn(5, 4) y = torch.tensor([0]) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) assert data.is_node_attr('x') assert 'x' in data._store._cached_attr[AttrType.NODE] assert 'x' not in data._store._cached_attr[AttrType.EDGE] assert 'x' not in data._store._cached_attr[AttrType.OTHER] assert not data.is_node_attr('edge_index') assert 'edge_index' not in data._store._cached_attr[AttrType.NODE] assert 'edge_index' in data._store._cached_attr[AttrType.EDGE] assert 'edge_index' not in data._store._cached_attr[AttrType.OTHER] assert data.is_edge_attr('edge_attr') assert 'edge_attr' not in data._store._cached_attr[AttrType.NODE] assert 'edge_attr' in data._store._cached_attr[AttrType.EDGE] assert 'edge_attr' not in data._store._cached_attr[AttrType.OTHER] assert not data.is_edge_attr('y') assert 'y' not in data._store._cached_attr[AttrType.NODE] assert 'y' not in data._store._cached_attr[AttrType.EDGE] assert 'y' in data._store._cached_attr[AttrType.OTHER] def test_data_attr_cache_not_shared(): x = torch.rand((4, 4)) edge_index = torch.tensor([[0, 1, 2, 3, 0, 1], [0, 1, 2, 3, 0, 1]]) time = torch.arange(edge_index.size(1)) data = Data(x=x, edge_index=edge_index, time=time) assert data.is_node_attr('x') out = data.up_to(3.5) # This is expected behavior due to the ambiguity of between node-level and # edge-level tensors when they share the same number of nodes/edges. assert out.is_node_attr('time') assert not data.is_node_attr('time') def test_to_heterogeneous_empty_edge_index(): data = Data( x=torch.randn(5, 10), edge_index=torch.empty(2, 0, dtype=torch.long), ) hetero_data = data.to_heterogeneous() assert hetero_data.node_types == ['0'] assert hetero_data.edge_types == [] assert len(hetero_data) == 1 assert torch.equal(hetero_data['0'].x, data.x) hetero_data = data.to_heterogeneous( node_type_names=['0'], edge_type_names=[('0', 'to', '0')], ) assert hetero_data.node_types == ['0'] assert hetero_data.edge_types == [('0', 'to', '0')] assert len(hetero_data) == 2 assert torch.equal(hetero_data['0'].x, data.x) assert torch.equal(hetero_data['0', 'to', '0'].edge_index, data.edge_index) def test_data_subgraph(): x = torch.arange(5) y = torch.tensor([0.]) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3]]) edge_weight = torch.arange(edge_index.size(1)) data = Data(x=x, y=y, edge_index=edge_index, edge_weight=edge_weight, num_nodes=5) out = data.subgraph(torch.tensor([1, 2, 3])) assert len(out) == 5 assert torch.equal(out.x, torch.arange(1, 4)) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert torch.equal(out.edge_weight, edge_weight[torch.arange(2, 6)]) assert out.num_nodes == 3 # Test unordered selection: out = data.subgraph(torch.tensor([3, 1, 2])) assert len(out) == 5 assert torch.equal(out.x, torch.tensor([3, 1, 2])) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[1, 2, 2, 0], [2, 1, 0, 2]] assert torch.equal(out.edge_weight, edge_weight[torch.arange(2, 6)]) assert out.num_nodes == 3 out = data.subgraph(torch.tensor([False, False, False, True, True])) assert len(out) == 5 assert torch.equal(out.x, torch.arange(3, 5)) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[0, 1], [1, 0]] assert torch.equal(out.edge_weight, edge_weight[torch.arange(6, 8)]) assert out.num_nodes == 2 out = data.edge_subgraph(torch.tensor([1, 2, 3])) assert len(out) == 5 assert out.num_nodes == data.num_nodes assert torch.equal(out.x, data.x) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[1, 1, 2], [0, 2, 1]] assert torch.equal(out.edge_weight, edge_weight[torch.tensor([1, 2, 3])]) out = data.edge_subgraph( torch.tensor([False, True, True, True, False, False, False, False])) assert len(out) == 5 assert out.num_nodes == data.num_nodes assert torch.equal(out.x, data.x) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[1, 1, 2], [0, 2, 1]] assert torch.equal(out.edge_weight, edge_weight[torch.tensor([1, 2, 3])]) def test_data_subgraph_with_list_field(): x = torch.arange(5) y = list(range(5)) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3]]) data = Data(x=x, y=y, edge_index=edge_index) out = data.subgraph(torch.tensor([1, 2, 3])) assert len(out) == 3 assert out.x.tolist() == out.y == [1, 2, 3] out = data.subgraph(torch.tensor([False, True, True, True, False])) assert len(out) == 3 assert out.x.tolist() == out.y == [1, 2, 3] def test_data_empty_subgraph(): data = Data(x=torch.arange(5), y=torch.tensor(0.0)) out = data.subgraph(torch.tensor([1, 2, 3])) assert 'edge_index' not in out assert torch.equal(out.x, torch.arange(1, 4)) assert torch.equal(out.y, data.y) assert out.num_nodes == 3 def test_copy_data(): data = Data(x=torch.randn(20, 5)) out = copy.copy(data) assert id(data) != id(out) assert id(data._store) != id(out._store) assert len(data.stores) == len(out.stores) for store1, store2 in zip(data.stores, out.stores): assert id(store1) != id(store2) assert id(data) == id(store1._parent()) assert id(out) == id(store2._parent()) assert data.x.data_ptr() == out.x.data_ptr() out = copy.deepcopy(data) assert id(data) != id(out) assert id(data._store) != id(out._store) assert len(data.stores) == len(out.stores) for store1, store2 in zip(data.stores, out.stores): assert id(store1) != id(store2) assert id(data) == id(store1._parent()) assert id(out) == id(store2._parent()) assert data.x.data_ptr() != out.x.data_ptr() assert data.x.tolist() == out.x.tolist() def test_data_sort(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 2, 1, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.randn(6, 8) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) assert not data.is_sorted(sort_by_row=True) assert not data.is_sorted(sort_by_row=False) out = data.sort(sort_by_row=True) assert out.is_sorted(sort_by_row=True) assert not out.is_sorted(sort_by_row=False) assert torch.equal(out.x, data.x) assert out.edge_index.tolist() == [[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]] assert torch.equal( out.edge_attr, data.edge_attr[torch.tensor([0, 1, 2, 4, 3, 5])], ) out = data.sort(sort_by_row=False) assert not out.is_sorted(sort_by_row=True) assert out.is_sorted(sort_by_row=False) assert torch.equal(out.x, data.x) assert out.edge_index.tolist() == [[1, 2, 3, 0, 0, 0], [0, 0, 0, 1, 2, 3]] assert torch.equal( out.edge_attr, data.edge_attr[torch.tensor([4, 3, 5, 0, 1, 2])], ) def test_debug_data(): torch_geometric.set_debug(True) Data() Data(edge_index=torch.zeros((2, 0), dtype=torch.long), num_nodes=10) Data(face=torch.zeros((3, 0), dtype=torch.long), num_nodes=10) Data(edge_index=torch.tensor([[0, 1], [1, 0]]), edge_attr=torch.randn(2)) Data(x=torch.torch.randn(5, 3), num_nodes=5) Data(pos=torch.torch.randn(5, 3), num_nodes=5) Data(norm=torch.torch.randn(5, 3), num_nodes=5) torch_geometric.set_debug(False) def run(rank, data_list): for data in data_list: assert data.x.is_shared() data.x.add_(1) def test_data_share_memory(): data_list = [Data(x=torch.zeros(8)) for _ in range(10)] for data in data_list: assert not data.x.is_shared() assert torch.all(data.x == 0.0) mp.spawn(run, args=(data_list, ), nprocs=4, join=True) for data in data_list: assert data.x.is_shared() assert torch.all(data.x > 0.0) def test_data_setter_properties(): class MyData(Data): def __init__(self): super().__init__() self.my_attr1 = 1 self.my_attr2 = 2 @property def my_attr1(self): return self._my_attr1 @my_attr1.setter def my_attr1(self, value): self._my_attr1 = value data = MyData() assert data.my_attr2 == 2 assert 'my_attr1' not in data._store assert data.my_attr1 == 1 data.my_attr1 = 2 assert 'my_attr1' not in data._store assert data.my_attr1 == 2 def test_data_update(): data = Data(x=torch.arange(0, 5), y=torch.arange(5, 10)) other = Data(z=torch.arange(10, 15), x=torch.arange(15, 20)) data.update(other) assert len(data) == 3 assert torch.equal(data.x, torch.arange(15, 20)) assert torch.equal(data.y, torch.arange(5, 10)) assert torch.equal(data.z, torch.arange(10, 15)) def test_data_connected_components(): data = Data() data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]]) data.y = torch.tensor([[1.1, 1.2], [2.1, 2.2], [3.1, 3.2], [4.1, 4.2], [5.1, 5.2]]) data.edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long) split_data = data.connected_components() assert isinstance(split_data, list) assert len(split_data) == 3 assert torch.equal(split_data[0].x, torch.tensor([[1.0], [2.0]])) assert torch.equal(split_data[0].y, torch.tensor([[1.1, 1.2], [2.1, 2.2]])) assert torch.equal(split_data[0].edge_index, torch.tensor([[0, 1], [1, 0]])) assert torch.equal(split_data[1].x, torch.tensor([[3.0], [4.0]])) assert torch.equal(split_data[1].y, torch.tensor([[3.1, 3.2], [4.1, 4.2]])) assert torch.equal(split_data[1].edge_index, torch.tensor([[0, 1], [1, 0]])) assert torch.equal(split_data[2].x, torch.tensor([[5.0]])) assert torch.equal(split_data[2].y, torch.tensor([[5.1, 5.2]])) assert torch.equal(split_data[2].edge_index, torch.tensor([[], []], dtype=torch.long)) def test_data_find_parent(): # Case 1: Parent does not exist data = Data() data._parents = {} data._ranks = {} node = 1 assert data._find_parent(node) == node assert data._parents == {1: 1} assert data._ranks == {1: 0} # Case 2: Parent exists data._parents[node] = 0 assert data._find_parent(node) == 0 def test_data_union(): # Setup: two nodes in different sets data = Data() data._parents = {} data._ranks = {} node1 = 1 node2 = 2 # Initially, both nodes are their own parents with rank 0 assert data._find_parent(node1) == node1 assert data._find_parent(node2) == node2 data._ranks[node1] = 0 data._ranks[node2] = 0 # Union them: node2 should now point to node1, and node1's rank increases data._union(node1, node2) assert data._find_parent(node1) == node1 assert data._find_parent(node2) == node1 assert data._ranks[node1] == 1 # Add a third node with higher rank and union with node1 node3 = 3 data._parents[node3] = node3 data._ranks[node3] = 2 data._union(node1, node3) # node1's parent should now be node3, since node3 has higher rank assert data._find_parent(node1) == node3 assert data._find_parent(node3) == node3 # Add a fourth node with lower rank and union with node3 node4 = 4 data._parents[node4] = node4 data._ranks[node4] = 0 data._union(node3, node4) assert data._find_parent(node4) == node3 assert data._find_parent(node3) == node3 # Union of already connected nodes should not change anything prev_ranks = data._ranks.copy() prev_parents = data._parents.copy() data._union(node1, node3) assert data._ranks == prev_ranks assert data._parents == prev_parents # Feature Store ############################################################### def test_basic_feature_store(): data = Data() x = torch.randn(20, 20) data.not_a_tensor_attr = 10 # don't include, not a tensor attr data.bad_attr = torch.randn(10, 20) # don't include, bad cat_dim # Put tensor: assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None) assert torch.equal(data.x, x) # Put (modify) tensor slice: x[15:] = 0 data.put_tensor(0, attr_name='x', index=slice(15, None, None)) # Get tensor: out = data.get_tensor(attr_name='x', index=None) assert torch.equal(x, out) # Get tensor size: assert data.get_tensor_size(attr_name='x') == (20, 20) # Get tensor attrs: tensor_attrs = data.get_all_tensor_attrs() assert len(tensor_attrs) == 1 assert tensor_attrs[0].attr_name == 'x' # Remove tensor: assert 'x' in data.__dict__['_store'] data.remove_tensor(attr_name='x', index=None) assert 'x' not in data.__dict__['_store'] # Graph Store ################################################################# @withPackage('torch_sparse') def test_basic_graph_store(): r"""Test the core graph store API.""" data = Data() def assert_equal_tensor_tuple(expected, actual): assert len(expected) == len(actual) for i in range(len(expected)): assert torch.equal(expected[i], actual[i]) # We put all three tensor types: COO, CSR, and CSC, and we get them back # to confirm that `GraphStore` works as intended. coo = (torch.tensor([0, 1]), torch.tensor([1, 2])) csr = (torch.tensor([0, 1, 2, 2]), torch.tensor([1, 2])) csc = (torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2])) # Put: data.put_edge_index(coo, layout='coo', size=(3, 3)) data.put_edge_index(csr, layout='csr') data.put_edge_index(csc, layout='csc') # Get: assert_equal_tensor_tuple(coo, data.get_edge_index('coo')) assert_equal_tensor_tuple(csr, data.get_edge_index('csr')) assert_equal_tensor_tuple(csc, data.get_edge_index('csc')) # Get attrs: edge_attrs = data.get_all_edge_attrs() assert len(edge_attrs) == 3 # Remove: coo, csr, csc = edge_attrs data.remove_edge_index(coo) data.remove_edge_index(csr) data.remove_edge_index(csc) assert len(data.get_all_edge_attrs()) == 0 def test_data_generate_ids(): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]]) data = Data(x=x, edge_index=edge_index) assert len(data) == 2 data.generate_ids() assert len(data) == 4 assert data.n_id.tolist() == [0, 1, 2] assert data.e_id.tolist() == [0, 1, 2, 3, 4] @withPackage('torch_frame') def test_data_with_tensor_frame(): tf = get_random_tensor_frame(num_rows=10) data = Data(tf=tf, edge_index=torch.randint(0, 10, size=(2, 20))) # Test basic attributes: assert data.is_node_attr('tf') assert data.num_nodes == tf.num_rows assert data.num_edges == 20 assert data.num_node_features == tf.num_cols # Test subgraph: index = torch.tensor([1, 2, 3]) sub_data = data.subgraph(index) assert sub_data.num_nodes == 3 for key, value in sub_data.tf.feat_dict.items(): assert torch.allclose(value, tf.feat_dict[key][index]) mask = torch.tensor( [False, True, True, True, False, False, False, False, False, False]) data_sub = data.subgraph(mask) assert data_sub.num_nodes == 3 for key, value in sub_data.tf.feat_dict.items(): assert torch.allclose(value, tf.feat_dict[key][mask]) @pytest.mark.parametrize('num_nodes', [4]) @pytest.mark.parametrize('num_edges', [8]) def test_data_time_handling(num_nodes, num_edges): data = Data( x=torch.randn(num_nodes, 12), edge_index=torch.randint(0, num_nodes, (2, num_edges)), edge_attr=torch.rand((num_edges, 16)), time=torch.arange(num_edges), num_nodes=num_nodes, ) assert data.is_edge_attr('time') assert not data.is_node_attr('time') assert data.is_sorted_by_time() out = data.up_to(5) assert out.num_edges == 6 assert torch.allclose(out.x, data.x) assert torch.equal(out.edge_index, data.edge_index[:, :6]) assert torch.allclose(out.edge_attr, data.edge_attr[:6]) assert torch.equal(out.time, data.time[:6]) out = data.snapshot(2, 5) assert out.num_edges == 4 assert torch.allclose(out.x, data.x) assert torch.equal(out.edge_index, data.edge_index[:, 2:6]) assert torch.allclose(out.edge_attr, data.edge_attr[2:6, :]) assert torch.equal(out.time, data.time[2:6]) out = data.sort_by_time() assert data.is_sorted_by_time() out = data.concat(data) assert out.num_nodes == 8 assert not out.is_sorted_by_time() assert torch.allclose(out.x, torch.cat([data.x, data.x], dim=0)) assert torch.equal( out.edge_index, torch.cat([data.edge_index, data.edge_index], dim=1), ) assert torch.allclose( out.edge_attr, torch.cat([data.edge_attr, data.edge_attr], dim=0), ) assert torch.allclose(out.time, torch.cat([data.time, data.time], dim=0)) out = out.sort_by_time() assert torch.equal(out.time, data.time.repeat_interleave(2)) def test_data_inc(): data = Data(edge_index=torch.tensor([[0, 1], [1, 0]])) with pytest.warns(UserWarning, match="Unable to accurately infer"): assert data.__inc__('edge_index', data.edge_index) == 2 data = Data(index=torch.empty(2, 0, dtype=torch.long)) with pytest.raises(RuntimeError, match="Unable to infer"): with pytest.warns(UserWarning, match="Unable to accurately infer"): data.__inc__('index', data.edge_index) ================================================ FILE: test/data/test_database.py ================================================ import math import os.path as osp import pytest import torch from torch_geometric import EdgeIndex, Index from torch_geometric.data import Data, RocksDatabase, SQLiteDatabase from torch_geometric.data.database import TensorInfo from torch_geometric.profile import benchmark from torch_geometric.testing import has_package, withPackage AVAILABLE_DATABASES = [] if has_package('sqlite3'): AVAILABLE_DATABASES.append(SQLiteDatabase) if has_package('rocksdict'): AVAILABLE_DATABASES.append(RocksDatabase) @pytest.mark.parametrize('Database', AVAILABLE_DATABASES) @pytest.mark.parametrize('batch_size', [None, 1]) def test_database_single_tensor(tmp_path, Database, batch_size): kwargs = dict(path=osp.join(tmp_path, 'storage.db')) if Database == SQLiteDatabase: kwargs['name'] = 'test_table' db = Database(**kwargs) assert db.schema == {0: object} try: assert len(db) == 0 assert str(db) == f'{Database.__name__}(0)' except NotImplementedError: assert str(db) == f'{Database.__name__}()' data = torch.randn(5) db.insert(0, data) try: assert len(db) == 1 except NotImplementedError: pass assert torch.equal(db.get(0), data) indices = torch.tensor([1, 2]) data_list = torch.randn(2, 5) db.multi_insert(indices, data_list, batch_size=batch_size) try: assert len(db) == 3 except NotImplementedError: pass out_list = db.multi_get(indices, batch_size=batch_size) assert isinstance(out_list, list) assert len(out_list) == 2 assert torch.equal(out_list[0], data_list[0]) assert torch.equal(out_list[1], data_list[1]) db.close() @pytest.mark.parametrize('Database', AVAILABLE_DATABASES) def test_database_schema(tmp_path, Database): kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} path = osp.join(tmp_path, 'tuple_storage.db') schema = (int, float, str, dict(dtype=torch.float, size=(2, -1)), object) db = Database(path, schema=schema, **kwargs) assert db.schema == { 0: int, 1: float, 2: str, 3: TensorInfo(dtype=torch.float, size=(2, -1)), 4: object, } data1 = (1, 0.1, 'a', torch.randn(2, 8), Data(x=torch.randn(8))) data2 = (2, float('inf'), 'b', torch.randn(2, 16), Data(x=torch.randn(8))) data3 = (3, float('NaN'), 'c', torch.randn(2, 32), Data(x=torch.randn(8))) db.insert(0, data1) db.multi_insert([1, 2], [data2, data3]) out1 = db.get(0) out2, out3 = db.multi_get([1, 2]) for out, data in zip([out1, out2, out3], [data1, data2, data3]): assert out[0] == data[0] if math.isnan(data[1]): assert math.isnan(out[1]) else: assert out[1] == data[1] assert out[2] == data[2] assert torch.equal(out[3], data[3]) assert isinstance(out[4], Data) and len(out[4]) == 1 assert torch.equal(out[4].x, data[4].x) db.close() path = osp.join(tmp_path, 'dict_storage.db') schema = { 'int': int, 'float': float, 'str': str, 'tensor': dict(dtype=torch.float, size=(2, -1)), 'data': object } db = Database(path, schema=schema, **kwargs) assert db.schema == { 'int': int, 'float': float, 'str': str, 'tensor': TensorInfo(dtype=torch.float, size=(2, -1)), 'data': object, } data1 = { 'int': 1, 'float': 0.1, 'str': 'a', 'tensor': torch.randn(2, 8), 'data': Data(x=torch.randn(1, 8)), } data2 = { 'int': 2, 'float': 0.2, 'str': 'b', 'tensor': torch.randn(2, 16), 'data': Data(x=torch.randn(2, 8)), } data3 = { 'int': 3, 'float': 0.3, 'str': 'c', 'tensor': torch.randn(2, 32), 'data': Data(x=torch.randn(3, 8)), } db.insert(0, data1) db.multi_insert([1, 2], [data2, data3]) out1 = db.get(0) out2, out3 = db.multi_get([1, 2]) for out, data in zip([out1, out2, out3], [data1, data2, data3]): assert out['int'] == data['int'] assert out['float'] == data['float'] assert out['str'] == data['str'] assert torch.equal(out['tensor'], data['tensor']) assert isinstance(out['data'], Data) and len(out['data']) == 1 assert torch.equal(out['data'].x, data['data'].x) db.close() @pytest.mark.parametrize('Database', AVAILABLE_DATABASES) def test_index(tmp_path, Database): kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} path = osp.join(tmp_path, 'tuple_storage.db') schema = dict(dtype=torch.long, is_index=True) db = Database(path, schema=schema, **kwargs) assert db.schema == { 0: TensorInfo(dtype=torch.long, is_index=True), } index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) index2 = Index([0, 1, 1, 2, 2, 3], dim_size=None, is_sorted=True) index3 = Index([], dtype=torch.long) db.insert(0, index1) db.multi_insert([1, 2], [index2, index3]) out1 = db.get(0) out2, out3 = db.multi_get([1, 2]) for out, index in zip([out1, out2, out3], [index1, index2, index3]): assert index.equal(out) assert index.dtype == out.dtype assert index.dim_size == out.dim_size assert index.is_sorted == out.is_sorted db.close() @pytest.mark.parametrize('Database', AVAILABLE_DATABASES) def test_edge_index(tmp_path, Database): kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} path = osp.join(tmp_path, 'tuple_storage.db') schema = dict(dtype=torch.long, is_edge_index=True) db = Database(path, schema=schema, **kwargs) assert db.schema == { 0: TensorInfo(dtype=torch.long, size=(2, -1), is_edge_index=True), } adj1 = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), sort_order='row', is_undirected=True, ) adj2 = EdgeIndex( [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]], sparse_size=(4, 4), sort_order='col', ) adj3 = EdgeIndex([[], []], dtype=torch.long) db.insert(0, adj1) db.multi_insert([1, 2], [adj2, adj3]) out1 = db.get(0) out2, out3 = db.multi_get([1, 2]) for out, adj in zip([out1, out2, out3], [adj1, adj2, adj3]): assert adj.equal(out) assert adj.dtype == out.dtype assert adj.sparse_size() == out.sparse_size() assert adj.sort_order == out.sort_order assert adj.is_undirected == out.is_undirected db.close() @withPackage('sqlite3') def test_database_syntactic_sugar(tmp_path): path = osp.join(tmp_path, 'storage.db') db = SQLiteDatabase(path, name='test_table') data = torch.randn(5, 16) db[0] = data[0] db[1:3] = data[1:3] db[torch.tensor([3, 4])] = data[torch.tensor([3, 4])] assert len(db) == 5 assert torch.equal(db[0], data[0]) assert torch.equal(torch.stack(db[:3], dim=0), data[:3]) assert torch.equal(torch.stack(db[3:], dim=0), data[3:]) assert torch.equal(torch.stack(db[1::2], dim=0), data[1::2]) assert torch.equal(torch.stack(db[[4, 3]], dim=0), data[[4, 3]]) assert torch.equal( torch.stack(db[torch.tensor([4, 3])], dim=0), data[torch.tensor([4, 3])], ) assert torch.equal( torch.stack(db[torch.tensor([4, 4])], dim=0), data[torch.tensor([4, 4])], ) if __name__ == '__main__': import argparse import tempfile import time parser = argparse.ArgumentParser() parser.add_argument('--numel', type=int, default=100_000) parser.add_argument('--batch_size', type=int, default=256) args = parser.parse_args() data = torch.randn(args.numel, 128) tmp_dir = tempfile.TemporaryDirectory() path = osp.join(tmp_dir.name, 'sqlite.db') sqlite_db = SQLiteDatabase(path, name='test_table') t = time.perf_counter() sqlite_db.multi_insert(range(args.numel), data, batch_size=100, log=True) print(f'Initialized SQLiteDB in {time.perf_counter() - t:.2f} seconds') path = osp.join(tmp_dir.name, 'rocks.db') rocks_db = RocksDatabase(path) t = time.perf_counter() rocks_db.multi_insert(range(args.numel), data, batch_size=100, log=True) print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds') def in_memory_get(data): index = torch.randint(0, args.numel, (args.batch_size, )) return data[index] def db_get(db): index = torch.randint(0, args.numel, (args.batch_size, )) return db[index] benchmark( funcs=[in_memory_get, db_get, db_get], func_names=['In-Memory', 'SQLite', 'RocksDB'], args=[(data, ), (sqlite_db, ), (rocks_db, )], num_steps=50, num_warmups=5, ) tmp_dir.cleanup() ================================================ FILE: test/data/test_datapipes.py ================================================ import pytest import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import DatasetAdapter from torch_geometric.loader import DataLoader from torch_geometric.testing import withPackage from torch_geometric.utils import to_smiles @pytest.fixture() def dataset_adapter() -> DatasetAdapter: x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]]) data = Data(x=x, edge_index=edge_index) return DatasetAdapter([data, data, data, data]) def test_dataset_adapter(dataset_adapter): loader = DataLoader(dataset_adapter, batch_size=2) batch = next(iter(loader)) assert batch.x.shape == (6, 8) assert len(loader) == 2 # Test sharding: dataset_adapter.apply_sharding(2, 0) assert len([data for data in dataset_adapter]) == 2 assert dataset_adapter.is_shardable() def test_datapipe_batch_graphs(dataset_adapter): dp = dataset_adapter.batch_graphs(batch_size=2) assert len(dp) == 2 batch = next(iter(dp)) assert batch.x.shape == (6, 8) def test_functional_transform(dataset_adapter): assert next(iter(dataset_adapter)).is_directed() dataset_adapter = dataset_adapter.to_undirected() assert next(iter(dataset_adapter)).is_undirected() @withPackage('rdkit') def test_datapipe_parse_smiles(): smiles = 'F/C=C/F' dp = DatasetAdapter([smiles]) dp = dp.parse_smiles() assert to_smiles(next(iter(dp))) == smiles dp = DatasetAdapter([{'abc': smiles, 'cba': '1.0'}]) dp = dp.parse_smiles(smiles_key='abc', target_key='cba') assert to_smiles(next(iter(dp))) == smiles ================================================ FILE: test/data/test_dataset.py ================================================ import copy import pytest import torch from torch_geometric import EdgeIndex, Index from torch_geometric.data import Data, HeteroData, InMemoryDataset from torch_geometric.datasets import KarateClub from torch_geometric.testing import withPackage from torch_geometric.transforms import BaseTransform from torch_geometric.typing import SparseTensor class MyTestDataset(InMemoryDataset): def __init__(self, data_list, transform=None): super().__init__(None, transform=transform) self.data, self.slices = self.collate(data_list) class MyStoredTestDataset(InMemoryDataset): def __init__(self, root, data_list, transform=None): self.data_list = data_list super().__init__(root, transform=transform) self.load(self.processed_paths[0], data_cls=data_list[0].__class__) @property def processed_file_names(self) -> str: return 'data.pt' def process(self): self.save(self.data_list, self.processed_paths[0]) def test_in_memory_dataset(): x1 = torch.tensor([[1.0], [1.0], [1.0]]) x2 = torch.tensor([[2.0], [2.0], [2.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) face = torch.tensor([[0], [1], [2]]) data1 = Data(x1, edge_index, face=face, test_int=1, test_str='1') data1.num_nodes = 10 data2 = Data(x2, edge_index, face=face, test_int=2, test_str='2') data2.num_nodes = 5 dataset = MyTestDataset([data1, data2]) assert str(dataset) == 'MyTestDataset(2)' assert len(dataset) == 2 assert len(dataset[0]) == 6 assert dataset[0].num_nodes == 10 assert dataset[0].x.tolist() == x1.tolist() assert dataset[0].edge_index.tolist() == edge_index.tolist() assert dataset[0].face.tolist() == face.tolist() assert dataset[0].test_int == 1 assert dataset[0].test_str == '1' assert len(dataset[1]) == 6 assert dataset[1].num_nodes == 5 assert dataset[1].x.tolist() == x2.tolist() assert dataset[1].edge_index.tolist() == edge_index.tolist() assert dataset[1].face.tolist() == face.tolist() assert dataset[1].test_int == 2 assert dataset[1].test_str == '2' with pytest.warns(UserWarning, match="internal storage format"): dataset.data # noqa: B018 assert torch.equal(dataset.x, torch.cat([x1, x2], dim=0)) assert dataset.edge_index.tolist() == [ [0, 1, 1, 2, 10, 11, 11, 12], [1, 0, 2, 1, 11, 10, 12, 11], ] assert torch.equal(dataset[1:].x, x2) def test_stored_in_memory_dataset(tmp_path): x1 = torch.tensor([[1.0], [1.0], [1.0]]) x2 = torch.tensor([[2.0], [2.0], [2.0], [2.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data1 = Data(x1, edge_index, num_nodes=3, test_int=1, test_str='1') data2 = Data(x2, edge_index, num_nodes=4, test_int=2, test_str='2') dataset = MyStoredTestDataset(tmp_path, [data1, data2]) assert dataset._data.num_nodes == 7 assert dataset._data._num_nodes == [3, 4] assert torch.equal(dataset[0].x, x1) assert torch.equal(dataset[0].edge_index, edge_index) assert dataset[0].num_nodes == 3 assert torch.equal(dataset[0].test_int, torch.tensor([1])) assert dataset[0].test_str == '1' assert torch.equal(dataset[1].x, x2) assert torch.equal(dataset[1].edge_index, edge_index) assert dataset[1].num_nodes == 4 assert torch.equal(dataset[1].test_int, torch.tensor([2])) assert dataset[1].test_str == '2' def test_stored_hetero_in_memory_dataset(tmp_path): x1 = torch.tensor([[1.0], [1.0], [1.0]]) x2 = torch.tensor([[2.0], [2.0], [2.0], [2.0]]) data1 = HeteroData() data1['paper'].x = x1 data1['paper'].num_nodes = 3 data2 = HeteroData() data2['paper'].x = x2 data2['paper'].num_nodes = 4 dataset = MyStoredTestDataset(tmp_path, [data1, data2]) assert dataset._data['paper'].num_nodes == 7 assert dataset._data['paper']._num_nodes == [3, 4] assert torch.equal(dataset[0]['paper'].x, x1) assert dataset[0]['paper'].num_nodes == 3 assert torch.equal(dataset[1]['paper'].x, x2) assert dataset[1]['paper'].num_nodes == 4 def test_index(tmp_path): index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True) data1 = Data(batch=index1) data2 = Data(batch=index2) dataset = MyTestDataset([data1, data2]) assert len(dataset) == 2 for data, index in zip(dataset, [index1, index2]): assert isinstance(data.batch, Index) assert data.batch.equal(index) assert data.batch.dim_size == index.dim_size assert data.batch.is_sorted == index.is_sorted dataset = MyStoredTestDataset(tmp_path, [data1, data2]) assert len(dataset) == 2 for data, index in zip(dataset, [index1, index2]): assert isinstance(data.batch, Index) assert data.batch.equal(index) assert data.batch.dim_size == index.dim_size assert data.batch.is_sorted == index.is_sorted def test_edge_index(tmp_path): edge_index1 = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), sort_order='row', is_undirected=True, ) edge_index2 = EdgeIndex( [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]], sparse_size=(4, 4), sort_order='col', ) data1 = Data(edge_index=edge_index1) data2 = Data(edge_index=edge_index2) dataset = MyTestDataset([data1, data2]) assert len(dataset) == 2 for data, edge_index in zip(dataset, [edge_index1, edge_index2]): assert isinstance(data.edge_index, EdgeIndex) assert data.edge_index.equal(edge_index) assert data.edge_index.sparse_size() == edge_index.sparse_size() assert data.edge_index.sort_order == edge_index.sort_order assert data.edge_index.is_undirected == edge_index.is_undirected dataset = MyStoredTestDataset(tmp_path, [data1, data2]) assert len(dataset) == 2 for data, edge_index in zip(dataset, [edge_index1, edge_index2]): assert isinstance(data.edge_index, EdgeIndex) assert data.edge_index.equal(edge_index) assert data.edge_index.sparse_size() == edge_index.sparse_size() assert data.edge_index.sort_order == edge_index.sort_order assert data.edge_index.is_undirected == edge_index.is_undirected def test_in_memory_num_classes(): dataset = MyTestDataset([Data(), Data()]) assert dataset.num_classes == 0 dataset = MyTestDataset([Data(y=0), Data(y=1)]) assert dataset.num_classes == 2 dataset = MyTestDataset([Data(y=1.5), Data(y=2.5), Data(y=3.5)]) with pytest.warns(UserWarning, match="unique elements"): assert dataset.num_classes == 3 dataset = MyTestDataset([ Data(y=torch.tensor([[0, 1, 0, 1]])), Data(y=torch.tensor([[1, 0, 0, 0]])), Data(y=torch.tensor([[0, 0, 1, 0]])), ]) assert dataset.num_classes == 4 # Test when `__getitem__` returns a tuple of data objects. def transform(data): copied_data = copy.copy(data) copied_data.y += 1 return data, copied_data, 'foo' dataset = MyTestDataset([Data(y=0), Data(y=1)], transform=transform) assert dataset.num_classes == 3 def test_in_memory_dataset_copy(): data_list = [Data(x=torch.randn(5, 16)) for _ in range(4)] dataset = MyTestDataset(data_list) copied_dataset = dataset.copy() assert id(copied_dataset) != id(dataset) assert len(copied_dataset) == len(dataset) == 4 for copied_data, data in zip(copied_dataset, dataset): assert torch.equal(copied_data.x, data.x) copied_dataset = dataset.copy([1, 2]) assert len(copied_dataset) == 2 assert torch.equal(copied_dataset[0].x, data_list[1].x) assert torch.equal(copied_dataset[1].x, data_list[2].x) def test_to_datapipe(): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data = Data(x=x, edge_index=edge_index) dataset = MyTestDataset([data, data]) dp = dataset.to_datapipe() assert isinstance(dp, torch.utils.data.IterDataPipe) assert len(dp) == 2 assert torch.equal(dataset[0].x, list(dp)[0].x) assert torch.equal(dataset[0].edge_index, list(dp)[0].edge_index) assert torch.equal(dataset[1].x, list(dp)[1].x) assert torch.equal(dataset[1].edge_index, list(dp)[1].edge_index) @withPackage('torch_sparse') def test_in_memory_sparse_tensor_dataset(): x = torch.randn(11, 16) adj = SparseTensor( row=torch.tensor([4, 1, 3, 2, 2, 3]), col=torch.tensor([1, 3, 2, 3, 3, 2]), sparse_sizes=(11, 11), ) data = Data(x=x, adj=adj) dataset = MyTestDataset([data, data]) assert len(dataset) == 2 assert torch.allclose(dataset[0].x, x) assert dataset[0].adj.sparse_sizes() == (11, 11) assert torch.allclose(dataset[1].x, x) assert dataset[1].adj.sparse_sizes() == (11, 11) def test_collate_with_new_dimension(): class MyData(Data): def __cat_dim__(self, key, value, *args, **kwargs): if key == 'foo': return None else: return super().__cat_dim__(key, value, *args, **kwargs) x = torch.tensor([1, 2, 3], dtype=torch.float) foo = torch.randn(4) y = torch.tensor(1) data = MyData(x=x, foo=foo, y=y) dataset = MyTestDataset([data, data]) assert str(dataset) == 'MyTestDataset(2)' assert len(dataset) == 2 data1 = dataset[0] assert len(data1) == 3 assert data1.x.tolist() == x.tolist() assert data1.foo.tolist() == foo.tolist() assert data1.y.tolist() == [1] data2 = dataset[0] assert len(data2) == 3 assert data2.x.tolist() == x.tolist() assert data2.foo.tolist() == foo.tolist() assert data2.y.tolist() == [1] def test_hetero_in_memory_dataset(): data1 = HeteroData() data1.y = torch.randn(5) data1['paper'].x = torch.randn(10, 16) data1['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long() data2 = HeteroData() data2.y = torch.randn(5) data2['paper'].x = torch.randn(10, 16) data2['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long() dataset = MyTestDataset([data1, data2]) assert str(dataset) == 'MyTestDataset(2)' assert len(dataset) == 2 assert len(dataset[0]) == 3 assert dataset[0].y.tolist() == data1.y.tolist() assert dataset[0]['paper'].x.tolist() == data1['paper'].x.tolist() assert (dataset[0]['paper', 'paper'].edge_index.tolist() == data1[ 'paper', 'paper'].edge_index.tolist()) assert len(dataset[1]) == 3 assert dataset[1].y.tolist() == data2.y.tolist() assert dataset[1]['paper'].x.tolist() == data2['paper'].x.tolist() assert (dataset[1]['paper', 'paper'].edge_index.tolist() == data2[ 'paper', 'paper'].edge_index.tolist()) def test_override_behavior(): class DS1(InMemoryDataset): def __init__(self): self.enter_download = False self.enter_process = False super().__init__() def _download(self): self.enter_download = True def _process(self): self.enter_process = True def download(self): pass def process(self): pass class DS2(InMemoryDataset): def __init__(self): self.enter_download = False self.enter_process = False super().__init__() def _download(self): self.enter_download = True def _process(self): self.enter_process = True def process(self): pass class DS3(InMemoryDataset): def __init__(self): self.enter_download = False self.enter_process = False super().__init__() def _download(self): self.enter_download = True def _process(self): self.enter_process = True class DS4(DS1): pass ds = DS1() assert ds.enter_download assert ds.enter_process ds = DS2() assert not ds.enter_download assert ds.enter_process ds = DS3() assert not ds.enter_download assert not ds.enter_process ds = DS4() assert ds.enter_download assert ds.enter_process def test_lists_of_tensors_in_memory_dataset(): def tr(n, m): return torch.rand((n, m)) d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)]) d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)]) d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)]) d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)]) data_list = [d1, d2, d3, d4] dataset = MyTestDataset(data_list) assert len(dataset) == 4 assert dataset[0].xs[1].size() == (11, 4) assert dataset[0].xs[2].size() == (1, 2) assert dataset[1].xs[0].size() == (5, 3) assert dataset[2].xs[1].size() == (15, 4) assert dataset[3].xs[1].size() == (16, 4) @withPackage('torch_sparse') def test_lists_of_sparse_tensors(): e1 = torch.tensor([[4, 1, 3, 2, 2, 3], [1, 3, 2, 3, 3, 2]]) e2 = torch.tensor([[0, 1, 4, 7, 2, 9], [7, 2, 2, 1, 4, 7]]) e3 = torch.tensor([[3, 5, 1, 2, 3, 3], [5, 0, 2, 1, 3, 7]]) e4 = torch.tensor([[0, 1, 9, 2, 0, 3], [1, 1, 2, 1, 3, 2]]) adj1 = SparseTensor.from_edge_index(e1, sparse_sizes=(11, 11)) adj2 = SparseTensor.from_edge_index(e2, sparse_sizes=(22, 22)) adj3 = SparseTensor.from_edge_index(e3, sparse_sizes=(12, 12)) adj4 = SparseTensor.from_edge_index(e4, sparse_sizes=(15, 15)) d1 = Data(adj_test=[adj1, adj2]) d2 = Data(adj_test=[adj3, adj4]) data_list = [d1, d2] dataset = MyTestDataset(data_list) assert len(dataset) == 2 assert dataset[0].adj_test[0].sparse_sizes() == (11, 11) assert dataset[0].adj_test[1].sparse_sizes() == (22, 22) assert dataset[1].adj_test[0].sparse_sizes() == (12, 12) assert dataset[1].adj_test[1].sparse_sizes() == (15, 15) def test_file_names_as_property_and_method(): class MyTestDataset(InMemoryDataset): def __init__(self): super().__init__('/tmp/MyTestDataset') @property def raw_file_names(self): return ['test_file'] def download(self): pass MyTestDataset() class MyTestDataset(InMemoryDataset): def __init__(self): super().__init__('/tmp/MyTestDataset') def raw_file_names(self): return ['test_file'] def download(self): pass MyTestDataset() @withPackage('sqlite3') def test_to_on_disk_dataset(tmp_path): class MyTransform(BaseTransform): def forward(self, data: Data) -> Data: data.z = 'test_str' return data in_memory_dataset = KarateClub(transform=MyTransform()) with pytest.raises(ValueError, match="root directory of 'KarateClub'"): in_memory_dataset.to_on_disk_dataset() on_disk_dataset = in_memory_dataset.to_on_disk_dataset(tmp_path, log=False) assert str(on_disk_dataset) == 'OnDiskKarateClub()' assert on_disk_dataset.schema == { 'x': dict(dtype=torch.float32, size=(-1, 34)), 'edge_index': dict(dtype=torch.int64, size=(2, -1)), 'y': dict(dtype=torch.int64, size=(-1, )), 'train_mask': dict(dtype=torch.bool, size=(-1, )), } assert in_memory_dataset.transform == on_disk_dataset.transform data1 = in_memory_dataset[0] data2 = on_disk_dataset[0] assert len(data1) == len(data2) assert torch.allclose(data1.x, data2.x) assert torch.equal(data1.edge_index, data2.edge_index) assert torch.equal(data1.y, data2.y) assert torch.equal(data1.train_mask, data2.train_mask) assert data1.z == data2.z on_disk_dataset.close() ================================================ FILE: test/data/test_dataset_summary.py ================================================ import torch from torch import Tensor from torch_geometric.data.summary import Stats, Summary from torch_geometric.datasets import FakeDataset, FakeHeteroDataset from torch_geometric.testing import withPackage def check_stats(stats: Stats, expected: Tensor): expected = expected.to(torch.float) assert stats.mean == float(expected.mean()) assert stats.std == float(expected.std()) assert stats.min == float(expected.min()) assert stats.quantile25 == float(expected.quantile(0.25)) assert stats.median == float(expected.median()) assert stats.quantile75 == float(expected.quantile(0.75)) assert stats.max == float(expected.max()) def test_dataset_summary(): dataset = FakeDataset(num_graphs=10) num_nodes = torch.tensor([data.num_nodes for data in dataset]) num_edges = torch.tensor([data.num_edges for data in dataset]) summary = dataset.get_summary() assert summary.name == 'FakeDataset' assert summary.num_graphs == 10 check_stats(summary.num_nodes, num_nodes) check_stats(summary.num_edges, num_edges) @withPackage('tabulate') def test_dataset_summary_representation(): dataset = FakeDataset(num_graphs=10) summary1 = Summary.from_dataset(dataset, per_type=False) summary2 = Summary.from_dataset(dataset, per_type=True) assert str(summary1) == str(summary2) @withPackage('tabulate') def test_dataset_summary_hetero(): dataset1 = FakeHeteroDataset(num_graphs=10) summary1 = Summary.from_dataset(dataset1, per_type=False) dataset2 = [data.to_homogeneous() for data in dataset1] summary2 = Summary.from_dataset(dataset2) summary2.name = 'FakeHeteroDataset' assert summary1 == summary2 assert str(summary1) == str(summary2) @withPackage('tabulate') def test_dataset_summary_hetero_representation_length(): dataset = FakeHeteroDataset(num_graphs=10) summary = Summary.from_dataset(dataset) num_lines = len(str(summary).splitlines()) stats_len = len(Stats.__dataclass_fields__) len_header_and_border = 5 num_tables = 3 # general, stats per node type, stats per edge type assert num_lines == num_tables * (stats_len + len_header_and_border) def test_dataset_summary_hetero_per_type_check(): dataset = FakeHeteroDataset(num_graphs=10) exp_num_nodes = torch.tensor([data.num_nodes for data in dataset]) exp_num_edges = torch.tensor([data.num_edges for data in dataset]) summary = dataset.get_summary() assert summary.name == 'FakeHeteroDataset' assert summary.num_graphs == 10 check_stats(summary.num_nodes, exp_num_nodes) check_stats(summary.num_edges, exp_num_edges) num_nodes_per_type = {} for node_type in dataset.node_types: num_nodes = [data[node_type].num_nodes for data in dataset] num_nodes_per_type[node_type] = torch.tensor(num_nodes) assert len(summary.num_nodes_per_type) == len(dataset.node_types) for node_type, stats in summary.num_nodes_per_type.items(): check_stats(stats, num_nodes_per_type[node_type]) num_edges_per_type = {} for edge_type in dataset.edge_types: num_edges = [data[edge_type].num_edges for data in dataset] num_edges_per_type[edge_type] = torch.tensor(num_edges) assert len(summary.num_edges_per_type) == len(dataset.edge_types) for edge_type, stats in summary.num_edges_per_type.items(): check_stats(stats, num_edges_per_type[edge_type]) ================================================ FILE: test/data/test_feature_store.py ================================================ from dataclasses import dataclass import pytest import torch from torch_geometric.data import TensorAttr from torch_geometric.data.feature_store import AttrView, _FieldStatus from torch_geometric.testing import MyFeatureStore @dataclass class MyTensorAttrNoGroupName(TensorAttr): def __init__(self, attr_name=_FieldStatus.UNSET, index=_FieldStatus.UNSET): # Treat group_name as optional, and move it to the end super().__init__(None, attr_name, index) class MyFeatureStoreNoGroupName(MyFeatureStore): def __init__(self): super().__init__() self._tensor_attr_cls = MyTensorAttrNoGroupName def test_feature_store(): store = MyFeatureStore() tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) group_name = 'A' attr_name = 'feat' index = torch.tensor([0, 1, 2]) attr = TensorAttr(group_name, attr_name, index) assert TensorAttr(group_name).update(attr) == attr # Normal API: store.put_tensor(tensor, attr) assert torch.equal(store.get_tensor(attr), tensor) assert torch.equal( store.get_tensor(group_name, attr_name, index=torch.tensor([0, 2])), tensor[torch.tensor([0, 2])], ) assert store.update_tensor(tensor + 1, attr) assert torch.equal(store.get_tensor(attr), tensor + 1) store.remove_tensor(attr) with pytest.raises(KeyError): _ = store.get_tensor(attr) # Views: view = store.view(group_name=group_name) view.attr_name = attr_name view['index'] = index assert view != "not a 'AttrView' object" assert view == AttrView(store, TensorAttr(group_name, attr_name, index)) assert str(view) == ("AttrView(store=MyFeatureStore(), " "attr=TensorAttr(group_name='A', attr_name='feat', " "index=tensor([0, 1, 2])))") # Indexing: store[group_name, attr_name, index] = tensor # Fully-specified forms, all of which produce a tensor output assert torch.equal(store[group_name, attr_name, index], tensor) assert torch.equal(store[group_name, attr_name, None], tensor) assert torch.equal(store[group_name, attr_name, :], tensor) assert torch.equal(store[group_name][attr_name][:], tensor) assert torch.equal(store[group_name].feat[:], tensor) assert torch.equal(store.view().A.feat[:], tensor) with pytest.raises(AttributeError) as exc_info: _ = store.view(group_name=group_name, index=None).feat.A print(exc_info) # Partially-specified forms, which produce an AttrView object assert store[group_name] == store.view(TensorAttr(group_name=group_name)) assert store[group_name].feat == store.view( TensorAttr(group_name=group_name, attr_name=attr_name)) # Partially-specified forms, when called, produce a Tensor output # from the `TensorAttr` that has been partially specified. store[group_name, None, None] = tensor assert isinstance(store[group_name], AttrView) assert torch.equal(store[group_name](), tensor) # Deletion: del store[group_name, attr_name, index] with pytest.raises(KeyError): _ = store[group_name, attr_name, index] del store[group_name] with pytest.raises(KeyError): _ = store[group_name]() def test_feature_store_override(): store = MyFeatureStoreNoGroupName() tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) attr_name = 'feat' index = torch.tensor([0, 1, 2]) # Only use attr_name and index, in that order: store[attr_name, index] = tensor # A few assertions to ensure group_name is not needed: assert isinstance(store[attr_name], AttrView) assert torch.equal(store[attr_name, index], tensor) assert torch.equal(store[attr_name][index], tensor) assert torch.equal(store[attr_name][:], tensor) assert torch.equal(store[attr_name, :], tensor) ================================================ FILE: test/data/test_graph_store.py ================================================ import pytest import torch from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout from torch_geometric.testing import MyGraphStore, get_random_edge_index from torch_geometric.utils import ( to_torch_coo_tensor, to_torch_csc_tensor, to_torch_csr_tensor, ) def test_graph_store(): graph_store = MyGraphStore() assert str(graph_store) == 'MyGraphStore()' coo = torch.tensor([0, 1]), torch.tensor([1, 2]) csr = torch.tensor([0, 1, 2]), torch.tensor([1, 2]) csc = torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2]) graph_store['edge_type', 'coo'] = coo graph_store['edge_type', 'csr'] = csr graph_store['edge_type', 'csc'] = csc assert torch.equal(graph_store['edge_type', 'coo'][0], coo[0]) assert torch.equal(graph_store['edge_type', 'coo'][1], coo[1]) assert torch.equal(graph_store['edge_type', 'csr'][0], csr[0]) assert torch.equal(graph_store['edge_type', 'csr'][1], csr[1]) assert torch.equal(graph_store['edge_type', 'csc'][0], csc[0]) assert torch.equal(graph_store['edge_type', 'csc'][1], csc[1]) assert len(graph_store.get_all_edge_attrs()) == 3 del graph_store['edge_type', 'coo'] with pytest.raises(KeyError): graph_store['edge_type', 'coo'] with pytest.raises(KeyError): graph_store['edge_type_2', 'coo'] def test_graph_store_conversion(): graph_store = MyGraphStore() edge_index = get_random_edge_index(100, 100, 300) adj = to_torch_coo_tensor(edge_index, size=(100, 100)) coo = (adj.indices()[0], adj.indices()[1]) adj = to_torch_csr_tensor(edge_index, size=(100, 100)) csr = (adj.crow_indices(), adj.col_indices()) adj = to_torch_csc_tensor(edge_index, size=(100, 100)) csc = (adj.row_indices(), adj.ccol_indices()) graph_store.put_edge_index(coo, ('v', '1', 'v'), 'coo', size=(100, 100)) graph_store.put_edge_index(csr, ('v', '2', 'v'), 'csr', size=(100, 100)) graph_store.put_edge_index(csc, ('v', '3', 'v'), 'csc', size=(100, 100)) # Convert to COO: row_dict, col_dict, perm_dict = graph_store.coo() assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 for row, col, perm in zip(row_dict.values(), col_dict.values(), perm_dict.values()): assert torch.equal(row.sort()[0], coo[0].sort()[0]) assert torch.equal(col.sort()[0], coo[1].sort()[0]) assert perm is None # Convert to CSR: row_dict, col_dict, perm_dict = graph_store.csr() assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 for row, col in zip(row_dict.values(), col_dict.values()): assert torch.equal(row, csr[0]) assert torch.equal(col.sort()[0], csr[1].sort()[0]) # Convert to CSC: row_dict, col_dict, perm_dict = graph_store.csc() assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 for row, col in zip(row_dict.values(), col_dict.values()): assert torch.equal(row.sort()[0], csc[0].sort()[0]) assert torch.equal(col, csc[1]) # Ensure that 'edge_types' parameters work as intended: out = graph_store.coo([('v', '1', 'v')]) assert torch.equal(list(out[0].values())[0], coo[0]) assert torch.equal(list(out[1].values())[0], coo[1]) # Ensure that 'store' parameter works as intended: key = EdgeAttr(edge_type=('v', '1', 'v'), layout=EdgeLayout.CSR, is_sorted=False, size=(100, 100)) with pytest.raises(KeyError): graph_store[key] out = graph_store.csr([('v', '1', 'v')], store=True) assert torch.equal(list(out[0].values())[0], csr[0]) assert torch.equal(list(out[1].values())[0].sort()[0], csr[1].sort()[0]) out = graph_store[key] assert torch.equal(out[0], csr[0]) assert torch.equal(out[1].sort()[0], csr[1].sort()[0]) ================================================ FILE: test/data/test_hetero_data.py ================================================ import copy import warnings import pytest import torch from torch_geometric.data import HeteroData from torch_geometric.data.storage import EdgeStorage from torch_geometric.testing import ( get_random_edge_index, get_random_tensor_frame, withPackage, ) from torch_geometric.typing import TensorFrame x_paper = torch.randn(10, 16) x_author = torch.randn(5, 32) x_conference = torch.randn(5, 8) idx_paper = torch.randint(x_paper.size(0), (100, ), dtype=torch.long) idx_author = torch.randint(x_author.size(0), (100, ), dtype=torch.long) idx_conference = torch.randint(x_conference.size(0), (100, ), dtype=torch.long) edge_index_paper_paper = torch.stack([idx_paper[:50], idx_paper[:50]], dim=0) edge_index_paper_author = torch.stack([idx_paper[:30], idx_author[:30]], dim=0) edge_index_author_paper = torch.stack([idx_author[:30], idx_paper[:30]], dim=0) edge_index_paper_conference = torch.stack( [idx_paper[:25], idx_conference[:25]], dim=0) edge_attr_paper_paper = torch.randn(edge_index_paper_paper.size(1), 8) edge_attr_author_paper = torch.randn(edge_index_author_paper.size(1), 8) def test_init_hetero_data(): data = HeteroData() data['v1'].x = 1 data['paper'].x = x_paper data['author'].x = x_author data['paper', 'paper'].edge_index = edge_index_paper_paper data['paper', 'author'].edge_index = edge_index_paper_author data['author', 'paper'].edge_index = edge_index_author_paper with pytest.warns(UserWarning, match="{'v1'} are isolated"): data.validate(raise_on_error=True) assert len(data) == 2 assert data.node_types == ['v1', 'paper', 'author'] assert len(data.node_stores) == 3 assert len(data.node_items()) == 3 assert len(data.edge_types) == 3 assert len(data.edge_stores) == 3 assert len(data.edge_items()) == 3 data = HeteroData( v1={'x': 1}, paper={'x': x_paper}, author={'x': x_author}, paper__paper={'edge_index': edge_index_paper_paper}, paper__author={'edge_index': edge_index_paper_author}, author__paper={'edge_index': edge_index_author_paper}, ) assert len(data) == 2 assert data.node_types == ['v1', 'paper', 'author'] assert len(data.node_stores) == 3 assert len(data.node_items()) == 3 assert len(data.edge_types) == 3 assert len(data.edge_stores) == 3 assert len(data.edge_items()) == 3 data = HeteroData({ 'v1': { 'x': 1 }, 'paper': { 'x': x_paper }, 'author': { 'x': x_author }, ('paper', 'paper'): { 'edge_index': edge_index_paper_paper }, ('paper', 'author'): { 'edge_index': edge_index_paper_author }, ('author', 'paper'): { 'edge_index': edge_index_author_paper }, }) assert len(data) == 2 assert data.node_types == ['v1', 'paper', 'author'] assert len(data.node_stores) == 3 assert len(data.node_items()) == 3 assert len(data.edge_types) == 3 assert len(data.edge_stores) == 3 assert len(data.edge_items()) == 3 def test_hetero_data_to_from_dict(): data = HeteroData() data.global_id = '1' data['v1'].x = torch.randn(5, 16) data['v2'].y = torch.randn(4, 16) data['v1', 'v2'].edge_index = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]) out = HeteroData.from_dict(data.to_dict()) assert out.global_id == data.global_id assert torch.equal(out['v1'].x, data['v1'].x) assert torch.equal(out['v2'].y, data['v2'].y) assert torch.equal(out['v1', 'v2'].edge_index, data['v1', 'v2'].edge_index) def test_hetero_data_functions(): data = HeteroData() with pytest.raises(KeyError, match="did not find any occurrences of it"): data.collect('x') data['paper'].x = x_paper data['author'].x = x_author data['paper', 'paper'].edge_index = edge_index_paper_paper data['paper', 'author'].edge_index = edge_index_paper_author data['author', 'paper'].edge_index = edge_index_author_paper data['paper', 'paper'].edge_attr = edge_attr_paper_paper assert len(data) == 3 assert sorted(data.keys()) == ['edge_attr', 'edge_index', 'x'] assert 'x' in data and 'edge_index' in data and 'edge_attr' in data assert data.num_nodes == 15 assert data.num_edges == 110 assert data.node_attrs() == ['x'] assert sorted(data.edge_attrs()) == ['edge_attr', 'edge_index'] assert data.num_node_features == {'paper': 16, 'author': 32} assert data.num_edge_features == { ('paper', 'to', 'paper'): 8, ('paper', 'to', 'author'): 0, ('author', 'to', 'paper'): 0, } node_types, edge_types = data.metadata() assert node_types == ['paper', 'author'] assert edge_types == [ ('paper', 'to', 'paper'), ('paper', 'to', 'author'), ('author', 'to', 'paper'), ] x_dict = data.collect('x') assert len(x_dict) == 2 assert x_dict['paper'].tolist() == x_paper.tolist() assert x_dict['author'].tolist() == x_author.tolist() assert x_dict == data.x_dict data.y = 0 assert data['y'] == 0 and data.y == 0 assert len(data) == 4 assert sorted(data.keys()) == ['edge_attr', 'edge_index', 'x', 'y'] del data['paper', 'author'] node_types, edge_types = data.metadata() assert node_types == ['paper', 'author'] assert edge_types == [('paper', 'to', 'paper'), ('author', 'to', 'paper')] assert len(data.to_dict()) == 5 assert len(data.to_namedtuple()) == 5 assert data.to_namedtuple().y == 0 assert len(data.to_namedtuple().paper) == 1 def test_hetero_data_set_value_dict(): data = HeteroData() data.set_value_dict('x', { 'paper': torch.randn(4, 16), 'author': torch.randn(8, 32), }) assert data.node_types == ['paper', 'author'] assert data.edge_types == [] assert data['paper'].x.size() == (4, 16) assert data['author'].x.size() == (8, 32) def test_hetero_data_rename(): data = HeteroData() data['paper'].x = x_paper data['author'].x = x_author data['paper', 'paper'].edge_index = edge_index_paper_paper data['paper', 'author'].edge_index = edge_index_paper_author data['author', 'paper'].edge_index = edge_index_author_paper data = data.rename('paper', 'article') assert data.node_types == ['author', 'article'] assert data.edge_types == [ ('article', 'to', 'article'), ('article', 'to', 'author'), ('author', 'to', 'article'), ] assert data['article'].x.tolist() == x_paper.tolist() edge_index = data['article', 'article'].edge_index assert edge_index.tolist() == edge_index_paper_paper.tolist() def test_dangling_types(): data = HeteroData() data['src', 'to', 'dst'].edge_index = torch.randint(0, 10, (2, 20)) with pytest.raises(ValueError, match="do not exist as node types"): data.validate() data = HeteroData() data['node'].num_nodes = 10 with pytest.warns(UserWarning, match="{'node'} are isolated"): data.validate() def test_hetero_data_subgraph(): data = HeteroData() data.num_node_types = 3 data['paper'].x = x_paper data['paper'].name = 'paper' data['paper'].num_nodes = x_paper.size(0) data['author'].x = x_author data['author'].num_nodes = x_author.size(0) data['conf'].x = x_conference data['conf'].num_nodes = x_conference.size(0) data['paper', 'paper'].edge_index = edge_index_paper_paper data['paper', 'paper'].edge_attr = edge_attr_paper_paper data['paper', 'paper'].name = 'cites' data['author', 'paper'].edge_index = edge_index_author_paper data['paper', 'author'].edge_index = edge_index_paper_author data['paper', 'conf'].edge_index = edge_index_paper_conference subset = { 'paper': torch.randperm(x_paper.size(0))[:4], 'author': torch.randperm(x_author.size(0))[:2], 'conf': torch.randperm(x_conference.size(0))[:2], } out = data.subgraph(subset) out.validate(raise_on_error=True) assert out.num_node_types == data.num_node_types assert out.node_types == ['paper', 'author', 'conf'] for key in out.node_types: assert len(out[key]) == len(data[key]) assert torch.allclose(out[key].x, data[key].x[subset[key]]) assert out[key].num_nodes == subset[key].size(0) if key == 'paper': assert out['paper'].name == 'paper' # Construct correct edge index manually: node_mask = {} # for each node type a mask of nodes in the subgraph node_map = {} # for each node type a map from old node id to new node id for key in out.node_types: node_mask[key] = torch.zeros((data[key].num_nodes, ), dtype=torch.bool) node_map[key] = torch.zeros((data[key].num_nodes, ), dtype=torch.long) node_mask[key][subset[key]] = True node_map[key][subset[key]] = torch.arange(subset[key].size(0)) edge_mask = {} # for each edge type a mask of edges in the subgraph subgraph_edge_index = { } # for each edge type the edge index of the subgraph for key in out.edge_types: edge_mask[key] = (node_mask[key[0]][data[key].edge_index[0]] & node_mask[key[-1]][data[key].edge_index[1]]) subgraph_edge_index[key] = data[key].edge_index[:, edge_mask[key]] subgraph_edge_index[key][0] = node_map[key[0]][subgraph_edge_index[key] [0]] subgraph_edge_index[key][1] = node_map[key[-1]][ subgraph_edge_index[key][1]] assert out.edge_types == [ ('paper', 'to', 'paper'), ('author', 'to', 'paper'), ('paper', 'to', 'author'), ('paper', 'to', 'conf'), ] for key in out.edge_types: assert len(out[key]) == len(data[key]) assert torch.equal(out[key].edge_index, subgraph_edge_index[key]) if key == ('paper', 'to', 'paper'): assert torch.allclose(out[key].edge_attr, data[key].edge_attr[edge_mask[key]]) assert out[key].name == 'cites' # Test for bool and long in `subset_dict`. author_mask = torch.zeros((x_author.size(0), ), dtype=torch.bool) author_mask[subset['author']] = True subset_mixed = { 'paper': subset['paper'], 'author': author_mask, } out = data.subgraph(subset_mixed) out.validate(raise_on_error=True) assert out.num_node_types == data.num_node_types assert out.node_types == ['paper', 'author', 'conf'] assert out['paper'].num_nodes == subset['paper'].size(0) assert out['author'].num_nodes == subset['author'].size(0) assert out['conf'].num_nodes == data['conf'].num_nodes assert out.edge_types == [ ('paper', 'to', 'paper'), ('author', 'to', 'paper'), ('paper', 'to', 'author'), ('paper', 'to', 'conf'), ] out = data.node_type_subgraph(['paper', 'author']) assert out.node_types == ['paper', 'author'] assert out.edge_types == [('paper', 'to', 'paper'), ('author', 'to', 'paper'), ('paper', 'to', 'author')] out = data.edge_type_subgraph([('paper', 'author')]) assert out.node_types == ['paper', 'author'] assert out.edge_types == [('paper', 'to', 'author')] subset = { ('paper', 'to', 'paper'): torch.arange(4), } out = data.edge_subgraph(subset) assert out.node_types == data.node_types assert out.edge_types == data.edge_types assert data['paper'] == out['paper'] assert data['author'] == out['author'] assert data['paper', 'author'] == out['paper', 'author'] assert data['author', 'paper'] == out['author', 'paper'] assert out['paper', 'paper'].num_edges == 4 assert out['paper', 'paper'].edge_index.size() == (2, 4) assert out['paper', 'paper'].edge_attr.size() == (4, 8) def test_hetero_data_empty_subgraph(): data = HeteroData() data.num_node_types = 3 data['paper'].x = torch.arange(5) data['author'].x = torch.arange(5) data['paper', 'author'].edge_weight = torch.arange(5) out = data.subgraph(subset_dict={ 'paper': torch.tensor([1, 2, 3]), 'author': torch.tensor([1, 2, 3]), }) assert torch.equal(out['paper'].x, torch.arange(1, 4)) assert out['paper'].num_nodes == 3 assert torch.equal(out['author'].x, torch.arange(1, 4)) assert out['author'].num_nodes == 3 assert 'edge_index' not in out['paper', 'author'] assert torch.equal(out['paper', 'author'].edge_weight, torch.arange(5)) def test_copy_hetero_data(): data = HeteroData() data['paper'].x = x_paper data['paper', 'to', 'paper'].edge_index = edge_index_paper_paper out = copy.copy(data) assert id(data) != id(out) assert len(data.stores) == len(out.stores) for store1, store2 in zip(data.stores, out.stores): assert id(store1) != id(store2) assert id(data) == id(store1._parent()) assert id(out) == id(store2._parent()) assert out['paper']._key == 'paper' assert data['paper'].x.data_ptr() == out['paper'].x.data_ptr() assert out['to']._key == ('paper', 'to', 'paper') assert data['to'].edge_index.data_ptr() == out['to'].edge_index.data_ptr() out = copy.deepcopy(data) assert id(data) != id(out) assert len(data.stores) == len(out.stores) for store1, store2 in zip(data.stores, out.stores): assert id(store1) != id(store2) assert id(out) == id(out['paper']._parent()) assert out['paper']._key == 'paper' assert data['paper'].x.data_ptr() != out['paper'].x.data_ptr() assert data['paper'].x.tolist() == out['paper'].x.tolist() assert id(out) == id(out['to']._parent()) assert out['to']._key == ('paper', 'to', 'paper') assert data['to'].edge_index.data_ptr() != out['to'].edge_index.data_ptr() assert data['to'].edge_index.tolist() == out['to'].edge_index.tolist() def test_to_homogeneous_and_vice_versa(): data = HeteroData() data['paper'].x = torch.randn(100, 128) data['paper'].y = torch.randint(0, 10, (100, )) data['author'].x = torch.randn(200, 128) data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 250) data['paper', 'paper'].edge_weight = torch.randn(250, ) data['paper', 'paper'].edge_attr = torch.randn(250, 64) data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 500) data['paper', 'author'].edge_weight = torch.randn(500, ) data['paper', 'author'].edge_attr = torch.randn(500, 64) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) data['author', 'paper'].edge_weight = torch.randn(1000, ) data['author', 'paper'].edge_attr = torch.randn(1000, 64) out = data.to_homogeneous() assert len(out) == 7 assert out.num_nodes == 300 assert out.num_edges == 1750 assert out.num_node_features == 128 assert out.num_edge_features == 64 assert out.node_type.size() == (300, ) assert out.node_type.min() == 0 assert out.node_type.max() == 1 assert out.edge_type.size() == (1750, ) assert out.edge_type.min() == 0 assert out.edge_type.max() == 2 assert len(out._node_type_names) == 2 assert len(out._edge_type_names) == 3 assert out.y.size() == (300, ) assert torch.allclose(out.y[:100], data['paper'].y) assert torch.all(out.y[100:] == -1) assert 'y' not in data['author'] out = out.to_heterogeneous() assert len(out) == 5 assert torch.allclose(data['paper'].x, out['paper'].x) assert torch.allclose(data['author'].x, out['author'].x) assert torch.allclose(data['paper'].y, out['paper'].y) edge_index1 = data['paper', 'paper'].edge_index edge_index2 = out['paper', 'paper'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'paper'].edge_weight, out['paper', 'paper'].edge_weight, ) assert torch.allclose( data['paper', 'paper'].edge_attr, out['paper', 'paper'].edge_attr, ) edge_index1 = data['paper', 'author'].edge_index edge_index2 = out['paper', 'author'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'author'].edge_weight, out['paper', 'author'].edge_weight, ) assert torch.allclose( data['paper', 'author'].edge_attr, out['paper', 'author'].edge_attr, ) edge_index1 = data['author', 'paper'].edge_index edge_index2 = out['author', 'paper'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['author', 'paper'].edge_weight, out['author', 'paper'].edge_weight, ) assert torch.allclose( data['author', 'paper'].edge_attr, out['author', 'paper'].edge_attr, ) out = data.to_homogeneous() node_type = out.node_type edge_type = out.edge_type del out.node_type del out.edge_type del out._edge_type_names del out._node_type_names out = out.to_heterogeneous(node_type, edge_type) assert len(out) == 5 assert torch.allclose(data['paper'].x, out['0'].x) assert torch.allclose(data['author'].x, out['1'].x) edge_index1 = data['paper', 'paper'].edge_index edge_index2 = out['0', '0'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'paper'].edge_weight, out['0', '0'].edge_weight, ) assert torch.allclose( data['paper', 'paper'].edge_attr, out['0', '0'].edge_attr, ) edge_index1 = data['paper', 'author'].edge_index edge_index2 = out['0', '1'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['paper', 'author'].edge_weight, out['0', '1'].edge_weight, ) assert torch.allclose( data['paper', 'author'].edge_attr, out['0', '1'].edge_attr, ) edge_index1 = data['author', 'paper'].edge_index edge_index2 = out['1', '0'].edge_index assert edge_index1.tolist() == edge_index2.tolist() assert torch.allclose( data['author', 'paper'].edge_weight, out['1', '0'].edge_weight, ) assert torch.allclose( data['author', 'paper'].edge_attr, out['1', '0'].edge_attr, ) data = HeteroData() data['paper'].num_nodes = 100 data['author'].num_nodes = 200 out = data.to_homogeneous(add_node_type=False) assert len(out) == 1 assert out.num_nodes == 300 out = data.to_homogeneous().to_heterogeneous() assert len(out) == 1 assert out['paper'].num_nodes == 100 assert out['author'].num_nodes == 200 def test_to_homogeneous_padding(): data = HeteroData() data['paper'].x = torch.randn(100, 128) data['author'].x = torch.randn(50, 64) out = data.to_homogeneous() assert len(out) == 2 assert out.node_type.size() == (150, ) assert out.node_type[:100].abs().sum() == 0 assert out.node_type[100:].sub(1).abs().sum() == 0 assert out.x.size() == (150, 128) assert torch.equal(out.x[:100], data['paper'].x) assert torch.equal(out.x[100:, :64], data['author'].x) assert out.x[100:, 64:].abs().sum() == 0 def test_hetero_data_to_canonical(): data = HeteroData() assert isinstance(data['user', 'product'], EdgeStorage) assert len(data.edge_types) == 1 assert isinstance(data['user', 'to', 'product'], EdgeStorage) assert len(data.edge_types) == 1 data = HeteroData() assert isinstance(data['user', 'buys', 'product'], EdgeStorage) assert isinstance(data['user', 'clicks', 'product'], EdgeStorage) assert len(data.edge_types) == 2 with pytest.raises(TypeError, match="missing 1 required"): data['user', 'product'] def test_hetero_data_invalid_names(): data = HeteroData() with pytest.warns(UserWarning, match="single underscores"): data['my test', 'a__b', 'my test'].edge_attr = torch.randn(10, 16) with warnings.catch_warnings(): # No warning should be raised afterwards: warnings.simplefilter('error') data['my test', 'a__c', 'my test'].edge_attr = torch.randn(10, 16) assert data.edge_types == [ ('my test', 'a__b', 'my test'), ('my test', 'a__c', 'my test'), ] def test_hetero_data_update(): data = HeteroData() data['paper'].x = torch.arange(0, 5) data['paper'].y = torch.arange(5, 10) data['author'].x = torch.arange(10, 15) other = HeteroData() other['paper'].x = torch.arange(15, 20) other['author'].y = torch.arange(20, 25) other['paper', 'paper'].edge_index = torch.randint(5, (2, 20)) data.update(other) assert len(data) == 3 assert torch.equal(data['paper'].x, torch.arange(15, 20)) assert torch.equal(data['paper'].y, torch.arange(5, 10)) assert torch.equal(data['author'].x, torch.arange(10, 15)) assert torch.equal(data['author'].y, torch.arange(20, 25)) assert torch.equal(data['paper', 'paper'].edge_index, other['paper', 'paper'].edge_index) def test_hetero_data_connected_components(): data = HeteroData() data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]]) data["red"].y = torch.tensor([1, 2, 3, 4, 5]) data["red"].z = torch.tensor([[1.1, 1.2], [2.1, 2.2], [3.1, 3.2], [4.1, 4.2], [5.1, 5.2]]) data["blue"].x = torch.tensor([[6.0], [7.0]]) data["red", "to", "red"].edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long) data["red", "with", "red"].edge_index = torch.tensor([[1], [1]], dtype=torch.long) data["red", "to", "blue"].edge_index = torch.tensor([[0], [0]]) split_data = data.connected_components() assert isinstance(split_data, list) assert len(split_data) == 4 assert isinstance(split_data[0], HeteroData) assert isinstance(split_data[1], HeteroData) assert isinstance(split_data[2], HeteroData) assert isinstance(split_data[3], HeteroData) assert split_data[0].node_types == ['red', 'blue'] assert split_data[1].node_types == ['red', 'blue'] assert split_data[2].node_types == ['red', 'blue'] assert split_data[3].node_types == ['red', 'blue'] assert split_data[0].edge_types == [('red', 'to', 'red'), ('red', 'with', 'red'), ('red', 'to', 'blue')] assert split_data[1].edge_types == [('red', 'to', 'red'), ('red', 'with', 'red'), ('red', 'to', 'blue')] assert split_data[2].edge_types == [('red', 'to', 'red'), ('red', 'with', 'red'), ('red', 'to', 'blue')] assert split_data[3].edge_types == [('red', 'to', 'red'), ('red', 'with', 'red'), ('red', 'to', 'blue')] assert torch.equal(split_data[0]["red"].x, torch.tensor([[1.0], [2.0]])) assert torch.equal(split_data[0]["red"].y, torch.tensor([1, 2])) assert torch.equal(split_data[0]["red"].z, torch.tensor([[1.1, 1.2], [2.1, 2.2]])) assert torch.equal(split_data[0]["blue"].x, torch.tensor([[6.0]])) assert torch.equal(split_data[1]["red"].x, torch.tensor([[3.0], [4.0]])) assert torch.equal(split_data[1]["red"].y, torch.tensor([3, 4])) assert torch.equal(split_data[1]["red"].z, torch.tensor([[3.1, 3.2], [4.1, 4.2]])) assert torch.equal(split_data[1]["blue"].x, torch.empty((0, 1))) assert torch.equal(split_data[2]["red"].x, torch.tensor([[5.0]])) assert torch.equal(split_data[2]["red"].y, torch.tensor([5])) assert torch.equal(split_data[2]["red"].z, torch.tensor([[5.1, 5.2]])) assert torch.equal(split_data[2]["blue"].x, torch.empty((0, 1))) assert torch.equal(split_data[3]["red"].x, torch.empty((0, 1))) assert torch.equal(split_data[3]["red"].y, torch.empty((0, ), dtype=torch.int64)) assert torch.equal(split_data[3]["red"].z, torch.empty((0, 2))) assert torch.equal(split_data[3]["blue"].x, torch.tensor([[7.0]])) assert torch.equal(split_data[0]["red", "to", "red"].edge_index, torch.tensor([[0, 1], [1, 0]])) assert torch.equal(split_data[0]["red", "with", "red"].edge_index, torch.tensor([[1], [1]])) assert torch.equal(split_data[0]["red", "to", "blue"].edge_index, torch.tensor([[0], [0]])) assert torch.equal(split_data[1]["red", "to", "red"].edge_index, torch.tensor([[0, 1], [1, 0]])) assert torch.equal(split_data[1]["red", "with", "red"].edge_index, torch.empty((2, 0), dtype=torch.long)) assert torch.equal(split_data[1]["red", "to", "blue"].edge_index, torch.empty((2, 0), dtype=torch.long)) assert torch.equal(split_data[2]["red", "to", "red"].edge_index, torch.empty((2, 0), dtype=torch.long)) assert torch.equal(split_data[2]["red", "with", "red"].edge_index, torch.empty((2, 0), dtype=torch.long)) assert torch.equal(split_data[2]["red", "to", "blue"].edge_index, torch.empty((2, 0), dtype=torch.long)) assert torch.equal(split_data[3]["red", "to", "red"].edge_index, torch.empty((2, 0), dtype=torch.long)) assert torch.equal(split_data[3]["red", "with", "red"].edge_index, torch.empty((2, 0), dtype=torch.long)) assert torch.equal(split_data[3]["red", "to", "blue"].edge_index, torch.empty((2, 0), dtype=torch.long)) def test_hetero_data_connected_components_single_component(): data = HeteroData() data["red"].x = torch.tensor([[1.0], [2.0]]) data["red"].y = torch.tensor([1, 2]) data["red"].z = torch.tensor([[1.1, 1.2], [2.1, 2.2]]) data["blue"].x = torch.tensor([[3.0]]) data["red", "to", "red"].edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) data["red", "to", "blue"].edge_index = torch.tensor([[0], [0]]) split_data = data.connected_components() assert isinstance(split_data, list) assert len(split_data) == 1 def test_hetero_data_find_parent(): # Case 1: Parent does not exist data = HeteroData() data._parents = {} data._ranks = {} node = ('paper', 1) assert data._find_parent(node) == node assert data._parents == {node: node} assert data._ranks == {node: 0} # Case 2: Parent exists data._parents[node] = ('paper', 0) assert data._find_parent(node) == ('paper', 0) def test_hetero_data_union(): # Setup: two nodes in different sets data = HeteroData() data._parents = {} data._ranks = {} node1 = ('paper', 1) node2 = ('paper', 2) # Initially, both nodes are their own parents with rank 0 assert data._find_parent(node1) == node1 assert data._find_parent(node2) == node2 data._ranks[node1] = 0 data._ranks[node2] = 0 # Union them: node2 should now point to node1, and node1's rank increases data._union(node1, node2) assert data._find_parent(node1) == node1 assert data._find_parent(node2) == node1 assert data._ranks[node1] == 1 # Add a third node with higher rank and union with node1 node3 = ('paper', 3) data._parents[node3] = node3 data._ranks[node3] = 2 data._union(node1, node3) # node1's parent should now be node3, since node3 has higher rank assert data._find_parent(node1) == node3 assert data._find_parent(node3) == node3 # Add a fourth node with lower rank and union with node3 node4 = ('paper', 4) data._parents[node4] = node4 data._ranks[node4] = 0 data._union(node3, node4) assert data._find_parent(node4) == node3 assert data._find_parent(node3) == node3 # Union of already connected nodes should not change anything prev_ranks = data._ranks.copy() prev_parents = data._parents.copy() data._union(node1, node3) assert data._ranks == prev_ranks assert data._parents == prev_parents # Feature Store ############################################################### def test_basic_feature_store(): data = HeteroData() x = torch.randn(20, 20) # Put tensor: assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x', index=None) assert torch.equal(data['paper'].x, x) # Put (modify) tensor slice: x[15:] = 0 data.put_tensor(0, group_name='paper', attr_name='x', index=slice(15, None, None)) # Get tensor: out = data.get_tensor(group_name='paper', attr_name='x', index=None) assert torch.equal(x, out) # Get tensor size: assert data.get_tensor_size(group_name='paper', attr_name='x') == (20, 20) # Get tensor attrs: data['paper'].num_nodes = 20 # don't include, not a tensor attr data['paper'].bad_attr = torch.randn(10, 20) # don't include, bad cat_dim tensor_attrs = data.get_all_tensor_attrs() assert len(tensor_attrs) == 1 assert tensor_attrs[0].group_name == 'paper' assert tensor_attrs[0].attr_name == 'x' # Remove tensor: assert 'x' in data['paper'].__dict__['_mapping'] data.remove_tensor(group_name='paper', attr_name='x', index=None) assert 'x' not in data['paper'].__dict__['_mapping'] @withPackage('torch_frame') def test_hetero_data_with_tensor_frame(): data = HeteroData() data['paper'].tf = get_random_tensor_frame(num_rows=x_paper.size(0)) data['author'].tf = get_random_tensor_frame(num_rows=x_author.size(0)) data['author', 'paper'].edge_index = edge_index_author_paper # Basic functionality: assert set(data.node_attrs()) == {'tf'} assert data.num_nodes == x_paper.size(0) + x_author.size(0) assert data.num_node_features['paper'] == 5 assert data.num_node_features['author'] == 5 # Test subgraph: subset = { 'paper': torch.tensor([1, 2, 3, 4]), 'author': torch.tensor([0, 1, 2, 3]), } out = data.subgraph(subset) assert set(out.node_attrs()) == {'tf'} assert out.num_nodes == 8 for key, value in out['paper'].tf.feat_dict.items(): assert value.size(0) == 4 assert torch.allclose(value, data['paper'].tf.feat_dict[key][1:5]) for key, value in out['author'].tf.feat_dict.items(): assert value.size(0) == 4 assert torch.allclose(value, data['author'].tf.feat_dict[key][0:4]) # Test conversion to homogenous graphs and back: for node_attrs in [None, ['tf']]: out = data.to_homogeneous(node_attrs=node_attrs) assert isinstance(out.tf, TensorFrame) assert len(out.tf) == data.num_nodes assert out.num_nodes == data.num_nodes assert out.num_node_features == 5 for key, value in out.tf.feat_dict.items(): assert torch.allclose( value, torch.cat([ data['paper'].tf.feat_dict[key], data['author'].tf.feat_dict[key], ], dim=0), ) out = out.to_heterogeneous() for node_type in data.node_types: for key, value in data[node_type].tf.feat_dict.items(): assert torch.allclose(value, out[node_type].tf.feat_dict[key]) # Graph Store ################################################################# @withPackage('torch_sparse') def test_basic_graph_store(): data = HeteroData() def assert_equal_tensor_tuple(expected, actual): assert len(expected) == len(actual) for i in range(len(expected)): assert torch.equal(expected[i], actual[i]) # We put all three tensor types: COO, CSR, and CSC, and we get them back # to confirm that `GraphStore` works as intended. coo = (torch.tensor([0, 1]), torch.tensor([1, 2])) csr = (torch.tensor([0, 1, 2, 2]), torch.tensor([1, 2])) csc = (torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2])) # Put: data.put_edge_index(coo, layout='coo', edge_type=('a', 'to', 'b'), size=(3, 3)) data.put_edge_index(csr, layout='csr', edge_type=('a', 'to', 'c'), size=(3, 3)) data.put_edge_index(csc, layout='csc', edge_type=('b', 'to', 'c'), size=(3, 3)) # Get: assert_equal_tensor_tuple( coo, data.get_edge_index(layout='coo', edge_type=('a', 'to', 'b'))) assert_equal_tensor_tuple( csr, data.get_edge_index(layout='csr', edge_type=('a', 'to', 'c'))) assert_equal_tensor_tuple( csc, data.get_edge_index(layout='csc', edge_type=('b', 'to', 'c'))) # Get attrs: edge_attrs = data.get_all_edge_attrs() assert len(edge_attrs) == 3 # Remove: coo, csr, csc = edge_attrs data.remove_edge_index(coo) data.remove_edge_index(csr) data.remove_edge_index(csc) assert len(data.get_all_edge_attrs()) == 0 def test_generate_ids(): data = HeteroData() data['paper'].x = torch.randn(100, 128) data['author'].x = torch.randn(200, 128) data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 300) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 400) assert len(data) == 2 data.generate_ids() assert len(data) == 4 assert data['paper'].n_id.tolist() == list(range(100)) assert data['author'].n_id.tolist() == list(range(200)) assert data['paper', 'author'].e_id.tolist() == list(range(300)) assert data['author', 'paper'].e_id.tolist() == list(range(400)) def test_invalid_keys(): data = HeteroData() data['paper'].x = torch.randn(10, 128) data['paper'].node_attrs = ['y'] data['paper', 'paper'].edge_index = get_random_edge_index(10, 10, 20) data['paper', 'paper'].edge_attrs = ['edge_attr'] assert data['paper'].node_attrs() == ['x'] assert data['paper']['node_attrs'] == ['y'] assert data['paper', 'paper'].edge_attrs() == ['edge_index'] assert data['paper', 'paper']['edge_attrs'] == ['edge_attr'] out = data.to_homogeneous() assert set(out.node_attrs()) == {'x', 'node_type'} assert set(out.edge_attrs()) == {'edge_index', 'edge_type'} ================================================ FILE: test/data/test_hypergraph_data.py ================================================ import pytest import torch import torch_geometric from torch_geometric.data.hypergraph_data import HyperGraphData from torch_geometric.loader import DataLoader def test_hypergraph_data(): torch_geometric.set_debug(True) x = torch.tensor([[1, 3, 5, 7], [2, 4, 6, 8], [7, 8, 9, 10]], dtype=torch.float).t() edge_index = torch.tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3], [0, 0, 0, 1, 1, 1, 2, 2, 2]]) data = HyperGraphData(x=x, edge_index=edge_index).to(torch.device('cpu')) data.validate(raise_on_error=True) assert data.num_nodes == 4 assert data.num_edges == 3 assert data.node_attrs() == ['x'] assert data.edge_attrs() == ['edge_index'] assert data.x.tolist() == x.tolist() assert data['x'].tolist() == x.tolist() assert data.get('x').tolist() == x.tolist() assert data.get('y', 2) == 2 assert data.get('y', None) is None assert sorted(data.keys()) == ['edge_index', 'x'] assert len(data) == 2 assert 'x' in data and 'edge_index' in data and 'pos' not in data D = data.to_dict() assert len(D) == 2 assert 'x' in D and 'edge_index' in D D = data.to_namedtuple() assert len(D) == 2 assert D.x is not None and D.edge_index is not None assert data.__cat_dim__('x', data.x) == 0 assert data.__cat_dim__('edge_index', data.edge_index) == -1 assert data.__inc__('x', data.x) == 0 assert torch.equal(data.__inc__('edge_index', data.edge_index), torch.tensor([[data.num_nodes], [data.num_edges]])) data_list = [data, data] loader = DataLoader(data_list, batch_size=2) batch = next(iter(loader)) batched_edge_index = batch.edge_index assert batched_edge_index.tolist() == [[ 0, 1, 2, 1, 2, 3, 0, 2, 3, 4, 5, 6, 5, 6, 7, 4, 6, 7 ], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]] assert not data.x.is_contiguous() data.contiguous() assert data.x.is_contiguous() assert not data.is_coalesced() data = data.coalesce() assert data.is_coalesced() clone = data.clone() assert clone != data assert len(clone) == len(data) assert clone.x.data_ptr() != data.x.data_ptr() assert clone.x.tolist() == data.x.tolist() assert clone.edge_index.data_ptr() != data.edge_index.data_ptr() assert clone.edge_index.tolist() == data.edge_index.tolist() data['x'] = x + 1 assert data.x.tolist() == (x + 1).tolist() assert str(data) == 'HyperGraphData(x=[4, 3], edge_index=[2, 9])' dictionary = {'x': data.x, 'edge_index': data.edge_index} data = HyperGraphData.from_dict(dictionary) assert sorted(data.keys()) == ['edge_index', 'x'] assert not data.has_isolated_nodes() # assert not data.has_self_loops() # assert data.is_undirected() # assert not data.is_directed() assert data.num_nodes == 4 assert data.num_edges == 3 with pytest.warns(UserWarning, match='deprecated'): assert data.num_faces is None assert data.num_node_features == 3 assert data.num_features == 3 data.edge_attr = torch.randn(data.num_edges, 2) assert data.num_edge_features == 2 assert data.is_edge_attr('edge_attr') data.edge_attr = None data.x = None with pytest.warns(UserWarning, match='Unable to accurately infer'): assert data.num_nodes == 4 data.edge_index = None with pytest.warns(UserWarning, match='Unable to accurately infer'): assert data.num_nodes is None assert data.num_edges == 0 data.num_nodes = 4 assert data.num_nodes == 4 data = HyperGraphData(x=x, attribute=x) assert len(data) == 2 assert data.x.tolist() == x.tolist() assert data.attribute.tolist() == x.tolist() face = torch.tensor([[0, 1], [1, 2], [2, 3]]) data = HyperGraphData(num_nodes=4, face=face) with pytest.warns(UserWarning, match='deprecated'): assert data.num_faces == 2 assert data.num_nodes == 4 data = HyperGraphData(title='test') assert str(data) == "HyperGraphData(title='test')" assert data.num_node_features == 0 # assert data.num_edge_features == 0 key = value = 'test_value' data[key] = value assert data[key] == value del data[value] del data[value] # Deleting unset attributes should work as well. assert data.get(key) is None assert data.get('title') == 'test' torch_geometric.set_debug(False) def test_hypergraphdata_subgraph(): x = torch.arange(5) y = torch.tensor([0.]) edge_index = torch.tensor([[0, 1, 3, 2, 4, 0, 3, 4, 2, 1, 2, 3], [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3]]) edge_attr = torch.rand(4, 2) data = HyperGraphData(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr, num_nodes=5) out = data.subgraph(torch.tensor([1, 2, 4])) assert len(out) == 5 assert torch.equal(out.x, torch.tensor([1, 2, 4])) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[1, 2, 2, 1, 0, 1], [0, 0, 1, 1, 2, 2]] assert torch.equal(out.edge_attr, edge_attr[[1, 2, 3]]) assert out.num_nodes == 3 # Test unordered selection: out = data.subgraph(torch.tensor([3, 1, 2])) assert len(out) == 5 assert torch.equal(out.x, torch.tensor([3, 1, 2])) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[0, 2, 0, 2, 1, 2, 0], [0, 0, 1, 1, 2, 2, 2]] assert torch.equal(out.edge_attr, edge_attr[[1, 2, 3]]) assert out.num_nodes == 3 out = data.subgraph(torch.tensor([False, False, False, True, True])) assert len(out) == 5 assert torch.equal(out.x, torch.arange(3, 5)) assert torch.equal(out.y, data.y) assert out.edge_index.tolist() == [[0, 1, 0, 1], [0, 0, 1, 1]] assert torch.equal(out.edge_attr, edge_attr[[1, 2]]) assert out.num_nodes == 2 ================================================ FILE: test/data/test_inherit.py ================================================ import torch from torch_geometric.data import Data, Dataset, InMemoryDataset class MyData(Data): def __init__(self, x=None, edge_index=None, arg=None): super().__init__(x=x, edge_index=edge_index, arg=arg) def random(self): return torch.randn(list(self.x.size()) + list(self.arg.size())) class MyInMemoryDataset(InMemoryDataset): def __init__(self): super().__init__('/tmp/MyInMemoryDataset') x = torch.randn(4, 5) edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]]) arg = torch.randn(4, 3) data_list = [MyData(x, edge_index, arg) for _ in range(10)] self.data, self.slices = self.collate(data_list) def _download(self): pass def _process(self): pass class MyDataset(Dataset): def __init__(self): super().__init__('/tmp/MyDataset') def _download(self): pass def _process(self): pass def len(self): return 10 def get(self, idx): x = torch.randn(4, 5) edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]]) arg = torch.randn(4, 3) return MyData(x, edge_index, arg) def test_inherit(): dataset = MyDataset() assert len(dataset) == 10 data = dataset[0] assert data.random().size() == (4, 5, 4, 3) dataset = MyInMemoryDataset() assert len(dataset) == 10 data = dataset[0] assert data.random().size() == (4, 5, 4, 3) ================================================ FILE: test/data/test_on_disk_dataset.py ================================================ import os.path as osp from typing import Any, Dict import torch from torch_geometric.data import Data, OnDiskDataset from torch_geometric.testing import withPackage @withPackage('sqlite3') def test_pickle(tmp_path): dataset = OnDiskDataset(tmp_path) assert len(dataset) == 0 assert str(dataset) == 'OnDiskDataset(0)' assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db')) data_list = [ Data( x=torch.randn(5, 8), edge_index=torch.randint(0, 5, (2, 16)), num_nodes=5, ) for _ in range(4) ] dataset.append(data_list[0]) assert len(dataset) == 1 dataset.extend(data_list[1:]) assert len(dataset) == 4 out = dataset.get(0) assert torch.equal(out.x, data_list[0].x) assert torch.equal(out.edge_index, data_list[0].edge_index) assert out.num_nodes == data_list[0].num_nodes out_list = dataset.multi_get([1, 2, 3]) for out, data in zip(out_list, data_list[1:]): assert torch.equal(out.x, data.x) assert torch.equal(out.edge_index, data.edge_index) assert out.num_nodes == data.num_nodes dataset.close() # Test persistence of datasets: dataset = OnDiskDataset(tmp_path) assert len(dataset) == 4 out = dataset.get(0) assert torch.equal(out.x, data_list[0].x) assert torch.equal(out.edge_index, data_list[0].edge_index) assert out.num_nodes == data_list[0].num_nodes dataset.close() @withPackage('sqlite3') def test_custom_schema(tmp_path): class CustomSchemaOnDiskDataset(OnDiskDataset): def __init__(self, root: str): schema = { 'x': dict(dtype=torch.float, size=(-1, 8)), 'edge_index': dict(dtype=torch.long, size=(2, -1)), 'num_nodes': int, } self.serialize_count = 0 self.deserialize_count = 0 super().__init__(root, schema=schema) def serialize(self, data: Data) -> Dict[str, Any]: self.serialize_count += 1 return data.to_dict() def deserialize(self, mapping: Dict[str, Any]) -> Any: self.deserialize_count += 1 return Data.from_dict(mapping) dataset = CustomSchemaOnDiskDataset(tmp_path) assert len(dataset) == 0 assert str(dataset) == 'CustomSchemaOnDiskDataset(0)' assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db')) data_list = [ Data( x=torch.randn(5, 8), edge_index=torch.randint(0, 5, (2, 16)), num_nodes=5, ) for _ in range(4) ] dataset.append(data_list[0]) assert dataset.serialize_count == 1 assert len(dataset) == 1 dataset.extend(data_list[1:]) assert dataset.serialize_count == 4 assert len(dataset) == 4 out = dataset.get(0) assert dataset.deserialize_count == 1 assert torch.equal(out.x, data_list[0].x) assert torch.equal(out.edge_index, data_list[0].edge_index) assert out.num_nodes == data_list[0].num_nodes out_list = dataset.multi_get([1, 2, 3]) assert dataset.deserialize_count == 4 for out, data in zip(out_list, data_list[1:]): assert torch.equal(out.x, data.x) assert torch.equal(out.edge_index, data.edge_index) assert out.num_nodes == data.num_nodes dataset.close() ================================================ FILE: test/data/test_remote_backend_utils.py ================================================ import pytest import torch from torch_geometric.data import HeteroData from torch_geometric.data.remote_backend_utils import num_nodes, size from torch_geometric.testing import ( MyFeatureStore, MyGraphStore, get_random_edge_index, ) @pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData]) @pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData]) def test_num_nodes_size(FeatureStore, GraphStore): feature_store = FeatureStore() graph_store = GraphStore() # Infer num nodes from features: x = torch.arange(100) feature_store.put_tensor(x, group_name='x', attr_name='x', index=None) assert num_nodes(feature_store, graph_store, 'x') == 100 # Infer num nodes and size from edges: xy = get_random_edge_index(100, 50, 20) graph_store.put_edge_index(xy, edge_type=('x', 'to', 'y'), layout='coo', size=(100, 50)) assert num_nodes(feature_store, graph_store, 'y') == 50 assert size(feature_store, graph_store, ('x', 'to', 'y')) == (100, 50) # Throw an error if we cannot infer for an unknown node type: with pytest.raises(ValueError, match="Unable to accurately infer"): _ = num_nodes(feature_store, graph_store, 'z') ================================================ FILE: test/data/test_storage.py ================================================ import copy from typing import Any import pytest import torch from torch_geometric.data.storage import BaseStorage def test_base_storage(): storage = BaseStorage() assert storage._mapping == {} storage.x = torch.zeros(1) storage.y = torch.ones(1) assert len(storage) == 2 assert storage._mapping == {'x': torch.zeros(1), 'y': torch.ones(1)} assert storage.x is not None assert storage.y is not None assert torch.allclose(storage.get('x', None), storage.x) assert torch.allclose(storage.get('y', None), storage.y) assert storage.get('z', 2) == 2 assert storage.get('z', None) is None assert len(list(storage.keys('x', 'y', 'z'))) == 2 assert len(list(storage.keys('x', 'y', 'z'))) == 2 assert len(list(storage.values('x', 'y', 'z'))) == 2 assert len(list(storage.items('x', 'y', 'z'))) == 2 del storage.y assert len(storage) == 1 assert storage.x is not None storage = BaseStorage({'x': torch.zeros(1)}) assert len(storage) == 1 assert storage.x is not None storage = BaseStorage(x=torch.zeros(1)) assert len(storage) == 1 assert storage.x is not None storage = BaseStorage(x=torch.zeros(1)) copied_storage = copy.copy(storage) assert storage == copied_storage assert id(storage) != id(copied_storage) assert storage.x.data_ptr() == copied_storage.x.data_ptr() assert int(storage.x) == 0 assert int(copied_storage.x) == 0 deepcopied_storage = copy.deepcopy(storage) assert storage == deepcopied_storage assert id(storage) != id(deepcopied_storage) assert storage.x.data_ptr() != deepcopied_storage.x.data_ptr() assert int(storage.x) == 0 assert int(deepcopied_storage.x) == 0 with pytest.raises(AttributeError, match="has no attribute 'asdf'"): storage.asdf # noqa: B018 def test_storage_tensor_methods(): x = torch.randn(5) storage = BaseStorage({'x': x}) storage = storage.clone() assert storage.x.data_ptr() != x.data_ptr() storage = storage.contiguous() assert storage.x.is_contiguous() storage = storage.to('cpu') assert storage.x.device == torch.device('cpu') storage = storage.cpu() assert storage.x.device == torch.device('cpu') if torch.cuda.is_available(): storage = storage.pin_memory() assert storage.x.is_pinned() storage = storage.share_memory_() assert storage.x.is_shared storage = storage.detach_() assert not storage.x.requires_grad storage = storage.detach() assert not storage.x.requires_grad storage = storage.requires_grad_() assert storage.x.requires_grad def test_setter_and_getter(): class MyStorage(BaseStorage): @property def my_property(self) -> Any: return self._my_property @my_property.setter def my_property(self, value: Any): self._my_property = value storage = MyStorage() storage.my_property = 'hello' assert storage.my_property == 'hello' assert storage._my_property == storage._my_property ================================================ FILE: test/data/test_temporal.py ================================================ import copy import torch from torch_geometric.data import TemporalData def get_temporal_data(num_events, msg_channels): return TemporalData( src=torch.arange(num_events), dst=torch.arange(num_events, num_events * 2), t=torch.arange(0, num_events * 1000, step=1000), msg=torch.randn(num_events, msg_channels), y=torch.randint(0, 2, (num_events, )), ) def test_temporal_data(): data = get_temporal_data(num_events=3, msg_channels=16) assert str(data) == ("TemporalData(src=[3], dst=[3], t=[3], " "msg=[3, 16], y=[3])") assert data.num_nodes == 6 assert data.num_events == data.num_edges == len(data) == 3 assert data.src.tolist() == [0, 1, 2] assert data['src'].tolist() == [0, 1, 2] assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]] data.edge_index = 'edge_index' assert data.edge_index == 'edge_index' del data.edge_index assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]] assert sorted(data.keys()) == ['dst', 'msg', 'src', 't', 'y'] assert sorted(data.to_dict().keys()) == sorted(data.keys()) data_tuple = data.to_namedtuple() assert len(data_tuple) == 5 assert data_tuple.src is not None assert data_tuple.dst is not None assert data_tuple.t is not None assert data_tuple.msg is not None assert data_tuple.y is not None assert data.__cat_dim__('src', data.src) == 0 assert data.__inc__('src', data.src) == 6 clone = data.clone() assert clone != data assert len(clone) == len(data) assert clone.src.data_ptr() != data.src.data_ptr() assert clone.src.tolist() == data.src.tolist() assert clone.dst.data_ptr() != data.dst.data_ptr() assert clone.dst.tolist() == data.dst.tolist() deepcopy = copy.deepcopy(data) assert deepcopy != data assert len(deepcopy) == len(data) assert deepcopy.src.data_ptr() != data.src.data_ptr() assert deepcopy.src.tolist() == data.src.tolist() assert deepcopy.dst.data_ptr() != data.dst.data_ptr() assert deepcopy.dst.tolist() == data.dst.tolist() key = value = 'test_value' data[key] = value assert data[key] == value assert data.test_value == value del data[key] del data[key] # Deleting unset attributes should work as well. assert data.get(key, 10) == 10 assert len([event for event in data]) == 3 assert len([attr for attr in data()]) == 5 assert data.size() == (2, 5) del data.src assert 'src' not in data def test_train_val_test_split(): data = get_temporal_data(num_events=100, msg_channels=16) train_data, val_data, test_data = data.train_val_test_split( val_ratio=0.2, test_ratio=0.15) assert len(train_data) == 65 assert len(val_data) == 20 assert len(test_data) == 15 assert train_data.t.max() < val_data.t.min() assert val_data.t.max() < test_data.t.min() def test_temporal_indexing(): data = get_temporal_data(num_events=10, msg_channels=16) elem = data[0] assert isinstance(elem, TemporalData) assert len(elem) == 1 assert elem.src.tolist() == data.src[0:1].tolist() assert elem.dst.tolist() == data.dst[0:1].tolist() assert elem.t.tolist() == data.t[0:1].tolist() assert elem.msg.tolist() == data.msg[0:1].tolist() assert elem.y.tolist() == data.y[0:1].tolist() subset = data[0:5] assert isinstance(subset, TemporalData) assert len(subset) == 5 assert subset.src.tolist() == data.src[0:5].tolist() assert subset.dst.tolist() == data.dst[0:5].tolist() assert subset.t.tolist() == data.t[0:5].tolist() assert subset.msg.tolist() == data.msg[0:5].tolist() assert subset.y.tolist() == data.y[0:5].tolist() index = [0, 4, 8] subset = data[torch.tensor(index)] assert isinstance(subset, TemporalData) assert len(subset) == 3 assert subset.src.tolist() == data.src[0::4].tolist() assert subset.dst.tolist() == data.dst[0::4].tolist() assert subset.t.tolist() == data.t[0::4].tolist() assert subset.msg.tolist() == data.msg[0::4].tolist() assert subset.y.tolist() == data.y[0::4].tolist() mask = [True, False, True, False, True, False, True, False, True, False] subset = data[torch.tensor(mask)] assert isinstance(subset, TemporalData) assert len(subset) == 5 assert subset.src.tolist() == data.src[0::2].tolist() assert subset.dst.tolist() == data.dst[0::2].tolist() assert subset.t.tolist() == data.t[0::2].tolist() assert subset.msg.tolist() == data.msg[0::2].tolist() assert subset.y.tolist() == data.y[0::2].tolist() ================================================ FILE: test/data/test_view.py ================================================ from torch_geometric.data.storage import BaseStorage def test_views(): storage = BaseStorage(x=1, y=2, z=3) assert str(storage.keys()) == "KeysView({'x': 1, 'y': 2, 'z': 3})" assert len(storage.keys()) == 3 assert list(storage.keys()) == ['x', 'y', 'z'] assert str(storage.values()) == "ValuesView({'x': 1, 'y': 2, 'z': 3})" assert len(storage.values()) == 3 assert list(storage.values()) == [1, 2, 3] assert str(storage.items()) == "ItemsView({'x': 1, 'y': 2, 'z': 3})" assert len(storage.items()) == 3 assert list(storage.items()) == [('x', 1), ('y', 2), ('z', 3)] args = ['x', 'z', 'foo'] assert str(storage.keys(*args)) == "KeysView({'x': 1, 'z': 3})" assert len(storage.keys(*args)) == 2 assert list(storage.keys(*args)) == ['x', 'z'] assert str(storage.values(*args)) == "ValuesView({'x': 1, 'z': 3})" assert len(storage.values(*args)) == 2 assert list(storage.values(*args)) == [1, 3] assert str(storage.items(*args)) == "ItemsView({'x': 1, 'z': 3})" assert len(storage.items(*args)) == 2 assert list(storage.items(*args)) == [('x', 1), ('z', 3)] ================================================ FILE: test/datasets/graph_generator/test_ba_graph.py ================================================ from torch_geometric.datasets.graph_generator import BAGraph def test_ba_graph(): graph_generator = BAGraph(num_nodes=300, num_edges=5) assert str(graph_generator) == 'BAGraph(num_nodes=300, num_edges=5)' data = graph_generator() assert len(data) == 2 assert data.num_nodes == 300 assert data.num_edges <= 2 * 300 * 5 ================================================ FILE: test/datasets/graph_generator/test_er_graph.py ================================================ from torch_geometric.datasets.graph_generator import ERGraph def test_er_graph(): graph_generator = ERGraph(num_nodes=300, edge_prob=0.1) assert str(graph_generator) == 'ERGraph(num_nodes=300, edge_prob=0.1)' data = graph_generator() assert len(data) == 2 assert data.num_nodes == 300 assert data.num_edges >= 300 * 300 * 0.05 assert data.num_edges <= 300 * 300 * 0.15 ================================================ FILE: test/datasets/graph_generator/test_grid_graph.py ================================================ from torch_geometric.datasets.graph_generator import GridGraph def test_grid_graph(): graph_generator = GridGraph(height=10, width=10) assert str(graph_generator) == 'GridGraph(height=10, width=10)' data = graph_generator() assert len(data) == 2 assert data.num_nodes == 100 assert data.num_edges == 784 ================================================ FILE: test/datasets/graph_generator/test_tree_graph.py ================================================ import pytest from torch_geometric.datasets.graph_generator import TreeGraph @pytest.mark.parametrize('undirected', [False, True]) def test_tree_graph(undirected): graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected) assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, ' f'undirected={undirected})') data = graph_generator() assert len(data) == 3 assert data.num_nodes == 7 assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2] if not undirected: assert data.edge_index.tolist() == [ [0, 0, 1, 1, 2, 2], [1, 2, 3, 4, 5, 6], ] else: assert data.edge_index.tolist() == [ [0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6], [1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2], ] ================================================ FILE: test/datasets/motif_generator/test_custom_motif.py ================================================ import pytest import torch from torch_geometric.data import Data from torch_geometric.datasets.motif_generator import CustomMotif from torch_geometric.testing import withPackage def test_custom_motif_pyg_data(): structure = Data( num_nodes=3, edge_index=torch.tensor([[0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]]), ) motif_generator = CustomMotif(structure) assert str(motif_generator) == 'CustomMotif()' assert structure == motif_generator() @withPackage('networkx') def test_custom_motif_networkx(): import networkx as nx structure = nx.gnm_random_graph(5, 10, seed=2000) motif_generator = CustomMotif(structure) assert str(motif_generator) == 'CustomMotif()' out = motif_generator() assert len(out) == 2 assert out.num_nodes == 5 assert out.num_edges == 20 def test_custom_motif_unknown(): with pytest.raises(ValueError, match="motif structure of type"): CustomMotif(structure='unknown') ================================================ FILE: test/datasets/motif_generator/test_cycle_motif.py ================================================ from torch_geometric.datasets.motif_generator import CycleMotif def test_cycle_motif(): motif_generator = CycleMotif(5) assert str(motif_generator) == 'CycleMotif(5)' motif = motif_generator() assert len(motif) == 2 assert motif.num_nodes == 5 assert motif.num_edges == 10 assert motif.edge_index.tolist() == [ [0, 0, 1, 1, 2, 2, 3, 3, 4, 4], [1, 4, 0, 2, 1, 3, 2, 4, 0, 3], ] ================================================ FILE: test/datasets/motif_generator/test_grid_motif.py ================================================ from torch_geometric.datasets.motif_generator import GridMotif def test_grid_motif(): motif_generator = GridMotif() assert str(motif_generator) == 'GridMotif()' motif = motif_generator() assert len(motif) == 3 assert motif.num_nodes == 9 assert motif.num_edges == 24 assert motif.edge_index.size() == (2, 24) assert motif.edge_index.min() == 0 assert motif.edge_index.max() == 8 assert motif.y.size() == (9, ) assert motif.y.min() == 0 assert motif.y.max() == 2 ================================================ FILE: test/datasets/motif_generator/test_house_motif.py ================================================ from torch_geometric.datasets.motif_generator import HouseMotif def test_house_motif(): motif_generator = HouseMotif() assert str(motif_generator) == 'HouseMotif()' motif = motif_generator() assert len(motif) == 3 assert motif.num_nodes == 5 assert motif.num_edges == 12 assert motif.y.min() == 0 and motif.y.max() == 2 ================================================ FILE: test/datasets/test_ba_shapes.py ================================================ import pytest def test_ba_shapes(get_dataset): with pytest.warns(UserWarning, match="is deprecated"): dataset = get_dataset(name='BAShapes') assert str(dataset) == 'BAShapes()' assert len(dataset) == 1 assert dataset.num_features == 10 assert dataset.num_classes == 4 data = dataset[0] assert len(data) == 5 assert data.edge_index.size(1) >= 1120 assert data.x.size() == (700, 10) assert data.y.size() == (700, ) assert data.expl_mask.sum() == 60 assert data.edge_label.sum() == 960 ================================================ FILE: test/datasets/test_bzr.py ================================================ from torch_geometric.testing import onlyFullTest, onlyOnline @onlyOnline @onlyFullTest def test_bzr(get_dataset): dataset = get_dataset(name='BZR') assert len(dataset) == 405 assert dataset.num_features == 53 assert dataset.num_node_labels == 53 assert dataset.num_node_attributes == 3 assert dataset.num_classes == 2 assert str(dataset) == 'BZR(405)' assert len(dataset[0]) == 3 @onlyOnline @onlyFullTest def test_bzr_with_node_attr(get_dataset): dataset = get_dataset(name='BZR', use_node_attr=True) assert dataset.num_features == 56 assert dataset.num_node_labels == 53 assert dataset.num_node_attributes == 3 ================================================ FILE: test/datasets/test_elliptic.py ================================================ from torch_geometric.testing import onlyFullTest, onlyOnline @onlyOnline @onlyFullTest def test_elliptic_bitcoin_dataset(get_dataset): dataset = get_dataset(name='EllipticBitcoinDataset') assert str(dataset) == 'EllipticBitcoinDataset()' assert len(dataset) == 1 assert dataset.num_features == 165 assert dataset.num_classes == 2 data = dataset[0] assert len(data) == 5 assert data.x.size() == (203769, 165) assert data.edge_index.size() == (2, 234355) assert data.y.size() == (203769, ) assert data.train_mask.size() == (203769, ) assert data.train_mask.sum() > 0 assert data.test_mask.size() == (203769, ) assert data.test_mask.sum() > 0 assert data.train_mask.sum() + data.test_mask.sum() == 4545 + 42019 assert data.y[data.train_mask].sum() == 3462 assert data.y[data.test_mask].sum() == 1083 assert data.y[data.train_mask].sum() + data.y[data.test_mask].sum() == 4545 assert data.y[data.test_mask | data.train_mask].min() == 0 assert data.y[data.test_mask | data.train_mask].max() == 1 ================================================ FILE: test/datasets/test_enzymes.py ================================================ import pytest import torch from torch_geometric.loader import DataListLoader, DataLoader, DenseDataLoader from torch_geometric.testing import onlyOnline from torch_geometric.transforms import ToDense @onlyOnline def test_enzymes(get_dataset): dataset = get_dataset(name='ENZYMES') assert len(dataset) == 600 assert dataset.num_features == 3 assert dataset.num_classes == 6 assert str(dataset) == 'ENZYMES(600)' assert len(dataset[0]) == 3 assert len(dataset.shuffle()) == 600 assert len(dataset.shuffle(return_perm=True)) == 2 assert len(dataset[:100]) == 100 assert len(dataset[0.1:0.2]) == 60 assert len(dataset[torch.arange(100, dtype=torch.long)]) == 100 mask = torch.zeros(600, dtype=torch.bool) mask[:100] = 1 assert len(dataset[mask]) == 100 loader = DataLoader(dataset, batch_size=len(dataset)) for batch in loader: assert batch.num_graphs == len(batch) == 600 avg_num_nodes = batch.num_nodes / batch.num_graphs assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63 avg_num_edges = batch.num_edges / (2 * batch.num_graphs) assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14 assert list(batch.x.size()) == [batch.num_nodes, 3] assert list(batch.y.size()) == [batch.num_graphs] assert batch.y.max() + 1 == 6 assert list(batch.batch.size()) == [batch.num_nodes] assert batch.ptr.numel() == batch.num_graphs + 1 assert batch.has_isolated_nodes() assert not batch.has_self_loops() assert batch.is_undirected() loader = DataListLoader(dataset, batch_size=len(dataset)) for data_list in loader: assert len(data_list) == 600 dataset.transform = ToDense(num_nodes=126) loader = DenseDataLoader(dataset, batch_size=len(dataset)) for data in loader: assert list(data.x.size()) == [600, 126, 3] assert list(data.adj.size()) == [600, 126, 126] assert list(data.mask.size()) == [600, 126] assert list(data.y.size()) == [600, 1] @onlyOnline def test_enzymes_with_node_attr(get_dataset): dataset = get_dataset(name='ENZYMES', use_node_attr=True) assert dataset.num_node_features == 21 assert dataset.num_features == 21 assert dataset.num_edge_features == 0 @onlyOnline def test_cleaned_enzymes(get_dataset): dataset = get_dataset(name='ENZYMES', cleaned=True) assert len(dataset) == 595 ================================================ FILE: test/datasets/test_explainer_dataset.py ================================================ import pytest import torch from torch_geometric import seed_everything from torch_geometric.datasets import ExplainerDataset from torch_geometric.datasets.graph_generator import BAGraph from torch_geometric.datasets.motif_generator import HouseMotif @pytest.mark.parametrize('graph_generator', [ pytest.param(BAGraph(num_nodes=80, num_edges=5), id='BAGraph'), ]) @pytest.mark.parametrize('motif_generator', [ pytest.param(HouseMotif(), id='HouseMotif'), 'house', ]) def test_explainer_dataset_ba_house(graph_generator, motif_generator): dataset = ExplainerDataset(graph_generator, motif_generator, num_motifs=2) assert str(dataset) == ('ExplainerDataset(1, graph_generator=' 'BAGraph(num_nodes=80, num_edges=5), ' 'motif_generator=HouseMotif(), num_motifs=2)') assert len(dataset) == 1 data = dataset[0] assert len(data) == 4 assert data.num_nodes == 80 + (2 * 5) assert data.edge_index.min() >= 0 assert data.edge_index.max() < data.num_nodes assert data.y.min() == 0 and data.y.max() == 3 assert data.node_mask.size() == (data.num_nodes, ) assert data.edge_mask.size() == (data.num_edges, ) assert data.node_mask.min() == 0 and data.node_mask.max() == 1 assert data.node_mask.sum() == 2 * 5 assert data.edge_mask.min() == 0 and data.edge_mask.max() == 1 assert data.edge_mask.sum() == 2 * 12 def test_explainer_dataset_reproducibility(): seed_everything(12345) data1 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(), num_motifs=2)[0] seed_everything(12345) data2 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(), num_motifs=2)[0] assert torch.equal(data1.edge_index, data2.edge_index) ================================================ FILE: test/datasets/test_fake.py ================================================ import pytest from torch_geometric.datasets import FakeDataset, FakeHeteroDataset @pytest.mark.parametrize('num_graphs', [1, 10]) @pytest.mark.parametrize('edge_dim', [0, 1, 4]) @pytest.mark.parametrize('task', ['node', 'graph', 'auto']) def test_fake_dataset(num_graphs, edge_dim, task): dataset = FakeDataset(num_graphs, edge_dim=edge_dim, task=task, global_features=3) if num_graphs > 1: assert str(dataset) == f'FakeDataset({num_graphs})' else: assert str(dataset) == 'FakeDataset()' assert len(dataset) == num_graphs data = dataset[0] assert data.num_features == 64 if edge_dim == 0: assert len(data) == 4 elif edge_dim == 1: assert len(data) == 5 assert data.edge_weight.size() == (data.num_edges, ) assert data.edge_weight.min() >= 0 and data.edge_weight.max() < 1 else: assert len(data) == 5 assert data.edge_attr.size() == (data.num_edges, edge_dim) assert data.edge_attr.min() >= 0 and data.edge_attr.max() < 1 assert data.y.min() >= 0 and data.y.max() < 10 if task == 'node' or (task == 'auto' and num_graphs == 1): assert data.y.size() == (data.num_nodes, ) else: assert data.y.size() == (1, ) assert data.global_features.size() == (3, ) @pytest.mark.parametrize('num_graphs', [1, 10]) @pytest.mark.parametrize('edge_dim', [0, 1, 4]) @pytest.mark.parametrize('task', ['node', 'graph', 'auto']) def test_fake_hetero_dataset(num_graphs, edge_dim, task): dataset = FakeHeteroDataset(num_graphs, edge_dim=edge_dim, task=task, global_features=3) if num_graphs > 1: assert str(dataset) == f'FakeHeteroDataset({num_graphs})' else: assert str(dataset) == 'FakeHeteroDataset()' assert len(dataset) == num_graphs data = dataset[0] for store in data.node_stores: assert store.num_features > 0 if task == 'node' or (task == 'auto' and num_graphs == 1): if store._key == 'v0': assert store.y.min() >= 0 and store.y.max() < 10 assert store.y.size() == (store.num_nodes, ) for store in data.edge_stores: if edge_dim == 0: assert len(data) == 4 elif edge_dim == 1: assert len(data) == 5 assert store.edge_weight.size() == (store.num_edges, ) assert store.edge_weight.min() >= 0 and store.edge_weight.max() < 1 else: assert len(data) == 5 assert store.edge_attr.size() == (store.num_edges, edge_dim) assert store.edge_attr.min() >= 0 and store.edge_attr.max() < 1 if task == 'graph' or (task == 'auto' and num_graphs > 1): assert data.y.min() >= 0 and data.y.max() < 10 assert data.y.size() == (1, ) assert data.global_features.size() == (3, ) ================================================ FILE: test/datasets/test_git_mol_dataset.py ================================================ from typing import Tuple import pytest from torch_geometric.datasets import GitMolDataset from torch_geometric.testing import onlyFullTest, withPackage @onlyFullTest @withPackage('torchvision', 'rdkit', 'PIL') @pytest.mark.parametrize('split', [ (0, 3610), (1, 451), (2, 451), ]) def test_git_mol_dataset(split: Tuple[int, int]) -> None: dataset = GitMolDataset(root='./data/GITMol', split=split[0]) assert len(dataset) == split[1] assert dataset[0].image.size() == (1, 3, 224, 224) assert dataset[0].num_node_features == 9 assert dataset[0].num_edge_features == 3 ================================================ FILE: test/datasets/test_imdb_binary.py ================================================ from torch_geometric.testing import onlyFullTest, onlyOnline @onlyOnline @onlyFullTest def test_imdb_binary(get_dataset): dataset = get_dataset(name='IMDB-BINARY') assert len(dataset) == 1000 assert dataset.num_features == 0 assert dataset.num_classes == 2 assert str(dataset) == 'IMDB-BINARY(1000)' data = dataset[0] assert len(data) == 3 assert data.edge_index.size() == (2, 146) assert data.y.size() == (1, ) assert data.num_nodes == 20 ================================================ FILE: test/datasets/test_infection_dataset.py ================================================ import torch from torch_geometric import seed_everything from torch_geometric.data import Data from torch_geometric.datasets import InfectionDataset from torch_geometric.datasets.graph_generator import ERGraph, GraphGenerator class DummyGraph(GraphGenerator): def __call__(self) -> Data: edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8], ]) return Data(num_nodes=10, edge_index=edge_index) def test_infection_dataset(): seed_everything(12345) graph_generator = DummyGraph() dataset = InfectionDataset(graph_generator, num_infected_nodes=2, max_path_length=2) assert str(dataset) == ('InfectionDataset(1, ' 'graph_generator=DummyGraph(), ' 'num_infected_nodes=2, ' 'max_path_length=2)') assert len(dataset) == 1 data = dataset[0] assert len(data) == 4 assert data.x.size() == (10, 2) assert data.x[:, 0].sum() == 8 and data.x[:, 1].sum() == 2 assert torch.equal(data.edge_index, graph_generator().edge_index) assert data.y.size() == (10, ) # With `seed=12345`, node 0 and node 7 will be infected: assert data.x[0].tolist() == [0, 1] # First infected node. assert data.x[7].tolist() == [0, 1] # Second infected node. assert data.y.tolist() == [0, 1, 2, 3, 3, 2, 1, 0, 1, 2] assert data.edge_mask.tolist() == [ 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0 ] def test_infection_dataset_reproducibility(): graph_generator = ERGraph(num_nodes=500, edge_prob=0.004) seed_everything(12345) dataset1 = InfectionDataset(graph_generator, num_infected_nodes=50, max_path_length=5) seed_everything(12345) dataset2 = InfectionDataset(graph_generator, num_infected_nodes=50, max_path_length=5) assert torch.equal(dataset1[0].edge_mask, dataset2[0].edge_mask) ================================================ FILE: test/datasets/test_karate.py ================================================ def test_karate(get_dataset): dataset = get_dataset(name='KarateClub') assert str(dataset) == 'KarateClub()' assert len(dataset) == 1 assert dataset.num_features == 34 assert dataset.num_classes == 4 assert len(dataset[0]) == 4 assert dataset[0].edge_index.size() == (2, 156) assert dataset[0].x.size() == (34, 34) assert dataset[0].y.size() == (34, ) assert dataset[0].train_mask.size() == (34, ) assert dataset[0].train_mask.sum().item() == 4 ================================================ FILE: test/datasets/test_medshapenet.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.datasets import MedShapeNet from torch_geometric.testing import withPackage @withPackage('MedShapeNet') def test_medshapenet(): dataset = MedShapeNet(root="./data/MedShapeNet", size=1) assert str(dataset) == f'MedShapeNet({len(dataset)})' assert isinstance(dataset[0], Data) assert dataset.num_classes == 8 assert isinstance(dataset[0].pos, torch.Tensor) assert len(dataset[0].pos) > 0 assert isinstance(dataset[0].face, torch.Tensor) assert len(dataset[0].face) == 3 assert isinstance(dataset[0].y, torch.Tensor) assert len(dataset[0].y) == 1 ================================================ FILE: test/datasets/test_molecule_gpt_dataset.py ================================================ from torch_geometric.datasets import MoleculeGPTDataset from torch_geometric.testing import onlyOnline, withPackage @onlyOnline @withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit') def test_molecule_gpt_dataset(): dataset = MoleculeGPTDataset( root='./data/MoleculeGPT', num_units=10, ) assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})' assert dataset.num_edge_features == 4 assert dataset.num_node_features == 6 ================================================ FILE: test/datasets/test_mutag.py ================================================ from torch_geometric.testing import onlyOnline @onlyOnline def test_mutag(get_dataset): dataset = get_dataset(name='MUTAG') assert len(dataset) == 188 assert dataset.num_features == 7 assert dataset.num_classes == 2 assert str(dataset) == 'MUTAG(188)' assert len(dataset[0]) == 4 assert dataset[0].edge_attr.size(1) == 4 @onlyOnline def test_mutag_with_node_attr(get_dataset): dataset = get_dataset(name='MUTAG', use_node_attr=True) assert dataset.num_features == 7 ================================================ FILE: test/datasets/test_planetoid.py ================================================ from torch_geometric.loader import DataLoader from torch_geometric.testing import onlyOnline, withPackage @onlyOnline @withPackage('scipy') def test_citeseer(get_dataset): dataset = get_dataset(name='CiteSeer') loader = DataLoader(dataset, batch_size=len(dataset)) assert len(dataset) == 1 assert str(dataset) == 'CiteSeer()' for batch in loader: assert batch.num_graphs == len(batch) == 1 assert batch.num_nodes == 3327 assert batch.num_edges / 2 == 4552 assert list(batch.x.size()) == [batch.num_nodes, 3703] assert list(batch.y.size()) == [batch.num_nodes] assert batch.y.max() + 1 == 6 assert batch.train_mask.sum() == 6 * 20 assert batch.val_mask.sum() == 500 assert batch.test_mask.sum() == 1000 assert (batch.train_mask & batch.val_mask & batch.test_mask).sum() == 0 assert list(batch.batch.size()) == [batch.num_nodes] assert batch.ptr.tolist() == [0, batch.num_nodes] assert batch.has_isolated_nodes() assert not batch.has_self_loops() assert batch.is_undirected() @onlyOnline @withPackage('scipy') def test_citeseer_with_full_split(get_dataset): dataset = get_dataset(name='CiteSeer', split='full') data = dataset[0] assert data.val_mask.sum() == 500 assert data.test_mask.sum() == 1000 assert data.train_mask.sum() == data.num_nodes - 1500 assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0 @onlyOnline @withPackage('scipy') def test_citeseer_with_random_split(get_dataset): dataset = get_dataset( name='CiteSeer', split='random', num_train_per_class=11, num_val=29, num_test=41, ) data = dataset[0] # from torch_geometric import EdgeIndex # assert isinstance(data.edge_index, EdgeIndex) # assert data.edge_index.sparse_size() == (data.num_nodes, data.num_nodes) # assert data.edge_index.is_undirected # assert data.edge_index.is_sorted_by_col assert data.train_mask.sum() == dataset.num_classes * 11 assert data.val_mask.sum() == 29 assert data.test_mask.sum() == 41 assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0 ================================================ FILE: test/datasets/test_protein_mpnn_dataset.py ================================================ from torch_geometric.datasets import ProteinMPNNDataset from torch_geometric.testing import onlyLinux, onlyOnline, withPackage @onlyLinux @onlyOnline @withPackage('pandas') def test_protein_mpnn_dataset(): dataset = ProteinMPNNDataset(root='./data/ProteinMPNN') assert len(dataset) == 150 assert dataset[0].x.size() == (229, 4, 3) assert dataset[0].chain_seq_label.size() == (229, ) assert dataset[0].mask.size() == (229, ) assert dataset[0].chain_mask_all.size() == (229, ) assert dataset[0].residue_idx.size() == (229, ) assert dataset[0].chain_encoding_all.size() == (229, ) ================================================ FILE: test/datasets/test_snap_dataset.py ================================================ from torch_geometric.testing import onlyFullTest, onlyOnline @onlyOnline @onlyFullTest def test_ego_facebook_snap_dataset(get_dataset): import warnings import torch from packaging import version if version.parse(torch.__version__) >= version.parse("2.2.0"): try: from torch.serialization import add_safe_globals from torch_geometric.datasets.snap_dataset import EgoData add_safe_globals([EgoData]) except ImportError: warnings.warn( "add_safe_globals is expected but not found in " "torch.serialization.", stacklevel=2) else: warnings.warn( "add_safe_globals is not available in this version " "of PyTorch; continuing without it.", stacklevel=2) dataset = get_dataset(name='ego-facebook') assert str(dataset) == 'SNAP-ego-facebook(10)' assert len(dataset) == 10 @onlyOnline @onlyFullTest def test_soc_slashdot_snap_dataset(get_dataset): dataset = get_dataset(name='soc-Slashdot0811') assert str(dataset) == 'SNAP-soc-slashdot0811(1)' assert len(dataset) == 1 @onlyOnline @onlyFullTest def test_wiki_vote_snap_dataset(get_dataset): dataset = get_dataset(name='wiki-vote') assert str(dataset) == 'SNAP-wiki-vote(1)' assert len(dataset) == 1 ================================================ FILE: test/datasets/test_suite_sparse.py ================================================ from torch_geometric.testing import onlyFullTest, onlyOnline @onlyOnline @onlyFullTest def test_suite_sparse_dataset(get_dataset): dataset = get_dataset(group='DIMACS10', name='citationCiteseer') assert str(dataset) == ('SuiteSparseMatrixCollection(' 'group=DIMACS10, name=citationCiteseer)') assert len(dataset) == 1 @onlyOnline @onlyFullTest def test_illc1850_suite_sparse_dataset(get_dataset): dataset = get_dataset(group='HB', name='illc1850') assert str(dataset) == ('SuiteSparseMatrixCollection(' 'group=HB, name=illc1850)') assert len(dataset) == 1 ================================================ FILE: test/datasets/test_tag_dataset.py ================================================ from torch_geometric.datasets import TAGDataset from torch_geometric.testing import onlyFullTest, withPackage @onlyFullTest @withPackage('ogb') def test_tag_dataset() -> None: from ogb.nodeproppred import PygNodePropPredDataset root = './data/ogb' hf_model = 'prajjwal1/bert-tiny' token_on_disk = True dataset = PygNodePropPredDataset('ogbn-arxiv', root=root) tag_dataset = TAGDataset(root, dataset, hf_model, token_on_disk=token_on_disk) assert 169343 == tag_dataset[0].num_nodes \ == len(tag_dataset.text) \ == len(tag_dataset.llm_explanation) assert 1166243 == tag_dataset[0].num_edges ================================================ FILE: test/datasets/test_teeth3ds.py ================================================ from torch_geometric.data import Data from torch_geometric.datasets import Teeth3DS from torch_geometric.testing import withPackage @withPackage('trimesh', 'fpsample') def test_teeth3ds(tmp_path) -> None: dataset = Teeth3DS(root=tmp_path, split='sample', train=True) assert len(dataset) > 0 data = dataset[0] assert isinstance(data, Data) assert data.pos.size(1) == 3 assert data.x.size(0) == data.pos.size(0) assert data.y.size(0) == data.pos.size(0) assert isinstance(data.jaw, str) ================================================ FILE: test/datasets/test_web_qsp_dataset.py ================================================ import os import random import string import pytest from torch_geometric.datasets import WebQSPDataset from torch_geometric.datasets.web_qsp_dataset import KGQABaseDataset from torch_geometric.testing import onlyFullTest, onlyOnline, withPackage @pytest.mark.skip(reason="Times out") @onlyOnline @onlyFullTest @withPackage("datasets", "pandas") def test_web_qsp_dataset(tmp_path): # Split for this dataset is 2826 train | 246 val | 1628 test # default split is train dataset_val = WebQSPDataset(root=tmp_path, split="val") assert len(dataset_val) == 246 assert str(dataset_val) == "WebQSPDataset(246)" class MockSentenceTransformer: def __init__(self, *args, **kwargs): pass def to(self, device): return self def eval(self): return self def encode(self, sentences, batch_size=None, output_device=None): import torch def string_to_tensor(s: str) -> torch.Tensor: return torch.ones(1024).float() if isinstance(sentences, str): return string_to_tensor(sentences) return torch.stack([string_to_tensor(s) for s in sentences]) def create_mock_graphs(tmp_path: str, train_size: int, val_size: int, test_size: int, num_nodes: int, num_edge_types: int, num_trips: int, seed: int = 42): random.seed(seed) strkeys = string.ascii_letters + string.digits qa_strkeys = string.ascii_letters + string.digits + " " def create_mock_triplets(num_nodes: int, num_edges: int, num_trips: int): nodes = list( {"".join(random.sample(strkeys, 10)) for i in range(num_nodes)}) edges = list( {"".join(random.sample(strkeys, 10)) for i in range(num_edges)}) triplets = [] for _ in range(num_trips): h = random.randint(0, num_nodes - 1) t = random.randint(0, num_nodes - 1) r = random.randint(0, num_edge_types - 1) triplets.append((nodes[h], edges[r], nodes[t])) return triplets train_triplets = [ create_mock_triplets(num_nodes, num_edge_types, num_trips) for _ in range(train_size) ] val_triplets = [ create_mock_triplets(num_nodes, num_edge_types, num_trips) for _ in range(val_size) ] test_triplets = [ create_mock_triplets(num_nodes, num_edge_types, num_trips) for _ in range(test_size) ] train_questions = [ "".join(random.sample(qa_strkeys, 10)) for _ in range(train_size) ] val_questions = [ "".join(random.sample(qa_strkeys, 10)) for _ in range(val_size) ] test_questions = [ "".join(random.sample(qa_strkeys, 10)) for _ in range(test_size) ] train_answers = [ "".join(random.sample(qa_strkeys, 10)) for _ in range(train_size) ] val_answers = [ "".join(random.sample(qa_strkeys, 10)) for _ in range(val_size) ] test_answers = [ "".join(random.sample(qa_strkeys, 10)) for _ in range(test_size) ] train_graphs = { "graph": train_triplets, "question": train_questions, "answer": train_answers } val_graphs = { "graph": val_triplets, "question": val_questions, "answer": val_answers } test_graphs = { "graph": test_triplets, "question": test_questions, "answer": test_answers } from datasets import Dataset, DatasetDict, load_from_disk ds_train = Dataset.from_dict(train_graphs, split="train") ds_val = Dataset.from_dict(val_graphs, split="validation") ds_test = Dataset.from_dict(test_graphs, split="test") ds = DatasetDict({ "train": ds_train, "validation": ds_val, "test": ds_test }) def mock_load_dataset(path: str): # Save the dataset and then load it to emulate downloading from HF DATASET_CACHE_DIR = os.path.join(tmp_path, ".cache/huggingface/datasets", path) os.makedirs(DATASET_CACHE_DIR, exist_ok=True) ds.save_to_disk(DATASET_CACHE_DIR) dataset_remote = load_from_disk(DATASET_CACHE_DIR) return dataset_remote return mock_load_dataset, ds @pytest.mark.rag @withPackage("datasets", "pandas") def test_kgqa_base_dataset(tmp_path, monkeypatch): num_nodes = 500 num_edge_types = 25 num_trips = 5000 # Mock the dataset graphs mock_load_dataset_func, expected_result = create_mock_graphs( tmp_path, train_size=10, val_size=5, test_size=5, num_nodes=num_nodes, num_edge_types=num_edge_types, num_trips=num_trips) import datasets monkeypatch.setattr(datasets, "load_dataset", mock_load_dataset_func) # Mock the SentenceTransformer import torch_geometric.datasets.web_qsp_dataset monkeypatch.setattr(torch_geometric.datasets.web_qsp_dataset, "SentenceTransformer", MockSentenceTransformer) dataset_train = KGQABaseDataset(root=tmp_path, dataset_name="TestDataset", split="train", use_pcst=False) assert len(dataset_train) == 10 assert str(dataset_train) == "KGQABaseDataset(10)" for graph in dataset_train: assert graph.x.shape == (num_nodes, 1024) assert graph.edge_index.shape == (2, num_trips) assert graph.edge_attr.shape == ( num_trips, 1024) # Reminder: edge_attr encodes the entire triplet dataset_val = KGQABaseDataset(root=tmp_path, dataset_name="TestDataset", split="val", use_pcst=False) assert len(dataset_val) == 5 assert str(dataset_val) == "KGQABaseDataset(5)" dataset_test = KGQABaseDataset(root=tmp_path, dataset_name="TestDataset", split="test", use_pcst=False) assert len(dataset_test) == 5 assert str(dataset_test) == "KGQABaseDataset(5)" # TODO(zaristei): More rigorous tests to validate that values are correct # TODO(zaristei): Proper tests for PCST and CWQ ================================================ FILE: test/distributed/test_dist_link_neighbor_loader.py ================================================ import socket from typing import Tuple import pytest import torch import torch.multiprocessing as mp from torch_geometric.data import Data, HeteroData from torch_geometric.datasets import FakeDataset, FakeHeteroDataset from torch_geometric.distributed import ( DistContext, DistLinkNeighborLoader, DistNeighborSampler, LocalFeatureStore, LocalGraphStore, Partitioner, ) from torch_geometric.testing import onlyDistributedTest, withMETIS from torch_geometric.testing.distributed import ProcArgs, assert_run_mproc def create_dist_data(tmp_path: str, rank: int): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) return feat_store, graph_store def dist_link_neighbor_loader_homo( world_size: int, tmp_path: str, rank: int, master_addr: str, master_port: int, num_workers: int, async_sampling: bool, neg_ratio: float, ): part_data = create_dist_data(tmp_path, rank) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-loader-test', ) edge_label_index = part_data[1].get_edge_index(None, 'coo') edge_label = torch.randint(high=2, size=(edge_label_index.size(1), )) loader = DistLinkNeighborLoader( data=part_data, edge_label_index=(None, edge_label_index), edge_label=edge_label if neg_ratio is not None else None, num_neighbors=[1], batch_size=10, num_workers=num_workers, master_addr=master_addr, master_port=master_port, current_ctx=current_ctx, concurrency=10, drop_last=True, async_sampling=async_sampling, ) assert str(loader).startswith('DistLinkNeighborLoader') assert str(mp.current_process().pid) in str(loader) assert isinstance(loader.dist_sampler, DistNeighborSampler) assert not part_data[0].meta['is_hetero'] for batch in loader: assert isinstance(batch, Data) assert batch.n_id.size() == (batch.num_nodes, ) assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert loader.channel.empty() def dist_link_neighbor_loader_hetero( world_size: int, tmp_path: str, rank: int, master_addr: str, master_port: int, num_workers: int, async_sampling: bool, neg_ratio: float, edge_type: Tuple[str, str, str], ): part_data = create_dist_data(tmp_path, rank) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name="dist-loader-test", ) edge_label_index = part_data[1].get_edge_index(edge_type, 'coo') edge_label = torch.randint(high=2, size=(edge_label_index.size(1), )) loader = DistLinkNeighborLoader( data=part_data, edge_label_index=(edge_type, edge_label_index), edge_label=edge_label if neg_ratio is not None else None, num_neighbors=[1], batch_size=10, num_workers=num_workers, master_addr=master_addr, master_port=master_port, current_ctx=current_ctx, concurrency=10, drop_last=True, async_sampling=async_sampling, ) assert str(loader).startswith('DistLinkNeighborLoader') assert str(mp.current_process().pid) in str(loader) assert isinstance(loader.dist_sampler, DistNeighborSampler) assert part_data[0].meta['is_hetero'] for batch in loader: assert isinstance(batch, HeteroData) assert len(batch.node_types) == 2 for node_type in batch.node_types: assert torch.equal(batch[node_type].x, batch.x_dict[node_type]) assert batch.x_dict[node_type].size(0) >= 0 assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes assert len(batch.edge_types) == 4 for key in batch.edge_types: if key[-1] == 'v0': assert batch[key].num_sampled_edges[0] > 0 assert batch[key].edge_attr.size(0) == batch[key].num_edges else: assert batch[key].num_sampled_edges[0] == 0 assert loader.channel.empty() @withMETIS @onlyDistributedTest @pytest.mark.parametrize('num_parts', [2]) @pytest.mark.parametrize('num_workers', [0]) @pytest.mark.parametrize('async_sampling', [True]) @pytest.mark.parametrize('neg_ratio', [None]) def test_dist_link_neighbor_loader_homo( tmp_path, num_parts, num_workers, async_sampling, neg_ratio, ): addr = '127.0.0.1' mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, edge_dim=2, )[0] partitioner = Partitioner(data, num_parts, tmp_path) partitioner.generate_partition() procs = [ ProcArgs( target=dist_link_neighbor_loader_homo, args=(tmp_path, part, addr, port, num_workers, async_sampling, neg_ratio), ) for part in range(num_parts) ] assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('num_parts', [2]) @pytest.mark.parametrize('num_workers', [0]) @pytest.mark.parametrize('async_sampling', [True]) @pytest.mark.parametrize('neg_ratio', [None]) @pytest.mark.parametrize('edge_type', [('v0', 'e0', 'v0')]) def test_dist_link_neighbor_loader_hetero( tmp_path, num_parts, num_workers, async_sampling, neg_ratio, edge_type, ): mp_context = torch.multiprocessing.get_context('spawn') addr = '127.0.0.1' with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] partitioner = Partitioner(data, num_parts, tmp_path) partitioner.generate_partition() procs = [ ProcArgs( target=dist_link_neighbor_loader_hetero, args=(tmp_path, part, addr, port, num_workers, async_sampling, neg_ratio, edge_type), ) for part in range(num_parts) ] assert_run_mproc(mp_context, procs) ================================================ FILE: test/distributed/test_dist_link_neighbor_sampler.py ================================================ import atexit import socket from typing import Optional import pytest import torch import torch_geometric.transforms as T from torch_geometric.data import Data from torch_geometric.datasets import FakeHeteroDataset from torch_geometric.distributed import ( DistNeighborSampler, LocalFeatureStore, LocalGraphStore, Partitioner, ) from torch_geometric.distributed.dist_context import DistContext from torch_geometric.distributed.event_loop import ConcurrentEventLoop from torch_geometric.distributed.rpc import init_rpc, shutdown_rpc from torch_geometric.sampler import EdgeSamplerInput, NeighborSampler from torch_geometric.sampler.neighbor_sampler import edge_sample from torch_geometric.testing import onlyDistributedTest, withMETIS from torch_geometric.testing.distributed import ProcArgs, assert_run_mproc from torch_geometric.typing import EdgeType def create_data(rank, world_size, time_attr: Optional[str] = None): if rank == 0: # Partition 0: node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9]) edge_index = torch.tensor([ # Sorted by destination. [1, 2, 3, 4, 5, 0, 0], [0, 1, 2, 3, 4, 4, 9], ]) else: # Partition 1: node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9]) edge_index = torch.tensor([ # Sorted by destination. [5, 6, 7, 8, 9, 5, 0], [4, 5, 6, 7, 8, 9, 9], ]) feature_store = LocalFeatureStore.from_data(node_id) graph_store = LocalGraphStore.from_data( edge_id=None, edge_index=edge_index, num_nodes=10, is_sorted=True, ) graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) graph_store.meta.update({'num_parts': 2}) graph_store.partition_idx = rank graph_store.num_partitions = world_size edge_index = torch.tensor([ # Create reference data: [1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0], [0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9], ]) data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10) if time_attr == 'time': # Create node-level time data: data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4]) feature_store.put_tensor(data.time, group_name=None, attr_name='time') elif time_attr == 'edge_time': # Create edge-level time data: data.edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 11]) if rank == 0: edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 11]) if rank == 1: edge_time = torch.tensor([4, 7, 7, 7, 7, 7, 11]) feature_store.put_tensor(edge_time, group_name=None, attr_name=time_attr) return (feature_store, graph_store), data def create_hetero_data(tmp_path: str, rank: int): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) other_graph_store = LocalGraphStore.from_partition(tmp_path, int(not rank)) feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) return (feature_store, graph_store), other_graph_store def dist_link_neighbor_sampler( world_size: int, rank: int, master_port: int, disjoint: bool = False, ): dist_data, data = create_data(rank, world_size) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, num_neighbors=[-1, -1], shuffle=False, disjoint=disjoint, ) # Close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() if rank == 0: # Seed nodes: input_row = torch.tensor([1, 6], dtype=torch.int64) input_col = torch.tensor([2, 7], dtype=torch.int64) else: input_row = torch.tensor([4, 9], dtype=torch.int64) input_col = torch.tensor([5, 0], dtype=torch.int64) inputs = EdgeSamplerInput( input_id=None, row=input_row, col=input_col, input_type=None, ) # evaluate distributed edge sample function out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample( inputs, dist_sampler.node_sample, data.num_nodes, disjoint)) sampler = NeighborSampler(data=data, num_neighbors=[-1, -1], disjoint=disjoint) # Evaluate edge sample function: out = edge_sample( inputs, sampler._sample, data.num_nodes, disjoint, node_time=None, neg_sampling=None, ) # Compare distributed output with single machine output: assert torch.equal(out_dist.node, out.node) assert torch.equal(out_dist.row, out.row) assert torch.equal(out_dist.col, out.col) if disjoint: assert torch.equal(out_dist.batch, out.batch) assert out_dist.num_sampled_nodes == out.num_sampled_nodes assert out_dist.num_sampled_edges == out.num_sampled_edges def dist_link_neighbor_sampler_temporal( world_size: int, rank: int, master_port: int, seed_time: torch.tensor = None, temporal_strategy: str = 'uniform', time_attr: str = 'time', ): dist_data, data = create_data(rank, world_size, time_attr) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1] dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, num_neighbors=num_neighbors, shuffle=False, disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # Close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() if rank == 0: # Seed nodes: input_row = torch.tensor([1, 6], dtype=torch.int64) input_col = torch.tensor([2, 7], dtype=torch.int64) else: input_row = torch.tensor([4, 9], dtype=torch.int64) input_col = torch.tensor([5, 0], dtype=torch.int64) inputs = EdgeSamplerInput( input_id=None, row=input_row, col=input_col, time=seed_time, ) # Evaluate distributed edge sample function out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample( inputs, dist_sampler.node_sample, data.num_nodes, disjoint=True, node_time=seed_time, neg_sampling=None)) sampler = NeighborSampler( data=data, num_neighbors=num_neighbors, disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # Evaluate edge sample function out = edge_sample( inputs, sampler._sample, data.num_nodes, disjoint=True, node_time=seed_time, neg_sampling=None, ) # Compare distributed output with single machine output assert torch.equal(out_dist.node, out.node) assert torch.equal(out_dist.row, out.row) assert torch.equal(out_dist.col, out.col) assert torch.equal(out_dist.batch, out.batch) assert out_dist.num_sampled_nodes == out.num_sampled_nodes assert out_dist.num_sampled_edges == out.num_sampled_edges def dist_link_neighbor_sampler_hetero( world_size: int, data: FakeHeteroDataset, tmp_path: str, rank: int, master_port: int, input_type: EdgeType, disjoint: bool = False, ): dist_data, other_graph_store = create_hetero_data(tmp_path, rank) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, rpc_worker_names={}, num_neighbors=[-1], shuffle=False, disjoint=disjoint, ) # close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() # Create input rows/cols such that pairs belong to different partitions. # Edge from the current partition: edge_label_index1 = dist_data[1]._edge_index[(input_type, 'coo')] row_0 = edge_label_index1[0][0] col_0 = edge_label_index1[1][0] # Edge from the other partition: edge_label_index2 = other_graph_store._edge_index[(input_type, 'coo')] row_1 = edge_label_index2[0][0] col_1 = edge_label_index2[1][0] # Seed edges: input_row = torch.tensor([row_0, row_1]) input_col = torch.tensor([col_0, col_1]) inputs = EdgeSamplerInput( input_id=None, row=input_row, col=input_col, input_type=input_type, ) # Evaluate distributed `node_sample` function: out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample( inputs, dist_sampler.node_sample, data.num_nodes, disjoint)) sampler = NeighborSampler( data=data, num_neighbors=[-1], disjoint=disjoint, ) # Evaluate edge sample function: out = edge_sample( inputs, sampler._sample, data.num_nodes, disjoint, ) # Compare distributed output with single machine output: for k in data.node_types: assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0]) assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k] if disjoint: assert torch.equal( out_dist.batch[k].sort()[0], out.batch[k].sort()[0], ) def dist_link_neighbor_sampler_temporal_hetero( world_size: int, data: FakeHeteroDataset, tmp_path: str, rank: int, master_port: int, input_type: EdgeType, seed_time: torch.tensor = None, temporal_strategy: str = 'uniform', time_attr: str = 'time', ): dist_data, other_graph_store = create_hetero_data(tmp_path, rank) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, rpc_worker_names={}, num_neighbors=[-1], shuffle=False, disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() # Create input rows/cols such that pairs belong to different partitions. # Edge from the current partition: edge_label_index1 = dist_data[1]._edge_index[(input_type, 'coo')] row_0 = edge_label_index1[0][0] col_0 = edge_label_index1[1][0] # Edge from the other partition: edge_label_index2 = other_graph_store._edge_index[(input_type, 'coo')] row_1 = edge_label_index2[0][0] col_1 = edge_label_index2[1][0] # Seed nodes: input_row = torch.tensor([row_0, row_1], dtype=torch.int64) input_col = torch.tensor([col_0, col_1], dtype=torch.int64) inputs = EdgeSamplerInput( input_id=None, row=input_row, col=input_col, time=seed_time, input_type=input_type, ) # Evaluate distributed node sample function: out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample( inputs, dist_sampler.node_sample, data.num_nodes, disjoint=True)) sampler = NeighborSampler( data=data, num_neighbors=[-1], disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # Evaluate edge sample function: out = edge_sample( inputs, sampler._sample, data.num_nodes, disjoint=True, node_time=seed_time, neg_sampling=None, ) # Compare distributed output with single machine output: for k in data.node_types: assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0]) assert torch.equal(out_dist.batch[k].sort()[0], out.batch[k].sort()[0]) assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k] @onlyDistributedTest @pytest.mark.parametrize('disjoint', [False, True]) def test_dist_link_neighbor_sampler(disjoint): mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] procs = [ ProcArgs(target=dist_link_neighbor_sampler, args=(0, port, disjoint)), ProcArgs(target=dist_link_neighbor_sampler, args=(1, port, disjoint)), ] assert_run_mproc(mp_context, procs) @onlyDistributedTest @pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) def test_dist_link_neighbor_sampler_temporal(seed_time, temporal_strategy): mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] procs = [ ProcArgs( target=dist_link_neighbor_sampler_temporal, args=(0, port, seed_time, temporal_strategy, 'time'), ), ProcArgs( target=dist_link_neighbor_sampler_temporal, args=(1, port, seed_time, temporal_strategy, 'time'), ), ] assert_run_mproc(mp_context, procs) @onlyDistributedTest @pytest.mark.parametrize('seed_time', [[1, 1], [3, 7]]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) def test_dist_link_neighbor_sampler_edge_level_temporal( seed_time, temporal_strategy, ): seed_time = torch.tensor(seed_time) mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] procs = [ ProcArgs( target=dist_link_neighbor_sampler_temporal, args=(0, port, seed_time, temporal_strategy, 'edge_time'), ), ProcArgs( target=dist_link_neighbor_sampler_temporal, args=(1, port, seed_time, temporal_strategy, 'edge_time'), ), ] assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('disjoint', [False, True]) def test_dist_link_neighbor_sampler_hetero(tmp_path, disjoint): mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] data = T.ToUndirected()(data) procs = [ ProcArgs( target=dist_link_neighbor_sampler_hetero, args=(data, tmp_path, 0, port, ('v0', 'e0', 'v0'), disjoint), ), ProcArgs( target=dist_link_neighbor_sampler_hetero, args=(data, tmp_path, 1, port, ('v1', 'e0', 'v0'), disjoint), ), ] partitioner = Partitioner(data, len(procs), tmp_path) partitioner.generate_partition() assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('seed_time', [None, [0, 0], [3, 3]]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) def test_dist_link_neighbor_sampler_temporal_hetero( tmp_path, seed_time, temporal_strategy, ): if seed_time is not None: seed_time = torch.tensor(seed_time) mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] data = T.ToUndirected()(data) # Add time information to the data: data['v0'].time = torch.ones(data['v0'].num_nodes, dtype=torch.int64) data['v1'].time = torch.full((data['v1'].num_nodes, ), 2).long() procs = [ ProcArgs( target=dist_link_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 0, port, ('v0', 'e0', 'v0'), seed_time, temporal_strategy, 'time'), ), ProcArgs( target=dist_link_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 1, port, ('v1', 'e0', 'v0'), seed_time, temporal_strategy, 'time'), ), ] partitioner = Partitioner(data, len(procs), tmp_path) partitioner.generate_partition() assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('seed_time', [[0, 0], [3, 3]]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) def test_dist_link_neighbor_sampler_edge_level_temporal_hetero( tmp_path, seed_time, temporal_strategy, ): seed_time = torch.tensor(seed_time) mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] data = T.ToUndirected()(data) # Add time information to the data: for i, edge_type in enumerate(data.edge_types): data[edge_type].edge_time = torch.full( # (data[edge_type].num_edges, ), i, dtype=torch.int64) procs = [ ProcArgs( target=dist_link_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 0, port, ('v0', 'e0', 'v0'), seed_time, temporal_strategy, 'edge_time'), ), ProcArgs( target=dist_link_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 1, port, ('v0', 'e0', 'v1'), seed_time, temporal_strategy, 'edge_time'), ), ] partitioner = Partitioner(data, len(procs), tmp_path) partitioner.generate_partition() assert_run_mproc(mp_context, procs) ================================================ FILE: test/distributed/test_dist_neighbor_loader.py ================================================ import socket import warnings import pytest import torch import torch.multiprocessing as mp from torch_geometric.data import Data, HeteroData from torch_geometric.datasets import FakeDataset, FakeHeteroDataset from torch_geometric.distributed import ( DistContext, DistNeighborLoader, DistNeighborSampler, LocalFeatureStore, LocalGraphStore, Partitioner, ) from torch_geometric.testing import onlyDistributedTest, withMETIS from torch_geometric.testing.distributed import ProcArgs, assert_run_mproc def create_dist_data(tmp_path: str, rank: int): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) return feat_store, graph_store def dist_neighbor_loader_homo( world_size: int, tmp_path: str, rank: int, master_addr: str, master_port: int, num_workers: int, async_sampling: bool, ): part_data = create_dist_data(tmp_path, rank) input_nodes = part_data[0].get_global_id(None) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-loader-test', ) loader = DistNeighborLoader( part_data, num_neighbors=[1], batch_size=10, num_workers=num_workers, input_nodes=input_nodes, master_addr=master_addr, master_port=master_port, current_ctx=current_ctx, concurrency=10, drop_last=True, async_sampling=async_sampling, ) edge_index = part_data[1]._edge_index[(None, 'coo')] assert str(loader).startswith('DistNeighborLoader') assert str(mp.current_process().pid) in str(loader) assert isinstance(loader.dist_sampler, DistNeighborSampler) assert not part_data[0].meta['is_hetero'] for batch in loader: assert isinstance(batch, Data) assert batch.n_id.size() == (batch.num_nodes, ) assert batch.input_id.numel() == batch.batch_size == 10 assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert torch.equal( batch.n_id[batch.edge_index], edge_index[:, batch.e_id], ) assert loader.channel.empty() def dist_neighbor_loader_hetero( world_size: int, tmp_path: str, rank: int, master_addr: str, master_port: int, num_workers: int, async_sampling: bool, ): part_data = create_dist_data(tmp_path, rank) input_nodes = ('v0', part_data[0].get_global_id('v0')) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-loader-test', ) loader = DistNeighborLoader( part_data, num_neighbors=[1], batch_size=10, num_workers=num_workers, input_nodes=input_nodes, master_addr=master_addr, master_port=master_port, current_ctx=current_ctx, concurrency=10, drop_last=True, async_sampling=async_sampling, ) assert str(loader).startswith('DistNeighborLoader') assert str(mp.current_process().pid) in str(loader) assert isinstance(loader.dist_sampler, DistNeighborSampler) assert part_data[0].meta['is_hetero'] for batch in loader: assert isinstance(batch, HeteroData) assert batch['v0'].input_id.numel() == batch['v0'].batch_size == 10 assert len(batch.node_types) == 2 for node_type in batch.node_types: assert torch.equal(batch[node_type].x, batch.x_dict[node_type]) assert batch.x_dict[node_type].size(0) >= 0 assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes assert len(batch.edge_types) == 4 for edge_type in batch.edge_types: num_edges = batch[edge_type].edge_index.size(1) if num_edges > 0: # Test edge mapping: assert batch[edge_type].edge_attr.size(0) == num_edges src, _, dst = edge_type edge_index = part_data[1]._edge_index[(edge_type, "coo")] global_edge_index1 = torch.stack([ batch[src].n_id[batch[edge_type].edge_index[0]], batch[dst].n_id[batch[edge_type].edge_index[1]], ], dim=0) # TODO There is a current known flake, which we need to fix: e_id = batch[edge_type].e_id if e_id.numel() > 0 and e_id.max() >= edge_index.size(1): warnings.warn("Known test flake", stacklevel=2) else: global_edge_index2 = edge_index[:, e_id] if not torch.equal(global_edge_index1, global_edge_index2): warnings.warn("Known test flake", stacklevel=2) assert loader.channel.empty() @withMETIS @onlyDistributedTest @pytest.mark.parametrize('num_parts', [2]) @pytest.mark.parametrize('num_workers', [0]) @pytest.mark.parametrize('async_sampling', [True]) def test_dist_neighbor_loader_homo( tmp_path, num_parts, num_workers, async_sampling, ): mp_context = torch.multiprocessing.get_context('spawn') addr = '127.0.0.1' with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, edge_dim=2, )[0] partitioner = Partitioner(data, num_parts, tmp_path) partitioner.generate_partition() procs = [ ProcArgs( target=dist_neighbor_loader_homo, args=(tmp_path, part, addr, port, num_workers, async_sampling), ) for part in range(num_parts) ] assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('num_parts', [2]) @pytest.mark.parametrize('num_workers', [0]) @pytest.mark.parametrize('async_sampling', [True]) def test_dist_neighbor_loader_hetero( tmp_path, num_parts, num_workers, async_sampling, ): mp_context = torch.multiprocessing.get_context('spawn') addr = '127.0.0.1' with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] partitioner = Partitioner(data, num_parts, tmp_path) partitioner.generate_partition() procs = [ ProcArgs( target=dist_neighbor_loader_hetero, args=(tmp_path, part, addr, port, num_workers, async_sampling), ) for part in range(num_parts) ] assert_run_mproc(mp_context, procs) ================================================ FILE: test/distributed/test_dist_neighbor_sampler.py ================================================ import atexit import socket from typing import Optional import pytest import torch from torch_geometric.data import Data from torch_geometric.datasets import FakeHeteroDataset from torch_geometric.distributed import ( DistNeighborSampler, LocalFeatureStore, LocalGraphStore, Partitioner, ) from torch_geometric.distributed.dist_context import DistContext from torch_geometric.distributed.event_loop import ConcurrentEventLoop from torch_geometric.distributed.rpc import init_rpc, shutdown_rpc from torch_geometric.sampler import NeighborSampler, NodeSamplerInput from torch_geometric.sampler.neighbor_sampler import node_sample from torch_geometric.testing import onlyDistributedTest, withMETIS from torch_geometric.testing.distributed import ProcArgs, assert_run_mproc def create_data(rank: int, world_size: int, time_attr: Optional[str] = None): if rank == 0: # Partition 0: node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9]) edge_index = torch.tensor([ # Sorted by destination. [1, 2, 3, 4, 5, 0, 0], [0, 1, 2, 3, 4, 4, 9], ]) else: # Partition 1: node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9]) edge_index = torch.tensor([ # Sorted by destination. [5, 6, 7, 8, 9, 5, 0], [4, 5, 6, 7, 8, 9, 9], ]) feature_store = LocalFeatureStore.from_data(node_id) graph_store = LocalGraphStore.from_data( edge_id=None, edge_index=edge_index, num_nodes=10, is_sorted=True, ) graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) graph_store.meta.update({'num_parts': 2}) graph_store.partition_idx = rank graph_store.num_partitions = world_size edge_index = torch.tensor([ # Create reference data: [1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0], [0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9], ]) data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10) if time_attr == 'time': # Create node-level time data: data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4]) feature_store.put_tensor(data.time, group_name=None, attr_name=time_attr) elif time_attr == 'edge_time': # Create edge-level time data: data.edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 11]) if rank == 0: edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 11]) if rank == 1: edge_time = torch.tensor([4, 7, 7, 7, 7, 7, 11]) feature_store.put_tensor(edge_time, group_name=None, attr_name=time_attr) return (feature_store, graph_store), data def create_hetero_data( tmp_path: str, rank: int, ): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) return feature_store, graph_store def dist_neighbor_sampler( world_size: int, rank: int, master_port: int, disjoint: bool = False, ): dist_data, data = create_data(rank, world_size) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, num_neighbors=[-1, -1], shuffle=False, disjoint=disjoint, ) # Close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() if rank == 0: # Seed nodes: input_node = torch.tensor([1, 6]) else: input_node = torch.tensor([4, 9]) inputs = NodeSamplerInput(input_id=None, node=input_node) # Evaluate distributed node sample function: out_dist = dist_sampler.event_loop.run_task( coro=dist_sampler.node_sample(inputs)) sampler = NeighborSampler( data=data, num_neighbors=[-1, -1], disjoint=disjoint, ) # Evaluate node sample function: out = node_sample(inputs, sampler._sample) # Compare distributed output with single machine output: assert torch.equal(out_dist.node, out.node) assert torch.equal(out_dist.row, out.row) assert torch.equal(out_dist.col, out.col) if disjoint: assert torch.equal(out_dist.batch, out.batch) assert out_dist.num_sampled_nodes == out.num_sampled_nodes assert out_dist.num_sampled_edges == out.num_sampled_edges def dist_neighbor_sampler_temporal( world_size: int, rank: int, master_port: int, seed_time: torch.tensor = None, temporal_strategy: str = 'uniform', time_attr: str = 'time', ): dist_data, data = create_data(rank, world_size, time_attr) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1] dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, num_neighbors=num_neighbors, shuffle=False, disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # Close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() if rank == 0: # Seed nodes: input_node = torch.tensor([1, 6], dtype=torch.int64) else: input_node = torch.tensor([4, 9], dtype=torch.int64) inputs = NodeSamplerInput( input_id=None, node=input_node, time=seed_time, ) # Evaluate distributed node sample function: out_dist = dist_sampler.event_loop.run_task( coro=dist_sampler.node_sample(inputs)) sampler = NeighborSampler( data=data, num_neighbors=num_neighbors, disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # Evaluate node sample function: out = node_sample(inputs, sampler._sample) # Compare distributed output with single machine output: assert torch.equal(out_dist.node, out.node) assert torch.equal(out_dist.row, out.row) assert torch.equal(out_dist.col, out.col) assert torch.equal(out_dist.batch, out.batch) assert out_dist.num_sampled_nodes == out.num_sampled_nodes assert out_dist.num_sampled_edges == out.num_sampled_edges def dist_neighbor_sampler_hetero( world_size: int, data: FakeHeteroDataset, tmp_path: str, rank: int, master_port: int, input_type: str, disjoint: bool = False, ): dist_data = create_hetero_data(tmp_path, rank) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) num_neighbors = [-1, -1] dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, rpc_worker_names={}, num_neighbors=num_neighbors, shuffle=False, disjoint=disjoint, ) # Close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() # Create inputs nodes such that each belongs to a different partition node_pb_list = dist_data[1].node_pb[input_type].tolist() node_0 = node_pb_list.index(0) node_1 = node_pb_list.index(1) input_node = torch.tensor([node_0, node_1], dtype=torch.int64) inputs = NodeSamplerInput( input_id=None, node=input_node, input_type=input_type, ) # Evaluate distributed node sample function: out_dist = dist_sampler.event_loop.run_task( coro=dist_sampler.node_sample(inputs)) sampler = NeighborSampler( data=data, num_neighbors=num_neighbors, disjoint=disjoint, ) # Evaluate node sample function: out = node_sample(inputs, sampler._sample) # Compare distributed output with single machine output: for k in data.node_types: assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0]) assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k] if disjoint: assert torch.equal( out_dist.batch[k].sort()[0], out.batch[k].sort()[0], ) def dist_neighbor_sampler_temporal_hetero( world_size: int, data: FakeHeteroDataset, tmp_path: str, rank: int, master_port: int, input_type: str, seed_time: torch.tensor = None, temporal_strategy: str = 'uniform', time_attr: str = 'time', ): dist_data = create_hetero_data(tmp_path, rank) current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-sampler-test', ) dist_sampler = DistNeighborSampler( data=dist_data, current_ctx=current_ctx, rpc_worker_names={}, num_neighbors=[-1, -1], shuffle=False, disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # Close RPC & worker group at exit: atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, rpc_worker_names={}, master_addr='localhost', master_port=master_port, ) dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() dist_sampler.event_loop = ConcurrentEventLoop(2) dist_sampler.event_loop.start_loop() # Create inputs nodes such that each belongs to a different partition: node_pb_list = dist_data[1].node_pb[input_type].tolist() node_0 = node_pb_list.index(0) node_1 = node_pb_list.index(1) input_node = torch.tensor([node_0, node_1], dtype=torch.int64) inputs = NodeSamplerInput( input_id=None, node=input_node, time=seed_time, input_type=input_type, ) # Evaluate distributed node sample function: out_dist = dist_sampler.event_loop.run_task( coro=dist_sampler.node_sample(inputs)) sampler = NeighborSampler( data=data, num_neighbors=[-1, -1], disjoint=True, temporal_strategy=temporal_strategy, time_attr=time_attr, ) # Evaluate node sample function: out = node_sample(inputs, sampler._sample) # Compare distributed output with single machine output: for k in data.node_types: assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0]) assert torch.equal(out_dist.batch[k].sort()[0], out.batch[k].sort()[0]) assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k] @onlyDistributedTest @pytest.mark.parametrize('disjoint', [False, True]) def test_dist_neighbor_sampler(disjoint): mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] procs = [ ProcArgs(target=dist_neighbor_sampler, args=(0, port, disjoint)), ProcArgs(target=dist_neighbor_sampler, args=(1, port, disjoint)), ] assert_run_mproc(mp_context, procs) @onlyDistributedTest @pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])]) @pytest.mark.parametrize('temporal_strategy', ['uniform']) def test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy): mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] procs = [ ProcArgs( target=dist_neighbor_sampler_temporal, args=(0, port, seed_time, temporal_strategy, 'time'), ), ProcArgs( target=dist_neighbor_sampler_temporal, args=(1, port, seed_time, temporal_strategy, 'time'), ), ] assert_run_mproc(mp_context, procs) @onlyDistributedTest @pytest.mark.parametrize('seed_time', [[3, 7]]) @pytest.mark.parametrize('temporal_strategy', ['last']) def test_dist_neighbor_sampler_edge_level_temporal( seed_time, temporal_strategy, ): seed_time = torch.tensor(seed_time) mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] procs = [ ProcArgs( target=dist_neighbor_sampler_temporal, args=(0, port, seed_time, temporal_strategy, 'edge_time'), ), ProcArgs( target=dist_neighbor_sampler_temporal, args=(1, port, seed_time, temporal_strategy, 'edge_time'), ), ] assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('disjoint', [False, True]) def test_dist_neighbor_sampler_hetero(tmp_path, disjoint): mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] procs = [ ProcArgs( target=dist_neighbor_sampler_hetero, args=(data, tmp_path, 0, port, 'v0', disjoint), ), ProcArgs( target=dist_neighbor_sampler_hetero, args=(data, tmp_path, 1, port, 'v1', disjoint), ), ] partitioner = Partitioner(data, len(procs), tmp_path) partitioner.generate_partition() assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('seed_time', [None, [0, 0], [2, 2]]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) def test_dist_neighbor_sampler_temporal_hetero( tmp_path, seed_time, temporal_strategy, ): if seed_time is not None: seed_time = torch.tensor(seed_time) mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] data['v0'].time = torch.full((data.num_nodes_dict['v0'], ), 1, dtype=torch.int64) data['v1'].time = torch.full((data.num_nodes_dict['v1'], ), 2, dtype=torch.int64) procs = [ ProcArgs( target=dist_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 0, port, 'v0', seed_time, temporal_strategy, 'time'), ), ProcArgs( target=dist_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 1, port, 'v1', seed_time, temporal_strategy, 'time'), ), ] partitioner = Partitioner(data, len(procs), tmp_path) partitioner.generate_partition() assert_run_mproc(mp_context, procs) @withMETIS @onlyDistributedTest @pytest.mark.parametrize('seed_time', [[0, 0], [1, 2]]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) def test_dist_neighbor_sampler_edge_level_temporal_hetero( tmp_path, seed_time, temporal_strategy, ): seed_time = torch.tensor(seed_time) mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('', 0)) port = s.getsockname()[1] data = FakeHeteroDataset( num_graphs=1, avg_num_nodes=100, avg_degree=3, num_node_types=2, num_edge_types=4, edge_dim=2, )[0] for i, edge_type in enumerate(data.edge_types): data[edge_type].edge_time = torch.full( (data[edge_type].edge_index.size(1), ), i, dtype=torch.int64) procs = [ ProcArgs( target=dist_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 0, port, 'v0', seed_time, temporal_strategy, 'edge_time'), ), ProcArgs( target=dist_neighbor_sampler_temporal_hetero, args=(data, tmp_path, 1, port, 'v1', seed_time, temporal_strategy, 'edge_time'), ), ] partitioner = Partitioner(data, len(procs), tmp_path) partitioner.generate_partition() assert_run_mproc(mp_context, procs) ================================================ FILE: test/distributed/test_dist_utils.py ================================================ import torch from torch_geometric.distributed.utils import remove_duplicates from torch_geometric.sampler import SamplerOutput from torch_geometric.testing import onlyDistributedTest @onlyDistributedTest def test_remove_duplicates(): node = torch.tensor([0, 1, 2, 3]) out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 7, 3, 8]) out = SamplerOutput(out_node, None, None, None) src, node, _, _ = remove_duplicates(out, node) assert src.tolist() == [4, 5, 6, 7, 8] assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8] @onlyDistributedTest def test_remove_duplicates_disjoint(): node = torch.tensor([0, 1, 2, 3]) batch = torch.tensor([0, 1, 2, 3]) out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 6, 7, 3, 8]) out_batch = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) out = SamplerOutput(out_node, None, None, None, out_batch) src, node, src_batch, batch = remove_duplicates(out, node, batch, disjoint=True) assert src.tolist() == [4, 5, 6, 7, 8] assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8] assert src_batch.tolist() == [0, 1, 2, 3, 3] assert batch.tolist() == [0, 1, 2, 3, 0, 1, 2, 3, 3] ================================================ FILE: test/distributed/test_local_feature_store.py ================================================ import torch from torch_geometric.distributed import LocalFeatureStore from torch_geometric.testing import onlyDistributedTest @onlyDistributedTest def test_local_feature_store_global_id(): store = LocalFeatureStore() feat = torch.tensor([ [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0], [7.0, 7.0, 7.0], [8.0, 8.0, 8.0], ]) paper_global_id = torch.tensor([1, 2, 3, 5, 8, 4]) paper_feat = feat[paper_global_id] store.put_global_id(paper_global_id, group_name='paper') store.put_tensor(paper_feat, group_name='paper', attr_name='feat') out = store.get_tensor_from_global_id(group_name='paper', attr_name='feat', index=torch.tensor([3, 8, 4])) assert torch.equal(out, feat[torch.tensor([3, 8, 4])]) @onlyDistributedTest def test_local_feature_store_utils(): store = LocalFeatureStore() feat = torch.tensor([ [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0], [7.0, 7.0, 7.0], [8.0, 8.0, 8.0], ]) paper_global_id = torch.tensor([1, 2, 3, 5, 8, 4]) paper_feat = feat[paper_global_id] store.put_tensor(paper_feat, group_name='paper', attr_name='feat') assert len(store.get_all_tensor_attrs()) == 1 attr = store.get_all_tensor_attrs()[0] assert attr.group_name == 'paper' assert attr.attr_name == 'feat' assert attr.index is None assert store.get_tensor_size(attr) == (6, 3) @onlyDistributedTest def test_homogeneous_feature_store(): node_id = torch.randperm(6) x = torch.randn(6, 32) y = torch.randint(0, 2, (6, )) edge_id = torch.randperm(12) edge_attr = torch.randn(12, 16) store = LocalFeatureStore.from_data(node_id, x, y, edge_id, edge_attr) assert len(store.get_all_tensor_attrs()) == 3 attrs = store.get_all_tensor_attrs() assert attrs[0].group_name is None assert attrs[0].attr_name == 'x' assert attrs[1].group_name is None assert attrs[1].attr_name == 'y' assert attrs[2].group_name == (None, None) assert attrs[2].attr_name == 'edge_attr' assert torch.equal(store.get_global_id(group_name=None), node_id) assert torch.equal(store.get_tensor(group_name=None, attr_name='x'), x) assert torch.equal(store.get_tensor(group_name=None, attr_name='y'), y) assert torch.equal(store.get_global_id(group_name=(None, None)), edge_id) assert torch.equal( store.get_tensor(group_name=(None, None), attr_name='edge_attr'), edge_attr, ) @onlyDistributedTest def test_heterogeneous_feature_store(): node_type = 'paper' edge_type = ('paper', 'to', 'paper') node_id_dict = {node_type: torch.randperm(6)} x_dict = {node_type: torch.randn(6, 32)} y_dict = {node_type: torch.randint(0, 2, (6, ))} edge_id_dict = {edge_type: torch.randperm(12)} edge_attr_dict = {edge_type: torch.randn(12, 16)} store = LocalFeatureStore.from_hetero_data( node_id_dict, x_dict, y_dict, edge_id_dict, edge_attr_dict, ) assert len(store.get_all_tensor_attrs()) == 3 attrs = store.get_all_tensor_attrs() assert attrs[0].group_name == node_type assert attrs[0].attr_name == 'x' assert attrs[1].group_name == node_type assert attrs[1].attr_name == 'y' assert attrs[2].group_name == edge_type assert attrs[2].attr_name == 'edge_attr' assert torch.equal( store.get_global_id(group_name=node_type), node_id_dict[node_type], ) assert torch.equal( store.get_tensor(group_name=node_type, attr_name='x'), x_dict[node_type], ) assert torch.equal( store.get_tensor(group_name=node_type, attr_name='y'), y_dict[node_type], ) assert torch.equal( store.get_global_id(group_name=edge_type), edge_id_dict[edge_type], ) assert torch.equal( store.get_tensor(group_name=edge_type, attr_name='edge_attr'), edge_attr_dict[edge_type], ) ================================================ FILE: test/distributed/test_local_graph_store.py ================================================ import torch from torch_geometric.distributed import LocalGraphStore from torch_geometric.testing import get_random_edge_index, onlyDistributedTest @onlyDistributedTest def test_local_graph_store(): graph_store = LocalGraphStore() edge_index = get_random_edge_index(100, 100, 300) edge_id = torch.tensor([1, 2, 3, 5, 8, 4]) graph_store.put_edge_index( edge_index, edge_type=None, layout='coo', size=(100, 100), ) graph_store.put_edge_id( edge_id, edge_type=None, layout='coo', size=(100, 100), ) assert len(graph_store.get_all_edge_attrs()) == 1 edge_attr = graph_store.get_all_edge_attrs()[0] assert torch.equal(graph_store.get_edge_index(edge_attr), edge_index) assert torch.equal(graph_store.get_edge_id(edge_attr), edge_id) assert not graph_store.is_sorted graph_store.remove_edge_index(edge_attr) graph_store.remove_edge_id(edge_attr) assert len(graph_store.get_all_edge_attrs()) == 0 @onlyDistributedTest def test_homogeneous_graph_store(): edge_id = torch.randperm(300) edge_index = get_random_edge_index(100, 100, 300) edge_index[1] = torch.sort(edge_index[1])[0] graph_store = LocalGraphStore.from_data( edge_id, edge_index, num_nodes=100, is_sorted=True, ) assert len(graph_store.get_all_edge_attrs()) == 1 edge_attr = graph_store.get_all_edge_attrs()[0] assert edge_attr.edge_type is None assert edge_attr.layout.value == 'coo' assert edge_attr.is_sorted assert edge_attr.size == (100, 100) assert torch.equal( graph_store.get_edge_id(edge_type=None, layout='coo'), edge_id, ) assert torch.equal( graph_store.get_edge_index(edge_type=None, layout='coo'), edge_index, ) @onlyDistributedTest def test_heterogeneous_graph_store(): edge_type = ('paper', 'to', 'paper') edge_id_dict = {edge_type: torch.randperm(300)} edge_index = get_random_edge_index(100, 100, 300) edge_index[1] = torch.sort(edge_index[1])[0] edge_index_dict = {edge_type: edge_index} graph_store = LocalGraphStore.from_hetero_data( edge_id_dict, edge_index_dict, num_nodes_dict={'paper': 100}, is_sorted=True, ) assert len(graph_store.get_all_edge_attrs()) == 1 edge_attr = graph_store.get_all_edge_attrs()[0] assert edge_attr.edge_type == edge_type assert edge_attr.layout.value == 'coo' assert edge_attr.is_sorted assert edge_attr.size == (100, 100) assert torch.equal( graph_store.get_edge_id(edge_type, layout='coo'), edge_id_dict[edge_type], ) assert torch.equal( graph_store.get_edge_index(edge_type, layout='coo'), edge_index_dict[edge_type], ) @onlyDistributedTest def test_sorted_graph_store(): edge_index_sorted = torch.tensor([[1, 7, 5, 6, 1], [0, 0, 1, 1, 2]]) edge_id_sorted = torch.tensor([0, 1, 2, 3, 4]) edge_index = torch.tensor([[1, 5, 7, 1, 6], [0, 1, 0, 2, 1]]) edge_id = torch.tensor([0, 2, 1, 4, 3]) graph_store = LocalGraphStore.from_data( edge_id, edge_index, num_nodes=8, is_sorted=False, ) assert torch.equal( graph_store.get_edge_index(edge_type=None, layout='coo'), edge_index_sorted, ) assert torch.equal( graph_store.get_edge_id(edge_type=None, layout='coo'), edge_id_sorted, ) edge_type = ('paper', 'to', 'paper') edge_index_dict = {edge_type: edge_index} edge_id_dict = {edge_type: edge_id} graph_store = LocalGraphStore.from_hetero_data( edge_id_dict, edge_index_dict, num_nodes_dict={'paper': 8}, is_sorted=False, ) assert torch.equal( graph_store.get_edge_index(edge_type, layout='coo'), edge_index_sorted, ) assert torch.equal( graph_store.get_edge_id(edge_type, layout='coo'), edge_id_sorted, ) ================================================ FILE: test/distributed/test_partition.py ================================================ import os.path as osp import torch from torch_geometric.datasets import FakeDataset, FakeHeteroDataset from torch_geometric.distributed import ( LocalFeatureStore, LocalGraphStore, Partitioner, ) from torch_geometric.io import fs from torch_geometric.testing import onlyDistributedTest, withMETIS from torch_geometric.typing import EdgeTypeStr @withMETIS @onlyDistributedTest def test_partition_data(tmp_path): data = FakeDataset()[0] partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() node_map_path = osp.join(tmp_path, 'node_map.pt') assert osp.exists(node_map_path) node_map = fs.torch_load(node_map_path) assert node_map.numel() == data.num_nodes edge_map_path = osp.join(tmp_path, 'edge_map.pt') assert osp.exists(edge_map_path) edge_map = fs.torch_load(edge_map_path) assert edge_map.numel() == data.num_edges meta_path = osp.join(tmp_path, 'META.json') assert osp.exists(meta_path) graph0_path = osp.join(tmp_path, 'part_0', 'graph.pt') assert osp.exists(graph0_path) graph0 = fs.torch_load(graph0_path) assert len({'edge_id', 'row', 'col', 'size'} & set(graph0.keys())) == 4 graph1_path = osp.join(tmp_path, 'part_1', 'graph.pt') assert osp.exists(graph1_path) graph1 = fs.torch_load(graph1_path) assert len({'edge_id', 'row', 'col', 'size'} & set(graph1.keys())) == 4 node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt') assert osp.exists(node_feats0_path) node_feats0 = fs.torch_load(node_feats0_path) node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt') assert osp.exists(node_feats1_path) node_feats1 = fs.torch_load(node_feats1_path) assert (node_feats0['feats']['x'].size(0) + node_feats1['feats']['x'].size(0) == data.num_nodes) assert torch.equal(data.x[node_feats0['global_id']], node_feats0['feats']['x']) assert torch.equal(data.x[node_feats1['global_id']], node_feats1['feats']['x']) @withMETIS @onlyDistributedTest def test_partition_hetero_data(tmp_path): data = FakeHeteroDataset()[0] num_parts = 2 partitioner = Partitioner(data, num_parts=num_parts, root=tmp_path) partitioner.generate_partition() meta_path = osp.join(tmp_path, 'META.json') assert osp.exists(meta_path) for edge_type, num_edges in data.num_edges_dict.items(): assert len(edge_type) == 3 edge_name = EdgeTypeStr(edge_type) edge_map_path = osp.join(tmp_path, 'edge_map', f'{edge_name}.pt') assert osp.exists(edge_map_path) edge_map = fs.torch_load(edge_map_path) assert edge_map.numel() == num_edges for node_type, num_nodes in data.num_nodes_dict.items(): node_map_path = osp.join(tmp_path, 'node_map', f'{node_type}.pt') assert osp.exists(node_map_path) node_map = fs.torch_load(node_map_path) assert node_map.numel() == num_nodes for pid in range(num_parts): graph_path = osp.join(tmp_path, f'part_{pid}', 'graph.pt') assert osp.exists(graph_path) node_feats_path = osp.join(tmp_path, f'part_{pid}', 'node_feats.pt') assert osp.exists(node_feats_path) edge_feats_path = osp.join(tmp_path, f'part_{pid}', 'edge_feats.pt') assert osp.exists(edge_feats_path) @withMETIS @onlyDistributedTest def test_partition_data_temporal(tmp_path): data = FakeDataset()[0] data.time = torch.arange(data.num_nodes) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt') assert osp.exists(node_feats0_path) node_feats0 = fs.torch_load(node_feats0_path) node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt') assert osp.exists(node_feats1_path) node_feats1 = fs.torch_load(node_feats1_path) assert torch.equal(data.time, node_feats0['time']) assert torch.equal(data.time, node_feats1['time']) @withMETIS @onlyDistributedTest def test_partition_data_edge_level_temporal(tmp_path): data = FakeDataset(edge_dim=2)[0] data.edge_time = torch.arange(data.num_edges) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() edge_feats0_path = osp.join(tmp_path, 'part_0', 'edge_feats.pt') assert osp.exists(edge_feats0_path) edge_feats0 = fs.torch_load(edge_feats0_path) edge_feats1_path = osp.join(tmp_path, 'part_1', 'edge_feats.pt') assert osp.exists(edge_feats1_path) edge_feats1 = fs.torch_load(edge_feats1_path) assert torch.equal(data.edge_time[edge_feats0['global_id']], edge_feats0['edge_time']) assert torch.equal(data.edge_time[edge_feats1['global_id']], edge_feats1['edge_time']) @withMETIS @onlyDistributedTest def test_partition_hetero_data_temporal(tmp_path): data = FakeHeteroDataset()[0] for key in data.node_types: data[key].time = torch.arange(data[key].num_nodes) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt') assert osp.exists(node_feats0_path) node_feats0 = fs.torch_load(node_feats0_path) node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt') assert osp.exists(node_feats1_path) node_feats1 = fs.torch_load(node_feats1_path) for key in data.node_types: assert torch.equal(data[key].time, node_feats0[key]['time']) assert torch.equal(data[key].time, node_feats1[key]['time']) @withMETIS @onlyDistributedTest def test_partition_hetero_data_edge_level_temporal(tmp_path): data = FakeHeteroDataset(edge_dim=2)[0] for key in data.edge_types: data[key].edge_time = torch.arange(data[key].num_edges) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() edge_feats0_path = osp.join(tmp_path, 'part_0', 'edge_feats.pt') assert osp.exists(edge_feats0_path) edge_feats0 = fs.torch_load(edge_feats0_path) edge_feats1_path = osp.join(tmp_path, 'part_1', 'edge_feats.pt') assert osp.exists(edge_feats1_path) edge_feats1 = fs.torch_load(edge_feats1_path) for key in data.edge_types: assert torch.equal( data[key].edge_time[edge_feats0[key]['global_id']], edge_feats0[key]['edge_time'], ) assert torch.equal( data[key].edge_time[edge_feats1[key]['global_id']], edge_feats1[key]['edge_time'], ) @withMETIS @onlyDistributedTest def test_from_partition_data(tmp_path): data = FakeDataset()[0] partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() graph_store1 = LocalGraphStore.from_partition(tmp_path, pid=0) graph_store2 = LocalGraphStore.from_partition(tmp_path, pid=1) attr1 = graph_store1.get_all_edge_attrs()[0] (row1, col1) = graph_store1.get_edge_index(attr1) attr2 = graph_store2.get_all_edge_attrs()[0] (row2, col2) = graph_store2.get_edge_index(attr2) assert row1.size(0) + row2.size(0) == data.num_edges feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0) feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1) node_attr1 = feat_store1.get_all_tensor_attrs()[0] assert node_attr1.attr_name == 'x' x1 = feat_store1.get_tensor(node_attr1) id1 = feat_store1.get_global_id(node_attr1.group_name) node_attr2 = feat_store2.get_all_tensor_attrs()[0] assert node_attr2.attr_name == 'x' x2 = feat_store2.get_tensor(node_attr2) id2 = feat_store2.get_global_id(node_attr2.group_name) assert x1.size(0) + x2.size(0) == data.num_nodes assert torch.allclose(data.x[id1], x1) assert torch.allclose(data.x[id2], x2) @withMETIS @onlyDistributedTest def test_from_partition_hetero_data(tmp_path): data = FakeHeteroDataset()[0] partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() graph_store1 = LocalGraphStore.from_partition(tmp_path, pid=0) graph_store2 = LocalGraphStore.from_partition(tmp_path, pid=1) attrs1 = graph_store1.get_all_edge_attrs() attrs2 = graph_store2.get_all_edge_attrs() assert len(data.edge_types) == len(attrs1) == len(attrs2) node_types = set() for attr in attrs1: node_types.add(attr.edge_type[0]) node_types.add(attr.edge_type[2]) assert node_types == set(data.node_types) node_types = set() for attr in attrs2: node_types.add(attr.edge_type[0]) node_types.add(attr.edge_type[2]) assert node_types == set(data.node_types) @withMETIS @onlyDistributedTest def test_from_partition_temporal_data(tmp_path): data = FakeDataset()[0] data.time = torch.arange(data.num_nodes) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0) feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1) time_attr1 = feat_store1.get_all_tensor_attrs()[1] assert time_attr1.attr_name == 'time' time1 = feat_store1.get_tensor(time_attr1) time_attr2 = feat_store2.get_all_tensor_attrs()[1] assert time_attr2.attr_name == 'time' time2 = feat_store2.get_tensor(time_attr2) assert time1.size(0) == data.num_nodes assert time2.size(0) == data.num_nodes assert torch.equal(time1, data.time) assert torch.equal(time2, data.time) @withMETIS @onlyDistributedTest def test_from_partition_edge_level_temporal_data(tmp_path): data = FakeDataset(edge_dim=2)[0] data.edge_time = torch.arange(data.num_edges) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0) feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1) time_attr1 = feat_store1.get_all_tensor_attrs()[2] assert time_attr1.attr_name == 'edge_time' time1 = feat_store1.get_tensor(time_attr1) time_attr2 = feat_store2.get_all_tensor_attrs()[2] assert time_attr2.attr_name == 'edge_time' time2 = feat_store2.get_tensor(time_attr2) edge_id1 = feat_store1.get_global_id(group_name=(None, None)) edge_id2 = feat_store2.get_global_id(group_name=(None, None)) assert time1.size(0) + time2.size(0) == data.edge_index.size(1) assert torch.equal(data.edge_time[edge_id1], time1) assert torch.equal(data.edge_time[edge_id2], time2) @withMETIS @onlyDistributedTest def test_from_partition_hetero_temporal_data(tmp_path): data = FakeHeteroDataset()[0] for key in data.node_types: data[key].time = torch.arange(data[key].num_nodes) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0) feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1) attrs1 = feat_store1.get_all_tensor_attrs() attrs2 = feat_store2.get_all_tensor_attrs() times1 = { attr.group_name: feat_store1.get_tensor(attr) for attr in attrs1 if attr.attr_name == 'time' } times2 = { attr.group_name: feat_store2.get_tensor(attr) for attr in attrs2 if attr.attr_name == 'time' } for key in data.node_types: assert times1[key].size(0) == data[key].num_nodes assert times2[key].size(0) == data[key].num_nodes assert torch.equal(times1[key], data[key].time) assert torch.equal(times2[key], data[key].time) @withMETIS @onlyDistributedTest def test_from_partition_hetero_edge_level_temporal_data(tmp_path): data = FakeHeteroDataset(edge_dim=2)[0] for key in data.edge_types: data[key].edge_time = torch.arange(data[key].num_edges) partitioner = Partitioner(data, num_parts=2, root=tmp_path) partitioner.generate_partition() feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0) feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1) attrs1 = feat_store1.get_all_tensor_attrs() attrs2 = feat_store2.get_all_tensor_attrs() times1 = { attr.group_name: feat_store1.get_tensor(attr) for attr in attrs1 if attr.attr_name == 'edge_time' } times2 = { attr.group_name: feat_store2.get_tensor(attr) for attr in attrs2 if attr.attr_name == 'edge_time' } for key in data.edge_types: edge_id1 = feat_store1.get_global_id(group_name=key) edge_id2 = feat_store2.get_global_id(group_name=key) assert times1[key].size(0) + times2[key].size(0) == data[key].num_edges assert torch.equal(data[key].edge_time[edge_id1], times1[key]) assert torch.equal(data[key].edge_time[edge_id2], times2[key]) ================================================ FILE: test/distributed/test_rpc.py ================================================ import socket import torch import torch_geometric.distributed.rpc as rpc from torch_geometric.distributed import LocalFeatureStore from torch_geometric.distributed.dist_context import DistContext from torch_geometric.distributed.rpc import RPCRouter from torch_geometric.testing import onlyDistributedTest def run_rpc_feature_test( world_size: int, rank: int, feature: LocalFeatureStore, partition_book: torch.Tensor, master_port: int, ): # 1) Initialize the context info: current_ctx = DistContext( rank=rank, global_rank=rank, world_size=world_size, global_world_size=world_size, group_name='dist-feature-test', ) rpc.init_rpc( current_ctx=current_ctx, master_addr='localhost', master_port=master_port, ) # 2) Collect all workers: partition_to_workers = rpc.rpc_partition_to_workers( current_ctx, world_size, rank) assert partition_to_workers == [ ['dist-feature-test-0'], ['dist-feature-test-1'], ] # 3) Find the mapping between worker and partition ID: rpc_router = RPCRouter(partition_to_workers) assert rpc_router.get_to_worker(partition_idx=0) == 'dist-feature-test-0' assert rpc_router.get_to_worker(partition_idx=1) == 'dist-feature-test-1' meta = { 'edge_types': None, 'is_hetero': False, 'node_types': None, 'num_parts': 2, } feature.num_partitions = world_size feature.partition_idx = rank feature.node_feat_pb = partition_book feature.meta = meta feature.local_only = False feature.set_rpc_router(rpc_router) # Global node IDs: global_id0 = torch.arange(128 * 2) global_id1 = torch.arange(128 * 2) + 128 * 2 # Lookup the features from stores including locally and remotely: tensor0 = feature.lookup_features(global_id0) tensor1 = feature.lookup_features(global_id1) # Expected searched results: cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2]) cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)]) # Verify.. assert torch.allclose(cpu_tensor0, tensor0.wait()) assert torch.allclose(cpu_tensor1, tensor1.wait()) rpc.shutdown_rpc() assert rpc.rpc_is_initialized() is False @onlyDistributedTest def test_dist_feature_lookup(): cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2]) cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)]) # Global node IDs: global_id0 = torch.arange(128 * 2) global_id1 = torch.arange(128 * 2) + 128 * 2 # Set the partition book for two features (partition 0 and 1): partition_book = torch.cat([ torch.zeros(128 * 2, dtype=torch.long), torch.ones(128 * 2, dtype=torch.long), ]) # Put the test tensor into the different feature stores with IDs: feature0 = LocalFeatureStore() feature0.put_global_id(global_id0, group_name=None) feature0.put_tensor(cpu_tensor0, group_name=None, attr_name='x') feature1 = LocalFeatureStore() feature1.put_global_id(global_id1, group_name=None) feature1.put_tensor(cpu_tensor1, group_name=None, attr_name='x') mp_context = torch.multiprocessing.get_context('spawn') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.settimeout(1) s.bind(('127.0.0.1', 0)) port = s.getsockname()[1] w0 = mp_context.Process(target=run_rpc_feature_test, args=(2, 0, feature0, partition_book, port)) w1 = mp_context.Process(target=run_rpc_feature_test, args=(2, 1, feature1, partition_book, port)) w0.start() w1.start() w0.join() w1.join() ================================================ FILE: test/explain/algorithm/test_attention_explainer.py ================================================ import pytest import torch from torch_geometric.explain import ( AttentionExplainer, Explainer, HeteroExplanation, ) from torch_geometric.explain.config import ( ExplanationType, MaskType, ModelConfig, ModelMode, ) from torch_geometric.nn import ( AttentiveFP, GATConv, GATv2Conv, TransformerConv, to_hetero, ) from torch_geometric.nn.conv import HeteroConv class AttentionGNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GATConv(3, 16, heads=4) self.conv2 = GATv2Conv(4 * 16, 16, heads=2) self.conv3 = TransformerConv(2 * 16, 7, heads=1) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) x = self.conv3(x, edge_index) return x class HeteroAttentionGNN(torch.nn.Module): def __init__(self, metadata, model_config=None): super().__init__() self.model_config = model_config # Create a single BaseGNN that uses all three attention mechanisms class BaseGNN(torch.nn.Module): def __init__(self): super().__init__() # Use different attention mechanisms in sequence self.conv1 = GATConv((-1, -1), 16, heads=2, add_self_loops=False) self.conv2 = GATv2Conv((-1, -1), 16, heads=2, add_self_loops=False) self.conv3 = TransformerConv((-1, -1), 32, heads=1) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) return x # Convert to heterogeneous model with a single to_hetero call self.gnn = to_hetero(BaseGNN(), metadata, debug=False) # Output dimension based on model config out_channels = 7 if (model_config and model_config.mode == ModelMode.multiclass_classification) else 1 self.lin = torch.nn.Linear(32, out_channels) def forward(self, x_dict, edge_index_dict, **kwargs): # Process through the heterogeneous GNN out_dict = self.gnn(x_dict, edge_index_dict) # Project paper node embeddings for classification/regression x = self.lin(out_dict['paper']) # Apply appropriate output transformation based on model config if self.model_config: if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == 'probs': x = x.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == 'probs': x = x.softmax(dim=-1) elif self.model_config.return_type == 'log_probs': x = x.log_softmax(dim=-1) return x class HeteroConvAttentionGNN(torch.nn.Module): def __init__(self, metadata, model_config=None): super().__init__() self.model_config = model_config # Determine output channels based on model_config self.out_channels = 1 if (model_config and model_config.mode == ModelMode.multiclass_classification): self.out_channels = 7 # Initialize node type-specific layers self.lin_dict = torch.nn.ModuleDict() self.initialized = False # Create a dictionary of attention-based convolutions for each edge # type conv_dict = {} for edge_type in metadata[1]: # metadata[1] contains edge types src_type, _, dst_type = edge_type if src_type == dst_type: # For same node type, use GATConv with add_self_loops=False # Use concat=False to avoid dimension issues conv_dict[edge_type] = GATConv( (-1, -1), 32, heads=2, add_self_loops=False, concat=False) else: # For different node types, use GATv2Conv with # add_self_loops=False Use concat=False to avoid dimension # issues conv_dict[edge_type] = GATv2Conv( (-1, -1), 32, heads=2, add_self_loops=False, concat=False) # Create the HeteroConv layer self.conv = HeteroConv(conv_dict, aggr='sum') # Output layer will be initialized in forward pass self.out_lin = None def _initialize_layers(self, x_dict): """Initialize layers with correct dimensions when we first see the data. """ if not self.initialized: # Initialize input projections for node_type, x in x_dict.items(): in_channels = x.size(-1) self.lin_dict[node_type] = torch.nn.Linear(in_channels, 32).to(x.device) # Initialize output projection self.out_lin = torch.nn.Linear(32, self.out_channels).to( x_dict['paper'].device) self.initialized = True def forward(self, x_dict, edge_index_dict): # Initialize layers if not done yet self._initialize_layers(x_dict) # Apply node type-specific transformations h_dict = {} for node_type, x in x_dict.items(): h_dict[node_type] = self.lin_dict[node_type](x).relu_() # Apply heterogeneous convolution out_dict = self.conv(h_dict, edge_index_dict) # Final transformation for paper nodes out = self.out_lin(out_dict['paper']) # Apply transformations based on model_config if available if self.model_config: if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == 'probs': out = out.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == 'probs': out = out.softmax(dim=-1) elif self.model_config.return_type == 'log_probs': out = out.log_softmax(dim=-1) return out x = torch.randn(8, 3) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) edge_attr = torch.randn(edge_index.size(1), 5) batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2]) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) def test_attention_explainer(index, check_explanation): explainer = Explainer( model=AttentionGNN(), algorithm=AttentionExplainer(), explanation_type='model', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='raw', ), ) explanation = explainer(x, edge_index, index=index) check_explanation(explanation, None, explainer.edge_mask_type) @pytest.mark.parametrize('explanation_type', [e for e in ExplanationType]) @pytest.mark.parametrize('node_mask_type', [m for m in MaskType]) def test_attention_explainer_supports(explanation_type, node_mask_type): with pytest.raises(ValueError, match="not support the given explanation"): Explainer( model=AttentionGNN(), algorithm=AttentionExplainer(), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='raw', ), ) def test_attention_explainer_attentive_fp(check_explanation): model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2) explainer = Explainer( model=model, algorithm=AttentionExplainer(), explanation_type='model', edge_mask_type='object', model_config=dict( mode='binary_classification', task_level='node', return_type='raw', ), ) explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch) check_explanation(explanation, None, explainer.edge_mask_type) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) def test_attention_explainer_hetero(index, hetero_data, check_explanation_hetero): # Create model configuration model_config = ModelConfig( mode='multiclass_classification', task_level='node', return_type='raw', ) # Get metadata from hetero_data metadata = hetero_data.metadata() # Create the hetero attention model model = HeteroAttentionGNN(metadata, model_config) # Create the explainer explainer = Explainer( model=model, algorithm=AttentionExplainer(), explanation_type='model', edge_mask_type='object', model_config=model_config, ) # Generate the explanation explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, index=index, ) # Check that the explanation is correct assert isinstance(explanation, HeteroExplanation) check_explanation_hetero(explanation, None, explainer.edge_mask_type, hetero_data) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) def test_attention_explainer_hetero_conv(index, hetero_data, check_explanation_hetero): """Test AttentionExplainer with HeteroConv using attention-based layers.""" # Create model configuration model_config = ModelConfig( mode='multiclass_classification', task_level='node', return_type='raw', ) # Get metadata from hetero_data metadata = hetero_data.metadata() # Create the hetero conv attention model model = HeteroConvAttentionGNN(metadata, model_config) # Create the explainer explainer = Explainer( model=model, algorithm=AttentionExplainer(), explanation_type='model', edge_mask_type='object', model_config=model_config, ) # Generate the explanation explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, index=index, ) # Check that the explanation is correct assert isinstance(explanation, HeteroExplanation) check_explanation_hetero(explanation, None, explainer.edge_mask_type, hetero_data) ================================================ FILE: test/explain/algorithm/test_captum.py ================================================ import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.explain.algorithm.captum import to_captum_input from torch_geometric.nn import GAT, GCN, SAGEConv from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.models import to_captum_model from torch_geometric.testing import withPackage x = torch.randn(8, 3, requires_grad=True) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) GCN = GCN(3, 16, 2, 7, dropout=0.5) GAT = GAT(3, 16, 2, 7, heads=2, concat=False) mask_types = ['edge', 'node_and_edge', 'node'] methods = [ 'Saliency', 'InputXGradient', 'Deconvolution', 'FeatureAblation', 'ShapleyValueSampling', 'IntegratedGradients', 'GradientShap', 'Occlusion', 'GuidedBackprop', 'KernelShap', 'Lime', ] @pytest.mark.parametrize('mask_type', mask_types) @pytest.mark.parametrize('model', [GCN, GAT]) @pytest.mark.parametrize('output_idx', [None, 1]) def test_to_captum(model, mask_type, output_idx): captum_model = to_captum_model(model, mask_type=mask_type, output_idx=output_idx) pre_out = model(x, edge_index) if mask_type == 'node': mask = x * 0.0 out = captum_model(mask.unsqueeze(0), edge_index) elif mask_type == 'edge': mask = torch.ones(edge_index.shape[1], dtype=torch.float, requires_grad=True) * 0.5 out = captum_model(mask.unsqueeze(0), x, edge_index) elif mask_type == 'node_and_edge': node_mask = x * 0.0 edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float, requires_grad=True) * 0.5 out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0), edge_index) if output_idx is not None: assert out.shape == (1, 7) assert torch.any(out != pre_out[[output_idx]]) else: assert out.shape == (8, 7) assert torch.any(out != pre_out) @withPackage('captum', 'sklearn') @pytest.mark.parametrize('mask_type', mask_types) @pytest.mark.parametrize('method', methods) def test_captum_attribution_methods(mask_type, method): from captum import attr # noqa captum_model = to_captum_model(GCN, mask_type, 0) explainer = getattr(attr, method)(captum_model) data = Data(x, edge_index) input, additional_forward_args = to_captum_input(data.x, data.edge_index, mask_type) if mask_type == 'node': sliding_window_shapes = (3, 3) elif mask_type == 'edge': sliding_window_shapes = (5, ) elif mask_type == 'node_and_edge': sliding_window_shapes = ((3, 3), (5, )) if method == 'IntegratedGradients': attributions, delta = explainer.attribute( input, target=0, internal_batch_size=1, additional_forward_args=additional_forward_args, return_convergence_delta=True) elif method == 'GradientShap': attributions, delta = explainer.attribute( input, target=0, return_convergence_delta=True, baselines=input, n_samples=1, additional_forward_args=additional_forward_args) elif method == 'DeepLiftShap' or method == 'DeepLift': attributions, delta = explainer.attribute( input, target=0, return_convergence_delta=True, baselines=input, additional_forward_args=additional_forward_args) elif method == 'Occlusion': attributions = explainer.attribute( input, target=0, sliding_window_shapes=sliding_window_shapes, additional_forward_args=additional_forward_args) else: attributions = explainer.attribute( input, target=0, additional_forward_args=additional_forward_args) if mask_type == 'node': assert attributions[0].shape == (1, 8, 3) elif mask_type == 'edge': assert attributions[0].shape == (1, 14) else: assert attributions[0].shape == (1, 8, 3) assert attributions[1].shape == (1, 14) def test_custom_explain_message(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) conv = SAGEConv(8, 32) def explain_message(self, inputs, x_i, x_j): assert isinstance(self, SAGEConv) assert inputs.size() == (6, 8) assert inputs.size() == x_i.size() == x_j.size() assert torch.allclose(inputs, x_j) self.x_i = x_i self.x_j = x_j return inputs conv.explain_message = explain_message.__get__(conv, MessagePassing) conv.explain = True conv(x, edge_index) assert torch.allclose(conv.x_i, x[edge_index[1]]) assert torch.allclose(conv.x_j, x[edge_index[0]]) @withPackage('captum') @pytest.mark.parametrize('mask_type', ['node', 'edge', 'node_and_edge']) def test_to_captum_input(mask_type): num_nodes = x.shape[0] num_node_feats = x.shape[1] num_edges = edge_index.shape[1] # Check for Data: data = Data(x, edge_index) args = 'test_args' inputs, additional_forward_args = to_captum_input(data.x, data.edge_index, mask_type, args) if mask_type == 'node': assert len(inputs) == 1 assert inputs[0].shape == (1, num_nodes, num_node_feats) assert len(additional_forward_args) == 2 assert torch.allclose(additional_forward_args[0], edge_index) elif mask_type == 'edge': assert len(inputs) == 1 assert inputs[0].shape == (1, num_edges) assert inputs[0].sum() == num_edges assert len(additional_forward_args) == 3 assert torch.allclose(additional_forward_args[0], x) assert torch.allclose(additional_forward_args[1], edge_index) else: assert len(inputs) == 2 assert inputs[0].shape == (1, num_nodes, num_node_feats) assert inputs[1].shape == (1, num_edges) assert inputs[1].sum() == num_edges assert len(additional_forward_args) == 2 assert torch.allclose(additional_forward_args[0], edge_index) # Check for HeteroData: data = HeteroData() x2 = torch.rand(8, 3) data['paper'].x = x data['author'].x = x2 data['paper', 'to', 'author'].edge_index = edge_index data['author', 'to', 'paper'].edge_index = edge_index.flip([0]) inputs, additional_forward_args = to_captum_input(data.x_dict, data.edge_index_dict, mask_type, args) if mask_type == 'node': assert len(inputs) == 2 assert inputs[0].shape == (1, num_nodes, num_node_feats) assert inputs[1].shape == (1, num_nodes, num_node_feats) assert len(additional_forward_args) == 2 for key in data.edge_types: torch.allclose(additional_forward_args[0][key], data[key].edge_index) elif mask_type == 'edge': assert len(inputs) == 2 assert inputs[0].shape == (1, num_edges) assert inputs[1].shape == (1, num_edges) assert inputs[1].sum() == inputs[0].sum() == num_edges assert len(additional_forward_args) == 3 for key in data.node_types: torch.allclose(additional_forward_args[0][key], data[key].x) for key in data.edge_types: torch.allclose(additional_forward_args[1][key], data[key].edge_index) else: assert len(inputs) == 4 assert inputs[0].shape == (1, num_nodes, num_node_feats) assert inputs[1].shape == (1, num_nodes, num_node_feats) assert inputs[2].shape == (1, num_edges) assert inputs[3].shape == (1, num_edges) assert inputs[3].sum() == inputs[2].sum() == num_edges assert len(additional_forward_args) == 2 for key in data.edge_types: torch.allclose(additional_forward_args[0][key], data[key].edge_index) ================================================ FILE: test/explain/algorithm/test_captum_explainer.py ================================================ from typing import Optional import pytest import torch from torch_geometric.explain import Explainer, Explanation from torch_geometric.explain.algorithm import CaptumExplainer from torch_geometric.explain.config import ( MaskType, ModelConfig, ModelMode, ModelReturnType, ModelTaskLevel, ) from torch_geometric.nn import GCNConv, global_add_pool from torch_geometric.testing import withPackage methods = [ 'Saliency', 'InputXGradient', 'Deconvolution', 'ShapleyValueSampling', 'IntegratedGradients', 'GuidedBackprop', ] unsupported_methods = [ 'FeatureAblation', 'Occlusion', 'DeepLift', 'DeepLiftShap', 'GradientShap', 'KernelShap', 'Lime', ] class GCN(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config if model_config.mode == ModelMode.multiclass_classification: out_channels = 7 else: out_channels = 1 self.conv1 = GCNConv(3, 16) self.conv2 = GCNConv(16, out_channels) # Add unused parameter: self.param = torch.nn.Parameter(torch.empty(1)) def forward(self, x, edge_index, batch=None, edge_label_index=None): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) if self.model_config.task_level == ModelTaskLevel.graph: x = global_add_pool(x, batch) elif self.model_config.task_level == ModelTaskLevel.edge: assert edge_label_index is not None x = x[edge_label_index[0]] * x[edge_label_index[1]] if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.softmax(dim=-1) elif self.model_config.return_type == ModelReturnType.log_probs: x = x.log_softmax(dim=-1) return x node_mask_types = [MaskType.attributes, None] edge_mask_types = [MaskType.object, None] task_levels = [ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph] indices = [1, torch.arange(2)] def check_explanation( explanation: Explanation, node_mask_type: Optional[MaskType], edge_mask_type: Optional[MaskType], ): if node_mask_type == MaskType.attributes: assert explanation.node_mask.size() == explanation.x.size() elif node_mask_type is None: assert 'node_mask' not in explanation if edge_mask_type == MaskType.object: assert explanation.edge_mask.size() == (explanation.num_edges, ) elif edge_mask_type is None: assert 'edge_mask' not in explanation @withPackage('captum') @pytest.mark.parametrize('method', unsupported_methods) def test_unsupported_methods(method): model_config = ModelConfig(mode='regression', task_level='node') with pytest.raises(ValueError, match="does not support attribution"): Explainer( GCN(model_config), algorithm=CaptumExplainer(method), explanation_type='model', edge_mask_type='object', node_mask_type='attributes', model_config=model_config, ) @withPackage('captum') @pytest.mark.parametrize('method', ['IntegratedGradients']) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('task_level', task_levels) @pytest.mark.parametrize('index', indices) def test_captum_explainer_binary_classification( method, data, node_mask_type, edge_mask_type, task_level, index, ): if node_mask_type is None and edge_mask_type is None: return batch = torch.tensor([0, 0, 1, 1]) edge_label_index = torch.tensor([[0, 1, 2], [2, 3, 1]]) model_config = ModelConfig( mode='binary_classification', task_level=task_level, return_type='probs', ) explainer = Explainer( GCN(model_config), algorithm=CaptumExplainer(method), explanation_type='model', edge_mask_type=edge_mask_type, node_mask_type=node_mask_type, model_config=model_config, ) explanation = explainer( data.x, data.edge_index, index=index, batch=batch, edge_label_index=edge_label_index, ) check_explanation(explanation, node_mask_type, edge_mask_type) @withPackage('captum') @pytest.mark.parametrize('method', methods) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('task_level', task_levels) @pytest.mark.parametrize('index', indices) def test_captum_explainer_multiclass_classification( method, data, node_mask_type, edge_mask_type, task_level, index, ): if node_mask_type is None and edge_mask_type is None: return batch = torch.tensor([0, 0, 1, 1]) edge_label_index = torch.tensor([[0, 1, 2], [2, 3, 1]]) model_config = ModelConfig( mode='multiclass_classification', task_level=task_level, return_type='probs', ) explainer = Explainer( GCN(model_config), algorithm=CaptumExplainer(method), explanation_type='model', edge_mask_type=edge_mask_type, node_mask_type=node_mask_type, model_config=model_config, ) explanation = explainer( data.x, data.edge_index, index=index, batch=batch, edge_label_index=edge_label_index, ) check_explanation(explanation, node_mask_type, edge_mask_type) @withPackage('captum') @pytest.mark.parametrize( 'method', [m for m in methods if m != 'ShapleyValueSampling'], ) @pytest.mark.parametrize( 'node_mask_type', [nm for nm in node_mask_types if nm is not None], ) @pytest.mark.parametrize( 'edge_mask_type', [em for em in edge_mask_types if em is not None], ) @pytest.mark.parametrize('index', [1, torch.arange(2)]) def test_captum_hetero_data(method, node_mask_type, edge_mask_type, index, hetero_data, hetero_model): model_config = ModelConfig(mode='regression', task_level='node') explainer = Explainer( hetero_model(hetero_data.metadata()), algorithm=CaptumExplainer(method), edge_mask_type=edge_mask_type, node_mask_type=node_mask_type, model_config=model_config, explanation_type='model', ) explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict, index=index) explanation.validate(raise_on_error=True) @withPackage('captum') @pytest.mark.parametrize('node_mask_type', [ MaskType.object, MaskType.common_attributes, ]) def test_captum_explainer_supports(node_mask_type): model_config = ModelConfig( mode='multiclass_classification', task_level='node', return_type='probs', ) with pytest.raises(ValueError, match="not support the given explanation"): Explainer( GCN(model_config), algorithm=CaptumExplainer('IntegratedGradients'), edge_mask_type=MaskType.object, node_mask_type=node_mask_type, model_config=model_config, explanation_type='model', ) ================================================ FILE: test/explain/algorithm/test_captum_hetero.py ================================================ import pytest from torch_geometric.explain.algorithm.captum import ( CaptumHeteroModel, captum_output_to_dicts, to_captum_input, ) from torch_geometric.nn import to_captum_model from torch_geometric.testing import withPackage mask_types = [ 'node', 'edge', 'node_and_edge', ] methods = [ 'Saliency', 'InputXGradient', 'Deconvolution', 'FeatureAblation', 'ShapleyValueSampling', 'IntegratedGradients', 'GradientShap', 'Occlusion', 'GuidedBackprop', 'KernelShap', 'Lime', ] @withPackage('captum', 'sklearn') @pytest.mark.parametrize('mask_type', mask_types) @pytest.mark.parametrize('method', methods) def test_captum_attribution_methods_hetero(mask_type, method, hetero_data, hetero_model): from captum import attr # noqa data = hetero_data metadata = data.metadata() model = hetero_model(metadata) captum_model = to_captum_model(model, mask_type, 0, metadata) explainer = getattr(attr, method)(captum_model) assert isinstance(captum_model, CaptumHeteroModel) inputs, additional_forward_args = to_captum_input( data.x_dict, data.edge_index_dict, mask_type, 'additional_arg', ) if mask_type == 'node': sliding_window_shapes = ((3, 3), (3, 3)) elif mask_type == 'edge': sliding_window_shapes = ((5, ), (5, ), (5, )) else: sliding_window_shapes = ((3, 3), (3, 3), (5, ), (5, ), (5, )) if method == 'IntegratedGradients': attributions, delta = explainer.attribute( inputs, target=0, internal_batch_size=1, additional_forward_args=additional_forward_args, return_convergence_delta=True) elif method == 'GradientShap': attributions, delta = explainer.attribute( inputs, target=0, return_convergence_delta=True, baselines=inputs, n_samples=1, additional_forward_args=additional_forward_args) elif method == 'DeepLiftShap' or method == 'DeepLift': attributions, delta = explainer.attribute( inputs, target=0, return_convergence_delta=True, baselines=inputs, additional_forward_args=additional_forward_args) elif method == 'Occlusion': attributions = explainer.attribute( inputs, target=0, sliding_window_shapes=sliding_window_shapes, additional_forward_args=additional_forward_args) else: attributions = explainer.attribute( inputs, target=0, additional_forward_args=additional_forward_args) if mask_type == 'node': assert len(attributions) == len(metadata[0]) x_attr_dict, _ = captum_output_to_dicts(attributions, mask_type, metadata) for node_type in metadata[0]: num_nodes = data[node_type].num_nodes num_node_feats = data[node_type].x.shape[1] assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats) elif mask_type == 'edge': assert len(attributions) == len(metadata[1]) _, edge_attr_dict = captum_output_to_dicts(attributions, mask_type, metadata) for edge_type in metadata[1]: num_edges = data[edge_type].edge_index.shape[1] assert edge_attr_dict[edge_type].shape == (num_edges, ) else: assert len(attributions) == len(metadata[0]) + len(metadata[1]) x_attr_dict, edge_attr_dict = captum_output_to_dicts( attributions, mask_type, metadata) for edge_type in metadata[1]: num_edges = data[edge_type].edge_index.shape[1] assert edge_attr_dict[edge_type].shape == (num_edges, ) for node_type in metadata[0]: num_nodes = data[node_type].num_nodes num_node_feats = data[node_type].x.shape[1] assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats) ================================================ FILE: test/explain/algorithm/test_explain_algorithm_utils.py ================================================ import torch from torch_geometric.explain.algorithm.utils import ( clear_masks, set_hetero_masks, ) from torch_geometric.nn import GCNConv, HeteroConv, SAGEConv, to_hetero class HeteroModel(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = HeteroConv({ ('paper', 'to', 'paper'): GCNConv(-1, 32), ('author', 'to', 'paper'): SAGEConv((-1, -1), 32), ('paper', 'to', 'author'): SAGEConv((-1, -1), 32), }) self.conv2 = HeteroConv({ ('paper', 'to', 'paper'): GCNConv(-1, 32), ('author', 'to', 'paper'): SAGEConv((-1, -1), 32), ('paper', 'to', 'author'): SAGEConv((-1, -1), 32), }) class GraphSAGE(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv((-1, -1), 32) self.conv2 = SAGEConv((-1, -1), 32) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) def test_set_clear_mask(hetero_data): edge_mask_dict = { ('paper', 'to', 'paper'): torch.ones(200), ('author', 'to', 'paper'): torch.ones(100), ('paper', 'to', 'author'): torch.ones(100), } model = HeteroModel() set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict) for edge_type in hetero_data.edge_types: # Check that masks are correctly set: assert torch.allclose(model.conv1.convs[edge_type]._edge_mask, edge_mask_dict[edge_type]) assert model.conv1.convs[edge_type].explain clear_masks(model) for edge_type in hetero_data.edge_types: assert model.conv1.convs[edge_type]._edge_mask is None assert not model.conv1.convs[edge_type].explain model = to_hetero(GraphSAGE(), hetero_data.metadata(), debug=False) set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict) for edge_type in hetero_data.edge_types: # Check that masks are correctly set: str_edge_type = '__'.join(edge_type) assert torch.allclose(model.conv1[str_edge_type]._edge_mask, edge_mask_dict[edge_type]) assert model.conv1[str_edge_type].explain clear_masks(model) for edge_type in hetero_data.edge_types: str_edge_type = '__'.join(edge_type) assert model.conv1[str_edge_type]._edge_mask is None assert not model.conv1[str_edge_type].explain ================================================ FILE: test/explain/algorithm/test_gnn_explainer.py ================================================ import pytest import torch from torch_geometric.explain import Explainer, GNNExplainer, HeteroExplanation from torch_geometric.explain.config import ( ExplanationType, MaskType, ModelConfig, ModelMode, ModelReturnType, ModelTaskLevel, ) from torch_geometric.nn import ( AttentiveFP, ChebConv, GCNConv, TransformerConv, global_add_pool, ) class GNN(torch.nn.Module): def __init__(self, Conv, model_config: ModelConfig): super().__init__() self.model_config = model_config if model_config.mode == ModelMode.multiclass_classification: out_channels = 7 else: out_channels = 1 self.conv1 = Conv(3, 16) self.conv2 = Conv(16, out_channels) # Add unused parameter: self.param = torch.nn.Parameter(torch.empty(1)) def forward(self, x, edge_index, batch=None, edge_label_index=None): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) if self.model_config.task_level == ModelTaskLevel.graph: x = global_add_pool(x, batch) elif self.model_config.task_level == ModelTaskLevel.edge: assert edge_label_index is not None x = x[edge_label_index[0]] * x[edge_label_index[1]] if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.softmax(dim=-1) elif self.model_config.return_type == ModelReturnType.log_probs: x = x.log_softmax(dim=-1) return x node_mask_types = [ MaskType.object, MaskType.common_attributes, MaskType.attributes, ] edge_mask_types = [MaskType.object, None] explanation_types = [ExplanationType.model, ExplanationType.phenomenon] task_levels = [ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph] indices = [None, 2, torch.arange(3)] x = torch.randn(8, 3) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) edge_attr = torch.randn(edge_index.size(1), 5) batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2]) edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]]) @pytest.mark.parametrize('Conv', [GCNConv, TransformerConv]) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('explanation_type', explanation_types) @pytest.mark.parametrize('task_level', task_levels) @pytest.mark.parametrize('return_type', [ ModelReturnType.probs, ModelReturnType.raw, ]) @pytest.mark.parametrize('index', indices) def test_gnn_explainer_binary_classification( Conv, edge_mask_type, node_mask_type, explanation_type, task_level, return_type, index, check_explanation, ): model_config = ModelConfig( mode='binary_classification', task_level=task_level, return_type=return_type, ) model = GNN(Conv, model_config) target = None if explanation_type == ExplanationType.phenomenon: with torch.no_grad(): out = model(x, edge_index, batch, edge_label_index) if model_config.return_type == ModelReturnType.raw: target = (out > 0).long().view(-1) if model_config.return_type == ModelReturnType.probs: target = (out > 0.5).long().view(-1) explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=2), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, model_config=model_config, ) explanation = explainer( x, edge_index, target=target, index=index, batch=batch, edge_label_index=edge_label_index, ) assert explainer.algorithm.node_mask is None assert explainer.algorithm.edge_mask is None check_explanation(explanation, node_mask_type, edge_mask_type) @pytest.mark.parametrize('Conv', [GCNConv]) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('explanation_type', explanation_types) @pytest.mark.parametrize('task_level', task_levels) @pytest.mark.parametrize('return_type', [ ModelReturnType.log_probs, ModelReturnType.probs, ModelReturnType.raw, ]) @pytest.mark.parametrize('index', indices) def test_gnn_explainer_multiclass_classification( Conv, edge_mask_type, node_mask_type, explanation_type, task_level, return_type, index, check_explanation, ): model_config = ModelConfig( mode='multiclass_classification', task_level=task_level, return_type=return_type, ) model = GNN(Conv, model_config) target = None if explanation_type == ExplanationType.phenomenon: with torch.no_grad(): target = model(x, edge_index, batch, edge_label_index).argmax(-1) explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=2), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, model_config=model_config, ) explanation = explainer( x, edge_index, target=target, index=index, batch=batch, edge_label_index=edge_label_index, ) assert explainer.algorithm.node_mask is None assert explainer.algorithm.edge_mask is None check_explanation(explanation, node_mask_type, edge_mask_type) @pytest.mark.parametrize('Conv', [GCNConv]) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('explanation_type', explanation_types) @pytest.mark.parametrize('task_level', task_levels) @pytest.mark.parametrize('index', indices) def test_gnn_explainer_regression( Conv, edge_mask_type, node_mask_type, explanation_type, task_level, index, check_explanation, ): model_config = ModelConfig( mode='regression', task_level=task_level, ) model = GNN(Conv, model_config) target = None if explanation_type == ExplanationType.phenomenon: with torch.no_grad(): target = model(x, edge_index, batch, edge_label_index) explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=2), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, model_config=model_config, ) explanation = explainer( x, edge_index, target=target, index=index, batch=batch, edge_label_index=edge_label_index, ) assert explainer.algorithm.node_mask is None assert explainer.algorithm.edge_mask is None check_explanation(explanation, node_mask_type, edge_mask_type) def test_gnn_explainer_cheb_conv(check_explanation): explainer = Explainer( model=ChebConv(3, 1, K=2), algorithm=GNNExplainer(epochs=2), explanation_type='model', node_mask_type='object', edge_mask_type='object', model_config=dict( mode='binary_classification', task_level='node', return_type='raw', ), ) explanation = explainer(x, edge_index) assert explainer.algorithm.node_mask is None assert explainer.algorithm.edge_mask is None check_explanation(explanation, MaskType.object, MaskType.object) def test_gnn_explainer_attentive_fp(check_explanation): model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2) explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=2), explanation_type='model', node_mask_type='object', edge_mask_type='object', model_config=dict( mode='binary_classification', task_level='node', return_type='raw', ), ) explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch) assert explainer.algorithm.node_mask is None assert explainer.algorithm.edge_mask is None check_explanation(explanation, MaskType.object, MaskType.object) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('explanation_type', explanation_types) @pytest.mark.parametrize('task_level', task_levels) @pytest.mark.parametrize('return_type', [ ModelReturnType.log_probs, ModelReturnType.probs, ModelReturnType.raw, ]) @pytest.mark.parametrize('index', indices) def test_gnn_explainer_hetero( node_mask_type, edge_mask_type, explanation_type, task_level, return_type, index, hetero_data, hetero_model, check_explanation_hetero, ): if node_mask_type is None and edge_mask_type is None: return model_config = ModelConfig( mode='multiclass_classification', task_level=task_level, return_type=return_type, ) metadata = hetero_data.metadata() model = hetero_model(metadata, model_config) target = None if explanation_type == ExplanationType.phenomenon: with torch.no_grad(): target = model(hetero_data.x_dict, hetero_data.edge_index_dict).argmax(-1) explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=2), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, model_config=model_config, ) explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index, ) assert isinstance(explanation, HeteroExplanation) check_explanation_hetero(explanation, node_mask_type, edge_mask_type, hetero_data) ================================================ FILE: test/explain/algorithm/test_graphmask_explainer.py ================================================ import pytest import torch from torch_geometric.explain import Explainer, Explanation, GraphMaskExplainer from torch_geometric.explain.config import ( MaskType, ModelConfig, ModelMode, ModelReturnType, ModelTaskLevel, ) from torch_geometric.nn import GCNConv, global_add_pool class GCN(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config if model_config.mode == ModelMode.multiclass_classification: out_channels = 7 else: out_channels = 1 self.conv1 = GCNConv(3, 16) self.conv2 = GCNConv(16, out_channels) def forward(self, x, edge_index, batch=None, edge_label_index=None): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) if self.model_config.task_level == ModelTaskLevel.graph: x = global_add_pool(x, batch) elif self.model_config.task_level == ModelTaskLevel.edge: assert edge_label_index is not None x = x[edge_label_index[0]] * x[edge_label_index[1]] if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.softmax(dim=-1) elif self.model_config.return_type == ModelReturnType.log_probs: x = x.log_softmax(dim=-1) return x def check_explanation( edge_mask_type: MaskType, node_mask_type: MaskType, explanation: Explanation, ): if node_mask_type == MaskType.attributes: assert explanation.node_mask.size() == explanation.x.size() assert explanation.node_mask.min() >= 0 assert explanation.node_mask.max() <= 1 elif node_mask_type == MaskType.object: assert explanation.node_mask.size() == (explanation.num_nodes, 1) assert explanation.node_mask.min() >= 0 assert explanation.node_mask.max() <= 1 elif node_mask_type == MaskType.common_attributes: assert explanation.node_mask.size() == (1, explanation.num_features) assert explanation.node_mask.min() >= 0 assert explanation.node_mask.max() <= 1 if edge_mask_type == MaskType.object: assert explanation.edge_mask.size() == (explanation.num_edges, ) assert explanation.edge_mask.min() >= 0 assert explanation.edge_mask.max() <= 1 node_mask_types = [ MaskType.object, MaskType.common_attributes, MaskType.attributes, ] edge_mask_types = [ MaskType.object, None, ] x = torch.randn(8, 3) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2]) edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]]) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) @pytest.mark.parametrize('task_level', ['node', 'edge', 'graph']) @pytest.mark.parametrize('return_type', ['probs', 'raw']) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) def test_graph_mask_explainer_binary_classification( edge_mask_type, node_mask_type, explanation_type, task_level, return_type, index, ): model_config = ModelConfig( mode='binary_classification', task_level=task_level, return_type=return_type, ) model = GCN(model_config) target = None if explanation_type == 'phenomenon': with torch.no_grad(): out = model(x, edge_index, batch, edge_label_index) if model_config.return_type == ModelReturnType.raw: target = (out > 0).long().view(-1) if model_config.return_type == ModelReturnType.probs: target = (out > 0.5).long().view(-1) explainer = Explainer( model=model, algorithm=GraphMaskExplainer(2, epochs=5, log=False), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, model_config=model_config, ) explanation = explainer( x, edge_index, target=target, index=index, batch=batch, edge_label_index=edge_label_index, ) check_explanation(edge_mask_type, node_mask_type, explanation) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) @pytest.mark.parametrize('task_level', ['node', 'edge', 'graph']) @pytest.mark.parametrize('return_type', ['log_probs', 'probs', 'raw']) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) def test_graph_mask_explainer_multiclass_classification( edge_mask_type, node_mask_type, explanation_type, task_level, return_type, index, ): model_config = ModelConfig( mode='multiclass_classification', task_level=task_level, return_type=return_type, ) model = GCN(model_config) target = None if explanation_type == 'phenomenon': with torch.no_grad(): target = model(x, edge_index, batch, edge_label_index).argmax(-1) explainer = Explainer( model=model, algorithm=GraphMaskExplainer(2, epochs=5, log=False), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, model_config=model_config, ) explanation = explainer( x, edge_index, target=target, index=index, batch=batch, edge_label_index=edge_label_index, ) check_explanation(edge_mask_type, node_mask_type, explanation) @pytest.mark.parametrize('edge_mask_type', edge_mask_types) @pytest.mark.parametrize('node_mask_type', node_mask_types) @pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) @pytest.mark.parametrize('task_level', ['node', 'edge', 'graph']) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) def test_graph_mask_explainer_regression( edge_mask_type, node_mask_type, explanation_type, task_level, index, ): model_config = ModelConfig( mode='regression', task_level=task_level, ) model = GCN(model_config) target = None if explanation_type == 'phenomenon': with torch.no_grad(): target = model(x, edge_index, batch, edge_label_index) explainer = Explainer( model=model, algorithm=GraphMaskExplainer(2, epochs=5, log=False), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, model_config=model_config, ) explanation = explainer( x, edge_index, target=target, index=index, batch=batch, edge_label_index=edge_label_index, ) check_explanation(edge_mask_type, node_mask_type, explanation) ================================================ FILE: test/explain/algorithm/test_pg_explainer.py ================================================ import pytest import torch from torch_geometric.explain import Explainer, HeteroExplanation, PGExplainer from torch_geometric.explain.config import ( ExplanationType, ModelConfig, ModelMode, ModelReturnType, ModelTaskLevel, ) from torch_geometric.nn import GCNConv, global_add_pool from torch_geometric.testing import withCUDA class GCN(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config if model_config.mode == ModelMode.multiclass_classification: out_channels = 7 else: out_channels = 1 self.conv1 = GCNConv(3, 16) self.conv2 = GCNConv(16, out_channels) def forward(self, x, edge_index, batch=None, edge_label_index=None): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) if self.model_config.task_level == ModelTaskLevel.graph: x = global_add_pool(x, batch) return x @withCUDA @pytest.mark.parametrize('mode', [ ModelMode.binary_classification, ModelMode.multiclass_classification, ModelMode.regression, ]) def test_pg_explainer_node(device, check_explanation, mode): x = torch.randn(8, 3, device=device) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ], device=device) if mode == ModelMode.binary_classification: target = torch.randint(2, (x.size(0), ), device=device) elif mode == ModelMode.multiclass_classification: target = torch.randint(7, (x.size(0), ), device=device) elif mode == ModelMode.regression: target = torch.randn((x.size(0), 1), device=device) model_config = ModelConfig(mode=mode, task_level='node', return_type='raw') model = GCN(model_config).to(device) explainer = Explainer( model=model, algorithm=PGExplainer(epochs=2).to(device), explanation_type='phenomenon', edge_mask_type='object', model_config=model_config, ) with pytest.raises(ValueError, match="not yet fully trained"): explainer(x, edge_index, target=target) explainer.algorithm.reset_parameters() for epoch in range(2): for index in range(x.size(0)): loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, index=index) assert loss >= 0.0 explanation = explainer(x, edge_index, target=target, index=0) check_explanation(explanation, None, explainer.edge_mask_type) @withCUDA @pytest.mark.parametrize('mode', [ ModelMode.binary_classification, ModelMode.multiclass_classification, ModelMode.regression, ]) def test_pg_explainer_graph(device, check_explanation, mode): x = torch.randn(8, 3, device=device) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ], device=device) if mode == ModelMode.binary_classification: target = torch.randint(2, (1, ), device=device) elif mode == ModelMode.multiclass_classification: target = torch.randint(7, (1, ), device=device) elif mode == ModelMode.regression: target = torch.randn((1, 1), device=device) model_config = ModelConfig(mode=mode, task_level='graph', return_type='raw') model = GCN(model_config).to(device) explainer = Explainer( model=model, algorithm=PGExplainer(epochs=2).to(device), explanation_type='phenomenon', edge_mask_type='object', model_config=model_config, ) with pytest.raises(ValueError, match="not yet fully trained"): explainer(x, edge_index, target=target) explainer.algorithm.reset_parameters() for epoch in range(2): loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target) assert loss >= 0.0 explanation = explainer(x, edge_index, target=target) check_explanation(explanation, None, explainer.edge_mask_type) @withCUDA @pytest.mark.parametrize('mode', [ ModelMode.binary_classification, ModelMode.multiclass_classification, ModelMode.regression, ]) @pytest.mark.parametrize('task_level', [ ModelTaskLevel.node, ModelTaskLevel.graph, ]) def test_pg_explainer_hetero(device, hetero_data, hetero_model, check_explanation_hetero, mode, task_level): # Move data to device hetero_data = hetero_data.to(device) # Prepare target based on mode and task level index = 0 if task_level == ModelTaskLevel.node else None # Create model config model_config = ModelConfig( mode=mode, task_level=task_level, return_type=ModelReturnType.raw, ) # Create and initialize model metadata = hetero_data.metadata() model = hetero_model(metadata, model_config).to(device) with torch.no_grad(): raw_output = model(hetero_data.x_dict, hetero_data.edge_index_dict) if mode == ModelMode.multiclass_classification: # For multiclass, use class indices (long tensor) target = raw_output.argmax(dim=-1) elif mode == ModelMode.binary_classification: # For binary, convert to binary targets (long tensor) target = (raw_output > 0).long() else: # regression # For regression, use raw outputs (float tensor) target = raw_output.float() # Create explainer explainer = Explainer( model=model, algorithm=PGExplainer(epochs=2).to(device), explanation_type=ExplanationType.phenomenon, edge_mask_type='object', model_config=model_config, ) # Should raise error when not fully trained with pytest.raises(ValueError, match="not yet fully trained"): explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index if task_level == ModelTaskLevel.node else None, ) # Train the explainer explainer.algorithm.reset_parameters() for epoch in range(2): if task_level == ModelTaskLevel.node: # For node-level, train on a single node loss = explainer.algorithm.train( epoch, model, hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index, ) else: # For graph-level, train on the whole graph loss = explainer.algorithm.train( epoch, model, hetero_data.x_dict, hetero_data.edge_index_dict, target=target, ) assert isinstance(loss, float) # Get explanation explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index if task_level == ModelTaskLevel.node else None, ) # Check if the explanation is valid assert isinstance(explanation, HeteroExplanation) # Run through the standard explanation checker check_explanation_hetero(explanation, None, explainer.edge_mask_type, hetero_data) def test_pg_explainer_supports(): # Test unsupported model task level: with pytest.raises(ValueError, match="not support the given explanation"): model_config = ModelConfig( mode='binary_classification', task_level='edge', return_type='raw', ) Explainer( model=GCN(model_config), algorithm=PGExplainer(epochs=2), explanation_type='phenomenon', edge_mask_type='object', model_config=model_config, ) # Test unsupported explanation type: with pytest.raises(ValueError, match="not support the given explanation"): model_config = ModelConfig( mode='binary_classification', task_level='node', return_type='raw', ) Explainer( model=GCN(model_config), algorithm=PGExplainer(epochs=2), explanation_type='model', edge_mask_type='object', model_config=model_config, ) # Test unsupported node mask: with pytest.raises(ValueError, match="not support the given explanation"): model_config = ModelConfig( mode='binary_classification', task_level='node', return_type='raw', ) Explainer( model=GCN(model_config), algorithm=PGExplainer(epochs=2), explanation_type='model', node_mask_type='object', edge_mask_type='object', model_config=model_config, ) @withCUDA @pytest.mark.parametrize('conv_type', ['HGTConv', 'HANConv']) @pytest.mark.parametrize('mode', [ ModelMode.binary_classification, ModelMode.multiclass_classification, ]) @pytest.mark.parametrize('task_level', [ ModelTaskLevel.node, ModelTaskLevel.graph, ]) def test_pg_explainer_native_hetero(device, hetero_data, hetero_model_native, check_explanation_hetero, conv_type, mode, task_level): """Test PGExplainer with native heterogeneous GNNs (not created by to_hetero). """ # Move data to device hetero_data = hetero_data.to(device) # Create model config model_config = ModelConfig( mode=mode, task_level=task_level, return_type=ModelReturnType.raw, ) # Create and initialize model metadata = hetero_data.metadata() model = hetero_model_native(metadata, model_config, conv_type=conv_type).to(device) # Generate target with torch.no_grad(): raw_output = model(hetero_data.x_dict, hetero_data.edge_index_dict) if mode == ModelMode.multiclass_classification: # For multiclass, use class indices (long tensor) target = raw_output.argmax(dim=-1) else: # binary classification # For binary, convert to binary targets (long tensor) target = (raw_output > 0).long() # Setup index for node-level tasks index = 0 if task_level == ModelTaskLevel.node else None # Create explainer explainer = Explainer( model=model, algorithm=PGExplainer(epochs=2).to(device), explanation_type=ExplanationType.phenomenon, edge_mask_type='object', model_config=model_config, ) # Should raise error when not fully trained with pytest.raises(ValueError, match="not yet fully trained"): explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index if task_level == ModelTaskLevel.node else None, ) # Train the explainer explainer.algorithm.reset_parameters() for epoch in range(2): if task_level == ModelTaskLevel.node: # For node-level, train on a single node loss = explainer.algorithm.train( epoch, model, hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index, ) else: # For graph-level, train on the whole graph loss = explainer.algorithm.train( epoch, model, hetero_data.x_dict, hetero_data.edge_index_dict, target=target, ) assert isinstance(loss, float) # Get explanation explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index if task_level == ModelTaskLevel.node else None, ) # Check if the explanation is valid assert isinstance(explanation, HeteroExplanation) # Run through the standard explanation checker check_explanation_hetero(explanation, None, explainer.edge_mask_type, hetero_data) @withCUDA @pytest.mark.parametrize('mode', [ ModelMode.binary_classification, ModelMode.multiclass_classification, ]) @pytest.mark.parametrize('task_level', [ ModelTaskLevel.node, ModelTaskLevel.graph, ]) def test_pg_explainer_hetero_conv(device, hetero_data, hetero_model_custom, check_explanation_hetero, mode, task_level): """Test PGExplainer with the built-in HeteroConv model.""" # Move data to device hetero_data = hetero_data.to(device) # Create model config model_config = ModelConfig( mode=mode, task_level=task_level, return_type=ModelReturnType.raw, ) # Create and initialize model metadata = hetero_data.metadata() model = hetero_model_custom(metadata, model_config).to(device) # Generate target with torch.no_grad(): raw_output = model(hetero_data.x_dict, hetero_data.edge_index_dict) if mode == ModelMode.multiclass_classification: # For multiclass, use class indices (long tensor) target = raw_output.argmax(dim=-1) else: # binary classification # For binary, convert to binary targets (long tensor) target = (raw_output > 0).long() # Setup index for node-level tasks index = 0 if task_level == ModelTaskLevel.node else None # Create explainer explainer = Explainer( model=model, algorithm=PGExplainer(epochs=2).to(device), explanation_type=ExplanationType.phenomenon, edge_mask_type='object', model_config=model_config, ) # Should raise error when not fully trained with pytest.raises(ValueError, match="not yet fully trained"): explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index if task_level == ModelTaskLevel.node else None, ) # Train the explainer explainer.algorithm.reset_parameters() for epoch in range(2): if task_level == ModelTaskLevel.node: # For node-level, train on a single node loss = explainer.algorithm.train( epoch, model, hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index, ) else: # For graph-level, train on the whole graph loss = explainer.algorithm.train( epoch, model, hetero_data.x_dict, hetero_data.edge_index_dict, target=target, ) assert isinstance(loss, float) # Get explanation explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target, index=index if task_level == ModelTaskLevel.node else None, ) # Check if the explanation is valid assert isinstance(explanation, HeteroExplanation) # Run through the standard explanation checker check_explanation_hetero(explanation, None, explainer.edge_mask_type, hetero_data) ================================================ FILE: test/explain/conftest.py ================================================ from typing import Optional import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.explain import Explanation, HeteroExplanation from torch_geometric.explain.config import ( MaskType, ModelConfig, ModelMode, ModelReturnType, ModelTaskLevel, ) from torch_geometric.nn import ( HANConv, HGTConv, SAGEConv, global_add_pool, to_hetero, ) from torch_geometric.nn.conv import GCNConv, HeteroConv from torch_geometric.testing import get_random_edge_index @pytest.fixture() def data(): return Data( x=torch.randn(4, 3), edge_index=get_random_edge_index(4, 4, num_edges=6), edge_attr=torch.randn(6, 3), ) @pytest.fixture() def hetero_data(): data = HeteroData() data['paper'].x = torch.randn(8, 16) data['author'].x = torch.randn(10, 8) data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10) data['paper', 'paper'].edge_attr = torch.randn(10, 16) data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10) data['paper', 'author'].edge_attr = torch.randn(10, 8) data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10) data['author', 'paper'].edge_attr = torch.randn(10, 8) return data @pytest.fixture() def hetero_model(): return HeteroSAGE @pytest.fixture() def hetero_model_custom(): return HeteroConvModel class GraphSAGE(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv((-1, -1), 32) self.conv2 = SAGEConv((-1, -1), 32) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) class HeteroSAGE(torch.nn.Module): def __init__(self, metadata, model_config: Optional[ModelConfig] = None): super().__init__() self.model_config = model_config self.graph_sage = to_hetero(GraphSAGE(), metadata, debug=False) # Determine output channels based on model_config out_channels = 1 if (model_config and model_config.mode == ModelMode.multiclass_classification): out_channels = 7 self.lin = torch.nn.Linear(32, out_channels) def forward(self, x_dict, edge_index_dict, additonal_arg=None) -> torch.Tensor: x = self.lin(self.graph_sage(x_dict, edge_index_dict)['paper']) # Apply transformations based on model_config if available if self.model_config: if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.softmax(dim=-1) elif (self.model_config.return_type == ModelReturnType.log_probs): x = x.log_softmax(dim=-1) return x @pytest.fixture() def check_explanation(): def _check_explanation( explanation: Explanation, node_mask_type: Optional[MaskType], edge_mask_type: Optional[MaskType], ): if node_mask_type == MaskType.attributes: assert explanation.node_mask.size() == explanation.x.size() assert explanation.node_mask.min() >= 0 assert explanation.node_mask.max() <= 1 elif node_mask_type == MaskType.object: assert explanation.node_mask.size() == (explanation.num_nodes, 1) assert explanation.node_mask.min() >= 0 assert explanation.node_mask.max() <= 1 elif node_mask_type == MaskType.common_attributes: assert explanation.node_mask.size() == (1, explanation.x.size(-1)) assert explanation.node_mask.min() >= 0 assert explanation.node_mask.max() <= 1 elif node_mask_type is None: assert 'node_mask' not in explanation if edge_mask_type == MaskType.object: assert explanation.edge_mask.size() == (explanation.num_edges, ) assert explanation.edge_mask.min() >= 0 assert explanation.edge_mask.max() <= 1 elif edge_mask_type is None: assert 'edge_mask' not in explanation return _check_explanation @pytest.fixture() def check_explanation_hetero(): def _check_explanation_hetero( explanation: HeteroExplanation, node_mask_type: Optional[MaskType], edge_mask_type: Optional[MaskType], hetero_data: HeteroData, ): # Validate the explanation explanation.validate(raise_on_error=True) # Check node masks for different node types if node_mask_type is not None: for node_type in hetero_data.node_types: assert explanation[node_type].get('node_mask') is not None assert explanation[node_type].get('node_mask').min() >= 0 assert explanation[node_type].get('node_mask').max() <= 1 # Check dimensions based on mask type if node_mask_type == MaskType.attributes: mask = explanation[node_type].get('node_mask') assert mask.size() == hetero_data.x_dict[node_type].size() elif node_mask_type == MaskType.object: mask = explanation[node_type].get('node_mask') assert mask.size() == ( hetero_data.x_dict[node_type].size(0), 1) elif node_mask_type == MaskType.common_attributes: mask = explanation[node_type].get('node_mask') assert mask.size() == ( 1, hetero_data.x_dict[node_type].size(1)) # Check edge masks for different edge types if edge_mask_type is not None: for edge_type in hetero_data.edge_types: assert explanation[edge_type].get('edge_mask') is not None assert explanation[edge_type].get('edge_mask').min() >= 0 assert explanation[edge_type].get('edge_mask').max() <= 1 return _check_explanation_hetero class NativeHeteroGNN(torch.nn.Module): def __init__(self, metadata, model_config: Optional[ModelConfig] = None, conv_type: str = 'HGTConv', hidden_channels: int = 32): super().__init__() self.model_config = model_config self.conv_type = conv_type self.hidden_channels = hidden_channels self.metadata = metadata # Determine output size based on model_config self.out_channels = 1 if (model_config and model_config.mode == ModelMode.multiclass_classification): self.out_channels = 7 # Initialize dictionaries to store the layers self.lin_dict = torch.nn.ModuleDict() self.initialized = False # Heterogeneous convolution layer if conv_type == 'HGTConv': self.conv = HGTConv(hidden_channels, hidden_channels, metadata, heads=2) elif conv_type == 'HANConv': self.conv = HANConv(hidden_channels, hidden_channels, metadata, heads=2) else: raise ValueError(f"Unsupported conv_type: {conv_type}") # Output projection will be initialized in forward pass self.out_lin = None def _initialize_layers(self, x_dict): """Initialize layers with correct dimensions when we first see the data. """ if not self.initialized: # Initialize input projections for node_type, x in x_dict.items(): in_channels = x.size(-1) self.lin_dict[node_type] = torch.nn.Linear( in_channels, self.hidden_channels).to(x.device) # Initialize output projection self.out_lin = torch.nn.Linear(self.hidden_channels, self.out_channels).to(x.device) self.initialized = True def forward(self, x_dict, edge_index_dict): # Initialize layers if not done yet self._initialize_layers(x_dict) # Apply input projections x_dict = { node_type: self.lin_dict[node_type](x).relu_() for node_type, x in x_dict.items() } # Apply heterogeneous convolution x_dict = self.conv(x_dict, edge_index_dict) # Get paper node features for prediction x = x_dict['paper'] # Apply output projection out = self.out_lin(x) # For graph-level tasks, perform global pooling if (self.model_config and self.model_config.task_level == ModelTaskLevel.graph): # Since we don't have batch information in the fixture, # we'll treat the whole graph as a single graph batch_size = x.size(0) batch = torch.zeros(batch_size, dtype=torch.long, device=x.device) out = global_add_pool(out, batch) return out @pytest.fixture() def hetero_model_native(): return NativeHeteroGNN class HeteroConvModel(torch.nn.Module): def __init__(self, metadata, model_config: Optional[ModelConfig] = None): super().__init__() self.model_config = model_config # Create a HeteroConv model conv_dict = {} for edge_type in metadata[1]: # metadata[1] contains edge types src_type, _, dst_type = edge_type if src_type == dst_type: conv_dict[edge_type] = GCNConv(-1, 32) else: # For different node types, use SAGEConv conv_dict[edge_type] = SAGEConv((-1, -1), 32) self.conv = HeteroConv(conv_dict, aggr='sum') # Determine output channels based on model_config out_channels = 1 if (model_config and model_config.mode == ModelMode.multiclass_classification): out_channels = 7 # Output layer self.out_lin = torch.nn.Linear(32, out_channels) def forward(self, x_dict, edge_index_dict): # Apply heterogeneous convolution out_dict = self.conv(x_dict, edge_index_dict) # Final transformation for paper nodes out = self.out_lin(out_dict['paper']) # Apply transformations based on model_config if available if self.model_config: if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == ModelReturnType.probs: out = out.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == ModelReturnType.probs: out = out.softmax(dim=-1) elif (self.model_config.return_type == ModelReturnType.log_probs): out = out.log_softmax(dim=-1) return out ================================================ FILE: test/explain/metric/test_basic_metric.py ================================================ import warnings import torch from torch_geometric.explain import groundtruth_metrics from torch_geometric.testing import withPackage @withPackage('torchmetrics>=0.10.0') def test_groundtruth_metrics(): pred_mask = torch.rand(10) target_mask = torch.rand(10) accuracy, recall, precision, f1_score, auroc = groundtruth_metrics( pred_mask, target_mask) assert accuracy >= 0.0 and accuracy <= 1.0 assert recall >= 0.0 and recall <= 1.0 assert precision >= 0.0 and precision <= 1.0 assert f1_score >= 0.0 and f1_score <= 1.0 assert auroc >= 0.0 and auroc <= 1.0 @withPackage('torchmetrics>=0.10.0') def test_perfect_groundtruth_metrics(): pred_mask = target_mask = torch.rand(10) accuracy, recall, precision, f1_score, auroc = groundtruth_metrics( pred_mask, target_mask) assert round(accuracy, 6) == 1.0 assert round(recall, 6) == 1.0 assert round(precision, 6) == 1.0 assert round(f1_score, 6) == 1.0 assert round(auroc, 6) == 1.0 @withPackage('torchmetrics>=0.10.0') def test_groundtruth_true_negative(): warnings.filterwarnings('ignore', '.*No positive samples in targets.*') pred_mask = target_mask = torch.zeros(10) accuracy, recall, precision, f1_score, auroc = groundtruth_metrics( pred_mask, target_mask) assert round(accuracy, 6) == 1.0 assert round(recall, 6) == 0.0 assert round(precision, 6) == 0.0 assert round(f1_score, 6) == 0.0 assert round(auroc, 6) == 0.0 ================================================ FILE: test/explain/metric/test_faithfulness.py ================================================ import pytest import torch from torch_geometric.explain import ( DummyExplainer, Explainer, ModelConfig, unfaithfulness, ) class DummyModel(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config def forward(self, x, edge_index): if self.model_config.return_type.value == 'probs': x = x.softmax(dim=-1) elif self.model_config.return_type.value == 'log_probs': x = x.log_softmax(dim=-1) return x @pytest.mark.parametrize('top_k', [None, 2]) @pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) @pytest.mark.parametrize('node_mask_type', ['common_attributes', 'attributes']) @pytest.mark.parametrize('return_type', ['raw', 'probs', 'log_probs']) def test_unfaithfulness(top_k, explanation_type, node_mask_type, return_type): x = torch.randn(8, 4) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) model_config = ModelConfig( mode='multiclass_classification', task_level='node', return_type=return_type, ) explainer = Explainer( DummyModel(model_config), algorithm=DummyExplainer(), explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type='object', model_config=model_config, ) target = None if explanation_type == 'phenomenon': target = torch.randint(0, x.size(1), (x.size(0), )) explanation = explainer(x, edge_index, target=target, index=torch.arange(4)) metric = unfaithfulness(explainer, explanation, top_k) assert metric >= 0. and metric <= 1. ================================================ FILE: test/explain/metric/test_fidelity.py ================================================ import pytest import torch from torch_geometric.explain import ( DummyExplainer, Explainer, characterization_score, fidelity, fidelity_curve_auc, ) class DummyModel(torch.nn.Module): def forward(self, x, edge_index): return x @pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) def test_fidelity(explanation_type): x = torch.randn(8, 4) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) explainer = Explainer( DummyModel(), algorithm=DummyExplainer(), explanation_type=explanation_type, node_mask_type='object', edge_mask_type='object', model_config=dict( mode='multiclass_classification', return_type='raw', task_level='node', ), ) target = None if explanation_type == 'phenomenon': target = torch.randint(0, x.size(1), (x.size(0), )) explanation = explainer(x, edge_index, target=target, index=torch.arange(4)) pos_fidelity, neg_fidelity = fidelity(explainer, explanation) assert pos_fidelity == 0.0 and neg_fidelity == 0.0 def test_characterization_score(): out = characterization_score( pos_fidelity=torch.tensor([1.0, 0.6, 0.5, 1.0]), neg_fidelity=torch.tensor([0.0, 0.2, 0.5, 1.0]), pos_weight=0.2, neg_weight=0.8, ) assert out.tolist() == [1.0, 0.75, 0.5, 0.0] def test_fidelity_curve_auc(): pos_fidelity = torch.tensor([1.0, 1.0, 0.5, 1.0]) neg_fidelity = torch.tensor([0.0, 0.5, 0.5, 0.9]) x = torch.tensor([0, 1, 2, 3]) out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4) assert out == 8.5 x = torch.tensor([10, 11, 12, 13]) out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4) assert out == 8.5 x = torch.tensor([0, 1, 2, 5]) out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4) assert out == 19.5 ================================================ FILE: test/explain/test_explain_config.py ================================================ import pytest from torch_geometric.explain.config import ExplainerConfig, ThresholdConfig @pytest.mark.parametrize('threshold_pairs', [ ('hard', 0.5, True), ('hard', 1.1, False), ('hard', -1, False), ('topk', 1, True), ('topk', 0, False), ('topk', -1, False), ('topk', 0.5, False), ('invalid', None, False), ('hard', None, False), ]) def test_threshold_config(threshold_pairs): threshold_type, threshold_value, valid = threshold_pairs if valid: threshold = ThresholdConfig(threshold_type, threshold_value) assert threshold.type.value == threshold_type assert threshold.value == threshold_value else: with pytest.raises(ValueError): ThresholdConfig(threshold_type, threshold_value) @pytest.mark.parametrize('explanation_type', [ 'model', 'phenomenon', 'invalid', ]) @pytest.mark.parametrize('mask_type', [ None, 'object', 'common_attributes', 'attributes', 'invalid', ]) def test_configuration_config(explanation_type, mask_type): if (explanation_type != 'invalid' and mask_type is not None and mask_type != 'invalid'): config = ExplainerConfig(explanation_type, mask_type, None) assert config.explanation_type.value == explanation_type assert config.node_mask_type.value == mask_type else: with pytest.raises(ValueError): ExplainerConfig(explanation_type, mask_type, mask_type) ================================================ FILE: test/explain/test_explainer.py ================================================ import pytest import torch from torch_geometric.explain import DummyExplainer, Explainer, Explanation from torch_geometric.explain.config import ExplanationType class DummyModel(torch.nn.Module): def forward(self, x, edge_index): return x.mean().view(-1) def test_get_prediction(data): model = DummyModel() assert model.training explainer = Explainer( model, algorithm=DummyExplainer(), explanation_type='phenomenon', node_mask_type='object', model_config=dict( mode='regression', task_level='graph', ), ) pred = explainer.get_prediction(data.x, data.edge_index) assert model.training assert pred.size() == (1, ) @pytest.mark.parametrize('target', [None, torch.randn(2)]) @pytest.mark.parametrize('explanation_type', [x for x in ExplanationType]) def test_forward(data, target, explanation_type): model = DummyModel() assert model.training explainer = Explainer( model, algorithm=DummyExplainer(), explanation_type=explanation_type, node_mask_type='attributes', model_config=dict( mode='regression', task_level='graph', ), ) if target is None and explanation_type == ExplanationType.phenomenon: with pytest.raises(ValueError): explainer(data.x, data.edge_index, target=target) else: explanation = explainer( data.x, data.edge_index, target=target if explanation_type == ExplanationType.phenomenon else None, ) assert model.training assert isinstance(explanation, Explanation) assert 'x' in explanation assert 'edge_index' in explanation assert 'target' in explanation assert 'node_mask' in explanation.available_explanations assert explanation.node_mask.size() == data.x.size() @pytest.mark.parametrize('threshold_value', [0.2, 0.5, 0.8]) @pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) def test_hard_threshold(data, threshold_value, node_mask_type): explainer = Explainer( DummyModel(), algorithm=DummyExplainer(), explanation_type='model', node_mask_type=node_mask_type, edge_mask_type='object', model_config=dict( mode='regression', task_level='graph', ), threshold_config=('hard', threshold_value), ) explanation = explainer(data.x, data.edge_index) assert 'node_mask' in explanation.available_explanations assert 'edge_mask' in explanation.available_explanations for key in explanation.available_explanations: mask = explanation[key] assert set(mask.unique().tolist()).issubset({0, 1}) @pytest.mark.parametrize('threshold_value', [1, 5, 10]) @pytest.mark.parametrize('threshold_type', ['topk', 'topk_hard']) @pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) def test_topk_threshold(data, threshold_value, threshold_type, node_mask_type): explainer = Explainer( DummyModel(), algorithm=DummyExplainer(), explanation_type='model', node_mask_type=node_mask_type, edge_mask_type='object', model_config=dict( mode='regression', task_level='graph', ), threshold_config=(threshold_type, threshold_value), ) explanation = explainer(data.x, data.edge_index) assert 'node_mask' in explanation.available_explanations assert 'edge_mask' in explanation.available_explanations for key in explanation.available_explanations: mask = explanation[key] if threshold_type == 'topk': assert (mask > 0).sum() == min(mask.numel(), threshold_value) assert ((mask == 0).sum() == mask.numel() - min(mask.numel(), threshold_value)) else: assert (mask == 1).sum() == min(mask.numel(), threshold_value) assert ((mask == 0).sum() == mask.numel() - min(mask.numel(), threshold_value)) ================================================ FILE: test/explain/test_explanation.py ================================================ import os.path as osp from typing import Optional, Union import pytest import torch from torch_geometric.data import Data from torch_geometric.explain import Explanation from torch_geometric.explain.config import MaskType from torch_geometric.testing import withPackage def create_random_explanation( data: Data, node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, ): if node_mask_type is not None: node_mask_type = MaskType(node_mask_type) if edge_mask_type is not None: edge_mask_type = MaskType(edge_mask_type) if node_mask_type == MaskType.object: node_mask = torch.rand(data.x.size(0), 1) elif node_mask_type == MaskType.common_attributes: node_mask = torch.rand(1, data.x.size(1)) elif node_mask_type == MaskType.attributes: node_mask = torch.rand_like(data.x) else: node_mask = None if edge_mask_type == MaskType.object: edge_mask = torch.rand(data.edge_index.size(1)) else: edge_mask = None return Explanation( # Create explanation. node_mask=node_mask, edge_mask=edge_mask, x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, ) @pytest.mark.parametrize('node_mask_type', [None, 'object', 'common_attributes', 'attributes']) @pytest.mark.parametrize('edge_mask_type', [None, 'object']) def test_available_explanations(data, node_mask_type, edge_mask_type): expected = [] if node_mask_type is not None: expected.append('node_mask') if edge_mask_type is not None: expected.append('edge_mask') explanation = create_random_explanation( data, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, ) assert set(explanation.available_explanations) == set(expected) def test_validate_explanation(data): explanation = create_random_explanation(data) explanation.validate(raise_on_error=True) with pytest.raises(ValueError, match="with 5 nodes"): explanation = create_random_explanation(data, node_mask_type='object') explanation.x = torch.randn(5, 5) explanation.validate(raise_on_error=True) with pytest.raises(ValueError, match="with 4 features"): explanation = create_random_explanation(data, 'attributes') explanation.x = torch.randn(4, 4) explanation.validate(raise_on_error=True) with pytest.raises(ValueError, match="with 7 edges"): explanation = create_random_explanation(data, edge_mask_type='object') explanation.edge_index = torch.randint(0, 4, (2, 7)) explanation.validate(raise_on_error=True) def test_node_mask(data): node_mask = torch.tensor([[1.], [0.], [1.], [0.]]) explanation = Explanation( node_mask=node_mask, x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, ) explanation.validate(raise_on_error=True) out = explanation.get_explanation_subgraph() assert out.node_mask.size() == (2, 1) assert (out.node_mask > 0.0).sum() == 2 assert out.x.size() == (2, 3) assert out.edge_index.size(1) <= 6 assert out.edge_index.size(1) == out.edge_attr.size(0) out = explanation.get_complement_subgraph() assert out.node_mask.size() == (2, 1) assert (out.node_mask == 0.0).sum() == 2 assert out.x.size() == (2, 3) assert out.edge_index.size(1) <= 6 assert out.edge_index.size(1) == out.edge_attr.size(0) def test_edge_mask(data): edge_mask = torch.tensor([1., 0., 1., 0., 0., 1.]) explanation = Explanation( edge_mask=edge_mask, x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, ) explanation.validate(raise_on_error=True) out = explanation.get_explanation_subgraph() assert out.x.size() == (4, 3) assert out.edge_mask.size() == (3, ) assert (out.edge_mask > 0.0).sum() == 3 assert out.edge_index.size() == (2, 3) assert out.edge_attr.size() == (3, 3) out = explanation.get_complement_subgraph() assert out.x.size() == (4, 3) assert out.edge_mask.size() == (3, ) assert (out.edge_mask == 0.0).sum() == 3 assert out.edge_index.size() == (2, 3) assert out.edge_attr.size() == (3, 3) @withPackage('matplotlib', 'pandas') @pytest.mark.parametrize('top_k', [2, None]) @pytest.mark.parametrize('node_mask_type', [None, 'attributes']) def test_visualize_feature_importance(tmp_path, data, top_k, node_mask_type): explanation = create_random_explanation(data, node_mask_type) path = osp.join(tmp_path, 'feature_importance.png') if node_mask_type is None: with pytest.raises(ValueError, match="node_mask' is not"): explanation.visualize_feature_importance(path, top_k=top_k) else: explanation.visualize_feature_importance(path, top_k=top_k) assert osp.exists(path) ================================================ FILE: test/explain/test_hetero_explainer.py ================================================ import pytest import torch from torch_geometric.explain import ( DummyExplainer, Explainer, HeteroExplanation, ) from torch_geometric.explain.config import ExplanationType class DummyModel(torch.nn.Module): def forward(self, x_dict, edge_index_dict, *args) -> torch.Tensor: return x_dict['paper'].mean().view(-1) def test_get_prediction(hetero_data): model = DummyModel() assert model.training explainer = Explainer( model, algorithm=DummyExplainer(), explanation_type='phenomenon', node_mask_type='object', model_config=dict( mode='regression', task_level='graph', ), ) pred = explainer.get_prediction(hetero_data.x_dict, hetero_data.edge_index_dict) assert model.training assert pred.size() == (1, ) @pytest.mark.parametrize('target', [None, torch.randn(2)]) @pytest.mark.parametrize('explanation_type', [x for x in ExplanationType]) def test_forward(hetero_data, target, explanation_type): model = DummyModel() explainer = Explainer( model, algorithm=DummyExplainer(), explanation_type=explanation_type, node_mask_type='attributes', model_config=dict( mode='regression', task_level='graph', ), ) if target is None and explanation_type == ExplanationType.phenomenon: with pytest.raises(ValueError): explainer(hetero_data.x_dict, hetero_data.edge_index_dict, target=target) else: explanation = explainer( hetero_data.x_dict, hetero_data.edge_index_dict, target=target if explanation_type == ExplanationType.phenomenon else None, ) assert model.training assert isinstance(explanation, HeteroExplanation) assert 'node_mask' in explanation.available_explanations for key in explanation.node_types: expected_size = hetero_data[key].x.size() assert explanation[key].node_mask.size() == expected_size @pytest.mark.parametrize('threshold_value', [0.2, 0.5, 0.8]) @pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) def test_hard_threshold(hetero_data, threshold_value, node_mask_type): explainer = Explainer( DummyModel(), algorithm=DummyExplainer(), explanation_type='model', node_mask_type=node_mask_type, edge_mask_type='object', model_config=dict( mode='regression', task_level='graph', ), threshold_config=('hard', threshold_value), ) explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict) assert 'node_mask' in explanation.available_explanations assert 'edge_mask' in explanation.available_explanations for key in explanation.available_explanations: for mask in explanation.collect(key).values(): assert set(mask.unique().tolist()).issubset({0, 1}) @pytest.mark.parametrize('threshold_value', [1, 5, 10]) @pytest.mark.parametrize('threshold_type', ['topk', 'topk_hard']) @pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) def test_topk_threshold(hetero_data, threshold_value, threshold_type, node_mask_type): explainer = Explainer( DummyModel(), algorithm=DummyExplainer(), explanation_type='model', node_mask_type=node_mask_type, edge_mask_type='object', model_config=dict( mode='regression', task_level='graph', ), threshold_config=(threshold_type, threshold_value), ) explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict) assert 'node_mask' in explanation.available_explanations assert 'edge_mask' in explanation.available_explanations for key in explanation.available_explanations: for mask in explanation.collect(key).values(): if threshold_type == 'topk': assert (mask > 0).sum() == min(mask.numel(), threshold_value) assert ((mask == 0).sum() == mask.numel() - min(mask.numel(), threshold_value)) else: assert (mask == 1).sum() == min(mask.numel(), threshold_value) assert ((mask == 0).sum() == mask.numel() - min(mask.numel(), threshold_value)) ================================================ FILE: test/explain/test_hetero_explanation.py ================================================ import os.path as osp from typing import Optional, Union import pytest import torch from torch_geometric.data import HeteroData from torch_geometric.explain import HeteroExplanation from torch_geometric.explain.config import MaskType from torch_geometric.testing import withPackage def create_random_explanation( hetero_data: HeteroData, node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, ): if node_mask_type is not None: node_mask_type = MaskType(node_mask_type) if edge_mask_type is not None: edge_mask_type = MaskType(edge_mask_type) out = HeteroExplanation() for key in ['paper', 'author']: out[key].x = hetero_data[key].x if node_mask_type == MaskType.object: out[key].node_mask = torch.rand(hetero_data[key].num_nodes, 1) elif node_mask_type == MaskType.common_attributes: out[key].node_mask = torch.rand(1, hetero_data[key].num_features) elif node_mask_type == MaskType.attributes: out[key].node_mask = torch.rand_like(hetero_data[key].x) for key in [('paper', 'paper'), ('paper', 'author')]: out[key].edge_index = hetero_data[key].edge_index out[key].edge_attr = hetero_data[key].edge_attr if edge_mask_type == MaskType.object: out[key].edge_mask = torch.rand(hetero_data[key].num_edges) return out @pytest.mark.parametrize('node_mask_type', [None, 'object', 'common_attributes', 'attributes']) @pytest.mark.parametrize('edge_mask_type', [None, 'object']) def test_available_explanations(hetero_data, node_mask_type, edge_mask_type): expected = [] if node_mask_type: expected.append('node_mask') if edge_mask_type: expected.append('edge_mask') explanation = create_random_explanation( hetero_data, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, ) assert set(explanation.available_explanations) == set(expected) def test_validate_explanation(hetero_data): explanation = create_random_explanation(hetero_data) explanation.validate(raise_on_error=True) with pytest.raises(ValueError, match="with 8 nodes"): explanation = create_random_explanation(hetero_data) explanation['paper'].node_mask = torch.rand(5, 5) explanation.validate(raise_on_error=True) with pytest.raises(ValueError, match="with 5 features"): explanation = create_random_explanation(hetero_data, 'attributes') explanation['paper'].x = torch.randn(8, 5) explanation.validate(raise_on_error=True) with pytest.raises(ValueError, match="with 10 edges"): explanation = create_random_explanation(hetero_data) explanation['paper', 'paper'].edge_mask = torch.randn(5) explanation.validate(raise_on_error=True) def test_node_mask(): explanation = HeteroExplanation() explanation['paper'].node_mask = torch.tensor([[1.], [0.], [1.], [1.]]) explanation['author'].node_mask = torch.tensor([[1.], [0.], [1.], [1.]]) with pytest.warns(UserWarning, match="are isolated"): explanation.validate(raise_on_error=True) out = explanation.get_explanation_subgraph() assert out['paper'].node_mask.size() == (3, 1) assert out['author'].node_mask.size() == (3, 1) out = explanation.get_complement_subgraph() assert out['paper'].node_mask.size() == (1, 1) assert out['author'].node_mask.size() == (1, 1) def test_edge_mask(): explanation = HeteroExplanation() explanation['paper'].num_nodes = 4 explanation['author'].num_nodes = 4 explanation['paper', 'author'].edge_index = torch.tensor([ [0, 1, 2, 3], [0, 1, 2, 3], ]) explanation['paper', 'author'].edge_mask = torch.tensor([1., 0., 1., 1.]) out = explanation.get_explanation_subgraph() assert out['paper'].num_nodes == 4 assert out['author'].num_nodes == 4 assert out['paper', 'author'].edge_mask.size() == (3, ) assert torch.equal(out['paper', 'author'].edge_index, torch.tensor([[0, 2, 3], [0, 2, 3]])) out = explanation.get_complement_subgraph() assert out['paper'].num_nodes == 4 assert out['author'].num_nodes == 4 assert out['paper', 'author'].edge_mask.size() == (1, ) assert torch.equal(out['paper', 'author'].edge_index, torch.tensor([[1], [1]])) @withPackage('matplotlib', 'pandas') @pytest.mark.parametrize('top_k', [2, None]) @pytest.mark.parametrize('node_mask_type', [None, 'attributes']) def test_visualize_feature_importance( top_k, node_mask_type, tmp_path, hetero_data, ): explanation = create_random_explanation( hetero_data, node_mask_type=node_mask_type, ) path = osp.join(tmp_path, 'feature_importance.png') if node_mask_type is None: with pytest.raises(KeyError, match="Tried to collect 'node_mask'"): explanation.visualize_feature_importance(path, top_k=top_k) else: explanation.visualize_feature_importance(path, top_k=top_k) assert osp.exists(path) @withPackage('matplotlib', 'networkx') def test_hetero_visualize_graph(tmp_path, hetero_data): # Create explanation with both node and edge masks explanation = create_random_explanation(hetero_data, node_mask_type='object', edge_mask_type='object') path = osp.join(tmp_path, 'explanation_graph.png') # Test with default parameters explanation.visualize_graph(path=path) assert osp.exists(path) # Test with custom visualization parameters explanation.visualize_graph(path=path, node_size_range=(20, 400), node_opacity_range=(0.3, 0.9), edge_width_range=(0.2, 3.0), edge_opacity_range=(0.3, 0.9)) assert osp.exists(path) # Test with node labels node_labels = { 'paper': [f'Paper {i}' for i in range(hetero_data['paper'].num_nodes)], 'author': [f'Author {i}' for i in range(hetero_data['author'].num_nodes)], } explanation.visualize_graph(path=path, node_labels=node_labels) assert osp.exists(path) # Test with invalid number of labels invalid_labels = { 'paper': ['Paper 0'], # Too few labels 'author': ['Author 0', 'Author 1'], # Too few labels } with pytest.raises(ValueError, match="Number of labels"): explanation.visualize_graph(node_labels=invalid_labels) # Test with invalid node type in labels invalid_labels = { 'paper': [f'Paper {i}' for i in range(hetero_data['paper'].num_nodes)], 'author': [f'Author {i}' for i in range(hetero_data['author'].num_nodes)], 'invalid_type': ['Invalid 0', 'Invalid 1'], # Invalid node type } with pytest.raises(ValueError, match="Node type"): explanation.visualize_graph(node_labels=invalid_labels) ================================================ FILE: test/graphgym/example_node.yml ================================================ tensorboard_each_run: false tensorboard_agg: false dataset: format: PyG name: Cora task: node task_type: classification node_encoder: false node_encoder_name: Atom edge_encoder: false edge_encoder_name: Bond train: batch_size: 128 eval_period: 2 ckpt_period: 100 enable_ckpt: false skip_train_eval: true sampler: full_batch model: type: gnn loss_fun: cross_entropy edge_decoding: dot graph_pooling: add gnn: layers_pre_mp: 2 layers_mp: 2 layers_post_mp: 1 dim_inner: 16 layer_type: gcnconv stage_type: stack batchnorm: false act: prelu dropout: 0.1 agg: mean normalize_adj: false optim: optimizer: adam base_lr: 0.01 max_epoch: 6 ================================================ FILE: test/graphgym/test_config.py ================================================ from dataclasses import dataclass from torch_geometric.graphgym.config import from_config @dataclass class MyConfig: a: int b: int = 4 def my_func(a: int, b: int = 2) -> str: return f'a={a},b={b}' def test_from_config(): assert my_func(a=1) == 'a=1,b=2' assert my_func.__name__ == from_config(my_func).__name__ assert from_config(my_func)(cfg=MyConfig(a=1)) == 'a=1,b=4' assert from_config(my_func)(cfg=MyConfig(a=1, b=1)) == 'a=1,b=1' assert from_config(my_func)(2, cfg=MyConfig(a=1, b=3)) == 'a=2,b=3' assert from_config(my_func)(cfg=MyConfig(a=1), b=3) == 'a=1,b=3' ================================================ FILE: test/graphgym/test_graphgym.py ================================================ import os.path as osp import warnings from collections import namedtuple import pytest import torch from torch_geometric import seed_everything from torch_geometric.graphgym import register from torch_geometric.graphgym.checkpoint import get_ckpt_dir from torch_geometric.graphgym.config import ( cfg, dump_cfg, load_cfg, set_out_dir, set_run_dir, ) from torch_geometric.graphgym.loader import create_loader from torch_geometric.graphgym.logger import set_printing from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNStackStage from torch_geometric.graphgym.models.head import GNNNodeHead from torch_geometric.graphgym.train import GraphGymDataModule, train from torch_geometric.graphgym.utils import ( agg_runs, auto_select_device, params_count, ) from torch_geometric.testing import onlyLinux, onlyOnline, withPackage num_trivial_metric_calls = 0 Args = namedtuple('Args', ['cfg_file', 'opts']) root = osp.join(osp.dirname(osp.realpath(__file__))) args = Args(osp.join(root, 'example_node.yml'), []) def trivial_metric(true, pred, task_type): global num_trivial_metric_calls num_trivial_metric_calls += 1 return 1 @onlyOnline @withPackage('yacs', 'pytorch_lightning') @pytest.mark.parametrize('auto_resume', [True, False]) @pytest.mark.parametrize('skip_train_eval', [True, False]) @pytest.mark.parametrize('use_trivial_metric', [True, False]) def test_run_single_graphgym(tmp_path, capfd, auto_resume, skip_train_eval, use_trivial_metric): warnings.filterwarnings('ignore', ".*does not have many workers.*") warnings.filterwarnings('ignore', ".*lower value for log_every_n_steps.*") load_cfg(cfg, args) cfg.out_dir = osp.join(tmp_path, 'out_dir') cfg.run_dir = osp.join(tmp_path, 'run_dir') cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid') cfg.train.auto_resume = auto_resume set_out_dir(cfg.out_dir, args.cfg_file) dump_cfg(cfg) set_printing() seed_everything(cfg.seed) auto_select_device() set_run_dir(cfg.out_dir) cfg.train.skip_train_eval = skip_train_eval cfg.train.enable_ckpt = use_trivial_metric and skip_train_eval if use_trivial_metric: if 'trivial' not in register.metric_dict: register.register_metric('trivial', trivial_metric) global num_trivial_metric_calls num_trivial_metric_calls = 0 cfg.metric_best = 'trivial' cfg.custom_metrics = ['trivial'] else: cfg.metric_best = 'auto' cfg.custom_metrics = [] datamodule = GraphGymDataModule() assert len(datamodule.loaders) == 3 model = create_model() assert isinstance(model, torch.nn.Module) assert isinstance(model.encoder, FeatureEncoder) assert isinstance(model.mp, GNNStackStage) assert isinstance(model.post_mp, GNNNodeHead) assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp optimizer, scheduler = model.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adam) assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR) cfg.params = params_count(model) assert cfg.params == 23883 train(model, datamodule, logger=True, trainer_config={"enable_progress_bar": False}) assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt agg_runs(cfg.out_dir, cfg.metric_best) out, _ = capfd.readouterr() assert "train: {'epoch': 0," in out assert "val: {'epoch': 0," in out assert "train: {'epoch': 5," in out assert "val: {'epoch': 5," in out @onlyOnline @withPackage('yacs', 'pytorch_lightning') def test_graphgym_module(tmp_path): import pytorch_lightning as pl load_cfg(cfg, args) cfg.out_dir = osp.join(tmp_path, 'out_dir') cfg.run_dir = osp.join(tmp_path, 'run_dir') cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid') set_out_dir(cfg.out_dir, args.cfg_file) dump_cfg(cfg) set_printing() seed_everything(cfg.seed) auto_select_device() set_run_dir(cfg.out_dir) loaders = create_loader() assert len(loaders) == 3 model = create_model() assert isinstance(model, pl.LightningModule) optimizer, scheduler = model.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adam) assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR) cfg.params = params_count(model) assert cfg.params == 23883 keys = {"loss", "true", "pred_score", "step_end_time"} # test training step batch = next(iter(loaders[0])) batch.to(model.device) outputs = model.training_step(batch) assert keys == set(outputs.keys()) assert isinstance(outputs["loss"], torch.Tensor) # test validation step batch = next(iter(loaders[1])) batch.to(model.device) outputs = model.validation_step(batch) assert keys == set(outputs.keys()) assert isinstance(outputs["loss"], torch.Tensor) # test test step batch = next(iter(loaders[2])) batch.to(model.device) outputs = model.test_step(batch) assert keys == set(outputs.keys()) assert isinstance(outputs["loss"], torch.Tensor) @pytest.fixture def destroy_process_group(): yield if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() @onlyOnline @onlyLinux @withPackage('yacs', 'pytorch_lightning') def test_train(destroy_process_group, tmp_path, capfd): warnings.filterwarnings('ignore', ".*does not have many workers.*") import pytorch_lightning as pl load_cfg(cfg, args) cfg.out_dir = osp.join(tmp_path, 'out_dir') cfg.run_dir = osp.join(tmp_path, 'run_dir') cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid') set_out_dir(cfg.out_dir, args.cfg_file) dump_cfg(cfg) set_printing() seed_everything(cfg.seed) auto_select_device() set_run_dir(cfg.out_dir) loaders = create_loader() model = create_model() cfg.params = params_count(model) # --- minimal logger callback that collects logs --- class LoggerCallback(pl.Callback): def __init__(self): super().__init__() self.logged = [] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.logged.append({"type": "train", "step": trainer.global_step}) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): self.logged.append({"type": "val", "step": trainer.global_step}) logger = LoggerCallback() trainer = pl.Trainer(max_epochs=2, max_steps=4, callbacks=[logger], log_every_n_steps=1, enable_progress_bar=False) train_loader, val_loader = loaders[0], loaders[1] trainer.fit(model, train_loader, val_loader) assert trainer.current_epoch > 0 # ensure both train and val batches were seen types = {entry["type"] for entry in logger.logged} assert "val" in types, "Validation did not run" assert "train" in types, "Training did not run" ================================================ FILE: test/graphgym/test_logger.py ================================================ from torch_geometric.graphgym.config import set_run_dir from torch_geometric.graphgym.loader import create_loader from torch_geometric.graphgym.logger import Logger, LoggerCallback from torch_geometric.testing import withPackage @withPackage('yacs', 'pytorch_lightning') def test_logger_callback(): loaders = create_loader() assert len(loaders) == 3 set_run_dir('.') logger = LoggerCallback() assert isinstance(logger.train_logger, Logger) assert isinstance(logger.val_logger, Logger) assert isinstance(logger.test_logger, Logger) ================================================ FILE: test/graphgym/test_register.py ================================================ import torch import torch_geometric.graphgym.register as register from torch_geometric.testing import withPackage @register.register_act('identity') def identity_act(x: torch.Tensor) -> torch.Tensor: return x @withPackage('yacs') def test_register(): assert len(register.act_dict) == 8 assert list(register.act_dict.keys()) == [ 'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05', 'identity' ] assert str(register.act_dict['relu']()) == 'ReLU()' register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3)) assert len(register.act_dict) == 9 assert 'lrelu_03' in register.act_dict ================================================ FILE: test/io/example1.off ================================================ OFF 4 2 0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 3 0 1 2 3 1 2 3 ================================================ FILE: test/io/example2.off ================================================ OFF 4 1 0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 4 0 1 2 3 ================================================ FILE: test/io/test_fs.py ================================================ import zipfile from os import path as osp import fsspec import pytest import torch import torch_geometric.typing from torch_geometric.data import extract_zip from torch_geometric.io import fs from torch_geometric.testing import noWindows if torch_geometric.typing.WITH_WINDOWS: # FIXME params = ['file'] else: params = ['file', 'memory'] @pytest.fixture(params=params) def tmp_fs_path(request, tmp_path) -> str: if request.param == 'file': return tmp_path.resolve().as_posix() elif request.param == 'memory': return f'memory://{tmp_path}' raise NotImplementedError def test_get_fs(): assert 'file' in fs.get_fs('/tmp/test').protocol assert 'memory' in fs.get_fs('memory:///tmp/test').protocol @noWindows def test_normpath(): assert fs.normpath('////home') == '/home' assert fs.normpath('memory:////home') == 'memory:////home' def test_exists(tmp_fs_path): path = osp.join(tmp_fs_path, 'file.txt') assert not fs.exists(path) with fsspec.open(path, 'w') as f: f.write('here') assert fs.exists(path) def test_makedirs(tmp_fs_path): path = osp.join(tmp_fs_path, '1', '2') assert not fs.isdir(path) fs.makedirs(path) assert fs.isdir(path) @pytest.mark.parametrize('detail', [False, True]) def test_ls(tmp_fs_path, detail): for i in range(2): with fsspec.open(osp.join(tmp_fs_path, str(i)), 'w') as f: f.write('here') res = fs.ls(tmp_fs_path, detail) assert len(res) == 2 expected_protocol = fs.get_fs(tmp_fs_path).protocol for output in res: if detail: output = output['name'] assert fs.get_fs(output).protocol == expected_protocol def test_cp(tmp_fs_path): src = osp.join(tmp_fs_path, 'src') for i in range(2): with fsspec.open(osp.join(src, str(i)), 'w') as f: f.write('here') assert fs.exists(src) dst = osp.join(tmp_fs_path, 'dst') assert not fs.exists(dst) # Can copy a file to new name: fs.cp(osp.join(src, '1'), dst) assert fs.isfile(dst) fs.rm(dst) # Can copy a single file to directory: fs.makedirs(dst) fs.cp(osp.join(src, '1'), dst) assert len(fs.ls(dst)) == 1 # Can copy multiple files to directory: fs.cp(src, dst) assert len(fs.ls(dst)) == 2 for i in range(2): fs.exists(osp.join(dst, str(i))) def test_extract(tmp_fs_path): def make_zip(path: str): with fsspec.open(path, mode='wb') as f: with zipfile.ZipFile(f, mode='w') as z: z.writestr('1', b'data') z.writestr('2', b'data') src = osp.join(tmp_fs_path, 'src', 'test.zip') make_zip(src) assert len(fsspec.open_files(f'zip://*::{src}')) == 2 dst = osp.join(tmp_fs_path, 'dst') assert not fs.exists(dst) # Can copy and extract afterwards: if fs.isdisk(tmp_fs_path): fs.cp(src, osp.join(dst, 'test.zip')) assert fs.exists(osp.join(dst, 'test.zip')) extract_zip(osp.join(dst, 'test.zip'), dst) assert len(fs.ls(dst)) == 3 for i in range(2): fs.exists(osp.join(dst, str(i))) fs.rm(dst) # Can copy and extract: fs.cp(src, dst, extract=True) assert len(fs.ls(dst)) == 2 for i in range(2): fs.exists(osp.join(dst, str(i))) def test_torch_save_load(tmp_fs_path): x = torch.randn(5, 5) path = osp.join(tmp_fs_path, 'x.pt') fs.torch_save(x, path) out = fs.torch_load(path) assert torch.equal(x, out) ================================================ FILE: test/io/test_off.py ================================================ import os import os.path as osp import random import sys import torch from torch_geometric.data import Data from torch_geometric.io import read_off, write_off def test_read_off(): root_dir = osp.join(osp.dirname(osp.realpath(__file__))) data = read_off(osp.join(root_dir, 'example1.off')) assert len(data) == 2 assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]] assert data.face.tolist() == [[0, 1], [1, 2], [2, 3]] data = read_off(osp.join(root_dir, 'example2.off')) assert len(data) == 2 assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]] assert data.face.tolist() == [[0, 0], [1, 2], [2, 3]] def test_write_off(): pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]) face = torch.tensor([[0, 1], [1, 2], [2, 3]]) name = str(random.randrange(sys.maxsize)) path = osp.join('/', 'tmp', f'{name}.off') write_off(Data(pos=pos, face=face), path) data = read_off(path) os.unlink(path) assert data.pos.tolist() == pos.tolist() assert data.face.tolist() == face.tolist() ================================================ FILE: test/llm/conftest.py ================================================ import pathlib import pytest LLM_DIR = pathlib.Path(__file__).parent def pytest_collection_modifyitems(items): for item in items: if pathlib.Path(item.fspath).is_relative_to(LLM_DIR): item.add_marker(pytest.mark.rag) ================================================ FILE: test/llm/models/test_g_retriever.py ================================================ import gc from contextlib import nullcontext from types import SimpleNamespace import pytest import torch from torch import nn from torch_geometric.llm.models import LLM, GRetriever from torch_geometric.nn import GAT from torch_geometric.testing import withPackage @withPackage('transformers', 'sentencepiece', 'accelerate', 'peft') @pytest.mark.parametrize('use_lora', [True, False]) def test_g_retriever(use_lora: bool) -> None: llm = LLM(model_name='Qwen/Qwen3-0.6B', dtype=torch.float32, sys_prompt="You're an agent, answer my questions.") gnn = GAT( in_channels=1024, out_channels=1024, hidden_channels=1024, num_layers=2, heads=4, norm='batch_norm', ) model = GRetriever( llm=llm, gnn=gnn, use_lora=use_lora, ) assert str(model) == ('GRetriever(\n' ' llm=LLM(Qwen/Qwen3-0.6B),\n' ' gnn=GAT(1024, 1024, num_layers=2),\n' ')') x = torch.randn(10, 1024) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], ]) edge_attr = torch.randn(edge_index.size(1), 1024) batch = torch.zeros(x.size(0), dtype=torch.long) question = ["Is PyG the best open-source GNN library?"] label = ["yes!"] # Test train: loss = model(question, x, edge_index, batch, label, edge_attr) assert loss >= 0 # Test inference: pred = model.inference(question, x, edge_index, batch, edge_attr) assert len(pred) == 1 del model, llm, gnn gc.collect() torch.cuda.empty_cache() @withPackage('transformers', 'sentencepiece', 'accelerate', 'peft') def test_g_retriever_many_tokens() -> None: llm = LLM(model_name='Qwen/Qwen3-0.6B', dtype=torch.float32, sys_prompt="You're an agent, answer my questions.") gnn = GAT( in_channels=1024, out_channels=1024, hidden_channels=1024, num_layers=2, heads=4, norm='batch_norm', ) model = GRetriever( llm=llm, gnn=gnn, mlp_out_tokens=2, use_lora=True, ) assert str(model) == ('GRetriever(\n' ' llm=LLM(Qwen/Qwen3-0.6B),\n' ' gnn=GAT(1024, 1024, num_layers=2),\n' ')') x = torch.randn(10, 1024) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], ]) edge_attr = torch.randn(edge_index.size(1), 1024) batch = torch.zeros(x.size(0), dtype=torch.long) question = ["Is PyG the best open-source GNN library?"] label = ["yes!"] # Test train: loss = model(question, x, edge_index, batch, label, edge_attr) assert loss >= 0 # Test inference: pred = model.inference(question, x, edge_index, batch, edge_attr) assert len(pred) == 1 del model, llm, gnn gc.collect() torch.cuda.empty_cache() class DummyHFModel(nn.Module): def __init__(self, vocab_size=10): super().__init__() self.vocab_size = vocab_size self.dummy = nn.Parameter(torch.zeros(1)) def forward(self, inputs_embeds=None, **kwargs): B, T, _ = inputs_embeds.shape logits = torch.randn(B, T, self.vocab_size, device=inputs_embeds.device) loss = torch.tensor(0.0, device=inputs_embeds.device) loss.logits = logits return SimpleNamespace( logits=logits, loss=loss, ) class DummyLLM: def __init__(self, hidden_dim): self.word_embedding = nn.Embedding(100, hidden_dim) self.llm = DummyHFModel() self.device = torch.device("cpu") self.autocast_context = nullcontext() def _get_embeds(self, question, *args): batch_size = len(question) seq_len = 4 hidden = self.word_embedding.embedding_dim inputs_embeds = torch.randn(batch_size, seq_len, hidden) attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) return inputs_embeds, attention_mask, None class DummyGNN(nn.Module): """Simple GNN stub returning node embeddings.""" def __init__(self, in_channels=4, out_channels=8): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.lin = nn.Linear(in_channels, out_channels) def forward(self, *args, **kwargs): x = args[0] return self.lin(x) @pytest.mark.parametrize("batch_size", [1, 3]) def test_gretriever_prefix_embedding_injection(batch_size): hidden_dim = 8 num_nodes = 5 llm = DummyLLM(hidden_dim) gnn = DummyGNN(in_channels=4, out_channels=8) model = GRetriever( llm=llm, gnn=gnn, mlp_out_tokens=2, ) # graph inputs x = torch.randn(num_nodes, 4) edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) batch = torch.zeros(num_nodes, dtype=torch.long) # token ids questions = ["What is this graph?"] * batch_size labels = ["dummy answer"] * batch_size out = model( x=x, edge_index=edge_index, batch=batch, question=questions, label=labels, ) # basic correctness assertions assert hasattr(out, "logits") assert out.logits.shape[0] == batch_size ================================================ FILE: test/llm/models/test_git_mol.py ================================================ import torch from torch_geometric.llm.models import GITMol from torch_geometric.testing import withPackage @withPackage('transformers', 'sentencepiece', 'accelerate') def test_git_mol(): model = GITMol() x = torch.ones(10, 16, dtype=torch.long) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], ]) edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long) # batch size = 1 batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O'] captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.'] images = torch.randn(1, 3, 224, 224) loss = model(x, edge_index, batch, edge_attr, smiles, images, captions) assert loss >= 0 # batch size > 1 batch = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) smiles = [ 'CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O', 'CCOc1ccccc1', ] captions = [ 'The molecule is the (R)-(-)-enantiomer of columbianetin.', 'Ethoxybenzene is an aromatic ether.', ] images = torch.randn(2, 3, 224, 224) loss = model(x, edge_index, batch, edge_attr, smiles, images, captions) assert loss >= 0 ================================================ FILE: test/llm/models/test_glem.py ================================================ import pytest import torch from torch_geometric.llm.models.glem import deal_nan def test_deal_nan_tensor_replaces_nans(): x = torch.tensor([1.0, float('nan'), 3.0]) result = deal_nan(x) expected = torch.tensor([1.0, 0.0, 3.0]) assert torch.allclose(result, expected, equal_nan=True) assert isinstance(result, torch.Tensor) assert not torch.isnan(result).any() def test_deal_nan_non_tensor_passthrough(): assert deal_nan(42.0) == 42.0 assert deal_nan("foo") == "foo" @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_deal_nan_tensor_dtypes(dtype): # Create a tensor with one NaN value x = torch.tensor([1.0, float('nan'), 3.0], dtype=dtype) result = deal_nan(x) expected = torch.tensor([1.0, 0.0, 3.0], dtype=dtype) # `bfloat16` doesn't support `allclose` directly on CPU, # so we cast to float32 for comparison if dtype == torch.bfloat16: assert torch.allclose(result.to(torch.float32), expected.to(torch.float32), atol=1e-2) else: assert torch.allclose(result, expected, equal_nan=True) assert isinstance(result, torch.Tensor) assert not torch.isnan(result).any() assert result.dtype == dtype def test_deal_nan_is_non_mutating(): x = torch.tensor([1.0, float('nan'), 3.0]) x_copy = x.clone() _ = deal_nan(x) assert torch.isnan(x).any() # Original still contains NaN assert torch.allclose(x, x_copy, equal_nan=True) ================================================ FILE: test/llm/models/test_llm.py ================================================ import gc import pytest import torch from torch import Tensor from torch_geometric.llm.models import LLM from torch_geometric.llm.models.llm import get_llm_kwargs from torch_geometric.testing import withPackage def test_get_llm_kwargs(): kwargs = get_llm_kwargs(required_memory=640) assert kwargs == {'revision': 'main'} @withPackage('transformers', 'accelerate') @pytest.mark.parametrize('sys_prompt', ['You are an agent, answer my questions.', None]) @pytest.mark.parametrize('context', [['This is context.'], None]) @pytest.mark.parametrize('use_embedding', [True, False]) def test_llm(sys_prompt, context, use_embedding) -> None: question = ["Is PyG the best open-source GNN library?"] answer = ["yes!"] model = LLM( model_name='Qwen/Qwen3-0.6B', num_params=1, sys_prompt=sys_prompt, ) assert str(model) == 'LLM(Qwen/Qwen3-0.6B)' embedding = [torch.randn(1, 1024, dtype=torch.bfloat16).to(model.device) ] if use_embedding else None loss = model(question, answer, context=context, embedding=embedding) assert isinstance(loss, Tensor) assert loss.dim() == 0 assert loss >= 0.0 pred = model.inference(question) assert len(pred) == 1 del model gc.collect() torch.cuda.empty_cache() class DummyBatch(dict): """Mimics HuggingFace BatchEncoding.""" def to(self, device): return self class DummyTokenizer: pad_token_id = 0 padding_side = "left" def __call__(self, texts, return_tensors=None, padding=True): lengths = [len(t) for t in texts] max_len = max(lengths) ids = [] mask = [] for seq_len in lengths: padding = max_len - seq_len ids.append([0] * padding + list(range(1, seq_len + 1))) mask.append([0] * padding + [1] * seq_len) return DummyBatch({ "input_ids": torch.tensor(ids), "attention_mask": torch.tensor(mask) }) class DummyModel(torch.nn.Module): def get_input_embeddings(self): return torch.nn.Embedding(100, 8) def forward(self, inputs_embeds=None, attention_mask=None, **kwargs): batch, seq, dim = inputs_embeds.shape class Out: pass out = Out() out.logits = torch.zeros(batch, seq, 10) return out @pytest.fixture def dummy_llm(): llm = LLM.__new__(LLM) torch.nn.Module.__init__(llm) llm.device = torch.device("cpu") llm.tokenizer = DummyTokenizer() llm.model = DummyModel() return llm def test_llm_prepare_inputs(dummy_llm): prompts = ["hello", "hi"] encoded = dummy_llm.tokenizer(prompts) input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] emb = dummy_llm.model.get_input_embeddings() inputs_embeds = emb(input_ids) out = dummy_llm.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) assert inputs_embeds.shape[0] == 2 assert attention_mask.shape == input_ids.shape assert hasattr(out, "logits") assert out.logits.shape[:2] == inputs_embeds.shape[:2] def test_llm_single_prompt(dummy_llm): encoded = dummy_llm.tokenizer(["test"]) assert encoded["input_ids"].shape[0] == 1 def test_llm_variable_lengths(dummy_llm): prompts = ["a", "abcdef", "abc"] encoded = dummy_llm.tokenizer(prompts) input_ids = encoded["input_ids"] assert input_ids.shape[0] == 3 assert input_ids.shape[1] == max(len(p) for p in prompts) ================================================ FILE: test/llm/models/test_llm_judge.py ================================================ import numpy as np from torch_geometric.llm.models import LLMJudge def test_llm_judge(): judge = LLMJudge() assert judge._process_score('1234') == 1.0 assert judge._average_scores(1, 3) == 2 assert judge._average_scores(-1, 3) == 3 assert np.isnan(judge.score('question', 'model_pred', 'correct_answer')) ================================================ FILE: test/llm/models/test_molecule_gpt.py ================================================ import torch from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.llm.models import LLM, MoleculeGPT, SentenceTransformer from torch_geometric.nn import GINEConv from torch_geometric.testing import withPackage @withPackage('transformers', 'sentencepiece', 'accelerate') def test_molecule_gpt() -> None: llm = LLM( # model_name='lmsys/vicuna-7b-v1.5', model_name='Qwen/Qwen3-0.6B', num_params=1, dtype=torch.float32) graph_encoder = GINEConv(nn=Seq(Lin(16, 16), ReLU(), Lin(16, 16)), train_eps=True, edge_dim=16) smiles_encoder = SentenceTransformer( model_name='DeepChem/ChemBERTa-77M-MTR', pooling_strategy='last_hidden_state', ) model = MoleculeGPT( llm=llm, graph_encoder=graph_encoder, smiles_encoder=smiles_encoder, ) assert str(model) == ( 'MoleculeGPT(\n' ' llm=LLM(Qwen/Qwen3-0.6B),\n' ' graph=GINEConv,\n' ' smiles=SentenceTransformer(model_name=DeepChem/ChemBERTa-77M-MTR),\n' # noqa: E501 ')') x = torch.randn(10, 16) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], ]) edge_attr = torch.randn(edge_index.size(1), 16) batch = torch.zeros(x.size(0), dtype=torch.long) smiles = ['CCCCCCCCCC'] instructions = ['What is ∼ functional related to?'] label = ['I do not know!'] # Test train: loss = model(x, edge_index, batch, edge_attr, smiles, instructions, label) assert loss >= 0 # Test inference: pred = model.inference(x, edge_index, batch, edge_attr, smiles, instructions) assert len(pred) == 1 ================================================ FILE: test/llm/models/test_protein_mpnn.py ================================================ import torch from torch_geometric.llm.models import ProteinMPNN from torch_geometric.testing import withPackage @withPackage('torch_cluster') def test_protein_mpnn(): num_nodes = 10 vocab_size = 21 model = ProteinMPNN(vocab_size=vocab_size) x = torch.randn(num_nodes, 4, 3) chain_seq_label = torch.randint(0, vocab_size, (num_nodes, )) mask = torch.ones(num_nodes) chain_mask_all = torch.ones(num_nodes) residue_idx = torch.randint(0, 10, (num_nodes, )) chain_encoding_all = torch.ones(num_nodes) batch = torch.zeros(num_nodes, dtype=torch.long) logits = model(x, chain_seq_label, mask, chain_mask_all, residue_idx, chain_encoding_all, batch) assert logits.size() == (num_nodes, vocab_size) ================================================ FILE: test/llm/models/test_sentence_transformer.py ================================================ import pytest import torch from torch_geometric.llm.models import SentenceTransformer from torch_geometric.llm.models.sentence_transformer import ( last_pooling, mean_pooling, ) from torch_geometric.testing import withCUDA, withPackage @withCUDA @withPackage('transformers') @pytest.mark.parametrize('batch_size', [None, 1]) @pytest.mark.parametrize('pooling_strategy', ['mean', 'last', 'cls']) @pytest.mark.parametrize('verbose', [True, False]) def test_sentence_transformer(batch_size, pooling_strategy, device, verbose): model_name = 'bert-base-uncased' model = SentenceTransformer( model_name=model_name, pooling_strategy=pooling_strategy, ).to(device) assert model.device == device assert str(model) == f'SentenceTransformer(model_name={model_name})' text = [ "this is a basic english text", "PyG is the best open-source GNN library :)", ] model_embedding_dim = model.model.config.hidden_size out = model.encode(text, batch_size=batch_size, verbose=verbose) assert out.device == device assert out.shape == (2, model_embedding_dim) out = model.encode(text, batch_size=batch_size, output_device='cpu', verbose=verbose) assert out.is_cpu assert out.shape == (2, model_embedding_dim) out = model.encode([], batch_size=batch_size, verbose=verbose) assert out.device == device assert out.shape == (0, model_embedding_dim) def test_mean_pooling(): x = torch.randn(2, 1, 2) attention_mask = torch.zeros(2, 1) result = mean_pooling(x, attention_mask) expected = torch.zeros_like(x) assert torch.allclose(result, expected, atol=1e-6) @pytest.mark.parametrize('mask', [torch.ones, torch.zeros]) def test_last_pooling(mask): x = torch.randn(2, 1, 2) attention_mask = mask(2, 1, dtype=torch.long) out = last_pooling(x, attention_mask) assert torch.allclose(out, x[:, 0, :], atol=1e-6) ================================================ FILE: test/llm/models/test_txt2kg.py ================================================ import sys import types import pytest import torch_geometric.llm.models.txt2kg as txt2kg from torch_geometric.llm.models.txt2kg import ( TXT2KG, _chunk_text, _merge_triples_deterministically, _multiproc_helper, _parse_n_check_triples, ) def test_init_local_lm_flag(): model = TXT2KG(local_LM=True, chunk_size=20) assert model.local_LM is True assert model.initd_LM is False def test_parse_n_check_triples_formats(): s = "(A, rel, B)\n(C, rel2, D)" parsed = _parse_n_check_triples(s) assert ("A", "rel", "B") in parsed assert ("C", "rel2", "D") in parsed def test_chunk_text_simple_sentence(): text = "Hello world. Another sentence!" chunks = _chunk_text(text, chunk_size=10) # Only makes chunks at sentence boundaries assert any("Hello" in c for c in chunks) class DummyLLM: def __init__(self): pass def inference(self, *args, **kwargs): return ["(X,edge,Y)"] def test_local_lm_integration(monkeypatch): model = TXT2KG(local_LM=True) model.model = DummyLLM() model.initd_LM = True # Simulate time progression times = iter([100.0, 100.05]) # 0.05 sec elapsed monkeypatch.setattr("time.time", lambda: next(times)) out = model._chunk_to_triples_str_local("text") assert out == "(X,edge,Y)" assert model.time_to_parse > 0 def test_add_doc_empty(monkeypatch): model = TXT2KG(local_LM=True) model.add_doc_2_KG("", QA_pair=None) assert model.relevant_triples[0] == [] # Mock LLM + parsing on real text: def test_add_doc_to_KG(monkeypatch): model = TXT2KG(local_LM=True, chunk_size=10) # Mock only the LLM output stage monkeypatch.setattr(model, "_chunk_to_triples_str_local", lambda *_: "(A,rel,B)\n(C,rel,D)") model.add_doc_2_KG("Some text") triples = model.relevant_triples[0] assert len(triples) == 2 assert ("A", "rel", "B") in triples assert model.doc_id_counter == 1 def test_merge_triples_deterministically_basic(): # Simple case: multiple sublists, strings only results = [ [["b", "rel", "c"], ["A", "rel", "d"]], [["a", "rel", "c"]], ] merged = _merge_triples_deterministically(results) # Expect deterministic, casefolded lexicographic order expected = [ ("a", "rel", "c"), ("A", "rel", "d"), ("b", "rel", "c"), ] assert merged == expected def test_merge_triples_deterministically_unicode_and_nonstring(): # Include unicode and a numeric element to cover else branch in lambda results = [ [["ä", 2, "x"], ["A", 1, "y"]], [["a", 3, "z"]], ] merged = _merge_triples_deterministically(results) # Ensure tuples, unicode sorted, numeric untouched expected = [ ("A", 1, "y"), ("a", 3, "z"), ("ä", 2, "x"), ] assert merged == expected def test_merge_triples_deterministically_empty(): # Edge case: empty input results = [] merged = _merge_triples_deterministically(results) assert merged == [] def test_merge_triples_deterministically_singleton(): # Edge case: single sublist, single triple results = [[["only", "one", "triple"]]] merged = _merge_triples_deterministically(results) assert merged == [("only", "one", "triple")] def test_chunk_to_triples_str_cloud(monkeypatch): # Fake streaming chunk object class DummyChunk: class Choice: class Delta: content = "A" delta = Delta() choices = [Choice()] class DummyCompletion: def __iter__(self): return iter([DummyChunk()]) class DummyClient: class Chat: class Completions: def create(self, **kwargs): return DummyCompletion() completions = Completions() chat = Chat() class DummyOpenAI: def __init__(self, *args, **kwargs): pass chat = DummyClient.chat fake_openai = types.ModuleType("openai") fake_openai.OpenAI = DummyOpenAI monkeypatch.setitem(sys.modules, "openai", fake_openai) txt2kg.CLIENT_INITD = False out = txt2kg._chunk_to_triples_str_cloud("text") assert isinstance(out, str) def dummy_multiproc_helper( rank, chunks, py_fn, llm_fn, NIM_KEY, NIM_MODEL, ENDPOINT_URL, max_retries=3, base_delay=0, ): return [("A", "rel", "B")] def test_extract_relevant_triples_cloud(monkeypatch): model = TXT2KG(local_LM=False, chunk_size=10) # Mock the multiproc helper (module-level) monkeypatch.setattr(txt2kg, "_multiproc_helper", dummy_multiproc_helper) triples = model._extract_relevant_triples("Some text") assert ("A", "rel", "B") in triples def test_multiproc_helper_success(monkeypatch): # Dummy LLM/Python parser def dummy_llm_fn(x, **kwargs): return ["llm:" + str(x)] def dummy_py_fn(x): return ["py:" + str(i) for i in x] # Patch _llm_then_python_parse monkeypatch.setattr( "torch_geometric.llm.models.txt2kg._llm_then_python_parse", lambda chunks, py_fn, llm_fn, **kwargs: ["PARSED:" + str(chunks)]) # Input chunks for rank 0 chunks_for_rank = ["chunk0", "chunk1"] result = _multiproc_helper( rank=0, chunks_for_rank=chunks_for_rank, py_fn=dummy_py_fn, llm_fn=dummy_llm_fn, NIM_KEY="dummy", NIM_MODEL="dummy", ENDPOINT_URL="dummy", max_retries=3, base_delay=0.01 # keep backoff small in tests ) assert result == ["PARSED:['chunk0', 'chunk1']"] def test_multiproc_helper_retry(monkeypatch): attempts = [] def failing_parse(chunks, py_fn, llm_fn, **kwargs): attempts.append(1) if len(attempts) < 3: raise RuntimeError("fail") return ["SUCCESS"] monkeypatch.setattr( "torch_geometric.llm.models.txt2kg._llm_then_python_parse", failing_parse) result = _multiproc_helper( rank=0, chunks_for_rank=["chunk"], py_fn=lambda x: x, llm_fn=lambda x: x, NIM_KEY="dummy", NIM_MODEL="dummy", ENDPOINT_URL="dummy", max_retries=5, base_delay=0 # instant retries for test ) assert result == ["SUCCESS"] assert len(attempts) == 3 # retried twice, succeeded on 3rd def test_add_doc_empty_text(): kg = TXT2KG(local_LM=True) kg.add_doc_2_KG(txt="") # first doc uses doc_id_counter=0 as key assert 0 in kg.relevant_triples assert kg.relevant_triples[0] == [] # doc counter should increment assert kg.doc_id_counter == 1 def test_add_doc_empty_text_with_QA_pair(): kg = TXT2KG(local_LM=True) qa = ("What is PyG?", "Graph ML library") kg.add_doc_2_KG(txt="", QA_pair=qa) assert qa in kg.relevant_triples assert kg.relevant_triples[qa] == [] @pytest.fixture def kg_cpu(): # TXT2KG instance using CPU local LLM mode return TXT2KG(local_LM=True) def test_add_doc_empty_text_cpu(kg_cpu): """Cover the empty text branch (lines 194-201).""" kg_cpu.add_doc_2_KG(txt="") # doc_id_counter starts at 0 assert kg_cpu.relevant_triples[0] == [] assert kg_cpu.doc_id_counter == 1 def test_add_doc_empty_text_with_QA_pair_cpu(kg_cpu): """Cover QA_pair key path with empty text.""" qa = ("What is PyG?", "Graph ML library") kg_cpu.add_doc_2_KG(txt="", QA_pair=qa) assert qa in kg_cpu.relevant_triples assert kg_cpu.relevant_triples[qa] == [] def test_add_doc_nonempty_text_placeholder(kg_cpu, monkeypatch): """Minimal coverage for non-empty text branch. Avoids importing the real LLM. """ # Patch the module-level function _llm_then_python_parse monkeypatch.setattr(txt2kg, "_llm_then_python_parse", lambda chunks, *args, **kwargs: []) # Call add_doc_2_KG with non-empty text kg_cpu.add_doc_2_KG(txt="some text") # Ensure doc_id_counter incremented and key exists key = kg_cpu.doc_id_counter - 1 assert key in kg_cpu.relevant_triples ================================================ FILE: test/llm/models/test_vision_transformer.py ================================================ import torch from torch_geometric.llm.models import VisionTransformer from torch_geometric.testing import onlyFullTest, withCUDA, withPackage @withCUDA @onlyFullTest @withPackage('transformers') def test_vision_transformer(device): model = VisionTransformer( model_name='microsoft/swin-base-patch4-window7-224', ).to(device) assert model.device == device assert str( model ) == 'VisionTransformer(model_name=microsoft/swin-base-patch4-window7-224)' images = torch.randn(2, 3, 224, 224).to(device) out = model(images) assert out.device == device assert out.size() == (2, 49, 1024) out = model(images, output_device='cpu') assert out.is_cpu assert out.size() == (2, 49, 1024) ================================================ FILE: test/llm/test_large_graph_indexer.py ================================================ import random import string from typing import List import pytest import torch from torch_geometric.data import Data from torch_geometric.llm.large_graph_indexer import ( EDGE_PID, EDGE_RELATION, NODE_PID, LargeGraphIndexer, TripletLike, get_features_for_triplets, ) from torch_geometric.llm.utils.backend_utils import preprocess_triplet from torch_geometric.typing import WITH_PT20 # create possible nodes and edges for graph strkeys = string.ascii_letters + string.digits NODE_POOL = list( {"".join(random.sample(strkeys, 10)).lower() for i in range(1000)}) EDGE_POOL = list( {"".join(random.sample(strkeys, 10)).lower() for i in range(50)}) def featurize(s: str) -> int: return int.from_bytes(s.encode(), 'little') def sample_triplets(amount: int = 1) -> List[TripletLike]: trips = [] for _ in range(amount): h, t = random.sample(NODE_POOL, k=2) r = random.sample(EDGE_POOL, k=1)[0] trips.append(tuple([h, r, t])) return trips def test_basic_collate(): graphs = [sample_triplets(1000) for i in range(2)] indexer_0 = LargeGraphIndexer.from_triplets( graphs[0], pre_transform=preprocess_triplet) indexer_1 = LargeGraphIndexer.from_triplets( graphs[1], pre_transform=preprocess_triplet) big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1]) assert len(indexer_0._nodes) + len( indexer_1._nodes) - len(indexer_0._nodes.keys() & indexer_1._nodes.keys()) == len( big_indexer._nodes) assert len(indexer_0._edges) + len( indexer_1._edges) - len(indexer_0._edges.keys() & indexer_1._edges.keys()) == len( big_indexer._edges) assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes) assert len(set(big_indexer._edges.values())) == len(big_indexer._edges) for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()): assert big_indexer.node_attr[NODE_PID][ big_indexer._nodes[node]] == node def test_large_graph_index(): graphs = [sample_triplets(1000) for i in range(100)] # Preprocessing of trips lowercases nodes but not edges node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL} edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} def encode_graph_from_trips(triplets: List[TripletLike]) -> Data: seen_nodes = dict() edge_attrs = list() edge_idx = [] for trip in triplets: trip = preprocess_triplet(trip) h, r, t = trip seen_nodes[h] = len( seen_nodes) if h not in seen_nodes else seen_nodes[h] seen_nodes[t] = len( seen_nodes) if t not in seen_nodes else seen_nodes[t] edge_attrs.append(edge_feature_vecs[r]) edge_idx.append((seen_nodes[h], seen_nodes[t])) x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()]) edge_idx = torch.LongTensor(edge_idx).T edge_attrs = torch.Tensor(edge_attrs) return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs) naive_graph_ds = [ encode_graph_from_trips(triplets=trips) for trips in graphs ] indexer = LargeGraphIndexer.collate([ LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet) for g in graphs ]) indexer_nodes = indexer.get_unique_node_features() indexer_node_vals = torch.Tensor( [node_feature_vecs[n] for n in indexer_nodes]) indexer_edges = indexer.get_unique_edge_features( feature_name=EDGE_RELATION) indexer_edge_vals = torch.Tensor( [edge_feature_vecs[e] for e in indexer_edges]) indexer.add_node_feature('x', indexer_node_vals) indexer.add_edge_feature('edge_attr', indexer_edge_vals, map_from_feature=EDGE_RELATION) large_graph_ds = [ get_features_for_triplets(indexer=indexer, triplets=g, node_feature_name='x', edge_feature_name='edge_attr', pre_transform=preprocess_triplet) for g in graphs ] for ds in large_graph_ds: assert NODE_PID in ds assert EDGE_PID in ds assert "node_idx" in ds assert "edge_idx" in ds def results_are_close_enough(ground_truth: Data, new_method: Data, thresh=.99): def _sorted_tensors_are_close(tensor1, tensor2): return torch.all( torch.isclose(tensor1.sort()[0], tensor2.sort()[0]) > thresh) def _graphs_are_same(tensor1, tensor2): if not WITH_PT20: pytest.skip( "This test requires a PyG version with NetworkX as a " + "dependency.") import networkx as nx return nx.weisfeiler_lehman_graph_hash(nx.Graph( tensor1.T)) == nx.weisfeiler_lehman_graph_hash( nx.Graph(tensor2.T)) return _sorted_tensors_are_close( ground_truth.x, new_method.x) \ and _sorted_tensors_are_close( ground_truth.edge_attr, new_method.edge_attr) \ and _graphs_are_same( ground_truth.edge_index, new_method.edge_index) for dsets in zip(naive_graph_ds, large_graph_ds): assert results_are_close_enough(*dsets) def test_save_load(tmp_path): graph = sample_triplets(1000) node_feature_vecs = {s: featurize(s) for s in NODE_POOL} edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} indexer = LargeGraphIndexer.from_triplets(graph) indexer_nodes = indexer.get_unique_node_features() indexer_node_vals = torch.Tensor( [node_feature_vecs[n] for n in indexer_nodes]) indexer_edges = indexer.get_unique_edge_features( feature_name=EDGE_RELATION) indexer_edge_vals = torch.Tensor( [edge_feature_vecs[e] for e in indexer_edges]) indexer.add_node_feature('x', indexer_node_vals) indexer.add_edge_feature('edge_attr', indexer_edge_vals, map_from_feature=EDGE_RELATION) indexer.save(str(tmp_path)) assert indexer == LargeGraphIndexer.from_disk(str(tmp_path)) ================================================ FILE: test/llm/test_rag_loader.py ================================================ import os from typing import Any, Dict from unittest.mock import Mock import pytest import torch from torch_geometric.data import Data from torch_geometric.llm.models import SentenceTransformer from torch_geometric.llm.rag_loader import RAGQueryLoader from torch_geometric.llm.utils.backend_utils import ( create_graph_from_triples, create_remote_backend_from_graph_data, ) from torch_geometric.llm.utils.feature_store import KNNRAGFeatureStore from torch_geometric.llm.utils.graph_store import NeighborSamplingRAGGraphStore from torch_geometric.llm.utils.vectorrag import VectorRetriever from torch_geometric.sampler import SamplerOutput from torch_geometric.testing import withPackage class MockRAGFeatureStore: """Mock implementation of RAGFeatureStore protocol for testing.""" def __init__(self): self._config = {} self.x = torch.randn(10, 64) # Sample node features def retrieve_seed_nodes(self, query: Any, **kwargs): """Mock retrieve_seed_nodes method.""" seed_nodes = torch.tensor([0, 1, 2, 3, 4]) query_enc = torch.randn(1, 64) return seed_nodes, query_enc @property def config(self) -> Dict[str, Any]: return self._config @config.setter def config(self, config: Dict[str, Any]): if config is None: raise ValueError("Config cannot be None") if 'a' not in config: raise ValueError("Required config parameter 'a' not found") self._config = config def retrieve_seed_edges(self, query: Any, **kwargs): """Mock retrieve_seed_edges method.""" return torch.tensor([[0, 1], [1, 2], [2, 3]]) def load_subgraph(self, sample): """Mock load_subgraph method.""" data = Data() data.edge_idx = torch.tensor([0, 1, 2]) data.node_idx = torch.tensor([0, 1, 2, 3, 4]) return data class MockRAGGraphStore: """Mock implementation of RAGGraphStore protocol for testing.""" def __init__(self): self._config = {} self.edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]) def sample_subgraph(self, seed_nodes, seed_edges=None, **kwargs): """Mock sample_subgraph method.""" return SamplerOutput(node=seed_nodes, row=torch.tensor([0, 1, 2]), col=torch.tensor([1, 2, 3]), edge=torch.tensor([0, 1, 2]), batch=None) @property def config(self) -> Dict[str, Any]: return self._config @config.setter def config(self, config: Dict[str, Any]): if config is None: raise ValueError("Config cannot be None") if 'b' not in config: raise ValueError("Required config parameter 'b' not found") self._config = config def register_feature_store(self, feature_store): """Mock register_feature_store method.""" class TestRAGQueryLoader: """Test suite for RAGQueryLoader.""" def setup_method(self): """Set up test fixtures before each test method.""" self.mock_feature_store = MockRAGFeatureStore() self.mock_graph_store = MockRAGGraphStore() self.graph_data = (self.mock_feature_store, self.mock_graph_store) # Sample config self.sample_config = {"a": 5, "b": [10, 5], "c": "test_value"} def test_initialization_basic(self): """Test basic initialization of RAGQueryLoader.""" loader = RAGQueryLoader(self.graph_data, config=self.sample_config) assert loader.feature_store == self.mock_feature_store assert loader.graph_store == self.mock_graph_store assert loader.vector_retriever is None assert loader.augment_query is False assert loader.subgraph_filter is None assert loader.config == self.sample_config def test_initialization_with_all_params(self): """Test initialization with all parameters.""" mock_vector_retriever = Mock(spec=VectorRetriever) mock_subgraph_filter = Mock() loader = RAGQueryLoader(graph_data=self.graph_data, subgraph_filter=mock_subgraph_filter, augment_query=True, vector_retriever=mock_vector_retriever, config=self.sample_config) assert loader.feature_store == self.mock_feature_store assert loader.graph_store == self.mock_graph_store assert loader.vector_retriever == mock_vector_retriever assert loader.augment_query is True assert loader.subgraph_filter == mock_subgraph_filter assert loader.config == self.sample_config def test_bad_config(self): """Test bad config initialization.""" with pytest.raises(ValueError): RAGQueryLoader(self.graph_data) with pytest.raises(ValueError): RAGQueryLoader(self.graph_data, config={'d': 'foobar'}) def test_config_propagation(self): """Test that config is propagated during initialization.""" loader = RAGQueryLoader(self.graph_data, config=self.sample_config) assert loader.feature_store.config == self.sample_config assert loader.graph_store.config == self.sample_config def test_basic_query_without_vector_retriever(self): """Test basic query functionality without vector retriever.""" loader = RAGQueryLoader(self.graph_data, config=self.sample_config) query = "test query" result = loader.query(query) # Verify result is a Data object assert isinstance(result, Data) # Verify the data has expected attributes assert hasattr(result, 'node_idx') assert hasattr(result, 'num_nodes') assert hasattr(result, 'x') assert hasattr(result, 'edge_index') def test_query_with_vector_retriever(self): """Test query functionality with vector retriever.""" mock_vector_retriever = Mock(spec=VectorRetriever) mock_vector_retriever.query.return_value = [ "retrieved doc 1", "retrieved doc 2" ] loader = RAGQueryLoader(self.graph_data, vector_retriever=mock_vector_retriever, config=self.sample_config) query = "test query" result = loader.query(query) # Verify vector retriever was called mock_vector_retriever.query.assert_called_once_with(query) # Verify result has text_context assert hasattr(result, 'text_context') assert result.text_context == ["retrieved doc 1", "retrieved doc 2"] def test_query_with_subgraph_filter(self): """Test query functionality with subgraph filter.""" mock_filter_result = Data() mock_filter_result.filtered = True mock_subgraph_filter = Mock(return_value=mock_filter_result) loader = RAGQueryLoader(self.graph_data, subgraph_filter=mock_subgraph_filter, config=self.sample_config) query = "test query" result = loader.query(query) # Verify subgraph filter was called mock_subgraph_filter.assert_called_once() call_args = mock_subgraph_filter.call_args[0] assert len(call_args) == 2 assert call_args[1] == query # Verify result is the filtered result assert result == mock_filter_result assert hasattr(result, 'filtered') assert result.filtered is True @withPackage('pyg_lib', 'torch_sparse') def test_rag_loader_integration(tmp_path): """Test RAGQueryLoader with real feature and graph stores from triples.""" # Define test triplets - simple knowledge graph about cities/countries triplets = [ ["Paris", "capital_of", "France"], ["London", "capital_of", "UK"], ["Berlin", "capital_of", "Germany"], ["France", "in_continent", "Europe"], ["UK", "in_continent", "Europe"], ["Germany", "in_continent", "Europe"], ["Rome", "capital_of", "Italy"], ["Italy", "in_continent", "Europe"], ["Madrid", "capital_of", "Spain"], ["Spain", "in_continent", "Europe"], ] encoder_model = SentenceTransformer('bert-base-uncased') # Create graph from triplets graph_data = create_graph_from_triples(triplets, encoder_model.encode) save_path = os.path.join(tmp_path, "test_graph.pt") loader = create_remote_backend_from_graph_data( graph_data=graph_data, path=save_path, n_parts=1, graph_db=NeighborSamplingRAGGraphStore, feature_db=KNNRAGFeatureStore) feature_store, graph_store = loader.load() # Configuration config = { "k_nodes": 1, "encoder_model": encoder_model, "num_neighbors": [10] # 10 neighbors only one hop } # Create RAG loader rag_data = (feature_store, graph_store) loader = RAGQueryLoader(rag_data, config=config) # Test query about European capitals query = "countries in Europe" result = loader.query(query) # Verify result structure assert isinstance(result, Data) assert torch.equal(result.edge_index, torch.tensor([[1, 2, 3, 4, 5], [0, 0, 0, 0, 0]])) expected_x = encoder_model.encode( ["Europe", "France", "UK", "Germany", "Italy", "Spain"]).cpu() expected_edge_attr = encoder_model.encode(["in_continent"] * 5).cpu() assert torch.allclose(result.x, expected_x, atol=1e-6) assert torch.allclose(result.edge_attr, expected_edge_attr, atol=1e-6) ================================================ FILE: test/llm/utils/test_rag_backend_utils.py ================================================ import os import tempfile from typing import List, Tuple import torch from torch_geometric.data import Data from torch_geometric.llm.utils.backend_utils import ( batch_knn, create_graph_from_triples, create_remote_backend_from_graph_data, make_pcst_filter, preprocess_triplet, retrieval_via_pcst, ) from torch_geometric.testing import onlyLinux def test_preprocess_triplet(): triplet = ('Alice', 'works with', 'Bob') processed = preprocess_triplet(triplet) assert processed == ('alice', 'works with', 'bob') def test_batch_knn(): query_embeddings = torch.randn(2, 64) candidate_embeddings = torch.randn(10, 64) k = 3 top_k_indices, top_k_scores = batch_knn( query_embeddings, candidate_embeddings, k, ) assert top_k_indices[0].size() == (k, ) assert top_k_indices[1].size() == (1, 64) assert top_k_scores[0].size() == (k, ) assert top_k_scores[1].size() == (1, 64) """Test retrieval_via_pcst""" def create_mock_data(num_nodes=3, num_edges=2): import pandas as pd x = torch.randn(num_nodes, 16) edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) edge_attr = torch.randn(num_edges, 16) node_idx = list(range(num_nodes)) edge_idx = list(range(num_edges)) textual_nodes = pd.DataFrame({ 'node_id': node_idx, 'text': [f"Node {i}" for i in node_idx] }) textual_edges = pd.DataFrame({ 'src': edge_index[0].tolist(), 'dst': edge_index[1].tolist(), 'edge_attr': [f"Edge {i}" for i in edge_idx] }) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, node_idx=node_idx, edge_idx=edge_idx), textual_nodes, textual_edges @onlyLinux def test_empty_graph(): import pandas as pd # w/o node and edge bad_data = Data(x=None, edge_index=None, edge_attr=None) textual_nodes = pd.DataFrame({'node_id': [], 'text': []}) textual_edges = pd.DataFrame({'src': [], 'dst': [], 'edge_attr': []}) q_emb = torch.randn(1, 16) result_data, desc = retrieval_via_pcst(bad_data, q_emb, textual_nodes, textual_edges) assert result_data == bad_data assert desc.strip() == 'node_id,text\n\nsrc,edge_attr,dst' def test_topk_zero(): data, textual_nodes, textual_edges = create_mock_data() q_emb = torch.randn(1, 16) result_data, desc = retrieval_via_pcst(data, q_emb, textual_nodes, textual_edges, topk=0, topk_e=0) assert isinstance(result_data, Data) """Test make_pcst_filter""" class MockSentenceTransformer: def encode(self, sentences, **kwargs): return torch.randn(len(sentences), 32) def create_mock_graph_and_triples(): triples: List[Tuple[str, str, str]] = [("Alice", "works_at", "Google"), ("Bob", "works_at", "Meta"), ("Alice", "knows", "Bob")] # Alice=0, Bob=1, Google=2, Meta=3 node_idx = [0, 1, 2, 3] edge_idx = [0, 1, 2] x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 0], [2, 3, 1]], dtype=torch.long) edge_attr = torch.randn(3, 32) graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, node_idx=node_idx, edge_idx=edge_idx) return triples, graph @onlyLinux def test_apply_retrieval_via_pcst_isolated_node(): triples, graph = create_mock_graph_and_triples() model = MockSentenceTransformer() mock_out_graph = Data(x=graph.x[:1], edge_index=torch.empty(2, 0, dtype=torch.long), edge_attr=torch.empty(0, 32)) mock_out_graph.node_idx = [0] mock_out_graph.edge_idx = [] filter_fn = make_pcst_filter(triples, model) result = filter_fn(graph, "Who is Alice?") assert result.triples == [ ('Alice', 'works_at', 'Google'), ('Bob', 'works_at', 'Meta'), ('Alice', 'knows', 'Bob'), ] class MockEmbeddingModel: """Mock embedding model for testing.""" def __init__(self, embed_dim: int = 64): self.embed_dim = embed_dim def __call__(self, texts: List[str], **kwargs) -> torch.Tensor: """Mock embedding generation - creates deterministic embeddings.""" # Create simple hash-based embeddings for reproducible testing if len(texts) == 0: return torch.empty(0, self.embed_dim) embeddings = [] for text in texts: # Simple deterministic embedding based on text hash hash_val = hash(text) # Use the hash to create a reproducible embedding torch.manual_seed(abs(hash_val) % 2**31) embedding = torch.randn(self.embed_dim) embeddings.append(embedding) return torch.stack(embeddings) class TestCreateGraphFromTriples: """Test suite for create_graph_from_triples function.""" def setup_method(self): """Set up test fixtures.""" self.sample_triples = [('Alice', 'works with', 'Bob'), ('Alice', 'leads', 'Carol'), ('Carol', 'works with', 'Dave')] self.mock_embedding_model = MockEmbeddingModel(embed_dim=32) def test_create_graph_basic_functionality(self): """Test basic functionality of create_graph_from_triples.""" result = create_graph_from_triples( triples=self.sample_triples, embedding_model=self.mock_embedding_model) # Verify result is a Data object assert isinstance(result, Data) x = result.x edge_attr = result.edge_attr assert x.shape == (4, 32) assert edge_attr.shape == (3, 32) for t in self.sample_triples: assert self.mock_embedding_model([t[0]]) in x assert self.mock_embedding_model([t[2]]) in x assert self.mock_embedding_model([t[1]]) in edge_attr expected_edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]]) assert torch.allclose(result.edge_index, expected_edge_index) def test_create_graph_empty_triples(self): """Test create_graph_from_triples with empty triples list.""" empty_triples = [] result = create_graph_from_triples( triples=empty_triples, embedding_model=self.mock_embedding_model) # Should create an empty graph assert isinstance(result, Data) assert result.num_nodes == 0 assert result.num_edges == 0 class TestCreateRemoteBackendFromGraphData: """Test suite for create_remote_backend_from_graph_data function.""" def setup_method(self): """Set up test fixtures.""" self.sample_triples = [('Alice', 'works with', 'Bob'), ('Alice', 'leads', 'Carol'), ('Carol', 'works with', 'Dave')] self.mock_embedding_model = MockEmbeddingModel(embed_dim=32) # Create sample graph data using create_graph_from_triples self.sample_graph_data = create_graph_from_triples( triples=self.sample_triples, embedding_model=self.mock_embedding_model) def test_create_backend_data_load(self): """Test that data integrity is preserved in backend creation.""" with tempfile.TemporaryDirectory() as temp_dir: save_path = os.path.join(temp_dir, "test_graph.pt") loader = create_remote_backend_from_graph_data( graph_data=self.sample_graph_data, path=save_path, n_parts=1) # Load and verify data feature_store, graph_store = loader.load() # Check that the original graph structure is preserved loaded_data = torch.load(save_path, weights_only=False) # Verify basic properties match assert loaded_data.num_nodes == self.sample_graph_data.num_nodes assert loaded_data.num_edges == self.sample_graph_data.num_edges # Verify tensors match assert torch.allclose(loaded_data.x, self.sample_graph_data.x) assert torch.allclose(loaded_data.edge_index, self.sample_graph_data.edge_index) assert torch.allclose(loaded_data.edge_attr, self.sample_graph_data.edge_attr) ================================================ FILE: test/llm/utils/test_rag_feature_store.py ================================================ from unittest.mock import Mock, patch import pytest import torch from torch_geometric.data import Data from torch_geometric.llm.utils.feature_store import KNNRAGFeatureStore from torch_geometric.sampler import SamplerOutput class TestKNNRAGFeatureStore: """Test suite for KNNRAGFeatureStore methods.""" def setup_method(self): """Set up test fixtures.""" self.mock_encoder = Mock() self.mock_encoder.encode = Mock() self.mock_encoder.to = Mock(return_value=self.mock_encoder) self.mock_encoder.eval = Mock() self.config = {"k_nodes": 5, "encoder_model": self.mock_encoder} self.sample_x = torch.randn(40, 128) # 40 nodes, 128 features self.sample_edge_attr = torch.randn(40, 64) # 40 edges, 64 features def test_bad_config(self): """Test bad config initialization.""" with pytest.raises(ValueError, match="Required config parameter"): store = KNNRAGFeatureStore() store.config = {} def create_feature_store(self): """Create a FeatureStore with mocked dependencies.""" store = KNNRAGFeatureStore() store.config = self.config # Mock the tensor storage store.put_tensor(self.sample_x, group_name=None, attr_name='x') store.put_tensor(self.sample_edge_attr, group_name=(None, None), attr_name='edge_attr') return store def test_retrieve_seed_nodes_single_query(self): """Test retrieve_seed_nodes with a single query.""" store = self.create_feature_store() # Mock the encoder output and batch_knn query_text = "test query" mock_query_enc = torch.randn(1, 128) self.mock_encoder.encode.return_value = mock_query_enc expected_indices = torch.tensor([0, 3, 7, 2, 9]) with patch('torch_geometric.llm.utils.feature_store.batch_knn' ) as mock_batch_knn: # Mock batch_knn to return an iterator def mock_generator(): yield (expected_indices, mock_query_enc) mock_batch_knn.return_value = mock_generator() result, query_enc = store.retrieve_seed_nodes(query_text) # Verify encoder was called correctly self.mock_encoder.encode.assert_called_once_with([query_text]) # Verify batch_knn was called correctly mock_batch_knn.assert_called_once() args = mock_batch_knn.call_args[0] assert torch.equal(args[0], mock_query_enc) assert torch.equal(args[1], self.sample_x) assert args[2] == 5 # k_nodes # Verify results assert torch.equal(result, expected_indices) assert torch.equal(query_enc, mock_query_enc) def test_retrieve_seed_nodes_multiple_queries(self): """Test retrieve_seed_nodes with multiple queries.""" store = self.create_feature_store() queries = ["query 1", "query 2"] mock_query_enc = torch.randn(2, 128) self.mock_encoder.encode.return_value = mock_query_enc expected_indices = [ torch.tensor([1, 4, 6, 8, 0]), torch.tensor([0, 3, 7, 2, 9]) ] with patch('torch_geometric.llm.utils.feature_store.batch_knn' ) as mock_batch_knn: def mock_generator(): for i in range(len(expected_indices)): yield (expected_indices[i], mock_query_enc[i]) mock_batch_knn.return_value = mock_generator() out_dict = store.retrieve_seed_nodes(queries) # Verify encoder was called with the list directly self.mock_encoder.encode.assert_called_once_with(queries) # Verify results for i, query in enumerate(queries): result, query_enc = out_dict[query] assert torch.equal(result, expected_indices[i]) assert torch.equal(query_enc, mock_query_enc[i]) @pytest.mark.parametrize("induced", [True, False]) def test_load_subgraph_valid_sample(self, induced): """Test load_subgraph with valid SamplerOutput.""" store = self.create_feature_store() # Create a mock SamplerOutput sample = SamplerOutput(node=torch.tensor([6, 7, 8, 9]), row=torch.tensor([0, 1, 2]), col=torch.tensor([1, 2, 3]), edge=torch.tensor([0, 1, 2]), batch=None) expected_edge_indices = torch.tensor([[0, 1, 2], [1, 2, 3]]) \ if induced else torch.tensor([[6, 7, 8], [7, 8, 9]]) result = store.load_subgraph(sample, induced=induced) # Verify result is a Data object assert isinstance(result, Data) # Verify edge attributes are correctly extracted expected_edge_attr = self.sample_edge_attr[torch.tensor([0, 1, 2])] assert torch.equal(result.edge_attr, expected_edge_attr) assert torch.equal(result.edge_index, expected_edge_indices) if induced: assert torch.equal(result.node_idx, sample.node) assert torch.equal(result.edge_idx, sample.edge) ================================================ FILE: test/llm/utils/test_rag_graph_store.py ================================================ from unittest.mock import Mock, patch import pytest import torch from torch_geometric.data import FeatureStore from torch_geometric.llm.utils.graph_store import NeighborSamplingRAGGraphStore from torch_geometric.sampler import BidirectionalNeighborSampler, SamplerOutput def setup_test_fixtures(): """Set up test fixtures.""" feature_store = Mock(spec=FeatureStore) config = {"num_neighbors": [10, 5]} return feature_store, config def test_sample_subgraph_with_valid_tensor_input(): """Test sample_subgraph with valid tensor input.""" # Create graph store and set config feature_store, config = setup_test_fixtures() graph_store = NeighborSamplingRAGGraphStore(replace=True, disjoint=False) graph_store.register_feature_store(feature_store=feature_store) graph_store.config = config assert graph_store.put_edge_id(torch.tensor([10]), edge_type=None, layout='coo') # Create mock sampler and its output mock_sampler = Mock(spec=BidirectionalNeighborSampler) expected_output = SamplerOutput(node=torch.tensor([0, 1, 2, 3]), row=torch.tensor([0, 1, 1]), col=torch.tensor([1, 2, 3]), edge=torch.tensor([0, 1, 2]), batch=None, num_sampled_nodes=[2, 2], num_sampled_edges=[3]) mock_sampler.sample_from_nodes.return_value = expected_output # Intentionally not sorted graph_store.edge_index = torch.tensor([[3, 1, 1, 0], [4, 2, 3, 1]]) # Initially sampler should not be initialized assert not graph_store._sampler_is_initialized # Mock the _init_sampler method to set our mock sampler with patch.object(graph_store, '_init_sampler') as mock_init: def set_sampler(): graph_store.sampler = mock_sampler graph_store._sampler_is_initialized = True mock_init.side_effect = set_sampler # Test input seed_nodes = torch.tensor([0]) result = graph_store.sample_subgraph(seed_nodes) # Verify sampler was initialized mock_init.assert_called_once() # Verify sample_from_nodes was called with correct input mock_sampler.sample_from_nodes.assert_called_once() assert result == expected_output def test_bad_config(): """Test bad config initialization.""" with pytest.raises(ValueError, match="Required config parameter"): store = NeighborSamplingRAGGraphStore() store.config = {} ================================================ FILE: test/llm/utils/test_vectorrag.py ================================================ import pytest import torch from torch_geometric.llm.utils.vectorrag import DocumentRetriever @pytest.fixture def sample_documents(): """Fixture providing sample documents for testing.""" return [ "This is the first test document.", "This is the second test document.", "This is the third test document.", ] @pytest.fixture def sample_model(): """Fixture providing a mock model for testing.""" from unittest.mock import Mock mock_model = Mock() # Mock the model to return a simple tensor when called mock_model.side_effect = [ torch.zeros(1, 384), torch.ones(1, 384), torch.ones(1, 384) * 2, torch.ones(1, 384) * 1, ] return mock_model def test_save_load(sample_documents, sample_model, tmp_path): """Test whether saving/loading a DocumentRetriever maintains state.""" retriever = DocumentRetriever(sample_documents, model=sample_model) retriever.save(tmp_path / "retriever.pth") loaded_retriever = DocumentRetriever.load(tmp_path / "retriever.pth", sample_model) assert retriever.raw_docs == loaded_retriever.raw_docs assert torch.allclose(retriever.embedded_docs, loaded_retriever.embedded_docs) assert retriever.k_for_docs == loaded_retriever.k_for_docs assert retriever.model == loaded_retriever.model def test_query(sample_documents, sample_model): """Test query functionality of DocumentRetriever.""" retriever = DocumentRetriever(sample_documents, model=sample_model) query = "What is the first test document?" retrieved_docs = retriever.query(query) assert retrieved_docs == [sample_documents[0]] ================================================ FILE: test/loader/test_cache.py ================================================ import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.loader import CachedLoader, NeighborLoader from torch_geometric.testing import withDevice, withPackage @withDevice @withPackage('pyg_lib') def test_cached_loader(device): x = torch.randn(14, 16) edge_index = torch.tensor([ [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], ]) loader = NeighborLoader( Data(x=x, edge_index=edge_index), num_neighbors=[2], batch_size=10, shuffle=False, ) cached_loader = CachedLoader(loader, device=device) assert len(cached_loader) == len(loader) assert len(cached_loader._cache) == 0 cache = [] for i, batch in enumerate(cached_loader): assert len(cached_loader._cache) == i + 1 assert batch.x.device == device assert batch.edge_index.device == device cache.append(batch) for i, batch in enumerate(cached_loader): assert batch == cache[i] cached_loader.clear() assert len(cached_loader._cache) == 0 @withDevice @withPackage('pyg_lib') def test_cached_loader_transform(device): x = torch.randn(14, 16) edge_index = torch.tensor([ [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], ]) loader = NeighborLoader( Data(x=x, edge_index=edge_index), num_neighbors=[2], batch_size=10, shuffle=False, ) cached_loader = CachedLoader( loader, device=device, transform=lambda batch: batch.edge_index, ) assert len(cached_loader) == len(loader) assert len(cached_loader._cache) == 0 cache = [] for i, batch in enumerate(cached_loader): assert len(cached_loader._cache) == i + 1 assert isinstance(batch, Tensor) assert batch.dim() == 2 and batch.size(0) == 2 assert batch.device == device cache.append(batch) for i, batch in enumerate(cached_loader): assert torch.equal(batch, cache[i]) ================================================ FILE: test/loader/test_cluster.py ================================================ import pytest import torch from torch_geometric.data import Data from torch_geometric.loader import ClusterData, ClusterLoader from torch_geometric.testing import onlyFullTest, onlyOnline, withMETIS from torch_geometric.utils import sort_edge_index @withMETIS def test_cluster_gcn(): adj = torch.tensor([ [1, 1, 1, 0, 1, 0], [1, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], ]) x = torch.tensor([ [0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0], ]) edge_index = adj.nonzero(as_tuple=False).t() edge_attr = torch.arange(edge_index.size(1)) n_id = torch.arange(6) data = Data(x=x, n_id=n_id, edge_index=edge_index, edge_attr=edge_attr) data.num_nodes = 6 cluster_data = ClusterData(data, num_parts=2, log=False) partition = cluster_data._partition( edge_index, cluster=torch.tensor([0, 1, 0, 1, 0, 1])) assert partition.partptr.tolist() == [0, 3, 6] assert partition.node_perm.tolist() == [0, 2, 4, 1, 3, 5] assert partition.edge_perm.tolist() == [ 0, 2, 3, 1, 8, 9, 10, 14, 15, 16, 4, 5, 6, 7, 11, 12, 13, 17, 18, 19 ] assert cluster_data.partition.partptr.tolist() == [0, 3, 6] assert torch.equal( cluster_data.partition.node_perm.sort()[0], torch.arange(data.num_nodes), ) assert torch.equal( cluster_data.partition.edge_perm.sort()[0], torch.arange(data.num_edges), ) out = cluster_data[0] expected = data.subgraph(out.n_id) out.validate() assert out.num_nodes == 3 assert out.n_id.size() == (3, ) assert torch.equal(out.x, expected.x) tmp = sort_edge_index(expected.edge_index, expected.edge_attr) assert torch.equal(out.edge_index, tmp[0]) assert torch.equal(out.edge_attr, tmp[1]) out = cluster_data[1] out.validate() assert out.num_nodes == 3 assert out.n_id.size() == (3, ) expected = data.subgraph(out.n_id) assert torch.equal(out.x, expected.x) tmp = sort_edge_index(expected.edge_index, expected.edge_attr) assert torch.equal(out.edge_index, tmp[0]) assert torch.equal(out.edge_attr, tmp[1]) loader = ClusterLoader(cluster_data, batch_size=1) iterator = iter(loader) out = next(iterator) out.validate() assert out.num_nodes == 3 assert out.n_id.size() == (3, ) expected = data.subgraph(out.n_id) assert torch.equal(out.x, expected.x) tmp = sort_edge_index(expected.edge_index, expected.edge_attr) assert torch.equal(out.edge_index, tmp[0]) assert torch.equal(out.edge_attr, tmp[1]) out = next(iterator) out.validate() assert out.num_nodes == 3 assert out.n_id.size() == (3, ) expected = data.subgraph(out.n_id) assert torch.equal(out.x, expected.x) tmp = sort_edge_index(expected.edge_index, expected.edge_attr) assert torch.equal(out.edge_index, tmp[0]) assert torch.equal(out.edge_attr, tmp[1]) loader = ClusterLoader(cluster_data, batch_size=2, shuffle=False) out = next(iter(loader)) out.validate() assert out.num_nodes == 6 assert out.n_id.size() == (6, ) expected = data.subgraph(out.n_id) assert torch.equal(out.x, expected.x) tmp = sort_edge_index(expected.edge_index, expected.edge_attr) assert torch.equal(out.edge_index, tmp[0]) assert torch.equal(out.edge_attr, tmp[1]) @withMETIS def test_keep_inter_cluster_edges(): adj = torch.tensor([ [1, 1, 1, 0, 1, 0], [1, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], ]) x = torch.tensor([ [0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0], ]) edge_index = adj.nonzero(as_tuple=False).t() edge_attr = torch.arange(edge_index.size(1)) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) data.num_nodes = 6 cluster_data = ClusterData(data, num_parts=2, log=False, keep_inter_cluster_edges=True) data = cluster_data[0] assert data.edge_index[0].min() == 0 assert data.edge_index[0].max() == 2 assert data.edge_index[1].min() == 0 assert data.edge_index[1].max() > 2 assert data.edge_index.size(1) == data.edge_attr.size(0) data = cluster_data[1] assert data.edge_index[0].min() == 0 assert data.edge_index[0].max() == 2 assert data.edge_index[1].min() == 0 assert data.edge_index[1].max() > 2 assert data.edge_index.size(1) == data.edge_attr.size(0) @withMETIS @onlyOnline @onlyFullTest @pytest.mark.parametrize('sparse_format', ['csr', 'csc']) def test_cluster_gcn_correctness(get_dataset, sparse_format): dataset = get_dataset('Cora') data = dataset[0].clone() data.n_id = torch.arange(data.num_nodes) cluster_data = ClusterData( data, num_parts=10, log=False, sparse_format=sparse_format, ) loader = ClusterLoader(cluster_data, batch_size=3, shuffle=False) for batch1 in loader: batch1.validate() batch2 = data.subgraph(batch1.n_id) assert batch1.num_nodes == batch2.num_nodes assert batch1.num_edges == batch2.num_edges assert torch.equal(batch1.x, batch2.x) assert torch.equal( batch1.edge_index, sort_edge_index( batch2.edge_index, sort_by_row=sparse_format == 'csr', ), ) if __name__ == '__main__': import argparse from ogb.nodeproppred import PygNodePropPredDataset from tqdm import tqdm parser = argparse.ArgumentParser() parser.add_argument('--num_workers', type=int, default=0) args = parser.parse_args() data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0] loader = ClusterLoader( ClusterData(data, num_parts=15_000, save_dir='/tmp/ogb/ogbn_products'), batch_size=32, shuffle=True, num_workers=args.num_workers, ) for _ in tqdm(loader): pass ================================================ FILE: test/loader/test_dataloader.py ================================================ import multiprocessing import sys from collections import namedtuple import pytest import torch from torch_geometric import EdgeIndex, Index from torch_geometric.data import Data, HeteroData, OnDiskDataset from torch_geometric.loader import DataLoader from torch_geometric.testing import ( get_random_edge_index, get_random_tensor_frame, onlyLinux, withDevice, withPackage, ) with_mp = sys.platform not in ['win32'] num_workers_list = [0, 2] if with_mp else [0] if sys.platform == 'darwin': multiprocessing.set_start_method('spawn') @withDevice @pytest.mark.parametrize('num_workers', num_workers_list) def test_dataloader(num_workers, device): if num_workers > 0 and device != torch.device('cpu'): return x = torch.tensor([[1.0], [1.0], [1.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) face = torch.tensor([[0], [1], [2]]) y = 2. z = torch.tensor(0.) name = 'data' data = Data(x=x, edge_index=edge_index, y=y, z=z, name=name).to(device) assert str(data) == ("Data(x=[3, 1], edge_index=[2, 4], y=2.0, z=0.0, " "name='data')") data.face = face loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False, num_workers=num_workers) assert len(loader) == 2 for batch in loader: assert batch.x.device == device assert batch.edge_index.device == device assert batch.z.device == device assert batch.num_graphs == len(batch) == 2 assert batch.batch.tolist() == [0, 0, 0, 1, 1, 1] assert batch.ptr.tolist() == [0, 3, 6] assert batch.x.tolist() == [[1], [1], [1], [1], [1], [1]] assert batch.edge_index.tolist() == [[0, 1, 1, 2, 3, 4, 4, 5], [1, 0, 2, 1, 4, 3, 5, 4]] assert batch.y.tolist() == [2.0, 2.0] assert batch.z.tolist() == [0.0, 0.0] assert batch.name == ['data', 'data'] assert batch.face.tolist() == [[0, 3], [1, 4], [2, 5]] for store in batch.stores: assert id(batch) == id(store._parent()) loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False, follow_batch=['edge_index'], num_workers=num_workers, collate_fn=None) assert len(loader) == 2 for batch in loader: assert batch.num_graphs == len(batch) == 2 assert batch.edge_index_batch.tolist() == [0, 0, 0, 0, 1, 1, 1, 1] @onlyLinux @pytest.mark.parametrize('num_workers', num_workers_list) def test_dataloader_on_disk_dataset(tmp_path, num_workers): dataset = OnDiskDataset(tmp_path) data1 = Data(x=torch.randn(3, 8)) data2 = Data(x=torch.randn(4, 8)) dataset.extend([data1, data2]) loader = DataLoader(dataset, batch_size=2, num_workers=num_workers) assert len(loader) == 1 batch = next(iter(loader)) assert batch.num_nodes == 7 assert torch.equal(batch.x, torch.cat([data1.x, data2.x], dim=0)) assert batch.batch.tolist() == [0, 0, 0, 1, 1, 1, 1] dataset.close() def test_dataloader_fallbacks(): # Test inputs of type List[torch.Tensor]: data_list = [torch.ones(3) for _ in range(4)] batch = next(iter(DataLoader(data_list, batch_size=4))) assert torch.equal(batch, torch.ones(4, 3)) # Test inputs of type List[float]: data_list = [1.0, 1.0, 1.0, 1.0] batch = next(iter(DataLoader(data_list, batch_size=4))) assert torch.equal(batch, torch.ones(4)) # Test inputs of type List[int]: data_list = [1, 1, 1, 1] batch = next(iter(DataLoader(data_list, batch_size=4))) assert torch.equal(batch, torch.ones(4, dtype=torch.long)) # Test inputs of type List[str]: data_list = ['test'] * 4 batch = next(iter(DataLoader(data_list, batch_size=4))) assert batch == data_list # Test inputs of type List[Mapping]: data_list = [{'x': torch.ones(3), 'y': 1}] * 4 batch = next(iter(DataLoader(data_list, batch_size=4))) assert torch.equal(batch['x'], torch.ones(4, 3)) assert torch.equal(batch['y'], torch.ones(4, dtype=torch.long)) # Test inputs of type List[Tuple]: DataTuple = namedtuple('DataTuple', 'x y') data_list = [DataTuple(0.0, 1)] * 4 batch = next(iter(DataLoader(data_list, batch_size=4))) assert torch.equal(batch.x, torch.zeros(4)) assert torch.equal(batch[1], torch.ones(4, dtype=torch.long)) # Test inputs of type List[Sequence]: data_list = [[0.0, 1]] * 4 batch = next(iter(DataLoader(data_list, batch_size=4))) assert torch.equal(batch[0], torch.zeros(4)) assert torch.equal(batch[1], torch.ones(4, dtype=torch.long)) # Test that inputs of unsupported types raise an error: class DummyClass: pass with pytest.raises(TypeError): data_list = [DummyClass()] * 4 next(iter(DataLoader(data_list, batch_size=4))) @pytest.mark.skipif(not with_mp, reason='Multi-processing not available') def test_multiprocessing(): queue = torch.multiprocessing.Manager().Queue() data = Data(x=torch.randn(5, 16)) data_list = [data, data, data, data] loader = DataLoader(data_list, batch_size=2) for batch in loader: queue.put(batch) batch = queue.get() assert batch.num_graphs == len(batch) == 2 batch = queue.get() assert batch.num_graphs == len(batch) == 2 def test_pin_memory(): x = torch.randn(3, 16) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data = Data(x=x, edge_index=edge_index) loader = DataLoader([data] * 4, batch_size=2, pin_memory=True) for batch in loader: assert batch.x.is_pinned() or not torch.cuda.is_available() assert batch.edge_index.is_pinned() or not torch.cuda.is_available() @pytest.mark.parametrize('num_workers', num_workers_list) def test_heterogeneous_dataloader(num_workers): data = HeteroData() data['p'].x = torch.randn(100, 128) data['a'].x = torch.randn(200, 128) data['p', 'a'].edge_index = get_random_edge_index(100, 200, 500) data['p'].edge_attr = torch.randn(500, 32) data['a', 'p'].edge_index = get_random_edge_index(200, 100, 400) data['a', 'p'].edge_attr = torch.randn(400, 32) loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False, num_workers=num_workers) assert len(loader) == 2 for batch in loader: assert batch.num_graphs == len(batch) == 2 assert batch.num_nodes == 600 for store in batch.stores: assert id(batch) == id(store._parent()) @pytest.mark.parametrize('num_workers', num_workers_list) def test_index_dataloader(num_workers): index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True) data1 = Data(index=index1, num_nodes=3) data2 = Data(index=index2, num_nodes=4) loader = DataLoader( [data1, data2, data1, data2], batch_size=2, num_workers=num_workers, ) assert len(loader) == 2 for batch in loader: assert isinstance(batch.index, Index) assert batch.index.dtype == torch.long assert batch.index.dim_size == 7 assert batch.index.is_sorted @pytest.mark.parametrize('num_workers', num_workers_list) @pytest.mark.parametrize('sort_order', [None, 'row', 'col']) def test_edge_index_dataloader(num_workers, sort_order): if sort_order == 'col': edge_index = [[1, 0, 2, 1], [0, 1, 1, 2]] else: edge_index = [[0, 1, 1, 2], [1, 0, 2, 1]] edge_index = EdgeIndex( edge_index, sparse_size=(3, 3), sort_order=sort_order, is_undirected=True, ) data = Data(edge_index=edge_index) assert data.num_nodes == 3 loader = DataLoader( [data, data, data, data], batch_size=2, num_workers=num_workers, ) assert len(loader) == 2 for batch in loader: assert isinstance(batch.edge_index, EdgeIndex) assert batch.edge_index.dtype == torch.long assert batch.edge_index.sparse_size() == (6, 6) assert batch.edge_index.sort_order == sort_order assert batch.edge_index.is_undirected @withPackage('torch_frame') def test_dataloader_tensor_frame(): tf = get_random_tensor_frame(num_rows=10) loader = DataLoader([tf, tf, tf, tf], batch_size=2, shuffle=False) assert len(loader) == 2 for batch in loader: assert batch.num_rows == 20 data = Data(tf=tf, edge_index=get_random_edge_index(10, 10, 20)) loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False) assert len(loader) == 2 for batch in loader: assert batch.num_graphs == len(batch) == 2 assert batch.num_nodes == 20 assert batch.tf.num_rows == 20 assert batch.edge_index.max() >= 10 def test_dataloader_sparse(): adj_t = torch.sparse_coo_tensor( indices=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), values=torch.randn(4), size=(3, 3), ) data = Data(adj_t=adj_t) loader = DataLoader([data, data], batch_size=2) for batch in loader: assert batch.adj_t.size() == (6, 6) if __name__ == '__main__': import argparse import time from torch_geometric.datasets import QM9 parser = argparse.ArgumentParser() parser.add_argument('--num_workers', type=int, default=0) args = parser.parse_args() kwargs = dict(batch_size=128, shuffle=True, num_workers=args.num_workers) in_memory_dataset = QM9('/tmp/QM9') loader = DataLoader(in_memory_dataset, **kwargs) print('In-Memory Dataset:') for _ in range(2): print(f'Start loading {len(loader)} mini-batches ... ', end='') t = time.perf_counter() for _ in loader: pass print(f'Done! [{time.perf_counter() - t:.4f}s]') on_disk_dataset = in_memory_dataset.to_on_disk_dataset() loader = DataLoader(on_disk_dataset, **kwargs) print('On-Disk Dataset:') for _ in range(2): print(f'Start loading {len(loader)} mini-batches ... ', end='') t = time.perf_counter() for _ in loader: pass print(f'Done! [{time.perf_counter() - t:.4f}s]') on_disk_dataset.close() ================================================ FILE: test/loader/test_dynamic_batch_sampler.py ================================================ from typing import List import pytest import torch from torch_geometric.data import Data from torch_geometric.loader import DataLoader, DynamicBatchSampler def test_dataloader_with_dynamic_batches(): data_list: List[Data] = [] for num_nodes in range(100, 110): data_list.append(Data(num_nodes=num_nodes)) torch.manual_seed(12345) batch_sampler = DynamicBatchSampler(data_list, 300, shuffle=True) loader = DataLoader(data_list, batch_sampler=batch_sampler) num_nodes_total = 0 for data in loader: assert data.num_nodes <= 300 num_nodes_total += data.num_nodes assert num_nodes_total == 1045 # Test skipping data_list = [Data(num_nodes=400)] + data_list batch_sampler = DynamicBatchSampler(data_list, 300, skip_too_big=True, num_steps=2) loader = DataLoader(data_list, batch_sampler=batch_sampler) num_nodes_total = 0 for data in loader: num_nodes_total += data.num_nodes assert num_nodes_total == 404 with pytest.raises(ValueError, match="length of 'DynamicBatchSampler'"): len(DynamicBatchSampler(data_list, max_num=300)) assert len(DynamicBatchSampler(data_list, max_num=300, num_steps=2)) == 2 ================================================ FILE: test/loader/test_graph_saint.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.loader import ( GraphSAINTEdgeSampler, GraphSAINTNodeSampler, GraphSAINTRandomWalkSampler, ) from torch_geometric.testing import withPackage @withPackage('torch_sparse') def test_graph_saint(): adj = torch.tensor([ [+1, +2, +3, +0, +4, +0], [+5, +6, +0, +7, +0, +8], [+9, +0, 10, +0, 11, +0], [+0, 12, +0, 13, +0, 14], [15, +0, 16, +0, 17, +0], [+0, 18, +0, 19, +0, 20], ]) edge_index = adj.nonzero(as_tuple=False).t() edge_id = adj[edge_index[0], edge_index[1]] x = torch.tensor([ [0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0], ]) n_id = torch.arange(6) data = Data(edge_index=edge_index, x=x, n_id=n_id, edge_id=edge_id, num_nodes=6) loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges ================================================ FILE: test/loader/test_hgt_loader.py ================================================ import numpy as np import torch from torch_geometric.data import HeteroData from torch_geometric.loader import HGTLoader from torch_geometric.nn import GraphConv, to_hetero from torch_geometric.testing import ( get_random_edge_index, onlyOnline, withPackage, ) from torch_geometric.typing import SparseTensor from torch_geometric.utils import k_hop_subgraph def is_subset(subedge_index, edge_index, src_idx, dst_idx): num_nodes = int(edge_index.max()) + 1 idx = num_nodes * edge_index[0] + edge_index[1] subidx = num_nodes * src_idx[subedge_index[0]] + dst_idx[subedge_index[1]] mask = torch.from_numpy(np.isin(subidx, idx)) return int(mask.sum()) == mask.numel() @withPackage('torch_sparse') def test_hgt_loader(): torch.manual_seed(12345) data = HeteroData() data['paper'].x = torch.arange(100) data['author'].x = torch.arange(100, 300) data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) data['paper', 'paper'].edge_attr = torch.arange(500) data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) data['paper', 'author'].edge_attr = torch.arange(500, 1500) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) data['author', 'paper'].edge_attr = torch.arange(1500, 2500) r1, c1 = data['paper', 'paper'].edge_index r2, c2 = data['paper', 'author'].edge_index + torch.tensor([[0], [100]]) r3, c3 = data['author', 'paper'].edge_index + torch.tensor([[100], [0]]) full_adj = SparseTensor( row=torch.cat([r1, r2, r3]), col=torch.cat([c1, c2, c3]), value=torch.arange(2500), ) batch_size = 20 loader = HGTLoader(data, num_samples=[5] * 4, batch_size=batch_size, input_nodes='paper') assert str(loader) == 'HGTLoader()' assert len(loader) == (100 + batch_size - 1) // batch_size for batch in loader: assert isinstance(batch, HeteroData) # Test node and types: assert set(batch.node_types) == {'paper', 'author'} assert set(batch.edge_types) == set(data.edge_types) assert len(batch['paper']) == 4 assert batch['paper'].n_id.size() == (batch['paper'].num_nodes, ) assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5 assert batch['paper'].input_id.numel() == batch_size assert batch['paper'].batch_size == batch_size assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 assert len(batch['author']) == 2 assert batch['author'].n_id.size() == (batch['author'].num_nodes, ) assert batch['author'].x.size() == (20, ) # 4 * 5 assert batch['author'].x.min() >= 100 and batch['author'].x.max() < 300 # Test edge type selection: assert set(batch.edge_types) == {('paper', 'to', 'paper'), ('paper', 'to', 'author'), ('author', 'to', 'paper')} assert len(batch['paper', 'paper']) == 3 num_edges = batch['paper', 'paper'].num_edges assert batch['paper', 'paper'].e_id.size() == (num_edges, ) row, col = batch['paper', 'paper'].edge_index value = batch['paper', 'paper'].edge_attr adj = full_adj[batch['paper'].x, batch['paper'].x] assert row.min() >= 0 and row.max() < 40 assert col.min() >= 0 and col.max() < 40 assert value.min() >= 0 and value.max() < 500 assert adj.nnz() == row.size(0) assert torch.allclose(row.unique(), adj.storage.row().unique()) assert torch.allclose(col.unique(), adj.storage.col().unique()) assert torch.allclose(value.unique(), adj.storage.value().unique()) assert is_subset(batch['paper', 'paper'].edge_index, data['paper', 'paper'].edge_index, batch['paper'].x, batch['paper'].x) assert len(batch['paper', 'author']) == 3 num_edges = batch['paper', 'author'].num_edges assert batch['paper', 'author'].e_id.size() == (num_edges, ) row, col = batch['paper', 'author'].edge_index value = batch['paper', 'author'].edge_attr adj = full_adj[batch['paper'].x, batch['author'].x] assert row.min() >= 0 and row.max() < 40 assert col.min() >= 0 and col.max() < 20 assert value.min() >= 500 and value.max() < 1500 assert adj.nnz() == row.size(0) assert torch.allclose(row.unique(), adj.storage.row().unique()) assert torch.allclose(col.unique(), adj.storage.col().unique()) assert torch.allclose(value.unique(), adj.storage.value().unique()) assert is_subset(batch['paper', 'author'].edge_index, data['paper', 'author'].edge_index, batch['paper'].x, batch['author'].x - 100) assert len(batch['author', 'paper']) == 3 num_edges = batch['author', 'paper'].num_edges assert batch['author', 'paper'].e_id.size() == (num_edges, ) row, col = batch['author', 'paper'].edge_index value = batch['author', 'paper'].edge_attr adj = full_adj[batch['author'].x, batch['paper'].x] assert row.min() >= 0 and row.max() < 20 assert col.min() >= 0 and col.max() < 40 assert value.min() >= 1500 and value.max() < 2500 assert adj.nnz() == row.size(0) assert torch.allclose(row.unique(), adj.storage.row().unique()) assert torch.allclose(col.unique(), adj.storage.col().unique()) assert torch.allclose(value.unique(), adj.storage.value().unique()) assert is_subset(batch['author', 'paper'].edge_index, data['author', 'paper'].edge_index, batch['author'].x - 100, batch['paper'].x) # Test for isolated nodes (there shouldn't exist any): n_id = torch.cat([batch['paper'].x, batch['author'].x]) row, col, _ = full_adj[n_id, n_id].coo() assert torch.cat([row, col]).unique().numel() >= 59 @onlyOnline @withPackage('torch_sparse') def test_hgt_loader_on_cora(get_dataset): dataset = get_dataset(name='Cora') data = dataset[0] data.edge_weight = torch.rand(data.num_edges) hetero_data = HeteroData() hetero_data['paper'].x = data.x hetero_data['paper'].n_id = torch.arange(data.num_nodes) hetero_data['paper', 'paper'].edge_index = data.edge_index hetero_data['paper', 'paper'].edge_weight = data.edge_weight split_idx = torch.arange(5, 8) # Sample the complete two-hop neighborhood: loader = HGTLoader(hetero_data, num_samples=[data.num_nodes] * 2, batch_size=split_idx.numel(), input_nodes=('paper', split_idx)) assert len(loader) == 1 hetero_batch = next(iter(loader)) batch_size = hetero_batch['paper'].batch_size n_id, _, _, e_mask = k_hop_subgraph(split_idx, num_hops=2, edge_index=data.edge_index, num_nodes=data.num_nodes) n_id = n_id.sort()[0] assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist() assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum()) class GNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GraphConv(in_channels, hidden_channels) self.conv2 = GraphConv(hidden_channels, out_channels) def forward(self, x, edge_index, edge_weight): x = self.conv1(x, edge_index, edge_weight).relu() x = self.conv2(x, edge_index, edge_weight).relu() return x model = GNN(dataset.num_features, 16, dataset.num_classes) hetero_model = to_hetero(model, hetero_data.metadata()) out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx] out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict, hetero_batch.edge_weight_dict)['paper'][:batch_size] assert torch.allclose(out1, out2, atol=1e-6) @withPackage('torch_sparse') def test_hgt_loader_disconnected(): data = HeteroData() data['paper'].x = torch.randn(10, 16) data['author'].x = torch.randn(10, 16) # Paper nodes are disconnected from author nodes: data['paper', 'paper'].edge_index = get_random_edge_index(10, 10, 15) data['paper', 'paper'].edge_attr = torch.randn(15, 8) data['author', 'author'].edge_index = get_random_edge_index(10, 10, 15) data['author', 'author'].edge_attr = torch.randn(15, 8) loader = HGTLoader(data, num_samples=[2], batch_size=2, input_nodes='paper') for batch in loader: assert isinstance(batch, HeteroData) # Test node and edge types: assert set(batch.node_types) == set(data.node_types) assert set(batch.edge_types) == set(data.edge_types) assert batch['author'].num_nodes == 0 assert batch['author'].x.size() == (0, 16) assert batch['author', 'author'].num_edges == 0 assert batch['author', 'author'].edge_index.size() == (2, 0) assert batch['author', 'author'].edge_attr.size() == (0, 8) ================================================ FILE: test/loader/test_ibmb_loader.py ================================================ import pytest import torch from torch import Tensor import torch_geometric.typing from torch_geometric.datasets import KarateClub from torch_geometric.loader.ibmb_loader import IBMBBatchLoader, IBMBNodeLoader from torch_geometric.testing import withPackage from torch_geometric.typing import SparseTensor @withPackage('python_tsp') @pytest.mark.parametrize( 'use_sparse_tensor', [False] + [True] if torch_geometric.typing.WITH_TORCH_SPARSE else []) @pytest.mark.parametrize('kwargs', [ dict(num_partitions=4, batch_size=1), dict(num_partitions=8, batch_size=2), ]) def test_ibmb_batch_loader(use_sparse_tensor, kwargs): data = KarateClub()[0] loader = IBMBBatchLoader( data, batch_order='order', input_nodes=torch.randperm(data.num_nodes)[:20], return_edge_index_type='adj' if use_sparse_tensor else 'edge_index', **kwargs, ) assert str(loader) == 'IBMBBatchLoader()' assert len(loader) == 4 assert sum([batch.output_node_mask.sum() for batch in loader]) == 20 for batch in loader: if use_sparse_tensor: assert isinstance(batch.edge_index, SparseTensor) else: assert isinstance(batch.edge_index, Tensor) @withPackage('python_tsp', 'numba') @pytest.mark.parametrize( 'use_sparse_tensor', [False] + [True] if torch_geometric.typing.WITH_TORCH_SPARSE else []) @pytest.mark.parametrize('kwargs', [ dict(num_nodes_per_batch=4, batch_size=1), dict(num_nodes_per_batch=2, batch_size=2), ]) def test_ibmb_node_loader(use_sparse_tensor, kwargs): data = KarateClub()[0] loader = IBMBNodeLoader( data, batch_order='order', input_nodes=torch.randperm(data.num_nodes)[:20], num_auxiliary_nodes=4, return_edge_index_type='adj' if use_sparse_tensor else 'edge_index', **kwargs, ) assert str(loader) == 'IBMBNodeLoader()' assert len(loader) == 5 assert sum([batch.output_node_mask.sum() for batch in loader]) == 20 for batch in loader: if use_sparse_tensor: assert isinstance(batch.edge_index, SparseTensor) else: assert isinstance(batch.edge_index, Tensor) ================================================ FILE: test/loader/test_imbalanced_sampler.py ================================================ from typing import List import torch from torch_geometric.data import Data from torch_geometric.datasets import FakeDataset, FakeHeteroDataset from torch_geometric.loader import ( DataLoader, ImbalancedSampler, NeighborLoader, ) from torch_geometric.testing import onlyNeighborSampler def test_dataloader_with_imbalanced_sampler(): data_list: List[Data] = [] for _ in range(10): data_list.append(Data(num_nodes=10, y=0)) for _ in range(90): data_list.append(Data(num_nodes=10, y=1)) torch.manual_seed(12345) sampler = ImbalancedSampler(data_list) loader = DataLoader(data_list, batch_size=10, sampler=sampler) y = torch.cat([batch.y for batch in loader]) histogram = y.bincount() prob = histogram / histogram.sum() assert histogram.sum() == len(data_list) assert prob.min() > 0.4 and prob.max() < 0.6 # Test with label tensor as input: torch.manual_seed(12345) sampler = ImbalancedSampler(torch.tensor([data.y for data in data_list])) loader = DataLoader(data_list, batch_size=10, sampler=sampler) assert torch.allclose(y, torch.cat([batch.y for batch in loader])) # Test with list of data objects as input where each y is a tensor: torch.manual_seed(12345) for data in data_list: data.y = torch.tensor([data.y]) sampler = ImbalancedSampler(data_list) loader = DataLoader(data_list, batch_size=100, sampler=sampler) assert torch.allclose(y, torch.cat([batch.y for batch in loader])) def test_in_memory_dataset_imbalanced_sampler(): torch.manual_seed(12345) dataset = FakeDataset(num_graphs=100, avg_num_nodes=10, avg_degree=0, num_channels=0, num_classes=2) sampler = ImbalancedSampler(dataset) loader = DataLoader(dataset, batch_size=10, sampler=sampler) y = torch.cat([batch.y for batch in loader]) histogram = y.bincount() prob = histogram / histogram.sum() assert histogram.sum() == len(dataset) assert prob.min() > 0.4 and prob.max() < 0.6 @onlyNeighborSampler def test_neighbor_loader_with_imbalanced_sampler(): zeros = torch.zeros(10, dtype=torch.long) ones = torch.ones(90, dtype=torch.long) y = torch.cat([zeros, ones], dim=0) edge_index = torch.empty((2, 0), dtype=torch.long) data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0)) torch.manual_seed(12345) sampler = ImbalancedSampler(data) loader = NeighborLoader(data, batch_size=10, sampler=sampler, num_neighbors=[-1]) y = torch.cat([batch.y for batch in loader]) histogram = y.bincount() prob = histogram / histogram.sum() assert histogram.sum() == data.num_nodes assert prob.min() > 0.4 and prob.max() < 0.6 # Test with label tensor as input: torch.manual_seed(12345) sampler = ImbalancedSampler(data.y) loader = NeighborLoader(data, batch_size=10, sampler=sampler, num_neighbors=[-1]) assert torch.allclose(y, torch.cat([batch.y for batch in loader])) @onlyNeighborSampler def test_hetero_neighbor_loader_with_imbalanced_sampler(): torch.manual_seed(12345) data = FakeHeteroDataset(num_classes=2)[0] loader = NeighborLoader( data, batch_size=100, input_nodes='v0', num_neighbors=[-1], sampler=ImbalancedSampler(data['v0'].y), ) y = torch.cat([batch['v0'].y[:batch['v0'].batch_size] for batch in loader]) histogram = y.bincount() prob = histogram / histogram.sum() assert histogram.sum() == data['v0'].num_nodes assert prob.min() > 0.4 and prob.max() < 0.6 ================================================ FILE: test/loader/test_link_neighbor_loader.py ================================================ import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.loader import LinkNeighborLoader from torch_geometric.testing import ( MyFeatureStore, MyGraphStore, get_random_edge_index, onlyNeighborSampler, withCUDA, withPackage, ) def unique_edge_pairs(edge_index): return set(map(tuple, edge_index.t().tolist())) @withCUDA @onlyNeighborSampler @pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional']) @pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0]) @pytest.mark.parametrize('filter_per_worker', [None, True, False]) def test_homo_link_neighbor_loader_basic(device, subgraph_type, neg_sampling_ratio, filter_per_worker): pos_edge_index = get_random_edge_index(50, 50, 500, device=device) neg_edge_index = get_random_edge_index(50, 50, 500, device=device) neg_edge_index += 50 input_edges = torch.cat([pos_edge_index, neg_edge_index], dim=-1) edge_label = torch.cat([ torch.ones(500, device=device), torch.zeros(500, device=device), ], dim=0) data = Data() data.edge_index = pos_edge_index data.x = torch.arange(100, device=device) data.edge_attr = torch.arange(500, device=device) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, batch_size=20, edge_label_index=input_edges, edge_label=edge_label if neg_sampling_ratio is None else None, subgraph_type=subgraph_type, neg_sampling_ratio=neg_sampling_ratio, shuffle=True, filter_per_worker=filter_per_worker, ) assert str(loader) == 'LinkNeighborLoader()' assert len(loader) == 1000 / 20 batch = loader([0]) assert isinstance(batch, Data) assert int(input_edges[0, 0]) in batch.n_id.tolist() assert int(input_edges[1, 0]) in batch.n_id.tolist() for batch in loader: assert isinstance(batch, Data) assert batch.n_id.size() == (batch.num_nodes, ) assert batch.e_id.size() == (batch.num_edges, ) assert batch.x.device == device assert batch.x.size(0) <= 100 assert batch.x.min() >= 0 and batch.x.max() < 100 assert batch.input_id.numel() == 20 assert batch.edge_index.device == device assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert batch.edge_attr.device == device assert batch.edge_attr.min() >= 0 assert batch.edge_attr.max() < 500 if neg_sampling_ratio is None: assert batch.edge_label_index.size(1) == 20 # Assert positive samples are present in the original graph: edge_index = unique_edge_pairs(batch.edge_index) edge_label_index = batch.edge_label_index[:, batch.edge_label == 1] edge_label_index = unique_edge_pairs(edge_label_index) assert len(edge_index | edge_label_index) == len(edge_index) # Assert negative samples are not present in the original graph: edge_index = unique_edge_pairs(batch.edge_index) edge_label_index = batch.edge_label_index[:, batch.edge_label == 0] edge_label_index = unique_edge_pairs(edge_label_index) assert len(edge_index & edge_label_index) == 0 else: assert batch.edge_label_index.size(1) == 40 assert torch.all(batch.edge_label[:20] == 1) assert torch.all(batch.edge_label[20:] == 0) # Ensure local `edge_label_index` correctly maps to input edges. global_edge_label_index = batch.n_id[batch.edge_label_index] global_edge_label_index = ( global_edge_label_index[:, batch.edge_label >= 1]) global_edge_label_index = unique_edge_pairs(global_edge_label_index) assert (len(global_edge_label_index & unique_edge_pairs(input_edges)) == len(global_edge_label_index)) @onlyNeighborSampler @pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional']) @pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0]) def test_hetero_link_neighbor_loader_basic(subgraph_type, neg_sampling_ratio): data = HeteroData() data['paper'].x = torch.arange(100) data['author'].x = torch.arange(100, 300) data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) data['paper', 'paper'].edge_attr = torch.arange(500) data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) data['paper', 'author'].edge_attr = torch.arange(500, 1500) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) data['author', 'paper'].edge_attr = torch.arange(1500, 2500) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, edge_label_index=('paper', 'author'), batch_size=20, subgraph_type=subgraph_type, neg_sampling_ratio=neg_sampling_ratio, shuffle=True, ) assert str(loader) == 'LinkNeighborLoader()' assert len(loader) == 1000 / 20 for batch in loader: assert isinstance(batch, HeteroData) assert batch.input_type == ('paper', 'to', 'author') if neg_sampling_ratio is None: # Assert only positive samples are present in the original graph: edge_index = unique_edge_pairs(batch['paper', 'author'].edge_index) edge_label_index = batch['paper', 'author'].edge_label_index edge_label_index = unique_edge_pairs(edge_label_index) assert len(edge_index | edge_label_index) == len(edge_index) else: assert batch['paper', 'author'].edge_label_index.size(1) == 40 assert torch.all(batch['paper', 'author'].edge_label[:20] == 1) assert torch.all(batch['paper', 'author'].edge_label[20:] == 0) @onlyNeighborSampler @pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional']) def test_hetero_link_neighbor_loader_loop(subgraph_type): data = HeteroData() data['paper'].x = torch.arange(100) data['author'].x = torch.arange(100, 300) data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, edge_label_index=('paper', 'paper'), batch_size=20, subgraph_type=subgraph_type, ) for batch in loader: assert batch['paper'].x.size(0) <= 100 assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 # Assert positive samples are present in the original graph: edge_index = unique_edge_pairs(batch['paper', 'paper'].edge_index) edge_label_index = batch['paper', 'paper'].edge_label_index edge_label_index = unique_edge_pairs(edge_label_index) assert len(edge_index | edge_label_index) == len(edge_index) @onlyNeighborSampler def test_link_neighbor_loader_edge_label(): edge_index = get_random_edge_index(100, 100, 500) data = Data(edge_index=edge_index, x=torch.arange(100)) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, batch_size=10, neg_sampling_ratio=1.0, ) for batch in loader: assert batch.edge_label.dtype == torch.float assert torch.all(batch.edge_label[:10] == 1.0) assert torch.all(batch.edge_label[10:] == 0.0) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, batch_size=10, edge_label=torch.ones(500, dtype=torch.long), neg_sampling_ratio=1.0, ) for batch in loader: assert batch.edge_label.dtype == torch.long assert torch.all(batch.edge_label[:10] == 1) assert torch.all(batch.edge_label[10:] == 0) @withPackage('pyg_lib') @pytest.mark.parametrize('batch_size', [1]) def test_temporal_homo_link_neighbor_loader(batch_size): data = Data( x=torch.randn(10, 5), edge_index=torch.randint(0, 10, (2, 123)), time=torch.arange(10), ) # Ensure that nodes exist at the time of the `edge_label_time`: edge_label_time = torch.max( data.time[data.edge_index[0]], data.time[data.edge_index[1]], ) loader = LinkNeighborLoader( data, num_neighbors=[-1], time_attr='time', edge_label=torch.ones(data.num_edges), edge_label_time=edge_label_time, batch_size=batch_size, shuffle=True, ) for batch in loader: assert batch.edge_label_index.size() == (2, batch_size) assert batch.edge_label_time.size() == (batch_size, ) assert batch.edge_label.size() == (batch_size, ) assert torch.all(batch.time <= batch.edge_label_time) @withPackage('pyg_lib') def test_temporal_hetero_link_neighbor_loader(): data = HeteroData() data['paper'].x = torch.arange(100) data['paper'].time = torch.arange(data['paper'].num_nodes) - 200 data['author'].x = torch.arange(100, 300) data['author'].time = torch.arange(data['author'].num_nodes) data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) with pytest.raises(ValueError, match=r"'edge_label_time' is not set"): loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, edge_label_index=('paper', 'paper'), batch_size=32, time_attr='time', ) # With edge_time: edge_time = torch.arange(data['paper', 'paper'].edge_index.size(1)) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, edge_label_index=('paper', 'paper'), edge_label_time=edge_time, batch_size=32, time_attr='time', neg_sampling_ratio=0.5, drop_last=True, ) for batch in loader: # Check if each seed edge has a different batch: assert int(batch['paper'].batch.max()) + 1 == 32 author_max = batch['author'].time.max() edge_max = batch['paper', 'paper'].edge_label_time.max() assert edge_max >= author_max author_min = batch['author'].time.min() edge_min = batch['paper', 'paper'].edge_label_time.min() assert edge_min >= author_min @onlyNeighborSampler def test_custom_hetero_link_neighbor_loader(): data = HeteroData() feature_store = MyFeatureStore() graph_store = MyGraphStore() # Set up node features: x = torch.arange(100) data['paper'].x = x feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None) x = torch.arange(100, 300) data['author'].x = x feature_store.put_tensor(x, group_name='author', attr_name='x', index=None) # Set up edge indices (GraphStore does not support `edge_attr` at the # moment): edge_index = get_random_edge_index(100, 100, 500) data['paper', 'to', 'paper'].edge_index = edge_index graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), edge_type=('paper', 'to', 'paper'), layout='coo', size=(100, 100)) edge_index = get_random_edge_index(100, 200, 1000) data['paper', 'to', 'author'].edge_index = edge_index graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), edge_type=('paper', 'to', 'author'), layout='coo', size=(100, 200)) edge_index = get_random_edge_index(200, 100, 1000) data['author', 'to', 'paper'].edge_index = edge_index graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), edge_type=('author', 'to', 'paper'), layout='coo', size=(200, 100)) loader1 = LinkNeighborLoader( data, num_neighbors=[-1] * 2, edge_label_index=('paper', 'to', 'author'), batch_size=20, ) loader2 = LinkNeighborLoader( (feature_store, graph_store), num_neighbors=[-1] * 2, edge_label_index=('paper', 'to', 'author'), batch_size=20, ) assert str(loader1) == str(loader2) for (batch1, batch2) in zip(loader1, loader2): # Mapped indices of neighbors may be differently sorted: assert torch.allclose(batch1['paper'].x.sort()[0], batch2['paper'].x.sort()[0]) assert torch.allclose(batch1['author'].x.sort()[0], batch2['author'].x.sort()[0]) # Assert that edge indices have the same size: assert (batch1['paper', 'to', 'paper'].edge_index.size() == batch1[ 'paper', 'to', 'paper'].edge_index.size()) assert (batch1['paper', 'to', 'author'].edge_index.size() == batch1[ 'paper', 'to', 'author'].edge_index.size()) assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[ 'author', 'to', 'paper'].edge_index.size()) @onlyNeighborSampler def test_homo_link_neighbor_loader_no_edges(): loader = LinkNeighborLoader( Data(num_nodes=100), num_neighbors=[], batch_size=20, edge_label_index=get_random_edge_index(100, 100, 100), ) for batch in loader: assert isinstance(batch, Data) assert batch.input_id.numel() == 20 assert batch.edge_label_index.size(1) == 20 assert batch.num_nodes == batch.edge_label_index.unique().numel() @onlyNeighborSampler def test_hetero_link_neighbor_loader_no_edges(): loader = LinkNeighborLoader( HeteroData(paper=dict(num_nodes=100)), num_neighbors=[], edge_label_index=( ('paper', 'paper'), get_random_edge_index(100, 100, 100), ), batch_size=20, ) for batch in loader: assert isinstance(batch, HeteroData) assert batch['paper', 'paper'].input_id.numel() == 20 assert batch['paper', 'paper'].edge_label_index.size(1) == 20 assert batch['paper'].num_nodes == batch[ 'paper', 'paper'].edge_label_index.unique().numel() @withPackage('pyg_lib') @pytest.mark.parametrize('disjoint', [False, True]) @pytest.mark.parametrize('temporal', [False, True]) @pytest.mark.parametrize('amount', [1, 2]) def test_homo_link_neighbor_loader_triplet(disjoint, temporal, amount): if not disjoint and temporal: return data = Data() data.x = torch.arange(100) data.edge_index = get_random_edge_index(100, 100, 400) data.edge_label_index = get_random_edge_index(100, 100, 500) data.edge_attr = torch.arange(data.num_edges) time_attr = edge_label_time = None if temporal: time_attr = 'time' data.time = torch.arange(data.num_nodes) edge_label_time = torch.max(data.time[data.edge_label_index[0]], data.time[data.edge_label_index[1]]) edge_label_time = edge_label_time + 50 batch_size = 20 loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, batch_size=batch_size, edge_label_index=data.edge_label_index, edge_label_time=edge_label_time, time_attr=time_attr, disjoint=disjoint, neg_sampling=dict(mode='triplet', amount=amount), shuffle=True, ) assert str(loader) == 'LinkNeighborLoader()' assert len(loader) == 500 / batch_size for batch in loader: assert isinstance(batch, Data) # Check that `src_index` and `dst_pos_index` point to valid edges: assert torch.equal(batch.x[batch.src_index], data.edge_label_index[0, batch.input_id]) assert torch.equal(batch.x[batch.dst_pos_index], data.edge_label_index[1, batch.input_id]) # Check that `dst_neg_index` points to valid nodes in the batch: if amount == 1: assert batch.dst_neg_index.size() == (batch_size, ) else: assert batch.dst_neg_index.size() == (batch_size, amount) assert batch.dst_neg_index.min() >= 0 assert batch.dst_neg_index.max() < batch.num_nodes if disjoint: # In disjoint mode, seed nodes should always be placed first: assert batch.src_index.min() == 0 assert batch.src_index.max() == batch_size - 1 assert batch.dst_pos_index.min() == batch_size assert batch.dst_pos_index.max() == 2 * batch_size - 1 assert batch.dst_neg_index.min() == 2 * batch_size max_seed_nodes = 2 * batch_size + batch_size * amount assert batch.dst_neg_index.max() == max_seed_nodes - 1 assert batch.batch.min() == 0 assert batch.batch.max() == batch_size - 1 # Check that `batch` is always increasing: for i in range(0, max_seed_nodes, batch_size): batch_vector = batch.batch[i:i + batch_size] assert torch.equal(batch_vector, torch.arange(batch_size)) if temporal: for i in range(batch_size): assert batch.time[batch.batch == i].max() <= batch.seed_time[i] @withPackage('pyg_lib') @pytest.mark.parametrize('disjoint', [False, True]) @pytest.mark.parametrize('temporal', [False, True]) @pytest.mark.parametrize('amount', [1, 2]) def test_hetero_link_neighbor_loader_triplet(disjoint, temporal, amount): if not disjoint and temporal: return data = HeteroData() data['paper'].x = torch.arange(100) data['author'].x = torch.arange(100, 300) data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 400) edge_label_index = get_random_edge_index(100, 100, 500) data['paper', 'paper'].edge_label_index = edge_label_index data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) time_attr = edge_label_time = None if temporal: time_attr = 'time' data['paper'].time = torch.arange(data['paper'].num_nodes) data['author'].time = torch.arange(data['author'].num_nodes) edge_label_time = torch.max( data['paper'].time[data['paper', 'paper'].edge_label_index[0]], data['paper'].time[data['paper', 'paper'].edge_label_index[1]], ) edge_label_time = edge_label_time + 50 weight = torch.rand(data['paper'].num_nodes) if not temporal else None batch_size = 20 index = (('paper', 'paper'), data['paper', 'paper'].edge_label_index) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, batch_size=batch_size, edge_label_index=index, edge_label_time=edge_label_time, time_attr=time_attr, disjoint=disjoint, neg_sampling=dict( mode='triplet', amount=amount, src_weight=weight, dst_weight=weight, ), shuffle=True, ) assert str(loader) == 'LinkNeighborLoader()' assert len(loader) == 500 / batch_size for batch in loader: assert isinstance(batch, HeteroData) node_store = batch['paper'] edge_store = batch['paper', 'paper'] # Check that `src_index` and `dst_pos_index` point to valid edges: assert torch.equal( node_store.x[node_store.src_index], data['paper', 'paper'].edge_label_index[0, edge_store.input_id]) assert torch.equal( node_store.x[node_store.dst_pos_index], data['paper', 'paper'].edge_label_index[1, edge_store.input_id]) # Check that `dst_neg_index` points to valid nodes in the batch: if amount == 1: assert node_store.dst_neg_index.size() == (batch_size, ) else: assert node_store.dst_neg_index.size() == (batch_size, amount) assert node_store.dst_neg_index.min() >= 0 assert node_store.dst_neg_index.max() < node_store.num_nodes if disjoint: # In disjoint mode, seed nodes should always be placed first: assert node_store.src_index.min() == 0 assert node_store.src_index.max() == batch_size - 1 assert node_store.dst_pos_index.min() == batch_size assert node_store.dst_pos_index.max() == 2 * batch_size - 1 assert node_store.dst_neg_index.min() == 2 * batch_size max_seed_nodes = 2 * batch_size + batch_size * amount assert node_store.dst_neg_index.max() == max_seed_nodes - 1 assert node_store.batch.min() == 0 assert node_store.batch.max() == batch_size - 1 # Check that `batch` is always increasing: for i in range(0, max_seed_nodes, batch_size): batch_vector = node_store.batch[i:i + batch_size] assert torch.equal(batch_vector, torch.arange(batch_size)) if temporal: for i in range(batch_size): assert (node_store.time[node_store.batch == i].max() <= node_store.seed_time[i]) @withPackage('pyg_lib') def test_link_neighbor_loader_mapping(): edge_index = torch.tensor([ [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5], [1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11], ]) data = Data(edge_index=edge_index, num_nodes=12) loader = LinkNeighborLoader( data, edge_label_index=data.edge_index, num_neighbors=[1], batch_size=2, shuffle=True, ) for batch in loader: assert torch.equal( batch.n_id[batch.edge_index], data.edge_index[:, batch.e_id], ) ================================================ FILE: test/loader/test_mixin.py ================================================ import subprocess from time import sleep import psutil import pytest import torch from torch_geometric.data import Data from torch_geometric.loader import NeighborLoader from torch_geometric.testing import onlyLinux, onlyNeighborSampler @pytest.mark.xfail(reason="TODO: Fix test") @onlyLinux @onlyNeighborSampler @pytest.mark.skipif( psutil.cpu_count(logical=False) == 1, reason="Requires multiple CPU cores") @pytest.mark.parametrize('loader_cores', [None, [1, 2]]) def test_cpu_affinity_neighbor_loader(loader_cores, spawn_context): data = Data(x=torch.randn(1, 1)) loader = NeighborLoader(data, num_neighbors=[-1], batch_size=1, num_workers=2) out = [] with loader.enable_cpu_affinity(loader_cores): iterator = loader._get_iterator() workers = iterator._workers sleep(3) # Gives time for worker to initialize. for worker in workers: process = subprocess.Popen( ['taskset', '-c', '-p', f'{worker.pid}'], stdout=subprocess.PIPE) stdout = process.communicate()[0].decode('utf-8') # returns "pid 's current affinity list -" out.append(stdout.split(':')[1].strip()) if loader_cores: assert out == ['[1]', '[2]'] else: assert out[0] != out[1] def init_fn(worker_id): assert torch.get_num_threads() == 2 @onlyLinux @onlyNeighborSampler @pytest.mark.skipif( psutil.cpu_count(logical=False) == 1, reason="Requires multiple CPU cores") def test_multithreading_neighbor_loader(spawn_context): loader = NeighborLoader( data=Data(x=torch.randn(1, 1)), num_neighbors=[-1], batch_size=1, num_workers=2, worker_init_fn=init_fn, ) with loader.enable_multithreading(2): loader._get_iterator() # Runs assertion in `init_fn`. ================================================ FILE: test/loader/test_neighbor_loader.py ================================================ import os.path as osp import numpy as np import pytest import torch import torch_geometric.typing from torch_geometric import EdgeIndex from torch_geometric.data import Data, HeteroData from torch_geometric.loader import NeighborLoader from torch_geometric.nn import GraphConv, to_hetero from torch_geometric.sampler.base import SubgraphType from torch_geometric.testing import ( MyFeatureStore, MyGraphStore, get_random_edge_index, get_random_tensor_frame, onlyLinux, onlyNeighborSampler, onlyOnline, withCUDA, withPackage, ) from torch_geometric.typing import ( WITH_EDGE_TIME_NEIGHBOR_SAMPLE, WITH_PYG_LIB, WITH_TORCH_SPARSE, WITH_WEIGHTED_NEIGHBOR_SAMPLE, TensorFrame, ) from torch_geometric.utils import ( is_undirected, sort_edge_index, to_torch_csr_tensor, to_undirected, ) DTYPES = [ pytest.param(torch.int64, id='int64'), pytest.param(torch.int32, id='int32'), ] SUBGRAPH_TYPES = [ pytest.param(SubgraphType.directional, id='directional'), pytest.param(SubgraphType.bidirectional, id='bidirectional'), pytest.param(SubgraphType.induced, id='induced'), ] FILTER_PER_WORKERS = [ pytest.param(None, id='auto_filter'), pytest.param(True, id='filter_per_worker'), pytest.param(False, id='filter_in_main'), ] def is_subset(subedge_index, edge_index, src_idx, dst_idx): num_nodes = int(edge_index.max()) + 1 idx = num_nodes * edge_index[0] + edge_index[1] subidx = num_nodes * src_idx[subedge_index[0]] + dst_idx[subedge_index[1]] mask = torch.from_numpy(np.isin(subidx.cpu().numpy(), idx.cpu().numpy())) return int(mask.sum()) == mask.numel() @withCUDA @onlyNeighborSampler @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) @pytest.mark.parametrize('filter_per_worker', FILTER_PER_WORKERS) def test_homo_neighbor_loader_basic( device, subgraph_type, dtype, filter_per_worker, ): if dtype != torch.int64 and not torch_geometric.typing.WITH_PT20: return induced = SubgraphType.induced if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE: return if dtype != torch.int64 and (not WITH_PYG_LIB or subgraph_type == induced): return torch.manual_seed(12345) data = Data() data.x = torch.arange(100, device=device) data.edge_index = get_random_edge_index(100, 100, 500, dtype, device) data.edge_attr = torch.arange(500, device=device) loader = NeighborLoader( data, num_neighbors=[5] * 2, batch_size=20, subgraph_type=subgraph_type, filter_per_worker=filter_per_worker, ) assert str(loader) == 'NeighborLoader()' assert len(loader) == 5 batch = loader([0]) assert isinstance(batch, Data) assert batch.n_id[:1].tolist() == [0] for i, batch in enumerate(loader): assert isinstance(batch, Data) assert batch.x.device == device assert batch.x.size(0) <= 100 assert batch.n_id.size() == (batch.num_nodes, ) assert batch.input_id.numel() == batch.batch_size == 20 assert batch.x.min() >= 0 and batch.x.max() < 100 # TODO Re-enable once `EdgeIndex` is stable. assert not isinstance(batch.edge_index, EdgeIndex) # batch.edge_index.validate() # size = (batch.num_nodes, batch.num_nodes) # assert batch.edge_index.sparse_size() == size # assert batch.edge_index.sort_order == 'col' assert batch.edge_index.device == device assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert batch.edge_attr.device == device assert batch.edge_attr.size(0) == batch.edge_index.size(1) # Input nodes are always sampled first: assert torch.equal( batch.x[:batch.batch_size], torch.arange(i * batch.batch_size, (i + 1) * batch.batch_size, device=device), ) if subgraph_type != SubgraphType.bidirectional: assert batch.e_id.size() == (batch.num_edges, ) assert batch.edge_attr.min() >= 0 assert batch.edge_attr.max() < 500 assert is_subset( batch.edge_index.to(torch.int64), data.edge_index.to(torch.int64), batch.x, batch.x, ) @onlyNeighborSampler @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) def test_hetero_neighbor_loader_basic(subgraph_type, dtype): if dtype != torch.int64 and not torch_geometric.typing.WITH_PT20: return induced = SubgraphType.induced if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE: return if dtype != torch.int64 and (not WITH_PYG_LIB or subgraph_type == induced): return torch.manual_seed(12345) data = HeteroData() data['paper'].x = torch.arange(100) data['author'].x = torch.arange(100, 300) edge_index = get_random_edge_index(100, 100, 500, dtype) data['paper', 'paper'].edge_index = edge_index data['paper', 'paper'].edge_attr = torch.arange(500) edge_index = get_random_edge_index(100, 200, 1000, dtype) data['paper', 'author'].edge_index = edge_index data['paper', 'author'].edge_attr = torch.arange(500, 1500) edge_index = get_random_edge_index(200, 100, 1000, dtype) data['author', 'paper'].edge_index = edge_index data['author', 'paper'].edge_attr = torch.arange(1500, 2500) r1, c1 = data['paper', 'paper'].edge_index r2, c2 = data['paper', 'author'].edge_index + torch.tensor([[0], [100]]) r3, c3 = data['author', 'paper'].edge_index + torch.tensor([[100], [0]]) batch_size = 20 with pytest.raises(ValueError, match="hops must be the same across all"): loader = NeighborLoader( data, num_neighbors={ ('paper', 'to', 'paper'): [-1], ('paper', 'to', 'author'): [-1, -1], ('author', 'to', 'paper'): [-1, -1], }, input_nodes='paper', batch_size=batch_size, subgraph_type=subgraph_type, ) next(iter(loader)) loader = NeighborLoader( data, num_neighbors=[10] * 2, input_nodes='paper', batch_size=batch_size, subgraph_type=subgraph_type, ) assert str(loader) == 'NeighborLoader()' assert len(loader) == (100 + batch_size - 1) // batch_size for batch in loader: assert isinstance(batch, HeteroData) assert batch.input_type == 'paper' # Test node type selection: assert set(batch.node_types) == {'paper', 'author'} assert batch['paper'].n_id.size() == (batch['paper'].num_nodes, ) assert batch['paper'].x.size(0) <= 100 assert batch['paper'].input_id.numel() == batch_size assert batch['paper'].batch_size == batch_size assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 assert batch['author'].n_id.size() == (batch['author'].num_nodes, ) assert batch['author'].x.size(0) <= 200 assert batch['author'].x.min() >= 100 and batch['author'].x.max() < 300 # Test edge type selection: assert set(batch.edge_types) == {('paper', 'to', 'paper'), ('paper', 'to', 'author'), ('author', 'to', 'paper')} for edge_type, edge_index in batch.edge_index_dict.items(): src, _, dst = edge_type # TODO Re-enable once `EdgeIndex` is stable. assert not isinstance(edge_index, EdgeIndex) # edge_index.validate() # size = (batch[src].num_nodes, batch[dst].num_nodes) # assert edge_index.sparse_size() == size # assert edge_index.sort_order == 'col' row, col = batch['paper', 'paper'].edge_index assert row.min() >= 0 and row.max() < batch['paper'].num_nodes assert col.min() >= 0 and col.max() < batch['paper'].num_nodes if subgraph_type != SubgraphType.bidirectional: assert batch['paper', 'paper'].e_id.size() == (row.numel(), ) value = batch['paper', 'paper'].edge_attr assert value.min() >= 0 and value.max() < 500 assert is_subset( batch['paper', 'paper'].edge_index.to(torch.int64), data['paper', 'paper'].edge_index.to(torch.int64), batch['paper'].x, batch['paper'].x, ) elif subgraph_type != SubgraphType.directional: assert 'e_id' not in batch['paper', 'paper'] assert 'edge_attr' not in batch['paper', 'paper'] assert is_undirected(batch['paper', 'paper'].edge_index) row, col = batch['paper', 'author'].edge_index assert row.min() >= 0 and row.max() < batch['paper'].num_nodes assert col.min() >= 0 and col.max() < batch['author'].num_nodes if subgraph_type != SubgraphType.bidirectional: assert batch['paper', 'author'].e_id.size() == (row.numel(), ) value = batch['paper', 'author'].edge_attr assert value.min() >= 500 and value.max() < 1500 assert is_subset( batch['paper', 'author'].edge_index.to(torch.int64), data['paper', 'author'].edge_index.to(torch.int64), batch['paper'].x, batch['author'].x - 100, ) elif subgraph_type != SubgraphType.directional: assert 'e_id' not in batch['paper', 'author'] assert 'edge_attr' not in batch['paper', 'author'] edge_index1 = batch['paper', 'author'].edge_index edge_index2 = batch['author', 'paper'].edge_index assert torch.equal( edge_index1, sort_edge_index(edge_index2.flip([0]), sort_by_row=False), ) row, col = batch['author', 'paper'].edge_index assert row.min() >= 0 and row.max() < batch['author'].num_nodes assert col.min() >= 0 and col.max() < batch['paper'].num_nodes if subgraph_type != SubgraphType.bidirectional: assert batch['author', 'paper'].e_id.size() == (row.numel(), ) value = batch['author', 'paper'].edge_attr assert value.min() >= 1500 and value.max() < 2500 assert is_subset( batch['author', 'paper'].edge_index.to(torch.int64), data['author', 'paper'].edge_index.to(torch.int64), batch['author'].x - 100, batch['paper'].x, ) elif subgraph_type != SubgraphType.directional: assert 'e_id' not in batch['author', 'paper'] assert 'edge_attr' not in batch['author', 'paper'] edge_index1 = batch['author', 'paper'].edge_index edge_index2 = batch['paper', 'author'].edge_index assert torch.equal( edge_index1, sort_edge_index(edge_index2.flip([0]), sort_by_row=False), ) # Test for isolated nodes (there shouldn't exist any): assert not batch.has_isolated_nodes() @onlyOnline @onlyNeighborSampler @pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) def test_homo_neighbor_loader_on_karate(get_dataset, subgraph_type): if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE: return dataset = get_dataset(name='karate') data = dataset[0] mask = data.edge_index[0] < data.edge_index[1] edge_index = data.edge_index[:, mask] edge_weight = torch.rand(edge_index.size(1)) data.edge_index, data.edge_weight = to_undirected(edge_index, edge_weight) split_idx = torch.arange(5, 8) loader = NeighborLoader( data, num_neighbors=[-1, -1], batch_size=split_idx.numel(), input_nodes=split_idx, subgraph_type=subgraph_type, ) assert len(loader) == 1 batch = next(iter(loader)) batch_size = batch.batch_size class GNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GraphConv(in_channels, hidden_channels) self.conv2 = GraphConv(hidden_channels, out_channels) def forward(self, x, edge_index, edge_weight): x = self.conv1(x, edge_index, edge_weight).relu() x = self.conv2(x, edge_index, edge_weight) return x model = GNN(dataset.num_features, 16, dataset.num_classes) out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx] out2 = model(batch.x, batch.edge_index, batch.edge_weight)[:batch_size] assert torch.allclose(out1, out2, atol=1e-6) @onlyOnline @onlyNeighborSampler @pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) def test_hetero_neighbor_loader_on_karate(get_dataset, subgraph_type): if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE: return dataset = get_dataset(name='karate') data = dataset[0] hetero_data = HeteroData() hetero_data['v'].x = data.x hetero_data['v', 'v'].edge_index = data.edge_index split_idx = torch.arange(5, 8) loader = NeighborLoader( hetero_data, num_neighbors=[-1, -1], batch_size=split_idx.numel(), input_nodes=('v', split_idx), subgraph_type=subgraph_type, ) assert len(loader) == 1 hetero_batch = next(iter(loader)) batch_size = hetero_batch['v'].batch_size class GNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GraphConv(in_channels, hidden_channels) self.conv2 = GraphConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x model = GNN(dataset.num_features, 16, dataset.num_classes) hetero_model = to_hetero(model, hetero_data.metadata()) out1 = model(data.x, data.edge_index)[split_idx] out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict)['v'][:batch_size] assert torch.allclose(out1, out2, atol=1e-6) @onlyOnline @withPackage('pyg_lib') def test_temporal_hetero_neighbor_loader_on_karate(get_dataset): dataset = get_dataset(name='karate') data = dataset[0] hetero_data = HeteroData() hetero_data['v'].x = data.x hetero_data['v'].time = torch.arange(data.num_nodes, 0, -1) hetero_data['v', 'v'].edge_index = data.edge_index loader = NeighborLoader(hetero_data, num_neighbors=[-1, -1], input_nodes='v', time_attr='time', batch_size=1) for batch in loader: mask = batch['v'].time[0] >= batch['v'].time[1:] assert torch.all(mask) @onlyNeighborSampler def test_custom_neighbor_loader(): # Initialize feature store, graph store, and reference: feature_store = MyFeatureStore() graph_store = MyGraphStore() # Set up node features: x = torch.arange(100, 300) feature_store.put_tensor(x, group_name=None, attr_name='x', index=None) y = torch.arange(100, 300) feature_store.put_tensor(y, group_name=None, attr_name='y', index=None) # COO: edge_index = get_random_edge_index(100, 100, 500, coalesce=True) edge_index = edge_index[:, torch.randperm(edge_index.size(1))] coo = (edge_index[0], edge_index[1]) graph_store.put_edge_index(edge_index=coo, edge_type=None, layout='coo', size=(100, 100)) data = Data(x=x, edge_index=edge_index, y=y, num_nodes=200) # Construct neighbor loaders: loader1 = NeighborLoader(data, batch_size=20, input_nodes=torch.arange(100), num_neighbors=[-1] * 2) loader2 = NeighborLoader((feature_store, graph_store), batch_size=20, input_nodes=torch.arange(100), num_neighbors=[-1] * 2) assert str(loader1) == str(loader2) assert len(loader1) == len(loader2) for batch1, batch2 in zip(loader1, loader2): assert len(batch1) == len(batch2) assert batch1.num_nodes == batch2.num_nodes assert batch1.num_edges == batch2.num_edges assert batch1.batch_size == batch2.batch_size # Mapped indices of neighbors may be differently sorted ... assert torch.allclose(batch1.x.sort()[0], batch2.x.sort()[0]) assert torch.allclose(batch1.y.sort()[0], batch2.y.sort()[0]) @onlyNeighborSampler def test_custom_hetero_neighbor_loader(): # Initialize feature store, graph store, and reference: feature_store = MyFeatureStore() graph_store = MyGraphStore() data = HeteroData() # Set up node features: x = torch.arange(100) data['paper'].x = x feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None) x = torch.arange(100, 300) data['author'].x = x feature_store.put_tensor(x, group_name='author', attr_name='x', index=None) # COO: edge_index = get_random_edge_index(100, 100, 500, coalesce=True) edge_index = edge_index[:, torch.randperm(edge_index.size(1))] data['paper', 'to', 'paper'].edge_index = edge_index coo = (edge_index[0], edge_index[1]) graph_store.put_edge_index(edge_index=coo, edge_type=('paper', 'to', 'paper'), layout='coo', size=(100, 100)) # CSR: edge_index = get_random_edge_index(100, 200, 1000, coalesce=True) data['paper', 'to', 'author'].edge_index = edge_index adj = to_torch_csr_tensor(edge_index, size=(100, 200)) csr = (adj.crow_indices(), adj.col_indices()) graph_store.put_edge_index(edge_index=csr, edge_type=('paper', 'to', 'author'), layout='csr', size=(100, 200)) # CSC: edge_index = get_random_edge_index(200, 100, 1000, coalesce=True) data['author', 'to', 'paper'].edge_index = edge_index adj = to_torch_csr_tensor(edge_index.flip([0]), size=(100, 200)) csc = (adj.col_indices(), adj.crow_indices()) graph_store.put_edge_index(edge_index=csc, edge_type=('author', 'to', 'paper'), layout='csc', size=(200, 100)) # COO (sorted): edge_index = get_random_edge_index(200, 200, 100, coalesce=True) edge_index = edge_index[:, edge_index[1].argsort()] data['author', 'to', 'author'].edge_index = edge_index coo = (edge_index[0], edge_index[1]) graph_store.put_edge_index(edge_index=coo, edge_type=('author', 'to', 'author'), layout='coo', size=(200, 200), is_sorted=True) # Construct neighbor loaders: loader1 = NeighborLoader(data, batch_size=20, input_nodes=('paper', range(100)), num_neighbors=[-1] * 2) loader2 = NeighborLoader((feature_store, graph_store), batch_size=20, input_nodes=('paper', range(100)), num_neighbors=[-1] * 2) assert str(loader1) == str(loader2) assert len(loader1) == len(loader2) for batch1, batch2 in zip(loader1, loader2): # `loader2` explicitly adds `num_nodes` to the batch: assert len(batch1) + 1 == len(batch2) assert batch1['paper'].batch_size == batch2['paper'].batch_size # Mapped indices of neighbors may be differently sorted ... for node_type in data.node_types: assert torch.allclose( batch1[node_type].x.sort()[0], batch2[node_type].x.sort()[0], ) # ... but should sample the exact same number of edges: for edge_type in data.edge_types: assert batch1[edge_type].num_edges == batch2[edge_type].num_edges @onlyNeighborSampler def test_custom_hetero_neighbor_loader_duplicate(): feature_store = MyFeatureStore() graph_store = MyGraphStore() x = torch.arange(10) feature_store.put_tensor(x, group_name='user', attr_name='x', index=None) edge_index = get_random_edge_index(10, 10, 20, coalesce=True) graph_store.put_edge_index( edge_index=(edge_index[0], edge_index[1]), edge_type=('user', 'user', 'user'), layout='coo', size=(10, 10), ) loader = NeighborLoader( (feature_store, graph_store), batch_size=10, input_nodes=('user', range(10)), num_neighbors=[-1] * 2, ) batch = next(iter(loader)) assert batch.node_types == ['user'] assert batch['user'].num_nodes == 10 assert batch.edge_types == [('user', 'user', 'user')] assert batch['user', 'user'].num_edges == edge_index.size(1) @onlyOnline @withPackage('pyg_lib') def test_temporal_custom_neighbor_loader_on_karate(get_dataset): dataset = get_dataset(name='karate') data = dataset[0] data.time = torch.arange(data.num_nodes, 0, -1) # Initialize feature store, graph store, and reference: feature_store = MyFeatureStore() graph_store = MyGraphStore() hetero_data = HeteroData() feature_store.put_tensor( data.x, group_name='v', attr_name='x', index=None, ) hetero_data['v'].x = data.x feature_store.put_tensor( data.time, group_name='v', attr_name='time', index=None, ) hetero_data['v'].time = data.time # Sort according to time in local neighborhoods: row, col = data.edge_index perm = ((col * (data.num_nodes + 1)) + data.time[row]).argsort() edge_index = data.edge_index[:, perm] graph_store.put_edge_index( edge_index, edge_type=('v', 'to', 'v'), layout='coo', is_sorted=True, size=(data.num_nodes, data.num_nodes), ) hetero_data['v', 'to', 'v'].edge_index = data.edge_index loader1 = NeighborLoader( hetero_data, num_neighbors=[-1, -1], input_nodes='v', time_attr='time', batch_size=128, ) loader2 = NeighborLoader( (feature_store, graph_store), num_neighbors=[-1, -1], input_nodes='v', time_attr='time', batch_size=128, ) for batch1, batch2 in zip(loader1, loader2): assert torch.equal(batch1['v'].time, batch2['v'].time) @withPackage('pyg_lib', 'torch_sparse') def test_pyg_lib_and_torch_sparse_homo_equality(): edge_index = get_random_edge_index(20, 20, 100) adj = to_torch_csr_tensor(edge_index.flip([0]), size=(20, 20)) colptr, row = adj.crow_indices(), adj.col_indices() seed = torch.arange(10) sample = torch.ops.pyg.neighbor_sample out1 = sample(colptr, row, seed, [-1, -1], None, None, None, None, True) sample = torch.ops.torch_sparse.neighbor_sample out2 = sample(colptr, row, seed, [-1, -1], False, True) row1, col1, node_id1, edge_id1 = out1[:4] node_id2, row2, col2, edge_id2 = out2 assert torch.equal(node_id1, node_id2) assert torch.equal(row1, row2) assert torch.equal(col1, col2) assert torch.equal(edge_id1, edge_id2) @withPackage('pyg_lib', 'torch_sparse') def test_pyg_lib_and_torch_sparse_hetero_equality(): edge_index = get_random_edge_index(20, 10, 50) adj = to_torch_csr_tensor(edge_index.flip([0]), size=(10, 20)) colptr1, row1 = adj.crow_indices(), adj.col_indices() edge_index = get_random_edge_index(10, 20, 50) adj = to_torch_csr_tensor(edge_index.flip([0]), size=(20, 10)) colptr2, row2 = adj.crow_indices(), adj.col_indices() node_types = ['paper', 'author'] edge_types = [('paper', 'to', 'author'), ('author', 'to', 'paper')] colptr_dict = { 'paper__to__author': colptr1, 'author__to__paper': colptr2, } row_dict = { 'paper__to__author': row1, 'author__to__paper': row2, } seed_dict = {'paper': torch.arange(1)} num_neighbors_dict = { 'paper__to__author': [-1, -1], 'author__to__paper': [-1, -1], } sample = torch.ops.pyg.hetero_neighbor_sample out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict, num_neighbors_dict, None, None, None, None, True, False, True, False, "uniform", True) sample = torch.ops.torch_sparse.hetero_neighbor_sample out2 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict, num_neighbors_dict, 2, False, True) row1_dict, col1_dict, node_id1_dict, edge_id1_dict = out1[:4] node_id2_dict, row2_dict, col2_dict, edge_id2_dict = out2 assert len(node_id1_dict) == len(node_id2_dict) for key in node_id1_dict.keys(): assert torch.equal(node_id1_dict[key], node_id2_dict[key]) assert len(row1_dict) == len(row2_dict) for key in row1_dict.keys(): assert torch.equal(row1_dict[key], row2_dict[key]) assert len(col1_dict) == len(col2_dict) for key in col1_dict.keys(): assert torch.equal(col1_dict[key], col2_dict[key]) assert len(edge_id1_dict) == len(edge_id2_dict) for key in edge_id1_dict.keys(): assert torch.equal(edge_id1_dict[key], edge_id2_dict[key]) @onlyLinux @onlyNeighborSampler def test_memmap_neighbor_loader(tmp_path): path = osp.join(tmp_path, 'x.npy') x = np.memmap(path, dtype=np.float32, mode='w+', shape=(100, 32)) x[:] = np.random.randn(100, 32) data = Data() data.x = np.memmap(path, dtype=np.float32, mode='r', shape=(100, 32)) data.edge_index = get_random_edge_index(100, 100, 500) assert str(data) == 'Data(x=[100, 32], edge_index=[2, 500])' assert data.num_nodes == 100 loader = NeighborLoader(data, num_neighbors=[5] * 2, batch_size=20, num_workers=2) batch = next(iter(loader)) assert batch.num_nodes <= 100 assert isinstance(batch.x, torch.Tensor) assert batch.x.size() == (batch.num_nodes, 32) @withPackage('pyg_lib') def test_homo_neighbor_loader_sampled_info(): edge_index = torch.tensor([ [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], ]) data = Data(edge_index=edge_index, num_nodes=14) loader = NeighborLoader( data, num_neighbors=[1, 2, 4], batch_size=2, shuffle=False, ) batch = next(iter(loader)) assert batch.num_sampled_nodes == [2, 2, 3, 4] assert batch.num_sampled_edges == [2, 4, 4] @withPackage('pyg_lib') def test_hetero_neighbor_loader_sampled_info(): edge_index = torch.tensor([ [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], ]) data = HeteroData() data['paper'].num_nodes = data['author'].num_nodes = 14 data['paper', 'paper'].edge_index = edge_index data['paper', 'author'].edge_index = edge_index data['author', 'paper'].edge_index = edge_index loader = NeighborLoader( data, num_neighbors=[1, 2, 4], batch_size=2, input_nodes='paper', shuffle=False, ) batch = next(iter(loader)) expected_num_sampled_nodes = { 'paper': [2, 2, 3, 4], 'author': [0, 2, 3, 4], } expected_num_sampled_edges = { ('paper', 'to', 'paper'): [2, 4, 4], ('paper', 'to', 'author'): [0, 4, 4], ('author', 'to', 'paper'): [2, 4, 4], } for node_type in batch.node_types: assert (batch[node_type].num_sampled_nodes == expected_num_sampled_nodes[node_type]) for edge_type in batch.edge_types: assert (batch[edge_type].num_sampled_edges == expected_num_sampled_edges[edge_type]) @withPackage('pyg_lib') def test_neighbor_loader_mapping(): edge_index = torch.tensor([ [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5], [1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11], ]) data = Data(edge_index=edge_index, num_nodes=12) loader = NeighborLoader( data, num_neighbors=[1], batch_size=2, shuffle=True, ) for batch in loader: assert torch.equal( batch.n_id[batch.edge_index], data.edge_index[:, batch.e_id], ) @pytest.mark.skipif( not WITH_WEIGHTED_NEIGHBOR_SAMPLE, reason="'pyg-lib' does not support weighted neighbor sampling", ) def test_weighted_homo_neighbor_loader(): edge_index = torch.tensor([ [1, 3, 0, 4], [2, 2, 1, 3], ]) edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0]) data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight) loader = NeighborLoader( data, input_nodes=torch.tensor([2]), num_neighbors=[1] * 2, batch_size=1, weight_attr='edge_weight', ) assert len(loader) == 1 batch = next(iter(loader)) assert batch.num_nodes == 3 assert batch.n_id.tolist() == [2, 3, 4] assert batch.num_edges == 2 assert batch.n_id[batch.edge_index].tolist() == [[3, 4], [2, 3]] @pytest.mark.skipif( not WITH_WEIGHTED_NEIGHBOR_SAMPLE, reason="'pyg-lib' does not support weighted neighbor sampling", ) def test_weighted_hetero_neighbor_loader(): edge_index = torch.tensor([ [1, 3, 0, 4], [2, 2, 1, 3], ]) edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0]) data = HeteroData() data['paper'].num_nodes = 5 data['paper', 'to', 'paper'].edge_index = edge_index data['paper', 'to', 'paper'].edge_weight = edge_weight loader = NeighborLoader( data, input_nodes=('paper', torch.tensor([2])), num_neighbors=[1] * 2, batch_size=1, weight_attr='edge_weight', ) assert len(loader) == 1 batch = next(iter(loader)) assert batch['paper'].num_nodes == 3 assert batch['paper'].n_id.tolist() == [2, 3, 4] assert batch['paper', 'paper'].num_edges == 2 global_edge_index = batch['paper'].n_id[batch['paper', 'paper'].edge_index] assert global_edge_index.tolist() == [[3, 4], [2, 3]] @pytest.mark.skipif( not WITH_EDGE_TIME_NEIGHBOR_SAMPLE, reason="'pyg-lib' does not support weighted neighbor sampling", ) def test_edge_level_temporal_homo_neighbor_loader(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3], ]) edge_time = torch.arange(edge_index.size(1)) data = Data(edge_index=edge_index, edge_time=edge_time, num_nodes=5) loader = NeighborLoader( data, num_neighbors=[-1, -1], input_time=torch.tensor([4, 4, 4, 4, 4]), time_attr='edge_time', batch_size=1, ) for batch in loader: assert batch.edge_time.numel() == batch.num_edges if batch.edge_time.numel() > 0: assert batch.edge_time.max() <= 4 @pytest.mark.skipif( not WITH_EDGE_TIME_NEIGHBOR_SAMPLE, reason="'pyg-lib' does not support weighted neighbor sampling", ) def test_edge_level_temporal_hetero_neighbor_loader(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3], ]) edge_time = torch.arange(edge_index.size(1)) data = HeteroData() data['A'].num_nodes = 5 data['A', 'A'].edge_index = edge_index data['A', 'A'].edge_time = edge_time loader = NeighborLoader( data, num_neighbors=[-1, -1], input_nodes='A', input_time=torch.tensor([4, 4, 4, 4, 4]), time_attr='edge_time', batch_size=1, ) for batch in loader: assert batch['A', 'A'].edge_time.numel() == batch['A', 'A'].num_edges if batch['A', 'A'].edge_time.numel() > 0: assert batch['A', 'A'].edge_time.max() <= 4 @withCUDA @onlyNeighborSampler @withPackage('torch_frame') def test_neighbor_loader_with_tensor_frame(device): data = Data() data.tf = get_random_tensor_frame(num_rows=100, device=device) data.edge_index = get_random_edge_index(100, 100, 500, device=device) data.edge_attr = get_random_tensor_frame(500, device=device) data.global_tf = get_random_tensor_frame(num_rows=1, device=device) loader = NeighborLoader(data, num_neighbors=[5] * 2, batch_size=20) assert len(loader) == 5 for batch in loader: assert isinstance(batch.tf, TensorFrame) assert batch.tf.device == device assert batch.tf.num_rows == batch.n_id.numel() assert batch.tf == data.tf[batch.n_id] assert isinstance(batch.edge_attr, TensorFrame) assert batch.edge_attr.device == device assert batch.edge_attr.num_rows == batch.e_id.numel() assert batch.edge_attr == data.edge_attr[batch.e_id] assert isinstance(batch.global_tf, TensorFrame) assert batch.global_tf.device == device assert batch.global_tf.num_rows == 1 assert batch.global_tf == data.global_tf @onlyNeighborSampler def test_neighbor_loader_input_id(): data = HeteroData() data['a'].num_nodes = 10 data['b'].num_nodes = 12 row = torch.randint(0, data['a'].num_nodes, (40, )) col = torch.randint(0, data['b'].num_nodes, (40, )) data['a', 'b'].edge_index = torch.stack([row, col], dim=0) data['b', 'a'].edge_index = torch.stack([col, row], dim=0) mask = torch.ones(data['a'].num_nodes, dtype=torch.bool) mask[0] = False loader = NeighborLoader( data, input_nodes=('a', mask), batch_size=2, num_neighbors=[2, 2], ) for i, batch in enumerate(loader): if i < 4: expected = [(2 * i) + 1, (2 * i) + 2] else: expected = [(2 * i) + 1] assert batch['a'].input_id.tolist() == expected @withPackage('pyg_lib') def test_temporal_neighbor_loader_single_link(): data = HeteroData() data['a'].x = torch.arange(10) data['b'].x = torch.arange(10) data['c'].x = torch.arange(10) data['b'].time = torch.arange(0, 10) data['c'].time = torch.arange(1, 11) data['a', 'b'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1) data['b', 'a'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1) data['a', 'c'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1) data['c', 'a'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1) loader = NeighborLoader( data, num_neighbors=[-1], input_nodes='a', time_attr='time', input_time=torch.arange(0, 10), batch_size=10, ) batch = next(iter(loader)) assert batch['a'].num_nodes == 10 assert batch['b'].num_nodes == 10 assert batch['c'].num_nodes == 0 ================================================ FILE: test/loader/test_neighbor_sampler.py ================================================ import numpy as np import torch from torch_geometric.loader import NeighborSampler from torch_geometric.nn.conv import GATConv, SAGEConv from torch_geometric.testing import onlyOnline, withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import erdos_renyi_graph @withPackage('torch_sparse') def test_neighbor_sampler_basic(): edge_index = erdos_renyi_graph(num_nodes=10, edge_prob=0.5) adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(10, 10)).t() E = edge_index.size(1) loader = NeighborSampler(edge_index, sizes=[2, 4], batch_size=2) assert str(loader) == 'NeighborSampler(sizes=[2, 4])' assert len(loader) == 5 for batch_size, n_id, adjs in loader: assert batch_size == 2 assert all(np.isin(n_id, torch.arange(10)).tolist()) assert n_id.unique().size(0) == n_id.size(0) for (edge_index, e_id, size) in adjs: assert int(edge_index[0].max() + 1) <= size[0] assert int(edge_index[1].max() + 1) <= size[1] assert all(np.isin(e_id, torch.arange(E)).tolist()) assert e_id.unique().size(0) == e_id.size(0) assert size[0] >= size[1] out = loader.sample([1, 2]) assert len(out) == 3 loader = NeighborSampler(adj_t, sizes=[2, 4], batch_size=2) for _, _, adjs in loader: for adj_t, _, size in adjs: assert adj_t.size(0) == size[1] assert adj_t.size(1) == size[0] @withPackage('torch_sparse') def test_neighbor_sampler_invalid_kwargs(): # Ignore `collate_fn` and `dataset` arguments: edge_index = torch.tensor([[0, 1], [1, 0]]) NeighborSampler(edge_index, sizes=[-1], collate_fn=None, dataset=None) @onlyOnline @withPackage('torch_sparse') def test_neighbor_sampler_on_cora(get_dataset): dataset = get_dataset(name='Cora') data = dataset[0] batch = torch.arange(10) loader = NeighborSampler(data.edge_index, sizes=[-1, -1, -1], node_idx=batch, batch_size=10) class SAGE(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, 16)) self.convs.append(SAGEConv(16, 16)) self.convs.append(SAGEConv(16, out_channels)) def batch(self, x, adjs): for i, (edge_index, _, size) in enumerate(adjs): x_target = x[:size[1]] # Target nodes are always placed first. x = self.convs[i]((x, x_target), edge_index) return x def full(self, x, edge_index): for conv in self.convs: x = conv(x, edge_index) return x model = SAGE(dataset.num_features, dataset.num_classes) _, n_id, adjs = next(iter(loader)) out1 = model.batch(data.x[n_id], adjs) out2 = model.full(data.x, data.edge_index)[batch] assert torch.allclose(out1, out2, atol=1e-7) class GAT(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(GATConv(in_channels, 16, heads=2)) self.convs.append(GATConv(32, 16, heads=2)) self.convs.append(GATConv(32, out_channels, heads=2, concat=False)) def batch(self, x, adjs): for i, (edge_index, _, size) in enumerate(adjs): x_target = x[:size[1]] # Target nodes are always placed first. x = self.convs[i]((x, x_target), edge_index) return x def full(self, x, edge_index): for conv in self.convs: x = conv(x, edge_index) return x _, n_id, adjs = next(iter(loader)) out1 = model.batch(data.x[n_id], adjs) out2 = model.full(data.x, data.edge_index)[batch] assert torch.allclose(out1, out2, atol=1e-7) ================================================ FILE: test/loader/test_prefetch.py ================================================ import torch from torch_geometric.loader import NeighborLoader, PrefetchLoader from torch_geometric.nn import GraphSAGE from torch_geometric.testing import withCUDA @withCUDA def test_prefetch_loader(device): data = [torch.randn(5, 5) for _ in range(10)] loader = PrefetchLoader(data, device=device) assert str(loader).startswith('PrefetchLoader') assert len(loader) == 10 for i, batch in enumerate(loader): assert batch.device == device assert torch.equal(batch.cpu(), data[i]) if __name__ == '__main__': import argparse from ogb.nodeproppred import PygNodePropPredDataset from tqdm import tqdm parser = argparse.ArgumentParser() parser.add_argument('--num_workers', type=int, default=0) args = parser.parse_args() data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0] model = GraphSAGE( in_channels=data.x.size(-1), hidden_channels=64, num_layers=2, ).cuda() loader = NeighborLoader( data, input_nodes=torch.arange(1024 * 200), batch_size=1024, num_neighbors=[10, 10], num_workers=args.num_workers, persistent_workers=args.num_workers > 0, ) print('Forward pass without prefetching...') for batch in tqdm(loader): with torch.no_grad(): batch = batch.cuda() model(batch.x, batch.edge_index) print('Forward pass with prefetching...') for batch in tqdm(PrefetchLoader(loader)): with torch.no_grad(): model(batch.x, batch.edge_index) ================================================ FILE: test/loader/test_random_node_loader.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.loader import RandomNodeLoader from torch_geometric.testing import get_random_edge_index def test_random_node_loader(): data = Data() data.x = torch.randn(100, 128) data.node_id = torch.arange(100) data.edge_index = get_random_edge_index(100, 100, 500) data.edge_attr = torch.randn(500, 32) loader = RandomNodeLoader(data, num_parts=4, shuffle=True) assert len(loader) == 4 for batch in loader: assert len(batch) == 4 assert batch.node_id.min() >= 0 assert batch.node_id.max() < 100 assert batch.edge_index.size(1) == batch.edge_attr.size(0) assert torch.allclose(batch.x, data.x[batch.node_id]) batch.validate() def test_heterogeneous_random_node_loader(): data = HeteroData() data['paper'].x = torch.randn(100, 128) data['paper'].node_id = torch.arange(100) data['author'].x = torch.randn(200, 128) data['author'].node_id = torch.arange(200) data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 500) data['paper', 'author'].edge_attr = torch.randn(500, 32) data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 400) data['author', 'paper'].edge_attr = torch.randn(400, 32) data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 600) data['paper', 'paper'].edge_attr = torch.randn(600, 32) loader = RandomNodeLoader(data, num_parts=4, shuffle=True) assert len(loader) == 4 for batch in loader: assert len(batch) == 4 assert batch.node_types == data.node_types assert batch.edge_types == data.edge_types batch.validate() ================================================ FILE: test/loader/test_shadow.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.loader import ShaDowKHopSampler from torch_geometric.testing import withPackage from torch_geometric.typing import SparseTensor @withPackage('torch_sparse') def test_shadow_k_hop_sampler(): row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5]) col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]) edge_index = torch.stack([row, col], dim=0) edge_weight = torch.arange(row.size(0)) x = torch.randn(6, 16) y = torch.randint(3, (6, ), dtype=torch.long) data = Data(edge_index=edge_index, edge_weight=edge_weight, x=x, y=y) train_mask = torch.tensor([1, 1, 0, 0, 0, 0], dtype=torch.bool) loader = ShaDowKHopSampler(data, depth=1, num_neighbors=3, node_idx=train_mask, batch_size=2) assert len(loader) == 1 batch1 = next(iter(loader)) assert batch1.num_graphs == len(batch1) == 2 assert batch1.batch.tolist() == [0, 0, 0, 0, 1, 1, 1] assert batch1.ptr.tolist() == [0, 4, 7] assert batch1.root_n_id.tolist() == [0, 5] assert batch1.x.tolist() == x[torch.tensor([0, 1, 2, 3, 0, 1, 2])].tolist() assert batch1.y.tolist() == y[train_mask].tolist() row, col = batch1.edge_index assert row.tolist() == [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6] assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5] e_id = torch.tensor([0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6]) assert batch1.edge_weight.tolist() == edge_weight[e_id].tolist() adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight).t() data = Data(adj_t=adj_t, x=x, y=y) loader = ShaDowKHopSampler(data, depth=1, num_neighbors=3, node_idx=train_mask, batch_size=2) assert len(loader) == 1 batch2 = next(iter(loader)) assert batch2.num_graphs == len(batch2) == 2 assert batch1.batch.tolist() == batch2.batch.tolist() assert batch1.ptr.tolist() == batch2.ptr.tolist() assert batch1.root_n_id.tolist() == batch2.root_n_id.tolist() assert batch1.x.tolist() == batch2.x.tolist() assert batch1.y.tolist() == batch2.y.tolist() row, col, value = batch2.adj_t.t().coo() assert batch1.edge_index[0].tolist() == row.tolist() assert batch1.edge_index[1].tolist() == col.tolist() assert batch1.edge_weight.tolist() == value.tolist() ================================================ FILE: test/loader/test_temporal_dataloader.py ================================================ import pytest import torch from torch_geometric.data import TemporalData from torch_geometric.loader import TemporalDataLoader @pytest.mark.parametrize('batch_size,drop_last', [(4, True), (2, False)]) def test_temporal_dataloader(batch_size, drop_last): src = dst = t = torch.arange(10) msg = torch.randn(10, 16) data = TemporalData(src=src, dst=dst, t=t, msg=msg) loader = TemporalDataLoader( data, batch_size=batch_size, drop_last=drop_last, ) assert len(loader) == 10 // batch_size for i, batch in enumerate(loader): assert len(batch) == batch_size arange = range(len(batch) * i, len(batch) * i + len(batch)) assert batch.src.tolist() == data.src[arange].tolist() assert batch.dst.tolist() == data.dst[arange].tolist() assert batch.t.tolist() == data.t[arange].tolist() assert batch.msg.tolist() == data.msg[arange].tolist() ================================================ FILE: test/loader/test_utils.py ================================================ import pytest import torch from torch_geometric.loader.utils import index_select def test_index_select(): x = torch.randn(3, 5) index = torch.tensor([0, 2]) assert torch.equal(index_select(x, index), x[index]) assert torch.equal(index_select(x, index, dim=-1), x[..., index]) def test_index_select_out_of_range(): with pytest.raises(IndexError, match="out of range"): index_select(torch.randn(3, 5), torch.tensor([0, 2, 3])) ================================================ FILE: test/loader/test_zip_loader.py ================================================ import pytest import torch from torch_geometric.data import Data from torch_geometric.loader import NeighborLoader, ZipLoader from torch_geometric.testing import onlyNeighborSampler @onlyNeighborSampler @pytest.mark.parametrize('filter_per_worker', [True, False]) def test_zip_loader(filter_per_worker): x = torch.arange(100) edge_index = torch.randint(0, 100, (2, 1000)) data = Data(x=x, edge_index=edge_index) loaders = [ NeighborLoader(data, [5], input_nodes=torch.arange(0, 50)), NeighborLoader(data, [5], input_nodes=torch.arange(50, 95)), ] loader = ZipLoader(loaders, batch_size=10, filter_per_worker=filter_per_worker) batches = loader(torch.arange(5)) assert isinstance(batches, tuple) assert len(batches) == 2 assert str(loader) == ('ZipLoader(loaders=[NeighborLoader(), ' 'NeighborLoader()])') assert len(loader) == 5 assert loader.dataset == range(0, 45) for i, (batch1, batch2) in enumerate(loader): n_id1 = batch1.n_id[:batch1.batch_size] n_id2 = batch2.n_id[:batch2.batch_size] if i < 4: assert batch1.batch_size == 10 assert batch2.batch_size == 10 assert torch.equal(n_id1, torch.arange(0 + i * 10, 10 + i * 10)) assert torch.equal(n_id2, torch.arange(50 + i * 10, 60 + i * 10)) else: assert batch1.batch_size == 5 assert batch2.batch_size == 5 assert torch.equal(n_id1, torch.arange(0 + i * 10, 5 + i * 10)) assert torch.equal(n_id2, torch.arange(50 + i * 10, 55 + i * 10)) ================================================ FILE: test/metrics/test_link_pred_metric.py ================================================ from typing import List import pytest import torch from torch_geometric.metrics import ( LinkPredAveragePopularity, LinkPredCoverage, LinkPredDiversity, LinkPredF1, LinkPredHitRatio, LinkPredMAP, LinkPredMetricCollection, LinkPredMRR, LinkPredNDCG, LinkPredPersonalization, LinkPredPrecision, LinkPredRecall, ) from torch_geometric.testing import withCUDA @pytest.mark.parametrize('num_src_nodes', [100]) @pytest.mark.parametrize('num_dst_nodes', [1000]) @pytest.mark.parametrize('num_edges', [3000]) @pytest.mark.parametrize('batch_size', [32]) @pytest.mark.parametrize('k', [1, 10, 100]) def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k): row = torch.randint(0, num_src_nodes, (num_edges, )) col = torch.randint(0, num_dst_nodes, (num_edges, )) edge_label_index = torch.stack([row, col], dim=0) pred = torch.rand(num_src_nodes, num_dst_nodes) pred[row, col] += 0.3 # Offset positive links by a little. pred_index_mat = pred.topk(k, dim=1)[1] metric = LinkPredPrecision(k) assert str(metric) == f'LinkPredPrecision(k={k})' for node_id in torch.split(torch.randperm(num_src_nodes), batch_size): mask = torch.isin(edge_label_index[0], node_id) y_batch, y_index = edge_label_index[:, mask] # Remap `y_batch` back to `[0, batch_size - 1]` range: arange = torch.empty(num_src_nodes, dtype=node_id.dtype) arange[node_id] = torch.arange(node_id.numel()) y_batch = arange[y_batch] metric.update(pred_index_mat[node_id], (y_batch, y_index)) out = metric.compute() metric.reset() values: List[float] = [] for i in range(num_src_nodes): # Naive computation per node: y_index = col[row == i] if y_index.numel() > 0: mask = torch.isin(pred_index_mat[i], y_index) precision = float(mask.sum() / k) values.append(precision) expected = torch.tensor(values).mean() assert torch.allclose(out, expected) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :k - 1], edge_label_index) metric.compute() metric.reset() def test_recall(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) edge_label_weight = torch.tensor([4.0, 1.0, 2.0, 3.0, 0.5]) metric = LinkPredRecall(k=2) assert str(metric) == 'LinkPredRecall(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5)) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() metric = LinkPredRecall(k=2, weighted=True) assert str(metric) == 'LinkPredRecall(k=2, weighted=True)' with pytest.raises(ValueError, match="'edge_label_weight'"): metric.update(pred_index_mat, edge_label_index) metric.update(pred_index_mat, edge_label_index, edge_label_weight) result = metric.compute() metric.reset() assert float(result) == pytest.approx(0.5 * (5.0 / 7.0 + 3.0 / 3.5)) edge_label_weight[0] = -2 metric.update(pred_index_mat, edge_label_index, edge_label_weight) result = metric.compute() metric.reset() assert float(result) == pytest.approx(0.5 * (1.0 / 3.0 + 3.0 / 3.5)) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight) metric.compute() metric.reset() def test_f1(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) metric = LinkPredF1(k=2) assert str(metric) == 'LinkPredF1(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.6500) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_map(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) metric = LinkPredMAP(k=2) assert str(metric) == 'LinkPredMAP(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.6250) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_ndcg(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) edge_label_weight = torch.tensor([1.0, 2.0, 0.1, 3.0, 0.5]) metric = LinkPredNDCG(k=2) assert str(metric) == 'LinkPredNDCG(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.6934264) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() metric = LinkPredNDCG(k=2, weighted=True) assert str(metric) == 'LinkPredNDCG(k=2, weighted=True)' with pytest.raises(ValueError, match="'edge_label_weight'"): metric.update(pred_index_mat, edge_label_index) metric.update(pred_index_mat, edge_label_index, edge_label_weight) result = metric.compute() metric.reset() assert float(result) == pytest.approx(0.7854486) perm = torch.randperm(edge_label_weight.size(0)) metric.update(pred_index_mat, edge_label_index[:, perm], edge_label_weight[perm]) assert metric.compute() == result # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight) metric.compute() metric.reset() def test_mrr(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]]) edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]]) metric = LinkPredMRR(k=2) assert str(metric) == 'LinkPredMRR(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx((1 + 0.5 + 0) / 3) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_hit_ratio(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]]) edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]]) metric = LinkPredHitRatio(k=2) assert str(metric) == 'LinkPredHitRatio(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(2 / 3) # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_coverage(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]]) edge_label_index = torch.empty(2, 0, dtype=torch.long) metric = LinkPredCoverage(k=2, num_dst_nodes=3) assert str(metric) == 'LinkPredCoverage(k=2, num_dst_nodes=3)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() metric.reset() assert metric.mask.sum() == 0 assert float(result) == 1.0 metric = LinkPredCoverage(k=1, num_dst_nodes=4) assert str(metric) == 'LinkPredCoverage(k=1, num_dst_nodes=4)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() metric.reset() assert metric.mask.sum() == 0 assert float(result) == 2 / 4 def test_diversity(): pred_index_mat = torch.tensor([[0, 1, 2], [3, 1, 0]]) category = torch.tensor([0, 1, 2, 0]) edge_label_index = torch.empty(2, 0, dtype=torch.long) metric = LinkPredDiversity(k=3, category=category) assert str(metric) == 'LinkPredDiversity(k=3)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() metric.reset() assert pytest.approx(float(result)) == (1 + 2 / 3) / 2 @withCUDA def test_personalization(device): pred_index_mat = torch.tensor([[0, 1, 2, 3], [2, 1, 0, 4], [1, 0, 2, 5]], device=device) edge_label_index = torch.empty(2, 0, dtype=torch.long, device=device) metric = LinkPredPersonalization(k=4).to(device) assert str(metric) == 'LinkPredPersonalization(k=4)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert result.device == device assert float(result) == 0.25 metric.reset() assert metric.preds == [] metric.update(pred_index_mat[:0], edge_label_index) result = metric.compute() assert result.device == device assert float(result) == 0.0 metric.reset() def test_average_popularity(): pred_index_mat = torch.tensor([[0, 1, 2], [3, 1, 0]]) popularity = torch.tensor([10, 5, 2, 1]) edge_label_index = torch.empty(2, 0, dtype=torch.long) metric = LinkPredAveragePopularity(k=3, popularity=popularity) assert str(metric) == 'LinkPredAveragePopularity(k=3)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() metric.reset() assert pytest.approx(float(result)) == (10 + 5 + 2 + 1 + 5 + 10) / 6 @pytest.mark.parametrize('num_src_nodes', [10]) @pytest.mark.parametrize('num_dst_nodes', [50]) @pytest.mark.parametrize('num_edges', [200]) def test_metric_collection(num_src_nodes, num_dst_nodes, num_edges): metrics = [ LinkPredMAP(k=10), LinkPredPrecision(k=100), LinkPredRecall(k=50), LinkPredF1(k=20), LinkPredMRR(k=40), LinkPredNDCG(k=80), LinkPredCoverage(k=5, num_dst_nodes=num_dst_nodes), ] row = torch.randint(0, num_src_nodes, (num_edges, )) col = torch.randint(0, num_dst_nodes, (num_edges, )) edge_label_index = torch.stack([row, col], dim=0) pred = torch.rand(num_src_nodes, num_dst_nodes) pred[row, col] += 0.3 # Offset positive links by a little. pred_index_mat = pred.argsort(dim=1) metric_collection = LinkPredMetricCollection(metrics) assert str(metric_collection) == ( 'LinkPredMetricCollection([\n' ' LinkPredMAP@10: LinkPredMAP(k=10),\n' ' LinkPredPrecision@100: LinkPredPrecision(k=100),\n' ' LinkPredRecall@50: LinkPredRecall(k=50),\n' ' LinkPredF1@20: LinkPredF1(k=20),\n' ' LinkPredMRR@40: LinkPredMRR(k=40),\n' ' LinkPredNDCG@80: LinkPredNDCG(k=80),\n' ' LinkPredCoverage@5: LinkPredCoverage(k=5, num_dst_nodes=50),\n' '])') assert metric_collection.max_k == 100 expected = {} for metric in metrics: metric.update(pred_index_mat[:, :metric.k], edge_label_index) out = metric.compute() expected[f'{metric.__class__.__name__}@{metric.k}'] = out metric.reset() metric_collection.update(pred_index_mat, edge_label_index) assert metric_collection.compute() == expected metric_collection.reset() def test_empty_ground_truth(): pred = torch.rand(10, 5) pred_index_mat = pred.argsort(dim=1) edge_label_index = torch.empty(2, 0, dtype=torch.long) edge_label_weight = torch.empty(0) metric = LinkPredMAP(k=5) metric.update(pred_index_mat, edge_label_index) assert metric.compute() == 0 metric.reset() metric = LinkPredNDCG(k=5, weighted=True) metric.update(pred_index_mat, edge_label_index, edge_label_weight) assert metric.compute() == 0 metric.reset() ================================================ FILE: test/my_config.yaml ================================================ defaults: - dataset: KarateClub - transform@dataset.transform: - NormalizeFeatures - AddSelfLoops - model: GCN - optimizer: Adam - lr_scheduler: ReduceLROnPlateau - _self_ model: in_channels: 34 out_channels: 4 hidden_channels: 16 num_layers: 2 ================================================ FILE: test/nn/aggr/test_aggr_utils.py ================================================ import torch from torch_geometric.nn.aggr.utils import ( InducedSetAttentionBlock, MultiheadAttentionBlock, PoolingByMultiheadAttention, SetAttentionBlock, ) from torch_geometric.testing import withCUDA @withCUDA def test_multihead_attention_block(device: torch.device): x = torch.randn(2, 4, 8, device=device) y = torch.randn(2, 3, 8, device=device) x_mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool, device=device) y_mask = torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.bool, device=device) block = MultiheadAttentionBlock(8, heads=2, device=device) block.reset_parameters() assert str(block) == ('MultiheadAttentionBlock(8, heads=2, ' 'layer_norm=True, dropout=0.0)') out = block(x, y, x_mask, y_mask) assert out.size() == (2, 4, 8) jit = torch.jit.script(block) assert torch.allclose(jit(x, y, x_mask, y_mask), out) @withCUDA def test_multihead_attention_block_dropout(device: torch.device): x = torch.randn(2, 4, 8, device=device) block = MultiheadAttentionBlock(8, dropout=0.5, device=device) assert not torch.allclose(block(x, x), block(x, x)) def test_set_attention_block(): x = torch.randn(2, 4, 8) mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool) block = SetAttentionBlock(8, heads=2) block.reset_parameters() assert str(block) == ('SetAttentionBlock(8, heads=2, layer_norm=True, ' 'dropout=0.0)') out = block(x, mask) assert out.size() == (2, 4, 8) jit = torch.jit.script(block) assert torch.allclose(jit(x, mask), out) def test_induced_set_attention_block(): x = torch.randn(2, 4, 8) mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool) block = InducedSetAttentionBlock(8, num_induced_points=2, heads=2) assert str(block) == ('InducedSetAttentionBlock(8, num_induced_points=2, ' 'heads=2, layer_norm=True, dropout=0.0)') out = block(x, mask) assert out.size() == (2, 4, 8) jit = torch.jit.script(block) assert torch.allclose(jit(x, mask), out) def test_pooling_by_multihead_attention(): x = torch.randn(2, 4, 8) mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool) block = PoolingByMultiheadAttention(8, num_seed_points=2, heads=2) assert str(block) == ('PoolingByMultiheadAttention(8, num_seed_points=2, ' 'heads=2, layer_norm=True, dropout=0.0)') out = block(x, mask) assert out.size() == (2, 2, 8) jit = torch.jit.script(block) assert torch.allclose(jit(x, mask), out) ================================================ FILE: test/nn/aggr/test_attention.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import MLP from torch_geometric.nn.aggr import AttentionalAggregation @pytest.mark.parametrize('dim', [2, 3]) def test_attentional_aggregation(dim): channels = 16 x = torch.randn(6, channels) if dim == 2 else torch.randn(2, 6, channels) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) gate_nn = MLP([channels, 1], act='relu') nn = MLP([channels, channels], act='relu') aggr = AttentionalAggregation(gate_nn, nn) aggr.reset_parameters() assert str(aggr) == (f'AttentionalAggregation(gate_nn=MLP({channels}, 1), ' f'nn=MLP({channels}, {channels}))') out = aggr(x, index) assert out.size() == (3, channels) if dim == 2 else (2, 3, channels) if (not torch_geometric.typing.WITH_TORCH_SCATTER and (dim == 3 or not torch_geometric.typing.WITH_PT20)): with pytest.raises(ImportError, match="requires the 'torch-scatter'"): aggr(x, ptr=ptr) else: assert torch.allclose(out, aggr(x, ptr=ptr)) ================================================ FILE: test/nn/aggr/test_basic.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import ( MaxAggregation, MeanAggregation, MinAggregation, MulAggregation, PowerMeanAggregation, SoftmaxAggregation, StdAggregation, SumAggregation, VarAggregation, ) def test_validate(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) aggr = MeanAggregation() with pytest.raises(ValueError, match="invalid dimension"): aggr(x, index, dim=-3) with pytest.raises(ValueError, match="invalid 'dim_size'"): aggr(x, ptr=ptr, dim_size=2) with pytest.raises(ValueError, match="invalid 'dim_size'"): aggr(x, index, dim_size=2) @pytest.mark.parametrize('Aggregation', [ MeanAggregation, SumAggregation, MaxAggregation, MinAggregation, MulAggregation, VarAggregation, StdAggregation, ]) def test_basic_aggregation(Aggregation): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) aggr = Aggregation() assert str(aggr) == f'{Aggregation.__name__}()' out = aggr(x, index) assert out.size() == (3, x.size(1)) if isinstance(aggr, MulAggregation): with pytest.raises(RuntimeError, match="requires 'index'"): aggr(x, ptr=ptr) elif (not torch_geometric.typing.WITH_TORCH_SCATTER and not torch_geometric.typing.WITH_PT20): with pytest.raises(ImportError, match="requires the 'torch-scatter'"): aggr(x, ptr=ptr) else: assert torch.allclose(out, aggr(x, ptr=ptr)) def test_var_aggregation(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) var_aggr = VarAggregation() out = var_aggr(x, index) mean_aggr = MeanAggregation() expected = mean_aggr((x - mean_aggr(x, index)[index]).pow(2), index) assert torch.allclose(out, expected, atol=1e-6) def test_empty_std_aggregation(): aggr = StdAggregation() x = torch.empty(0, 6).reshape(0, 6) index = torch.empty(0, dtype=torch.long) out = aggr(x, index, dim_size=5) assert out.size() == (5, 6) assert float(out.abs().sum()) == 0.0 @pytest.mark.parametrize('Aggregation', [ SoftmaxAggregation, PowerMeanAggregation, ]) @pytest.mark.parametrize('learn', [True, False]) def test_learnable_aggregation(Aggregation, learn): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) aggr = Aggregation(learn=learn) assert str(aggr) == f'{Aggregation.__name__}(learn={learn})' out = aggr(x, index) assert out.size() == (3, x.size(1)) if (not torch_geometric.typing.WITH_TORCH_SCATTER and not torch_geometric.typing.WITH_PT20): with pytest.raises(ImportError, match="requires the 'torch-scatter'"): aggr(x, ptr=ptr) else: assert torch.allclose(out, aggr(x, ptr=ptr)) if learn: out.mean().backward() for param in aggr.parameters(): assert not torch.isnan(param.grad).any() @pytest.mark.parametrize('Aggregation', [ SoftmaxAggregation, PowerMeanAggregation, ]) def test_learnable_channels_aggregation(Aggregation): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) aggr = Aggregation(learn=True, channels=16) assert str(aggr) == f'{Aggregation.__name__}(learn=True)' out = aggr(x, index) assert out.size() == (3, x.size(1)) if (not torch_geometric.typing.WITH_TORCH_SCATTER and not torch_geometric.typing.WITH_PT20): with pytest.raises(ImportError, match="requires the 'torch-scatter'"): aggr(x, ptr=ptr) else: assert torch.allclose(out, aggr(x, ptr=ptr)) out.mean().backward() for param in aggr.parameters(): assert not torch.isnan(param.grad).any() ================================================ FILE: test/nn/aggr/test_deep_sets.py ================================================ import torch from torch_geometric.nn import DeepSetsAggregation, Linear def test_deep_sets_aggregation(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) aggr = DeepSetsAggregation( local_nn=Linear(16, 32), global_nn=Linear(32, 64), ) aggr.reset_parameters() assert str(aggr) == ('DeepSetsAggregation(' 'local_nn=Linear(16, 32, bias=True), ' 'global_nn=Linear(32, 64, bias=True))') out = aggr(x, index) assert out.size() == (3, 64) ================================================ FILE: test/nn/aggr/test_equilibrium.py ================================================ import pytest import torch from torch_geometric.nn.aggr import EquilibriumAggregation @pytest.mark.parametrize('iter', [0, 1, 5]) @pytest.mark.parametrize('alpha', [0, .1, 5]) def test_equilibrium(iter, alpha): batch_size = 10 feature_channels = 3 output_channels = 2 x = torch.randn(batch_size, feature_channels) model = EquilibriumAggregation(feature_channels, output_channels, num_layers=[10, 10], grad_iter=iter) assert str(model) == 'EquilibriumAggregation()' out = model(x) assert out.size() == (1, 2) out = model(x, dim_size=3) assert out.size() == (3, 2) assert torch.all(out[1:, :] == 0) @pytest.mark.parametrize('iter', [0, 1, 5]) @pytest.mark.parametrize('alpha', [0, .1, 5]) def test_equilibrium_batch(iter, alpha): batch_1, batch_2 = 4, 6 feature_channels = 3 output_channels = 2 x = torch.randn(batch_1 + batch_2, feature_channels) batch = torch.tensor([0 for _ in range(batch_1)] + [1 for _ in range(batch_2)]) model = EquilibriumAggregation(feature_channels, output_channels, num_layers=[10, 10], grad_iter=iter) assert str(model) == 'EquilibriumAggregation()' out = model(x, batch) assert out.size() == (2, 2) out = model(x, dim_size=3) assert out.size() == (3, 2) assert torch.all(out[1:, :] == 0) ================================================ FILE: test/nn/aggr/test_fused.py ================================================ import pytest import torch from torch_geometric.nn.aggr.fused import FusedAggregation from torch_geometric.nn.resolver import aggregation_resolver from torch_geometric.profile import benchmark @pytest.mark.parametrize('aggrs', [ ['sum', 'mean', 'min', 'max', 'mul', 'var', 'std'], ['sum', 'min', 'max', 'mul', 'var', 'std'], ['min', 'max', 'mul', 'var', 'std'], ['mean', 'min', 'max', 'mul', 'var', 'std'], ['sum', 'min', 'max', 'mul', 'std'], ['mean', 'min', 'max', 'mul', 'std'], ['min', 'max', 'mul', 'std'], ]) def test_fused_aggregation(aggrs): aggrs = [aggregation_resolver(aggr) for aggr in aggrs] x = torch.randn(6, 1) y = x.clone() index = torch.tensor([0, 0, 1, 1, 1, 3]) x.requires_grad_(True) y.requires_grad_(True) aggr = FusedAggregation(aggrs) assert str(aggr) == 'FusedAggregation()' out = torch.cat(aggr(x, index), dim=-1) expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1) assert torch.allclose(out, expected, atol=1e-5) jit = torch.jit.script(aggr) assert torch.allclose(torch.cat(jit(x, index), dim=-1), out, atol=1e-5) out.mean().backward() assert x.grad is not None expected.mean().backward() assert y.grad is not None assert torch.allclose(x.grad, y.grad, atol=1e-5) def test_empty_fused_std_aggregation(): aggrs = [aggregation_resolver(aggr) for aggr in ['mean', 'var', 'std']] aggr = FusedAggregation(aggrs) x = torch.empty(0, 6).reshape(0, 6) index = torch.empty(0, dtype=torch.long) out = torch.cat(aggr(x, index, dim_size=5), dim=-1) assert out.size() == (5, 18) assert float(out.abs().sum()) == 0.0 if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() num_nodes, num_edges = 1_000, 50_000 x = torch.randn(num_edges, 64, device=args.device) index = torch.randint(num_nodes, (num_edges, ), device=args.device) aggrs = ['sum', 'mean', 'max', 'std'] print(f'Aggregators: {", ".join(aggrs)}') aggrs = [aggregation_resolver(aggr) for aggr in aggrs] fused_aggregation = FusedAggregation(aggrs) def naive_aggr(x, index, dim_size): outs = [aggr(x, index, dim_size=dim_size) for aggr in aggrs] return torch.cat(outs, dim=-1) def fused_aggr(x, index, dim_size): outs = fused_aggregation(x, index, dim_size=dim_size) return torch.cat(outs, dim=-1) benchmark( funcs=[naive_aggr, fused_aggr], func_names=['Naive', 'Fused'], args=(x, index, num_nodes), num_steps=100 if args.device == 'cpu' else 1000, num_warmups=50 if args.device == 'cpu' else 500, backward=args.backward, ) ================================================ FILE: test/nn/aggr/test_gmt.py ================================================ import torch from torch_geometric.nn.aggr import GraphMultisetTransformer from torch_geometric.testing import is_full_test def test_graph_multiset_transformer(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) aggr = GraphMultisetTransformer(16, k=2, heads=2) aggr.reset_parameters() assert str(aggr) == ('GraphMultisetTransformer(16, k=2, heads=2, ' 'layer_norm=False, dropout=0.0)') out = aggr(x, index) assert out.size() == (3, 16) if is_full_test(): jit = torch.jit.script(aggr) assert torch.allclose(jit(x, index), out) ================================================ FILE: test/nn/aggr/test_gru.py ================================================ import torch from torch_geometric.nn import GRUAggregation def test_gru_aggregation(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) aggr = GRUAggregation(16, 32) assert str(aggr) == 'GRUAggregation(16, 32)' out = aggr(x, index) assert out.size() == (3, 32) ================================================ FILE: test/nn/aggr/test_lcm.py ================================================ from itertools import product import pytest import torch from torch_geometric.nn import LCMAggregation from torch_geometric.profile import benchmark def test_lcm_aggregation_with_project(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) aggr = LCMAggregation(16, 32) assert str(aggr) == 'LCMAggregation(16, 32, project=True)' out = aggr(x, index) assert out.size() == (3, 32) def test_lcm_aggregation_without_project(): x = torch.randn(5, 16) index = torch.tensor([0, 1, 1, 2, 2]) aggr = LCMAggregation(16, 16, project=False) assert str(aggr) == 'LCMAggregation(16, 16, project=False)' out = aggr(x, index) assert out.size() == (3, 16) def test_lcm_aggregation_error_handling(): with pytest.raises(ValueError, match="must be projected"): LCMAggregation(16, 32, project=False) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() channels = 128 batch_size_list = [2**i for i in range(10, 12)] num_nodes_list = [2**i for i in range(15, 18)] aggr = LCMAggregation(channels, channels, project=False) aggr = aggr.to(args.device) funcs = [] func_names = [] args_list = [] for batch_size, num_nodes in product(batch_size_list, num_nodes_list): x = torch.randn((num_nodes, channels), device=args.device) index = torch.randint(0, batch_size, (num_nodes, ), device=args.device) index = index.sort()[0] funcs.append(aggr) func_names.append(f'B={batch_size}, N={num_nodes}') args_list.append((x, index)) benchmark( funcs=funcs, func_names=func_names, args=args_list, num_steps=10 if args.device == 'cpu' else 100, num_warmups=5 if args.device == 'cpu' else 50, backward=args.backward, progress_bar=True, ) ================================================ FILE: test/nn/aggr/test_lstm.py ================================================ import pytest import torch from torch_geometric.nn import LSTMAggregation def test_lstm_aggregation(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) aggr = LSTMAggregation(16, 32) assert str(aggr) == 'LSTMAggregation(16, 32)' with pytest.raises(ValueError, match="is not sorted"): aggr(x, torch.tensor([0, 1, 0, 1, 2, 1])) out = aggr(x, index) assert out.size() == (3, 32) ================================================ FILE: test/nn/aggr/test_mlp_aggr.py ================================================ import torch from torch_geometric.nn import MLPAggregation def test_mlp_aggregation(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) aggr = MLPAggregation( in_channels=16, out_channels=32, max_num_elements=3, num_layers=1, ) assert str(aggr) == 'MLPAggregation(16, 32, max_num_elements=3)' out = aggr(x, index) assert out.size() == (3, 32) ================================================ FILE: test/nn/aggr/test_multi.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import MultiAggregation @pytest.mark.parametrize('multi_aggr_tuple', [ (dict(mode='cat'), 3), (dict(mode='proj', mode_kwargs=dict(in_channels=16, out_channels=16)), 1), (dict(mode='attn', mode_kwargs=dict(in_channels=16, out_channels=16, num_heads=4)), 1), (dict(mode='sum'), 1), (dict(mode='mean'), 1), (dict(mode='max'), 1), (dict(mode='min'), 1), (dict(mode='logsumexp'), 1), (dict(mode='std'), 1), (dict(mode='var'), 1), ]) def test_multi_aggr(multi_aggr_tuple): # The 'cat' combine mode will expand the output dimensions by # the number of aggregators which is 3 here, while the other # modes keep output dimensions unchanged. aggr_kwargs, expand = multi_aggr_tuple x = torch.randn(7, 16) index = torch.tensor([0, 0, 1, 1, 1, 2, 3]) ptr = torch.tensor([0, 2, 5, 6, 7]) aggrs = ['mean', 'sum', 'max'] aggr = MultiAggregation(aggrs, **aggr_kwargs) aggr.reset_parameters() assert str(aggr) == ('MultiAggregation([\n' ' MeanAggregation(),\n' ' SumAggregation(),\n' ' MaxAggregation(),\n' f"], mode={aggr_kwargs['mode']})") out = aggr(x, index) assert out.size() == (4, expand * x.size(1)) if (not torch_geometric.typing.WITH_TORCH_SCATTER and not torch_geometric.typing.WITH_PT20): with pytest.raises(ImportError, match="requires the 'torch-scatter'"): aggr(x, ptr=ptr) else: assert torch.allclose(out, aggr(x, ptr=ptr)) jit = torch.jit.script(aggr) assert torch.allclose(out, jit(x, index)) ================================================ FILE: test/nn/aggr/test_patch_transformer.py ================================================ import torch from torch_geometric.nn import PatchTransformerAggregation from torch_geometric.testing import withCUDA @withCUDA def test_patch_transformer_aggregation(device: torch.device) -> None: aggr = PatchTransformerAggregation( in_channels=16, out_channels=32, patch_size=2, hidden_channels=8, num_transformer_blocks=1, heads=2, dropout=0.2, aggr=['sum', 'mean', 'min', 'max', 'var', 'std'], device=device, ) aggr.reset_parameters() assert str(aggr) == 'PatchTransformerAggregation(16, 32, patch_size=2)' index = torch.tensor([0, 0, 1, 1, 1, 2], device=device) x = torch.randn(index.size(0), 16, device=device) out = aggr(x, index) assert out.device == device assert out.size() == (3, aggr.out_channels) ================================================ FILE: test/nn/aggr/test_quantile.py ================================================ import pytest import torch from torch_geometric.nn import MedianAggregation, QuantileAggregation @pytest.mark.parametrize('q', [0., .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.]) @pytest.mark.parametrize('interpolation', QuantileAggregation.interpolations) @pytest.mark.parametrize('dim', [0, 1]) @pytest.mark.parametrize('dim_size', [None, 15]) @pytest.mark.parametrize('fill_value', [0.0, 10.0]) def test_quantile_aggregation(q, interpolation, dim, dim_size, fill_value): x = torch.tensor([ [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 0.0, 1.0], [2.0, 3.0, 4.0], [5.0, 6.0, 7.0], [8.0, 9.0, 0.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], ]) index = torch.zeros(x.size(dim), dtype=torch.long) aggr = QuantileAggregation(q=q, interpolation=interpolation, fill_value=fill_value) assert str(aggr) == f"QuantileAggregation(q={q})" out = aggr(x, index, dim=dim, dim_size=dim_size) expected = x.quantile(q, dim, interpolation=interpolation, keepdim=True) assert torch.allclose(out.narrow(dim, 0, 1), expected) if out.size(0) > index.max() + 1: padding = out.narrow(dim, 1, out.size(dim) - 1) assert torch.allclose(padding, torch.tensor(fill_value)) def test_median_aggregation(): x = torch.tensor([ [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 0.0, 1.0], [2.0, 3.0, 4.0], [5.0, 6.0, 7.0], [8.0, 9.0, 0.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], ]) aggr = MedianAggregation() assert str(aggr) == "MedianAggregation()" index = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2]) assert aggr(x, index).tolist() == [ [3.0, 1.0, 2.0], [5.0, 6.0, 4.0], [4.0, 5.0, 6.0], ] index = torch.tensor([0, 1, 0]) assert aggr(x, index, dim=1).tolist() == [ [0.0, 1.0], [3.0, 4.0], [6.0, 7.0], [1.0, 0.0], [2.0, 3.0], [5.0, 6.0], [0.0, 9.0], [1.0, 2.0], [4.0, 5.0], [7.0, 8.0], ] def test_quantile_aggregation_multi(): x = torch.tensor([ [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 0.0, 1.0], [2.0, 3.0, 4.0], [5.0, 6.0, 7.0], [8.0, 9.0, 0.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], ]) index = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2]) qs = [0.25, 0.5, 0.75] assert torch.allclose( QuantileAggregation(qs)(x, index), torch.cat([QuantileAggregation(q)(x, index) for q in qs], dim=-1), ) def test_quantile_aggregation_validate(): with pytest.raises(ValueError, match="at least one quantile"): QuantileAggregation(q=[]) with pytest.raises(ValueError, match="must be in the range"): QuantileAggregation(q=-1) with pytest.raises(ValueError, match="Invalid interpolation method"): QuantileAggregation(q=0.5, interpolation=None) ================================================ FILE: test/nn/aggr/test_scaler.py ================================================ import pytest import torch from torch_geometric.nn import DegreeScalerAggregation @pytest.mark.parametrize('train_norm', [True, False]) def test_degree_scaler_aggregation(train_norm): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) deg = torch.tensor([0, 3, 0, 1, 1, 0]) aggr = ['mean', 'sum', 'max'] scaler = [ 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' ] aggr = DegreeScalerAggregation(aggr, scaler, deg, train_norm=train_norm) assert str(aggr) == 'DegreeScalerAggregation()' out = aggr(x, index) assert out.size() == (3, 240) assert torch.allclose(torch.jit.script(aggr)(x, index), out) with pytest.raises(NotImplementedError, match="requires 'index'"): aggr(x, ptr=ptr) ================================================ FILE: test/nn/aggr/test_set2set.py ================================================ import torch from torch_geometric.nn.aggr import Set2Set def test_set2set(): set2set = Set2Set(in_channels=2, processing_steps=1) assert str(set2set) == 'Set2Set(2, 4)' N = 4 x_1, batch_1 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long) out_1 = set2set(x_1, batch_1).view(-1) N = 6 x_2, batch_2 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long) out_2 = set2set(x_2, batch_2).view(-1) x, batch = torch.cat([x_1, x_2]), torch.cat([batch_1, batch_2 + 1]) out = set2set(x, batch) assert out.size() == (2, 4) assert torch.allclose(out_1, out[0]) assert torch.allclose(out_2, out[1]) x, batch = torch.cat([x_2, x_1]), torch.cat([batch_2, batch_1 + 1]) out = set2set(x, batch) assert out.size() == (2, 4) assert torch.allclose(out_1, out[1]) assert torch.allclose(out_2, out[0]) ================================================ FILE: test/nn/aggr/test_set_transformer.py ================================================ import warnings import torch import torch_geometric.typing from torch_geometric.nn.aggr import SetTransformerAggregation from torch_geometric.testing import is_full_test def test_set_transformer_aggregation(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 3]) aggr = SetTransformerAggregation(16, num_seed_points=2, heads=2) aggr.reset_parameters() assert str(aggr) == ('SetTransformerAggregation(16, num_seed_points=2, ' 'heads=2, layer_norm=False, dropout=0.0)') out = aggr(x, index) assert out.size() == (4, 2 * 16) assert out.isnan().sum() == 0 if torch_geometric.typing.WITH_PT25: if not out[2].abs().sum() != 0: warnings.warn("'SetTransformerAggregation' broken on PyTorch>2.4", stacklevel=2) else: assert out[2].abs().sum() == 0 if is_full_test(): jit = torch.jit.script(aggr) assert torch.allclose(jit(x, index), out) ================================================ FILE: test/nn/aggr/test_sort.py ================================================ import torch from torch_geometric.nn.aggr import SortAggregation def test_sort_aggregation(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) aggr = SortAggregation(k=5) assert str(aggr) == 'SortAggregation(k=5)' out = aggr(x, index) assert out.size() == (2, 5 * 4) out_dim = out = aggr(x, index, dim=0) assert torch.allclose(out_dim, out) out = out.view(2, 5, 4) # First graph output has been filled up with zeros. assert out[0, -1].tolist() == [0, 0, 0, 0] # Nodes are sorted. assert torch.equal(out[0, :4, -1].argsort(), 3 - torch.arange(4)) assert torch.equal(out[1, :, -1].argsort(), 4 - torch.arange(5)) def test_sort_aggregation_smaller_than_k(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) # Set k which is bigger than both N_1=4 and N_2=6. aggr = SortAggregation(k=10) assert str(aggr) == 'SortAggregation(k=10)' out = aggr(x, index) assert out.size() == (2, 10 * 4) out_dim = out = aggr(x, index, dim=0) assert torch.allclose(out_dim, out) out = out.view(2, 10, 4) # Both graph outputs have been filled up with zeros. assert out[0, -1].tolist() == [0, 0, 0, 0] assert out[1, -1].tolist() == [0, 0, 0, 0] # Nodes are sorted. assert torch.equal(out[0, :4, -1].argsort(), 3 - torch.arange(4)) assert torch.equal(out[1, :6, -1].argsort(), 5 - torch.arange(6)) def test_sort_aggregation_dim_size(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) aggr = SortAggregation(k=5) assert str(aggr) == 'SortAggregation(k=5)' # expand batch output by 1 out = aggr(x, index, dim_size=3) assert out.size() == (3, 5 * 4) out = out.view(3, 5, 4) # Both first and last graph outputs have been filled up with zeros. assert out[0, -1].tolist() == [0, 0, 0, 0] assert out[2, -1].tolist() == [0, 0, 0, 0] ================================================ FILE: test/nn/aggr/test_variance_preserving.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import ( MeanAggregation, SumAggregation, VariancePreservingAggregation, ) def test_variance_preserving(): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 3]) ptr = torch.tensor([0, 2, 5, 5, 6]) vpa_aggr = VariancePreservingAggregation() mean_aggr = MeanAggregation() sum_aggr = SumAggregation() out_vpa = vpa_aggr(x, index) out_mean = mean_aggr(x, index) out_sum = sum_aggr(x, index) # Equivalent formulation: expected = torch.sqrt(out_mean.abs() * out_sum.abs()) * out_sum.sign() assert out_vpa.size() == (4, 16) assert torch.allclose(out_vpa, expected) if (not torch_geometric.typing.WITH_TORCH_SCATTER and not torch_geometric.typing.WITH_PT20): with pytest.raises(ImportError, match="requires the 'torch-scatter'"): vpa_aggr(x, ptr=ptr) else: assert torch.allclose(out_vpa, vpa_aggr(x, ptr=ptr)) ================================================ FILE: test/nn/attention/test_performer_attention.py ================================================ import torch from torch_geometric.nn.attention import PerformerAttention def test_performer_attention(): x = torch.randn(1, 4, 16) mask = torch.ones([1, 4], dtype=torch.bool) attn = PerformerAttention(channels=16, heads=4) out = attn(x, mask) assert out.shape == (1, 4, 16) assert str(attn) == ('PerformerAttention(heads=4, ' 'head_channels=64 kernel=ReLU())') ================================================ FILE: test/nn/attention/test_polynormer_attention.py ================================================ import torch from torch_geometric.nn.attention import PolynormerAttention def test_performer_attention(): x = torch.randn(1, 4, 16) mask = torch.ones([1, 4], dtype=torch.bool) attn = PolynormerAttention(channels=16, heads=4) out = attn(x, mask) assert out.shape == (1, 4, 256) assert str(attn) == 'PolynormerAttention(heads=4, head_channels=64)' ================================================ FILE: test/nn/attention/test_qformer.py ================================================ import torch from torch_geometric.nn.attention import QFormer def test_qformer(): x = torch.randn(1, 4, 16) attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4, num_layers=2) out = attn(x) assert out.shape == (1, 4, 32) assert str(attn) == ('QFormer(num_heads=4, num_layers=2)') ================================================ FILE: test/nn/conv/cugraph/test_cugraph_gat_conv.py ================================================ import pytest import torch from torch_geometric import EdgeIndex from torch_geometric.nn import CuGraphGATConv, GATConv from torch_geometric.testing import onlyCUDA, withPackage @onlyCUDA @withPackage('pylibcugraphops>=23.02') @pytest.mark.parametrize('bias', [True, False]) @pytest.mark.parametrize('bipartite', [True, False]) @pytest.mark.parametrize('concat', [True, False]) @pytest.mark.parametrize('edge_attr', [True, False]) @pytest.mark.parametrize('heads', [1, 2, 3]) @pytest.mark.parametrize('max_num_neighbors', [8, None]) def test_gat_conv_equality(bias, bipartite, concat, edge_attr, heads, max_num_neighbors): in_channels, out_channels = 5, 2 kwargs = dict(bias=bias, concat=concat) size = (10, 8) if bipartite else (10, 10) x = torch.rand(size[0], in_channels, device='cuda') edge_index = torch.tensor([ [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], ], device='cuda') conv1 = GATConv(in_channels, out_channels, heads, add_self_loops=False, **kwargs).cuda() conv2 = CuGraphGATConv(in_channels, out_channels, heads, **kwargs).cuda() with torch.no_grad(): conv2.lin.weight.data[:, :] = conv1.lin.weight.data conv2.att.data[:heads * out_channels] = conv1.att_src.data.flatten() conv2.att.data[heads * out_channels:] = conv1.att_dst.data.flatten() if edge_attr and not bipartite: e_attrs = torch.randn(size=(edge_index.size(1), 10)) out1 = conv1(x, edge_index, edge_attr=e_attrs) out2 = conv2( x, EdgeIndex(edge_index, sparse_size=size), max_num_neighbors=max_num_neighbors, edge_attr=e_attrs, ) else: if bipartite: out1 = conv1((x, x[:size[1]]), edge_index) else: out1 = conv1(x, edge_index) out2 = conv2( x, EdgeIndex(edge_index, sparse_size=size), max_num_neighbors=max_num_neighbors, ) assert torch.allclose(out1, out2, atol=1e-3) grad_output = torch.rand_like(out1) out1.backward(grad_output) out2.backward(grad_output) assert torch.allclose(conv1.lin.weight.grad, conv2.lin.weight.grad, atol=1e-3) assert torch.allclose(conv1.att_src.grad.flatten(), conv2.att.grad[:heads * out_channels], atol=1e-3) assert torch.allclose(conv1.att_dst.grad.flatten(), conv2.att.grad[heads * out_channels:], atol=1e-3) if bias: assert torch.allclose(conv1.bias.grad, conv2.bias.grad, atol=1e-3) ================================================ FILE: test/nn/conv/cugraph/test_cugraph_rgcn_conv.py ================================================ import pytest import torch from torch_geometric import EdgeIndex from torch_geometric.nn import CuGraphRGCNConv from torch_geometric.nn import FastRGCNConv as RGCNConv from torch_geometric.testing import onlyCUDA, withPackage @onlyCUDA @withPackage('pylibcugraphops>=23.02') @pytest.mark.parametrize('aggr', ['add', 'sum', 'mean']) @pytest.mark.parametrize('bias', [True, False]) @pytest.mark.parametrize('bipartite', [True, False]) @pytest.mark.parametrize('max_num_neighbors', [8, None]) @pytest.mark.parametrize('num_bases', [1, 2, None]) @pytest.mark.parametrize('root_weight', [True, False]) def test_rgcn_conv_equality(aggr, bias, bipartite, max_num_neighbors, num_bases, root_weight): in_channels, out_channels, num_relations = (4, 2, 3) kwargs = dict(aggr=aggr, bias=bias, num_bases=num_bases, root_weight=root_weight) size = (10, 8) if bipartite else (10, 10) x = torch.rand(size[0], in_channels, device='cuda') edge_index = torch.tensor([ [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], ], device='cuda') edge_type = torch.tensor([1, 2, 1, 0, 2, 1, 2, 0, 2, 2, 1, 1, 1, 2, 2]).cuda() torch.manual_seed(12345) conv1 = RGCNConv(in_channels, out_channels, num_relations, **kwargs).cuda() torch.manual_seed(12345) conv2 = CuGraphRGCNConv(in_channels, out_channels, num_relations, **kwargs).cuda() if bipartite: out1 = conv1((x, x[:size[1]]), edge_index, edge_type) else: out1 = conv1(x, edge_index, edge_type) out2 = conv2( x, EdgeIndex(edge_index, sparse_size=size), edge_type, max_num_neighbors=max_num_neighbors, ) assert torch.allclose(out1, out2, atol=1e-3) grad_out = torch.rand_like(out1) out1.backward(grad_out) out2.backward(grad_out) end = -1 if root_weight else None assert torch.allclose(conv1.weight.grad, conv2.weight.grad[:end], atol=1e-3) if root_weight: assert torch.allclose(conv1.root.grad, conv2.weight.grad[-1], atol=1e-3) if num_bases is not None: assert torch.allclose(conv1.comp.grad, conv2.comp.grad, atol=1e-3) ================================================ FILE: test/nn/conv/cugraph/test_cugraph_sage_conv.py ================================================ import pytest import torch from torch_geometric import EdgeIndex from torch_geometric.nn import CuGraphSAGEConv, SAGEConv from torch_geometric.testing import onlyCUDA, withPackage @onlyCUDA @withPackage('pylibcugraphops>=23.02') @pytest.mark.parametrize('aggr', ['sum', 'mean', 'min', 'max']) @pytest.mark.parametrize('bias', [True, False]) @pytest.mark.parametrize('bipartite', [True, False]) @pytest.mark.parametrize('max_num_neighbors', [8, None]) @pytest.mark.parametrize('normalize', [True, False]) @pytest.mark.parametrize('root_weight', [True, False]) def test_sage_conv_equality(aggr, bias, bipartite, max_num_neighbors, normalize, root_weight): in_channels, out_channels = (8, 16) kwargs = dict(aggr=aggr, bias=bias, normalize=normalize, root_weight=root_weight) size = (10, 8) if bipartite else (10, 10) x = torch.rand(size[0], in_channels, device='cuda') edge_index = torch.tensor([ [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], ], device='cuda') conv1 = SAGEConv(in_channels, out_channels, **kwargs).cuda() conv2 = CuGraphSAGEConv(in_channels, out_channels, **kwargs).cuda() with torch.no_grad(): conv2.lin.weight.data[:, :in_channels] = conv1.lin_l.weight.data if root_weight: conv2.lin.weight.data[:, in_channels:] = conv1.lin_r.weight.data if bias: conv2.lin.bias.data[:] = conv1.lin_l.bias.data if bipartite: out1 = conv1((x, x[:size[1]]), edge_index) else: out1 = conv1(x, edge_index) out2 = conv2( x, EdgeIndex(edge_index, sparse_size=size), max_num_neighbors=max_num_neighbors, ) assert torch.allclose(out1, out2, atol=1e-6) grad_out = torch.rand_like(out1) out1.backward(grad_out) out2.backward(grad_out) assert torch.allclose( conv1.lin_l.weight.grad, conv2.lin.weight.grad[:, :in_channels], atol=1e-6, ) if root_weight: assert torch.allclose( conv1.lin_r.weight.grad, conv2.lin.weight.grad[:, in_channels:], atol=1e-6, ) if bias: assert torch.allclose( conv1.lin_l.bias.grad, conv2.lin.bias.grad, atol=1e-6, ) ================================================ FILE: test/nn/conv/test_agnn_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import AGNNConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('requires_grad', [True, False]) def test_agnn_conv(requires_grad): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = AGNNConv(requires_grad=requires_grad) assert str(conv) == 'AGNNConv()' out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj2.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_antisymmetric_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import AntiSymmetricConv from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_antisymmetric_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = AntiSymmetricConv(8) assert str(conv) == ('AntiSymmetricConv(8, phi=GCNConv(8, 8), ' 'num_iters=1, epsilon=0.1, gamma=0.1)') out1 = conv(x, edge_index) assert out1.size() == (4, 8) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 8) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) ================================================ FILE: test/nn/conv/test_appnp.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import APPNP from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_appnp(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = APPNP(K=3, alpha=0.1, cached=True) assert str(conv) == 'APPNP(K=3, alpha=0.1)' out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj1.t()), out, rtol=1e-5, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out, rtol=1e-5, atol=1e-6) # Run again to test the cached functionality: assert conv._cached_edge_index is not None assert torch.allclose(conv(x, edge_index), conv(x, adj1.t()), rtol=1e-5, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert conv._cached_adj_t is not None assert torch.allclose(conv(x, edge_index), conv(x, adj2.t()), rtol=1e-5, atol=1e-6) conv.reset_parameters() assert conv._cached_edge_index is None assert conv._cached_adj_t is None if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out, rtol=1e-5, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj2.t()), out, rtol=1e-5, atol=1e-6) def test_appnp_dropout(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) # With dropout probability of 1.0, the final output equals to alpha * x: conv = APPNP(K=2, alpha=0.1, dropout=1.0) assert torch.allclose(0.1 * x, conv(x, edge_index), rtol=1e-5, atol=1e-6) assert torch.allclose(0.1 * x, conv(x, adj1.t()), rtol=1e-5, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(0.1 * x, conv(x, adj2.t()), rtol=1e-5, atol=1e-6) ================================================ FILE: test/nn/conv/test_arma_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import ARMAConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_arma_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = ARMAConv(16, 32, num_stacks=8, num_layers=4) assert str(conv) == 'ARMAConv(16, 32, num_stacks=8, num_layers=4)' out = conv(x, edge_index) assert out.size() == (4, 32) with pytest.raises(RuntimeError): # No 3D feature tensor support. assert torch.allclose(conv(x, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj2.t()), out, atol=1e-6) def test_lazy_arma_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = ARMAConv(-1, 32, num_stacks=8, num_layers=4) assert str(conv) == 'ARMAConv(-1, 32, num_stacks=8, num_layers=4)' out = conv(x, edge_index) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj2.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_cg_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import CGConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('batch_norm', [False, True]) def test_cg_conv(batch_norm): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = CGConv(8, batch_norm=batch_norm) assert str(conv) == 'CGConv(8, dim=0)' out = conv(x1, edge_index) assert out.size() == (4, 8) assert torch.allclose(conv(x1, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) conv = CGConv((8, 16)) assert str(conv) == 'CGConv((8, 16), dim=0)' out = conv((x1, x2), edge_index) assert out.size() == (2, 16) assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) def test_cg_conv_with_edge_features(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.rand(edge_index.size(1), 3) conv = CGConv(8, dim=3) assert str(conv) == 'CGConv(8, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x1, adj.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, value), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj.t()), out) # Test bipartite message passing: conv = CGConv((8, 16), dim=3) assert str(conv) == 'CGConv((8, 16), dim=3)' out = conv((x1, x2), edge_index, value) assert out.size() == (2, 16) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(conv((x1, x2), adj.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, value), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj.t()), out) ================================================ FILE: test/nn/conv/test_cheb_conv.py ================================================ import torch from torch_geometric.data import Batch, Data from torch_geometric.nn import ChebConv from torch_geometric.testing import is_full_test def test_cheb_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 edge_weight = torch.rand(edge_index.size(1)) x = torch.randn((num_nodes, in_channels)) conv = ChebConv(in_channels, out_channels, K=3) assert str(conv) == 'ChebConv(16, 32, K=3, normalization=sym)' out1 = conv(x, edge_index) assert out1.size() == (num_nodes, out_channels) out2 = conv(x, edge_index, edge_weight) assert out2.size() == (num_nodes, out_channels) out3 = conv(x, edge_index, edge_weight, lambda_max=3.0) assert out3.size() == (num_nodes, out_channels) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1) assert torch.allclose(jit(x, edge_index, edge_weight), out2) assert torch.allclose( jit(x, edge_index, edge_weight, lambda_max=torch.tensor(3.0)), out3) batch = torch.tensor([0, 0, 1, 1]) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) num_nodes = edge_index.max().item() + 1 edge_weight = torch.rand(edge_index.size(1)) x = torch.randn((num_nodes, in_channels)) lambda_max = torch.tensor([2.0, 3.0]) out4 = conv(x, edge_index, edge_weight, batch) assert out4.size() == (num_nodes, out_channels) out5 = conv(x, edge_index, edge_weight, batch, lambda_max) assert out5.size() == (num_nodes, out_channels) if is_full_test(): assert torch.allclose(jit(x, edge_index, edge_weight, batch), out4) assert torch.allclose( jit(x, edge_index, edge_weight, batch, lambda_max), out5) def test_cheb_conv_batch(): x1 = torch.randn(4, 8) edge_index1 = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) edge_weight1 = torch.rand(edge_index1.size(1)) data1 = Data(x=x1, edge_index=edge_index1, edge_weight=edge_weight1) x2 = torch.randn(3, 8) edge_index2 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight2 = torch.rand(edge_index2.size(1)) data2 = Data(x=x2, edge_index=edge_index2, edge_weight=edge_weight2) conv = ChebConv(8, 16, K=2) out1 = conv(x1, edge_index1, edge_weight1) out2 = conv(x2, edge_index2, edge_weight2) batch = Batch.from_data_list([data1, data2]) out = conv(batch.x, batch.edge_index, batch.edge_weight, batch.batch) assert out.size() == (7, 16) assert torch.allclose(out1, out[:4], atol=1e-6) assert torch.allclose(out2, out[4:], atol=1e-6) ================================================ FILE: test/nn/conv/test_cluster_gcn_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import ClusterGCNConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_cluster_gcn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = ClusterGCNConv(16, 32, diag_lambda=1.) assert str(conv) == 'ClusterGCNConv(16, 32, diag_lambda=1.0)' out = conv(x, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out, atol=1e-5) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out, atol=1e-5) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj2.t()), out, atol=1e-5) ================================================ FILE: test/nn/conv/test_create_gnn.py ================================================ import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='add') self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) row, col = edge_index deg = degree(row, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] x = self.lin(x) return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm) def message(self, x_j, norm): return norm.view(-1, 1) * x_j def update(self, aggr_out): return aggr_out def test_create_gnn(): conv = GCNConv(16, 32) x = torch.randn(5, 16) edge_index = torch.randint(5, (2, 64), dtype=torch.long) out = conv(x, edge_index) assert out.size() == (5, 32) ================================================ FILE: test/nn/conv/test_dir_gnn_conv.py ================================================ import torch from torch_geometric.nn import DirGNNConv, SAGEConv def test_dir_gnn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) conv = DirGNNConv(SAGEConv(16, 32)) assert str(conv) == 'DirGNNConv(SAGEConv(16, 32, aggr=mean), alpha=0.5)' out = conv(x, edge_index) assert out.size() == (4, 32) def test_static_dir_gnn_conv(): x = torch.randn(3, 4, 16) edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) conv = DirGNNConv(SAGEConv(16, 32)) out = conv(x, edge_index) assert out.size() == (3, 4, 32) ================================================ FILE: test/nn/conv/test_dna_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import DNAConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('channels', [32]) @pytest.mark.parametrize('num_layers', [3]) def test_dna_conv(channels, num_layers): x = torch.randn((4, num_layers, channels)) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = DNAConv(channels, heads=4, groups=8, dropout=0.0) assert str(conv) == 'DNAConv(32, heads=4, groups=8)' out = conv(x, edge_index) assert out.size() == (4, channels) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out, atol=1e-6) conv = DNAConv(channels, heads=1, groups=1, dropout=0.0) assert str(conv) == 'DNAConv(32, heads=1, groups=1)' out = conv(x, edge_index) assert out.size() == (4, channels) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out, atol=1e-6) conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True) out = conv(x, edge_index) assert conv._cached_edge_index is not None out = conv(x, edge_index) assert out.size() == (4, channels) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out, atol=1e-6) @pytest.mark.parametrize('channels', [32]) @pytest.mark.parametrize('num_layers', [3]) def test_dna_conv_sparse_tensor(channels, num_layers): x = torch.randn((4, num_layers, channels)) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = DNAConv(32, heads=4, groups=8, dropout=0.0) assert str(conv) == 'DNAConv(32, heads=4, groups=8)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True) out1 = conv(x, adj1.t()) assert conv._cached_edge_index is not None assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert conv._cached_adj_t is not None assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) ================================================ FILE: test/nn/conv/test_edge_conv.py ================================================ import torch from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq import torch_geometric.typing from torch_geometric.nn import DynamicEdgeConv, EdgeConv from torch_geometric.testing import is_full_test, withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_edge_conv_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32)) conv = EdgeConv(nn) assert str(conv) == ( 'EdgeConv(nn=Sequential(\n' ' (0): Linear(in_features=32, out_features=16, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=16, out_features=32, bias=True)\n' '))') out = conv(x1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv((x1, x1), edge_index), out, atol=1e-6) assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) assert torch.allclose(conv((x1, x1), adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) assert torch.allclose(conv((x1, x1), adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out, atol=1e-6) assert torch.allclose(jit((x1, x1), edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) assert torch.allclose(jit((x1, x1), adj2.t()), out, atol=1e-6) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) out = conv((x1, x2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6) if is_full_test(): assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) @withPackage('torch_cluster') def test_dynamic_edge_conv(): x1 = torch.randn(8, 16) x2 = torch.randn(4, 16) batch1 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) batch2 = torch.tensor([0, 0, 1, 1]) nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32)) conv = DynamicEdgeConv(nn, k=2) assert str(conv) == ( 'DynamicEdgeConv(nn=Sequential(\n' ' (0): Linear(in_features=32, out_features=16, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=16, out_features=32, bias=True)\n' '), k=2)') out11 = conv(x1) assert out11.size() == (8, 32) out12 = conv(x1, batch1) assert out12.size() == (8, 32) out21 = conv((x1, x2)) assert out21.size() == (4, 32) out22 = conv((x1, x2), (batch1, batch2)) assert out22.size() == (4, 32) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1), out11) assert torch.allclose(jit(x1, batch1), out12) assert torch.allclose(jit((x1, x2)), out21) assert torch.allclose(jit((x1, x2), (batch1, batch2)), out22) ================================================ FILE: test/nn/conv/test_eg_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import EGConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_eg_conv_with_error(): with pytest.raises(ValueError, match="must be divisible by the number of"): EGConv(16, 30, num_heads=8) with pytest.raises(ValueError, match="Unsupported aggregator"): EGConv(16, 32, aggregators=['xxx']) @pytest.mark.parametrize('aggregators', [ ['symnorm'], ['sum', 'symnorm', 'std'], ]) @pytest.mark.parametrize('add_self_loops', [True, False]) def test_eg_conv(aggregators, add_self_loops): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = EGConv( in_channels=16, out_channels=32, aggregators=aggregators, add_self_loops=add_self_loops, ) assert str(conv) == f"EGConv(16, 32, aggregators={aggregators})" out = conv(x, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out, atol=1e-2) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out, atol=1e-2) conv.cached = True assert torch.allclose(conv(x, edge_index), out, atol=1e-2) assert conv._cached_edge_index is not None assert torch.allclose(conv(x, edge_index), out, atol=1e-2) assert torch.allclose(conv(x, adj1.t()), out, atol=1e-2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x, adj2.t()), out, atol=1e-2) assert conv._cached_adj_t is not None assert torch.allclose(conv(x, adj2.t()), out, atol=1e-2) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out, atol=1e-2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj2.t()), out, atol=1e-2) def test_eg_conv_with_sparse_input_feature(): x = torch.randn(4, 16).to_sparse_coo() edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = EGConv(16, 32) assert conv(x, edge_index).size() == (4, 32) ================================================ FILE: test/nn/conv/test_fa_conv.py ================================================ from typing import Tuple import torch from torch import Tensor import torch_geometric.typing from torch_geometric.nn import FAConv from torch_geometric.testing import is_full_test from torch_geometric.typing import Adj, SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_fa_conv(): x = torch.randn(4, 16) x_0 = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = FAConv(16, eps=1.0, cached=True) assert str(conv) == 'FAConv(16, eps=1.0)' out = conv(x, x_0, edge_index) assert conv._cached_edge_index is not None assert out.size() == (4, 16) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6) assert conv._cached_adj_t is not None assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6) if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, x_0: Tensor, edge_index: Adj, ) -> Tensor: return self.conv(x, x_0, edge_index) jit = torch.jit.script(MyModule()) assert torch.allclose(jit(x, x_0, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, x_0, adj2.t()), out) conv.reset_parameters() assert conv._cached_edge_index is None assert conv._cached_adj_t is None # Test without caching: conv.cached = False out = conv(x, x_0, edge_index) assert torch.allclose(conv(x, x_0, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6) # Test `return_attention_weights`: result = conv(x, x_0, edge_index, return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1][0].size() == (2, 10) assert result[1][1].size() == (10, ) assert conv._alpha is None result = conv(x, x_0, adj1.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1][0].size() == torch.Size([4, 4]) assert result[1][0]._nnz() == 10 assert conv._alpha is None if torch_geometric.typing.WITH_TORCH_SPARSE: result = conv(x, x_0, adj2.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4] and result[1].nnz() == 10 assert conv._alpha is None if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, x_0: Tensor, edge_index: Tensor, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: return self.conv(x, x_0, edge_index, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x, x_0, edge_index) assert torch.allclose(result[0], out, atol=1e-6) assert result[1][0].size() == (2, 10) assert result[1][1].size() == (10, ) assert conv._alpha is None if torch_geometric.typing.WITH_TORCH_SPARSE: class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, x_0: Tensor, edge_index: SparseTensor, ) -> Tuple[Tensor, SparseTensor]: return self.conv(x, x_0, edge_index, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x, x_0, adj2.t()) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4] and result[1].nnz() == 10 assert conv._alpha is None ================================================ FILE: test/nn/conv/test_feast_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import FeaStConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_feast_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = FeaStConv(16, 32, heads=2) assert str(conv) == 'FeaStConv(16, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) out = conv((x1, x2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6) if is_full_test(): assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_film_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import FiLMConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor def test_film_conv(): x1 = torch.randn(4, 4) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]]) edge_type = torch.tensor([0, 1, 1, 0, 0, 1]) conv = FiLMConv(4, 32) assert str(conv) == 'FiLMConv(4, 32, num_relations=1)' out = conv(x1, edge_index) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) conv = FiLMConv(4, 32, num_relations=2) assert str(conv) == 'FiLMConv(4, 32, num_relations=2)' out = conv(x1, edge_index, edge_type) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 4)) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, edge_type), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) # Test bipartite message passing: conv = FiLMConv((4, 16), 32) assert str(conv) == 'FiLMConv((4, 16), 32, num_relations=1)' out = conv((x1, x2), edge_index) assert out.size() == (2, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6) conv = FiLMConv((4, 16), 32, num_relations=2) assert str(conv) == 'FiLMConv((4, 16), 32, num_relations=2)' out = conv((x1, x2), edge_index, edge_type) assert out.size() == (2, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 2)) assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, edge_type), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_fused_gat_conv.py ================================================ import torch from torch_geometric.nn import FusedGATConv from torch_geometric.testing import onlyCUDA, withPackage def test_to_graph_format() -> None: edge_index = torch.tensor([[1, 0, 2, 3], [0, 0, 1, 1]]) csr, csc, perm = FusedGATConv.to_graph_format(edge_index, size=(4, 4)) assert csr[0].dtype == torch.int assert torch.equal(csr[0], torch.tensor([0, 1, 2, 3, 4], dtype=torch.int)) assert csr[1].dtype == torch.int assert torch.equal(csr[1], torch.tensor([0, 0, 1, 1], dtype=torch.int)) assert csc[0].dtype == torch.int assert torch.equal(csc[0], torch.tensor([0, 1, 2, 3], dtype=torch.int)) assert csc[1].dtype == torch.int assert torch.equal(csc[1], torch.tensor([0, 2, 4, 4, 4], dtype=torch.int)) assert perm.dtype == torch.int assert torch.equal(perm, torch.tensor([0, 1, 2, 3], dtype=torch.int)) @onlyCUDA @withPackage('dgNN') def test_fused_gat_conv() -> None: device = torch.device('cuda') x = torch.randn(4, 8, device=device) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device) csr, csc, perm = FusedGATConv.to_graph_format(edge_index, size=(4, 4)) conv = FusedGATConv(8, 32, heads=2, add_self_loops=False).to(device) assert str(conv) == 'FusedGATConv(8, 32, heads=2)' out = conv(x, csr, csc, perm) assert out.size() == (4, 64) ================================================ FILE: test/nn/conv/test_gat_conv.py ================================================ from typing import Optional, Tuple import pytest import torch from torch import Tensor import torch_geometric.typing from torch_geometric.nn import GATConv from torch_geometric.testing import is_full_test, withDevice from torch_geometric.typing import Adj, Size, SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('residual', [False, True]) def test_gat_conv(residual): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = GATConv(8, 32, heads=2, residual=residual) assert str(conv) == 'GATConv(8, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 64) assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: Adj, size: Size = None, ) -> Tensor: return self.conv(x, edge_index, size=size) jit = torch.jit.script(MyModule()) assert torch.allclose(jit(x1, edge_index), out) assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) # Test `return_attention_weights`. result = conv(x1, edge_index, return_attention_weights=True) assert torch.allclose(result[0], out) assert result[1][0].size() == (2, 7) assert result[1][1].size() == (7, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 result = conv(x1, adj1.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1][0].size() == torch.Size([4, 4, 2]) assert result[1][0]._nnz() == 7 if torch_geometric.typing.WITH_TORCH_SPARSE: result = conv(x1, adj2.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: Tensor, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: return self.conv(x, edge_index, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x1, edge_index) assert torch.allclose(result[0], out) assert result[1][0].size() == (2, 7) assert result[1][1].size() == (7, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 if torch_geometric.typing.WITH_TORCH_SPARSE: class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: SparseTensor, ) -> Tuple[Tensor, SparseTensor]: return self.conv(x, edge_index, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x1, adj2.t()) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) conv = GATConv((8, 16), 32, heads=2, residual=residual) assert str(conv) == 'GATConv((8, 16), 32, heads=2)' out1 = conv((x1, x2), edge_index) assert out1.size() == (2, 64) assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6) out2 = conv((x1, None), edge_index, size=(4, 2)) assert out2.size() == (2, 64) assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6) if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tuple[Tensor, Optional[Tensor]], edge_index: Adj, size: Size = None, ) -> Tensor: return self.conv(x, edge_index, size=size) jit = torch.jit.script(MyModule()) assert torch.allclose(jit((x1, x2), edge_index), out1) assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj2.t()), out2, atol=1e-6) def test_gat_conv_with_edge_attr(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]]) edge_weight = torch.randn(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 4) conv = GATConv(8, 32, heads=2, edge_dim=1, fill_value=0.5) out = conv(x, edge_index, edge_weight) assert out.size() == (4, 64) if torch_geometric.typing.WITH_TORCH_SPARSE: adj1 = SparseTensor.from_edge_index(edge_index, edge_weight, (4, 4)) with pytest.raises(NotImplementedError): assert torch.allclose(conv(x, adj1.t()), out) conv = GATConv(8, 32, heads=2, edge_dim=1, fill_value='mean') out = conv(x, edge_index, edge_weight) assert out.size() == (4, 64) if torch_geometric.typing.WITH_TORCH_SPARSE: with pytest.raises(NotImplementedError): assert torch.allclose(conv(x, adj1.t()), out) conv = GATConv(8, 32, heads=2, edge_dim=4, fill_value=0.5) out = conv(x, edge_index, edge_attr) assert out.size() == (4, 64) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) with pytest.raises(NotImplementedError): assert torch.allclose(conv(x, adj2.t()), out) conv = GATConv(8, 32, heads=2, edge_dim=4, fill_value='mean') out = conv(x, edge_index, edge_attr) assert out.size() == (4, 64) if torch_geometric.typing.WITH_TORCH_SPARSE: with pytest.raises(NotImplementedError): assert torch.allclose(conv(x, adj2.t()), out) @withDevice def test_gat_conv_empty_edge_index(device): x = torch.randn(0, 8, device=device) edge_index = torch.empty(2, 0, dtype=torch.long, device=device) conv = GATConv(8, 32, heads=2).to(device) out = conv(x, edge_index) assert out.size() == (0, 64) ================================================ FILE: test/nn/conv/test_gated_graph_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import GatedGraphConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_gated_graph_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = GatedGraphConv(32, num_layers=3) assert str(conv) == 'GatedGraphConv(32, num_layers=3)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) ================================================ FILE: test/nn/conv/test_gatv2_conv.py ================================================ from typing import Tuple import pytest import torch from torch import Tensor import torch_geometric.typing from torch_geometric.nn import GATv2Conv from torch_geometric.testing import is_full_test from torch_geometric.typing import Adj, SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('residual', [False, True]) def test_gatv2_conv(residual): x1 = torch.randn(4, 8) x2 = torch.randn(2, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = GATv2Conv(8, 32, heads=2, residual=residual) assert str(conv) == 'GATv2Conv(8, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 64) assert torch.allclose(conv(x1, edge_index), out) assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: Adj, ) -> Tensor: return self.conv(x, edge_index) jit = torch.jit.script(MyModule()) assert torch.allclose(jit(x1, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) # Test `return_attention_weights`. result = conv(x1, edge_index, return_attention_weights=True) assert torch.allclose(result[0], out) assert result[1][0].size() == (2, 7) assert result[1][1].size() == (7, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 result = conv(x1, adj1.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1][0].size() == torch.Size([4, 4, 2]) assert result[1][0]._nnz() == 7 if torch_geometric.typing.WITH_TORCH_SPARSE: result = conv(x1, adj2.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: Tensor, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: return self.conv(x, edge_index, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x1, edge_index) assert torch.allclose(result[0], out) assert result[1][0].size() == (2, 7) assert result[1][1].size() == (7, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 if torch_geometric.typing.WITH_TORCH_SPARSE: class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: SparseTensor, ) -> Tuple[Tensor, SparseTensor]: return self.conv(x, edge_index, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x1, adj2.t()) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) out = conv((x1, x2), edge_index) assert out.size() == (2, 64) assert torch.allclose(conv((x1, x2), edge_index), out) assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6) if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tuple[Tensor, Tensor], edge_index: Adj, ) -> Tensor: return self.conv(x, edge_index) jit = torch.jit.script(MyModule()) assert torch.allclose(jit((x1, x2), edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) def test_gatv2_conv_with_edge_attr(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]]) edge_weight = torch.randn(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 4) conv = GATv2Conv(8, 32, heads=2, edge_dim=1, fill_value=0.5) out = conv(x, edge_index, edge_weight) assert out.size() == (4, 64) conv = GATv2Conv(8, 32, heads=2, edge_dim=1, fill_value='mean') out = conv(x, edge_index, edge_weight) assert out.size() == (4, 64) conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value=0.5) out = conv(x, edge_index, edge_attr) assert out.size() == (4, 64) conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value='mean') out = conv(x, edge_index, edge_attr) assert out.size() == (4, 64) ================================================ FILE: test/nn/conv/test_gcn2_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import GCN2Conv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_gcn2_conv(): x = torch.randn(4, 16) x_0 = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = GCN2Conv(16, alpha=0.2) assert str(conv) == 'GCN2Conv(16, alpha=0.2, beta=1.0)' out1 = conv(x, x_0, edge_index) assert out1.size() == (4, 16) assert torch.allclose(conv(x, x_0, adj1.t()), out1, atol=1e-6) out2 = conv(x, x_0, edge_index, value) assert out2.size() == (4, 16) assert torch.allclose(conv(x, x_0, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, x_0, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, x_0, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, x_0, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, x_0, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, x_0, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, x_0, adj4.t()), out2, atol=1e-6) conv.cached = True conv(x, x_0, edge_index) assert conv._cached_edge_index is not None assert torch.allclose(conv(x, x_0, edge_index), out1, atol=1e-6) assert torch.allclose(conv(x, x_0, adj1.t()), out1, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: conv(x, x_0, adj3.t()) assert conv._cached_adj_t is not None assert torch.allclose(conv(x, x_0, adj3.t()), out1, atol=1e-6) ================================================ FILE: test/nn/conv/test_gcn_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import GCNConv from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.testing import is_full_test from torch_geometric.typing import WITH_PT21, SparseTensor from torch_geometric.utils import to_torch_coo_tensor, to_torch_csc_tensor def test_gcn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = GCNConv(16, 32) assert str(conv) == 'GCNConv(16, 32)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv._cached_edge_index is not None assert torch.allclose(conv(x, edge_index), out1, atol=1e-6) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: conv(x, adj3.t()) assert conv._cached_adj_t is not None assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) def test_gcn_conv_with_decomposed_layers(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) def hook(module, inputs): assert inputs[0]['x_j'].size() == (10, 32 // module.decomposed_layers) conv = GCNConv(16, 32) conv.register_message_forward_pre_hook(hook) out1 = conv(x, edge_index) conv.decomposed_layers = 2 assert conv.propagate.__module__.endswith('message_passing') out2 = conv(x, edge_index) assert torch.allclose(out1, out2) # TorchScript should still work since it relies on class methods # (but without decomposition). torch.jit.script(conv) conv.decomposed_layers = 1 assert conv.propagate.__module__.endswith('GCNConv_propagate') def test_gcn_conv_with_sparse_input_feature(): x = torch.sparse_coo_tensor( indices=torch.tensor([[0, 0], [0, 1]]), values=torch.tensor([1., 1.]), size=torch.Size([4, 16]), ) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = GCNConv(16, 32) assert conv(x, edge_index).size() == (4, 32) def test_static_gcn_conv(): x = torch.randn(3, 4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = GCNConv(16, 32) out = conv(x, edge_index) assert out.size() == (3, 4, 32) def test_gcn_conv_error(): with pytest.raises(ValueError, match="does not support adding self-loops"): GCNConv(16, 32, normalize=False, add_self_loops=True) def test_gcn_conv_flow(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]]) conv = GCNConv(16, 32, flow="source_to_target") out1 = conv(x, edge_index) conv.flow = "target_to_source" out2 = conv(x, edge_index.flip(0)) assert torch.allclose(out1, out2, atol=1e-6) @pytest.mark.parametrize('requires_grad', [False, True]) @pytest.mark.parametrize('layout', [torch.sparse_coo, torch.sparse_csr]) def test_gcn_norm_gradient(requires_grad, layout): edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_weight = torch.ones(edge_index.size(1), requires_grad=requires_grad) adj = to_torch_coo_tensor(edge_index, edge_weight) if layout == torch.sparse_csr: adj = adj.to_sparse_csr() # TODO Sparse CSR tensor doesn't inherit `requires_grad` for PyTorch < 2.1. if layout == torch.sparse_csr and not WITH_PT21: assert not gcn_norm(adj)[0].requires_grad else: assert adj.requires_grad == gcn_norm(adj)[0].requires_grad ================================================ FILE: test/nn/conv/test_gen_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import GENConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_coo_tensor @pytest.mark.parametrize('aggr', [ 'softmax', 'powermean', ['softmax', 'powermean'], ]) def test_gen_conv(aggr): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.randn(edge_index.size(1), 16) adj1 = to_torch_coo_tensor(edge_index, size=(4, 4)) adj2 = to_torch_coo_tensor(edge_index, value, size=(4, 4)) conv = GENConv(16, 32, aggr, edge_dim=16, msg_norm=True) assert str(conv) == f'GENConv(16, 32, aggr={aggr})' out1 = conv(x1, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out1) assert torch.allclose(conv(x1, adj1.t().coalesce()), out1) out2 = conv(x1, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out2) # t() expects a tensor with <= 2 sparse and 0 dense dimensions assert torch.allclose(conv(x1, adj2.transpose(1, 0).coalesce()), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x1, adj3.t()), out1, atol=1e-4) assert torch.allclose(conv(x1, adj4.t()), out2, atol=1e-4) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out1, atol=1e-4) assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out1, atol=1e-4) assert torch.allclose(jit(x1, edge_index, value), out2, atol=1e-4) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out2, atol=1e-4) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj3.t()), out1, atol=1e-4) assert torch.allclose(jit(x1, adj4.t()), out2, atol=1e-4) # Test bipartite message passing: adj1 = to_torch_coo_tensor(edge_index, size=(4, 2)) adj2 = to_torch_coo_tensor(edge_index, value, size=(4, 2)) out1 = conv((x1, x2), edge_index) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(conv((x1, x2), adj1.t().coalesce()), out1) out2 = conv((x1, x2), edge_index, value) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out2) assert torch.allclose(conv((x1, x2), adj2.transpose(1, 0).coalesce()), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-4) assert torch.allclose(conv((x1, x2), adj4.t()), out2, atol=1e-4) if is_full_test(): assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-4) assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1, atol=1e-4) assert torch.allclose(jit((x1, x2), edge_index, value), out2, atol=1e-4) assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out2, atol=1e-4) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj3.t()), out1, atol=1e-4) assert torch.allclose(jit((x1, x2), adj4.t()), out2, atol=1e-4) # Test bipartite message passing with unequal feature dimensions: conv.reset_parameters() assert float(conv.msg_norm.scale) == 1 x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) conv = GENConv((8, 16), 32, aggr) assert str(conv) == f'GENConv((8, 16), 32, aggr={aggr})' out1 = conv((x1, x2), edge_index) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(conv((x1, x2), adj1.t().coalesce()), out1) out2 = conv((x1, None), edge_index, size=(4, 2)) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, None), adj1.t().coalesce()), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-4) assert torch.allclose(conv((x1, None), adj3.t()), out2, atol=1e-4) # Test lazy initialization: conv = GENConv((-1, -1), 32, aggr, edge_dim=-1) assert str(conv) == f'GENConv((-1, -1), 32, aggr={aggr})' out1 = conv((x1, x2), edge_index, value) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, size=(4, 2)), out1) assert torch.allclose(conv((x1, x2), adj2.transpose(1, 0).coalesce()), out1) out2 = conv((x1, None), edge_index, value, size=(4, 2)) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, None), adj2.transpose(1, 0).coalesce()), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv((x1, x2), adj4.t()), out1, atol=1e-4) assert torch.allclose(conv((x1, None), adj4.t()), out2, atol=1e-4) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, value), out1, atol=1e-4) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), out1, atol=1e-4) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), out2, atol=1e-4) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj4.t()), out1, atol=1e-4) assert torch.allclose(jit((x1, None), adj4.t()), out2, atol=1e-4) ================================================ FILE: test/nn/conv/test_general_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import GeneralConv from torch_geometric.typing import SparseTensor @pytest.mark.parametrize('kwargs', [ dict(), dict(skip_linear=True), dict(directed_msg=False), dict(heads=3), dict(attention=True), dict(heads=3, attention=True), dict(heads=3, attention=True, attention_type='dot_product'), dict(l2_normalize=True), ]) def test_general_conv(kwargs): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_attr = torch.randn(edge_index.size(1), 16) conv = GeneralConv(8, 32, **kwargs) assert str(conv) == 'GeneralConv(8, 32)' out = conv(x, edge_index) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) conv = GeneralConv(8, 32, in_edge_channels=16, **kwargs) assert str(conv) == 'GeneralConv(8, 32)' out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_gin_conv.py ================================================ import torch from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq import torch_geometric.typing from torch_geometric.nn import GINConv, GINEConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_gin_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINConv(nn, train_eps=True) assert str(conv) == ( 'GINConv(nn=Sequential(\n' ' (0): Linear(in_features=16, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6) assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out, atol=1e-6) assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) out1 = conv((x1, x2), edge_index) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6) assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6) out2 = conv((x1, None), edge_index, (4, 2)) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6) if is_full_test(): assert torch.allclose(jit((x1, x2), edge_index), out1) assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out1) assert torch.allclose(jit((x1, None), adj2.t()), out2) def test_gine_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.randn(edge_index.size(1), 16) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINEConv(nn, train_eps=True) assert str(conv) == ( 'GINEConv(nn=Sequential(\n' ' (0): Linear(in_features=16, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x1, adj.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, value), out) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj.t()), out) # Test bipartite message passing: out1 = conv((x1, x2), edge_index, value) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out2.size() == (2, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(conv((x1, x2), adj.t()), out1) assert torch.allclose(conv((x1, None), adj.t()), out2) if is_full_test(): assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj.t()), out1) assert torch.allclose(jit((x1, None), adj.t()), out2) def test_gine_conv_edge_dim(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_attr = torch.randn(edge_index.size(1), 8) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINEConv(nn, train_eps=True, edge_dim=8) out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) nn = Lin(16, 32) conv = GINEConv(nn, train_eps=True, edge_dim=8) out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) def test_static_gin_conv(): x = torch.randn(3, 4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINConv(nn, train_eps=True) out = conv(x, edge_index) assert out.size() == (3, 4, 32) def test_static_gine_conv(): x = torch.randn(3, 4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.randn(edge_index.size(1), 16) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINEConv(nn, train_eps=True) out = conv(x, edge_index, edge_attr) assert out.size() == (3, 4, 32) ================================================ FILE: test/nn/conv/test_gmm_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import GMMConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_coo_tensor @pytest.mark.parametrize('separate_gaussians', [True, False]) def test_gmm_conv(separate_gaussians): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.rand(edge_index.size(1), 3) adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 4)) conv = GMMConv(8, 32, dim=3, kernel_size=25, separate_gaussians=separate_gaussians) assert str(conv) == 'GMMConv(8, 32, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out) # t() expects a tensor with <= 2 sparse and 0 dense dimensions assert torch.allclose(conv(x1, adj1.transpose(0, 1).coalesce()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x1, adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, value), out) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out) # Test bipartite message passing: adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 2)) conv = GMMConv((8, 16), 32, dim=3, kernel_size=5, separate_gaussians=separate_gaussians) assert str(conv) == 'GMMConv((8, 16), 32, dim=3)' out1 = conv((x1, x2), edge_index, value) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(conv((x1, x2), adj1.transpose(0, 1).coalesce()), out1) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, None), adj1.transpose(0, 1).coalesce()), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out1) assert torch.allclose(conv((x1, None), adj2.t()), out2) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out1) assert torch.allclose(jit((x1, None), adj2.t()), out2) @pytest.mark.parametrize('separate_gaussians', [True, False]) def test_lazy_gmm_conv(separate_gaussians): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.rand(edge_index.size(1), 3) conv = GMMConv(-1, 32, dim=3, kernel_size=25, separate_gaussians=separate_gaussians) assert str(conv) == 'GMMConv(-1, 32, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 32) conv = GMMConv((-1, -1), 32, dim=3, kernel_size=25, separate_gaussians=separate_gaussians) assert str(conv) == 'GMMConv((-1, -1), 32, dim=3)' out = conv((x1, x2), edge_index, value) assert out.size() == (2, 32) ================================================ FILE: test/nn/conv/test_gps_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import GPSConv, SAGEConv from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('attn_type', ['multihead', 'performer']) @pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm']) def test_gps_conv(norm, attn_type): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) batch = torch.tensor([0, 0, 1, 1]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = GPSConv(16, conv=SAGEConv(16, 16), heads=4, norm=norm, attn_type=attn_type) conv.reset_parameters() assert str(conv) == (f'GPSConv(16, conv=SAGEConv(16, 16, aggr=mean), ' f'heads=4, attn_type={attn_type})') out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out, atol=1e-6) out = conv(x, edge_index, batch) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj1.t(), batch), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x, adj2.t(), batch), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_graph_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric import EdgeIndex from torch_geometric.nn import GraphConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_graph_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.randn(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = GraphConv(8, 32) assert str(conv) == 'GraphConv(8, 32)' out1 = conv(x1, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out1, atol=1e-6) assert torch.allclose(conv(x1, adj1.t()), out1, atol=1e-6) assert conv( x1, EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 4)), ).allclose(out1, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj3.t()), out1, atol=1e-6) out2 = conv(x1, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out2, atol=1e-6) assert torch.allclose(conv(x1, adj2.t()), out2, atol=1e-6) assert conv( x1, EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 4)), value, ).allclose(out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x1, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out1) assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out1) assert torch.allclose(jit(x1, edge_index, value), out2) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x1, adj4.t()), out2, atol=1e-6) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 2)) conv = GraphConv((8, 16), 32) assert str(conv) == 'GraphConv((8, 16), 32)' out1 = conv((x1, x2), edge_index) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6) assert conv( (x1, x2), EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)), ).allclose(out1, atol=1e-6) out2 = conv((x1, None), edge_index, size=(4, 2)) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6) assert conv( (x1, None), EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)), ).allclose(out2, atol=1e-6) out3 = conv((x1, x2), edge_index, value) assert out3.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out3) assert torch.allclose(conv((x1, x2), adj2.t()), out3, atol=1e-6) assert conv( (x1, x2), EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)), value, ).allclose(out3, atol=1e-6) out4 = conv((x1, None), edge_index, value, size=(4, 2)) assert out4.size() == (2, 32) assert torch.allclose(conv((x1, None), adj2.t()), out4, atol=1e-6) assert conv( (x1, None), EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)), value, ).allclose(out4, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj3.t()), out2, atol=1e-6) assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj4.t()), out4, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index), out1) assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2) assert torch.allclose(jit((x1, x2), edge_index, value), out3) assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out3) assert torch.allclose(jit((x1, None), edge_index, value, (4, 2)), out4) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj3.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj3.t()), out2, atol=1e-6) assert torch.allclose(jit((x1, x2), adj4.t()), out3, atol=1e-6) assert torch.allclose(jit((x1, None), adj4.t()), out4, atol=1e-6) class EdgeGraphConv(GraphConv): def message(self, x_j, edge_weight): return edge_weight.view(-1, 1) * x_j def test_inheritance(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_weight = torch.rand(4) conv = EdgeGraphConv(8, 16) assert conv(x, edge_index, edge_weight).size() == (4, 16) ================================================ FILE: test/nn/conv/test_gravnet_conv.py ================================================ import torch from torch_geometric.nn import GravNetConv from torch_geometric.testing import is_full_test, withPackage @withPackage('torch_cluster') def test_gravnet_conv(): x1 = torch.randn(8, 16) x2 = torch.randn(4, 16) batch1 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) batch2 = torch.tensor([0, 0, 1, 1]) conv = GravNetConv(16, 32, space_dimensions=4, propagate_dimensions=8, k=2) assert str(conv) == 'GravNetConv(16, 32, k=2)' out11 = conv(x1) assert out11.size() == (8, 32) out12 = conv(x1, batch1) assert out12.size() == (8, 32) out21 = conv((x1, x2)) assert out21.size() == (4, 32) out22 = conv((x1, x2), (batch1, batch2)) assert out22.size() == (4, 32) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1), out11) assert torch.allclose(jit(x1, batch1), out12) assert torch.allclose(jit((x1, x2)), out21) assert torch.allclose(jit((x1, x2), (batch1, batch2)), out22) ================================================ FILE: test/nn/conv/test_han_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import HANConv from torch_geometric.typing import SparseTensor from torch_geometric.utils import coalesce, to_torch_csc_tensor def test_han_conv(): x_dict = { 'author': torch.randn(6, 16), 'paper': torch.randn(5, 12), 'term': torch.randn(4, 3) } edge_index1 = coalesce(torch.randint(0, 6, (2, 7))) edge_index2 = coalesce(torch.randint(0, 5, (2, 4))) edge_index3 = coalesce(torch.randint(0, 3, (2, 5))) edge_index_dict = { ('author', 'metapath0', 'author'): edge_index1, ('paper', 'metapath1', 'paper'): edge_index2, ('paper', 'metapath2', 'paper'): edge_index3, } adj_t_dict1 = {} for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type adj_t_dict1[edge_type] = to_torch_csc_tensor( edge_index, size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)), ).t() metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) in_channels = {'author': 16, 'paper': 12, 'term': 3} conv = HANConv(in_channels, 16, metadata, heads=2) assert str(conv) == 'HANConv(16, heads=2)' out_dict1 = conv(x_dict, edge_index_dict) assert len(out_dict1) == 3 assert out_dict1['author'].size() == (6, 16) assert out_dict1['paper'].size() == (5, 16) assert out_dict1['term'] is None del out_dict1['term'] del x_dict['term'] out_dict2 = conv(x_dict, adj_t_dict1) assert len(out_dict1) == len(out_dict2) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict2 = {} for edge_type, edge_index in edge_index_dict.items(): adj_t_dict2[edge_type] = SparseTensor.from_edge_index( edge_index, sparse_sizes=adj_t_dict1[edge_type].size()[::-1], ).t() out_dict3 = conv(x_dict, adj_t_dict2) assert len(out_dict1) == len(out_dict3) for key in out_dict3.keys(): assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6) # Test non-zero dropout: conv = HANConv(in_channels, 16, metadata, heads=2, dropout=0.1) assert str(conv) == 'HANConv(16, heads=2)' out_dict1 = conv(x_dict, edge_index_dict) assert len(out_dict1) == 2 assert out_dict1['author'].size() == (6, 16) assert out_dict1['paper'].size() == (5, 16) def test_han_conv_lazy(): x_dict = { 'author': torch.randn(6, 16), 'paper': torch.randn(5, 12), } edge_index1 = coalesce(torch.randint(0, 6, (2, 8))) edge_index2 = coalesce(torch.randint(0, 5, (2, 6))) edge_index_dict = { ('author', 'to', 'author'): edge_index1, ('paper', 'to', 'paper'): edge_index2, } adj_t_dict1 = {} for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type adj_t_dict1[edge_type] = to_torch_csc_tensor( edge_index, size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)), ).t() metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) conv = HANConv(-1, 16, metadata, heads=2) assert str(conv) == 'HANConv(16, heads=2)' out_dict1 = conv(x_dict, edge_index_dict) assert len(out_dict1) == 2 assert out_dict1['author'].size() == (6, 16) assert out_dict1['paper'].size() == (5, 16) out_dict2 = conv(x_dict, adj_t_dict1) assert len(out_dict1) == len(out_dict2) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict2 = {} for edge_type, edge_index in edge_index_dict.items(): adj_t_dict2[edge_type] = SparseTensor.from_edge_index( edge_index, sparse_sizes=adj_t_dict1[edge_type].size()[::-1], ).t() out_dict3 = conv(x_dict, adj_t_dict2) assert len(out_dict1) == len(out_dict3) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6) def test_han_conv_empty_tensor(): x_dict = { 'author': torch.randn(6, 16), 'paper': torch.empty(0, 12), } edge_index_dict = { ('paper', 'to', 'author'): torch.empty((2, 0), dtype=torch.long), ('author', 'to', 'paper'): torch.empty((2, 0), dtype=torch.long), ('paper', 'to', 'paper'): torch.empty((2, 0), dtype=torch.long), } metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) in_channels = {'author': 16, 'paper': 12} conv = HANConv(in_channels, 16, metadata, heads=2) out_dict = conv(x_dict, edge_index_dict) assert len(out_dict) == 2 assert out_dict['author'].size() == (6, 16) assert torch.all(out_dict['author'] == 0) assert out_dict['paper'].size() == (0, 16) ================================================ FILE: test/nn/conv/test_heat_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import HEATConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor @pytest.mark.parametrize('concat', [True, False]) def test_heat_conv(concat): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_attr = torch.randn((4, 2)) node_type = torch.tensor([0, 0, 1, 2]) edge_type = torch.tensor([0, 2, 1, 2]) conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3, num_edge_types=3, edge_type_emb_dim=5, edge_dim=2, edge_attr_emb_dim=6, heads=2, concat=concat) assert str(conv) == 'HEATConv(8, 16, heads=2)' out = conv(x, edge_index, node_type, edge_type, edge_attr) assert out.size() == (4, 32 if concat else 16) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out, atol=1e-5) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose( jit(x, edge_index, node_type, edge_type, edge_attr), out, atol=1e-5) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out, atol=1e-5) ================================================ FILE: test/nn/conv/test_hetero_conv.py ================================================ import random import pytest import torch from torch_geometric.data import HeteroData from torch_geometric.datasets import FakeHeteroDataset from torch_geometric.nn import ( GATConv, GCN2Conv, GCNConv, HeteroConv, Linear, MessagePassing, SAGEConv, ) from torch_geometric.profile import benchmark from torch_geometric.testing import ( get_random_edge_index, onlyLinux, withDevice, withPackage, ) @pytest.mark.parametrize('aggr', ['sum', 'mean', 'min', 'max', 'cat', None]) def test_hetero_conv(aggr): data = HeteroData() data['paper'].x = torch.randn(50, 32) data['author'].x = torch.randn(30, 64) data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) data['paper', 'author'].edge_attr = torch.randn(100, 3) data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) data['paper', 'paper'].edge_weight = torch.rand(200) # Unspecified edge types should be ignored: data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100) conv = HeteroConv( { ('paper', 'to', 'paper'): GCNConv(-1, 64), ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'to', 'author'): GATConv((-1, -1), 64, edge_dim=3, add_self_loops=False), }, aggr=aggr, ) assert len(list(conv.parameters())) > 0 assert str(conv) == 'HeteroConv(num_relations=3)' out_dict = conv( data.x_dict, data.edge_index_dict, data.edge_attr_dict, edge_weight_dict=data.edge_weight_dict, ) assert len(out_dict) == 2 if aggr == 'cat': assert out_dict['paper'].size() == (50, 128) assert out_dict['author'].size() == (30, 64) elif aggr is not None: assert out_dict['paper'].size() == (50, 64) assert out_dict['author'].size() == (30, 64) else: assert out_dict['paper'].size() == (50, 2, 64) assert out_dict['author'].size() == (30, 1, 64) def test_gcn2_hetero_conv(): data = HeteroData() data['paper'].x = torch.randn(50, 32) data['author'].x = torch.randn(30, 64) data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100) data['paper', 'paper'].edge_weight = torch.rand(200) conv = HeteroConv({ ('paper', 'to', 'paper'): GCN2Conv(32, alpha=0.1), ('author', 'to', 'author'): GCN2Conv(64, alpha=0.2), }) out_dict = conv( data.x_dict, data.x_dict, data.edge_index_dict, edge_weight_dict=data.edge_weight_dict, ) assert len(out_dict) == 2 assert out_dict['paper'].size() == (50, 32) assert out_dict['author'].size() == (30, 64) class CustomConv(MessagePassing): def __init__(self, out_channels): super().__init__(aggr='add') self.lin = Linear(-1, out_channels) def forward(self, x, edge_index, y, z): return self.propagate(edge_index, x=x, y=y, z=z) def message(self, x_j, y_j, z_j): return self.lin(torch.cat([x_j, y_j, z_j], dim=-1)) def test_hetero_conv_with_custom_conv(): data = HeteroData() data['paper'].x = torch.randn(50, 32) data['paper'].y = torch.randn(50, 3) data['paper'].z = torch.randn(50, 3) data['author'].x = torch.randn(30, 64) data['author'].y = torch.randn(30, 3) data['author'].z = torch.randn(30, 3) data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) conv = HeteroConv({key: CustomConv(64) for key in data.edge_types}) # Test node `args_dict` and `kwargs_dict` with `y_dict` and `z_dict`: out_dict = conv( data.x_dict, data.edge_index_dict, data.y_dict, z_dict=data.z_dict, ) assert len(out_dict) == 2 assert out_dict['paper'].size() == (50, 64) assert out_dict['author'].size() == (30, 64) class MessagePassingLoops(MessagePassing): def __init__(self): super().__init__() self.add_self_loops = True def test_hetero_conv_self_loop_error(): HeteroConv({('a', 'to', 'a'): MessagePassingLoops()}) with pytest.raises(ValueError, match="incorrect message passing"): HeteroConv({('a', 'to', 'b'): MessagePassingLoops()}) def test_hetero_conv_with_dot_syntax_node_types(): data = HeteroData() data['src.paper'].x = torch.randn(50, 32) data['author'].x = torch.randn(30, 64) edge_index = get_random_edge_index(50, 50, 200) data['src.paper', 'src.paper'].edge_index = edge_index data['src.paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) data['author', 'src.paper'].edge_index = get_random_edge_index(30, 50, 100) data['src.paper', 'src.paper'].edge_weight = torch.rand(200) conv = HeteroConv({ ('src.paper', 'to', 'src.paper'): GCNConv(-1, 64), ('author', 'to', 'src.paper'): SAGEConv((-1, -1), 64), ('src.paper', 'to', 'author'): GATConv((-1, -1), 64, add_self_loops=False), }) assert len(list(conv.parameters())) > 0 assert str(conv) == 'HeteroConv(num_relations=3)' out_dict = conv( data.x_dict, data.edge_index_dict, edge_weight_dict=data.edge_weight_dict, ) assert len(out_dict) == 2 assert out_dict['src.paper'].size() == (50, 64) assert out_dict['author'].size() == (30, 64) @withDevice @onlyLinux @withPackage('torch>=2.1.0') def test_compile_hetero_conv_graph_breaks(device): import torch._dynamo as dynamo data = HeteroData() data['a'].x = torch.randn(50, 16, device=device) data['b'].x = torch.randn(50, 16, device=device) edge_index = get_random_edge_index(50, 50, 100, device=device) data['a', 'to', 'b'].edge_index = edge_index data['b', 'to', 'a'].edge_index = edge_index.flip([0]) conv = HeteroConv({ ('a', 'to', 'b'): SAGEConv(16, 32), ('b', 'to', 'a'): SAGEConv(16, 32), }).to(device) explanation = dynamo.explain(conv)(data.x_dict, data.edge_index_dict) assert explanation.graph_break_count == 0 compiled_conv = torch.compile(conv) expected = conv(data.x_dict, data.edge_index_dict) out = compiled_conv(data.x_dict, data.edge_index_dict) assert len(out) == len(expected) for key in expected.keys(): assert torch.allclose(out[key], expected[key], atol=1e-6) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() dataset = FakeHeteroDataset(num_graphs=10).to(args.device) def gen_args(): data = dataset[random.randrange(len(dataset))] return data.x_dict, data.edge_index_dict class HeteroGNN(torch.nn.Module): def __init__(self, channels: int = 32, num_layers: int = 2): super().__init__() self.convs = torch.nn.ModuleList() conv = HeteroConv({ edge_type: SAGEConv( in_channels=( dataset.num_features[edge_type[0]], dataset.num_features[edge_type[-1]], ), out_channels=channels, ) for edge_type in dataset[0].edge_types }) self.convs.append(conv) for _ in range(num_layers - 1): conv = HeteroConv({ edge_type: SAGEConv((channels, channels), channels) for edge_type in dataset[0].edge_types }) self.convs.append(conv) self.lin = Linear(channels, 1) def forward(self, x_dict, edge_index_dict): for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) x_dict = {key: x.relu() for key, x in x_dict.items()} return self.lin(x_dict['v0']) model = HeteroGNN().to(args.device) compiled_model = torch.compile(model) benchmark( funcs=[model, compiled_model], func_names=['Vanilla', 'Compiled'], args=gen_args, num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) ================================================ FILE: test/nn/conv/test_hgt_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.data import HeteroData from torch_geometric.nn import HGTConv from torch_geometric.profile import benchmark from torch_geometric.testing import get_random_edge_index from torch_geometric.typing import SparseTensor from torch_geometric.utils import coalesce, to_torch_csc_tensor def test_hgt_conv_same_dimensions(): x_dict = { 'author': torch.randn(4, 16), 'paper': torch.randn(6, 16), } edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) edge_index_dict = { ('author', 'writes', 'paper'): edge_index, ('paper', 'written_by', 'author'): edge_index.flip([0]), } adj_t_dict1 = {} for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type adj_t_dict1[edge_type] = to_torch_csc_tensor( edge_index, size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)), ).t() metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) conv = HGTConv(16, 16, metadata, heads=2) assert str(conv) == 'HGTConv(-1, 16, heads=2)' out_dict1 = conv(x_dict, edge_index_dict) assert len(out_dict1) == 2 assert out_dict1['author'].size() == (4, 16) assert out_dict1['paper'].size() == (6, 16) out_dict2 = conv(x_dict, adj_t_dict1) assert len(out_dict1) == len(out_dict2) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict2 = {} for edge_type, edge_index in edge_index_dict.items(): adj_t_dict2[edge_type] = SparseTensor.from_edge_index( edge_index, sparse_sizes=adj_t_dict1[edge_type].size()[::-1], ).t() out_dict3 = conv(x_dict, adj_t_dict2) assert len(out_dict1) == len(out_dict3) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6) # TODO: Test JIT functionality. We need to wait on this one until PyTorch # allows indexing `ParameterDict` mappings :( def test_hgt_conv_different_dimensions(): x_dict = { 'author': torch.randn(4, 16), 'paper': torch.randn(6, 32), } edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) edge_index_dict = { ('author', 'writes', 'paper'): edge_index, ('paper', 'written_by', 'author'): edge_index.flip([0]), } adj_t_dict1 = {} for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type adj_t_dict1[edge_type] = to_torch_csc_tensor( edge_index, size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)), ).t() metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) conv = HGTConv(in_channels={ 'author': 16, 'paper': 32 }, out_channels=32, metadata=metadata, heads=2) assert str(conv) == 'HGTConv(-1, 32, heads=2)' out_dict1 = conv(x_dict, edge_index_dict) assert len(out_dict1) == 2 assert out_dict1['author'].size() == (4, 32) assert out_dict1['paper'].size() == (6, 32) out_dict2 = conv(x_dict, adj_t_dict1) assert len(out_dict1) == len(out_dict2) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict2 = {} for edge_type, edge_index in edge_index_dict.items(): adj_t_dict2[edge_type] = SparseTensor.from_edge_index( edge_index, sparse_sizes=adj_t_dict1[edge_type].size()[::-1], ).t() out_dict3 = conv(x_dict, adj_t_dict2) assert len(out_dict1) == len(out_dict3) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6) def test_hgt_conv_lazy(): x_dict = { 'author': torch.randn(4, 16), 'paper': torch.randn(6, 32), } edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) edge_index_dict = { ('author', 'writes', 'paper'): edge_index, ('paper', 'written_by', 'author'): edge_index.flip([0]), } adj_t_dict1 = {} for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type adj_t_dict1[edge_type] = to_torch_csc_tensor( edge_index, size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)), ).t() metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) conv = HGTConv(-1, 32, metadata, heads=2) assert str(conv) == 'HGTConv(-1, 32, heads=2)' out_dict1 = conv(x_dict, edge_index_dict) assert len(out_dict1) == 2 assert out_dict1['author'].size() == (4, 32) assert out_dict1['paper'].size() == (6, 32) out_dict2 = conv(x_dict, adj_t_dict1) assert len(out_dict1) == len(out_dict2) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6) if False and torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict2 = {} for edge_type, edge_index in edge_index_dict.items(): adj_t_dict2[edge_type] = SparseTensor.from_edge_index( edge_index, sparse_sizes=adj_t_dict1[edge_type].size()[::-1], ).t() out_dict3 = conv(x_dict, adj_t_dict2) assert len(out_dict1) == len(out_dict3) for key in out_dict1.keys(): assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6) def test_hgt_conv_out_of_place(): data = HeteroData() data['author'].x = torch.randn(4, 16) data['paper'].x = torch.randn(6, 32) edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) data['author', 'paper'].edge_index = edge_index data['paper', 'author'].edge_index = edge_index.flip([0]) conv = HGTConv(-1, 64, data.metadata(), heads=1) x_dict, edge_index_dict = data.x_dict, data.edge_index_dict assert x_dict['author'].size() == (4, 16) assert x_dict['paper'].size() == (6, 32) _ = conv(x_dict, edge_index_dict) assert x_dict['author'].size() == (4, 16) assert x_dict['paper'].size() == (6, 32) def test_hgt_conv_missing_dst_node_type(): data = HeteroData() data['author'].x = torch.randn(4, 16) data['paper'].x = torch.randn(6, 32) data['university'].x = torch.randn(10, 32) data['author', 'paper'].edge_index = get_random_edge_index(4, 6, 20) data['paper', 'author'].edge_index = get_random_edge_index(6, 4, 20) data['university', 'author'].edge_index = get_random_edge_index(10, 4, 10) conv = HGTConv(-1, 64, data.metadata(), heads=1) out_dict = conv(data.x_dict, data.edge_index_dict) assert out_dict['author'].size() == (4, 64) assert out_dict['paper'].size() == (6, 64) assert 'university' not in out_dict def test_hgt_conv_missing_input_node_type(): data = HeteroData() data['author'].x = torch.randn(4, 16) data['paper'].x = torch.randn(6, 32) data['author', 'writes', 'paper'].edge_index = get_random_edge_index(4, 6, 20) # Some nodes from metadata are missing in data. # This might happen while using NeighborLoader. metadata = (['author', 'paper', 'university'], [('author', 'writes', 'paper')]) conv = HGTConv(-1, 64, metadata, heads=1) out_dict = conv(data.x_dict, data.edge_index_dict) assert out_dict['paper'].size() == (6, 64) assert 'university' not in out_dict def test_hgt_conv_missing_edge_type(): data = HeteroData() data['author'].x = torch.randn(4, 16) data['paper'].x = torch.randn(6, 32) data['university'].x = torch.randn(10, 32) data['author', 'writes', 'paper'].edge_index = get_random_edge_index(4, 6, 20) metadata = (['author', 'paper', 'university'], [('author', 'writes', 'paper'), ('university', 'employs', 'author')]) conv = HGTConv(-1, 64, metadata, heads=1) out_dict = conv(data.x_dict, data.edge_index_dict) assert out_dict['author'].size() == (4, 64) assert out_dict['paper'].size() == (6, 64) assert 'university' not in out_dict if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() num_nodes, num_edges = 30_000, 300_000 x_dict = { 'paper': torch.randn(num_nodes, 64, device=args.device), 'author': torch.randn(num_nodes, 64, device=args.device), } edge_index_dict = { ('paper', 'to', 'paper'): torch.randint(num_nodes, (2, num_edges), device=args.device), ('author', 'to', 'paper'): torch.randint(num_nodes, (2, num_edges), device=args.device), ('paper', 'to', 'author'): torch.randint(num_nodes, (2, num_edges), device=args.device), } conv = HGTConv( in_channels=64, out_channels=64, metadata=(list(x_dict.keys()), list(edge_index_dict.keys())), heads=4, ).to(args.device) benchmark( funcs=[conv], args=(x_dict, edge_index_dict), num_steps=10 if args.device == 'cpu' else 100, num_warmups=5 if args.device == 'cpu' else 50, backward=False, ) ================================================ FILE: test/nn/conv/test_hypergraph_conv.py ================================================ import torch from torch_geometric.nn import HypergraphConv def test_hypergraph_conv_with_more_nodes_than_edges(): in_channels, out_channels = (16, 32) hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3], [0, 1, 0, 1, 0, 1]]) num_nodes = hyperedge_index[0].max().item() + 1 num_edges = hyperedge_index[1].max().item() + 1 x = torch.randn((num_nodes, in_channels)) hyperedge_weight = torch.tensor([1.0, 0.5]) hyperedge_attr = torch.randn((num_edges, in_channels)) conv = HypergraphConv(in_channels, out_channels) assert str(conv) == 'HypergraphConv(16, 32)' out = conv(x, hyperedge_index) assert out.size() == (num_nodes, out_channels) out = conv(x, hyperedge_index, hyperedge_weight) assert out.size() == (num_nodes, out_channels) conv = HypergraphConv(in_channels, out_channels, use_attention=True, heads=2) out = conv(x, hyperedge_index, hyperedge_attr=hyperedge_attr) assert out.size() == (num_nodes, 2 * out_channels) out = conv(x, hyperedge_index, hyperedge_weight, hyperedge_attr) assert out.size() == (num_nodes, 2 * out_channels) conv = HypergraphConv(in_channels, out_channels, use_attention=True, heads=2, concat=False, dropout=0.5) out = conv(x, hyperedge_index, hyperedge_weight, hyperedge_attr) assert out.size() == (num_nodes, out_channels) def test_hypergraph_conv_with_more_edges_than_nodes(): in_channels, out_channels = (16, 32) hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3, 3, 3, 2, 1, 2], [0, 1, 2, 1, 2, 1, 0, 3, 3, 4, 4]]) hyperedge_weight = torch.tensor([1.0, 0.5, 0.8, 0.2, 0.7]) num_nodes = hyperedge_index[0].max().item() + 1 x = torch.randn((num_nodes, in_channels)) conv = HypergraphConv(in_channels, out_channels) assert str(conv) == 'HypergraphConv(16, 32)' out = conv(x, hyperedge_index) assert out.size() == (num_nodes, out_channels) out = conv(x, hyperedge_index, hyperedge_weight) assert out.size() == (num_nodes, out_channels) ================================================ FILE: test/nn/conv/test_le_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import LEConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_le_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = LEConv(16, 32) assert str(conv) == 'LEConv(16, 32)' out = conv(x, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj2.t()), out) ================================================ FILE: test/nn/conv/test_lg_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import LGConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_lg_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = LGConv() assert str(conv) == 'LGConv()' out1 = conv(x, edge_index) assert out1.size() == (4, 8) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 8) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) ================================================ FILE: test/nn/conv/test_meshcnn_conv.py ================================================ import pytest import torch from torch.nn import Linear, ModuleList, Sequential, Sigmoid from torch_geometric.nn import MeshCNNConv @pytest.mark.parametrize('in_channels, out_channels', [ (1, 1), (1, 2), (8, 3), (8, 3), (42, 40), ]) def test_meshcnn_conv(in_channels: int, out_channels: int): # m = (V, F), shape [|V| x 3, 3 * |F|] # The simplest manifold triangular mesh is a tetrahedron E_cardinality = 6 # |E|, the number of edges x0 = torch.randn(E_cardinality, in_channels) # X^(k), the prior layer edge_index = torch.tensor([[ 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5 ], [ 1, 2, 3, 4, 2, 0, 4, 5, 5, 3, 0, 1, 2, 5, 4, 0, 0, 3, 5, 1, 1, 4, 3, 2 ]], dtype=torch.int64) # in_channels is the `Dim-Out(k)` in torch.nn.conv.MeshCNNConv # out_channels is the `Dim-Out(k+1)` in torch.nn.conv.MeshCNNConv conv = MeshCNNConv(in_channels, out_channels) # Assert right representation (defined by the class's __repr__ method) # WARN: For now we do not account for the 5 default kernels in the # representation. assert str(conv) == f"MeshCNNConv({in_channels}, {out_channels})" x1 = conv(x0, edge_index) assert x1.size() == (E_cardinality, out_channels) # assert determinism assert torch.allclose(conv(x0, edge_index), x1) # kernels MUST be a ModuleList of length 5. # Where kernels[0] is known as W_0^{(k+1)} in MeshCNNConv etc kernels = ModuleList([ Sequential(Linear(in_channels, out_channels), Sigmoid()) for _ in range(5) ]) with pytest.warns(UserWarning, match="does not have attribute"): conv = MeshCNNConv(in_channels, out_channels, kernels) # WARN: For now we do not account for the 5 kernels in the # representation assert str(conv) == f"MeshCNNConv({in_channels}, {out_channels})" x1 = conv(x0, edge_index) assert x1.size() == (E_cardinality, out_channels) ================================================ FILE: test/nn/conv/test_message_passing.py ================================================ import copy import os.path as osp from typing import Optional, Tuple, Union import pytest import torch from torch import Tensor from torch.nn import Linear import torch_geometric.typing from torch_geometric import EdgeIndex from torch_geometric.nn import MessagePassing, aggr from torch_geometric.typing import ( Adj, OptPairTensor, OptTensor, Size, SparseTensor, ) from torch_geometric.utils import ( add_self_loops, scatter, spmm, to_torch_csc_tensor, ) class MyConv(MessagePassing): def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: str = 'add'): super().__init__(aggr=aggr) if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin_l = Linear(in_channels[0], out_channels) self.lin_r = Linear(in_channels[1], out_channels) def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: OptTensor = None, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) out = self.lin_l(out) x_r = x[1] if x_r is not None: out += self.lin_r(x_r) return out def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: return spmm(adj_t, x[0], reduce=self.aggr) class MyConvWithSelfLoops(MessagePassing): def __init__(self, aggr: str = 'add'): super().__init__(aggr=aggr) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: edge_index, _ = add_self_loops(edge_index) # propagate_type: (x: Tensor) return self.propagate(edge_index, x=x) def test_my_conv_basic(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.randn(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) conv = MyConv(8, 32) out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out, atol=1e-6) assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) conv.fuse = False assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) conv.fuse = True # Bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, value, size=(4, 2)) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6) conv.fuse = False assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6) # Test gradient computation for `torch.sparse` tensors: conv.fuse = True torch_adj_t = adj1.t().requires_grad_() out = conv((x1, x2), torch_adj_t) out.sum().backward() assert torch_adj_t.grad is not None def test_my_conv_save(tmp_path): conv = MyConv(8, 32) assert conv._jinja_propagate is not None assert conv.__class__._jinja_propagate is not None assert conv._orig_propagate is not None assert conv.__class__._orig_propagate is not None path = osp.join(tmp_path, 'model.pt') torch.save(conv, path) conv = torch.load(path, weights_only=False) assert conv._jinja_propagate is not None assert conv.__class__._jinja_propagate is not None assert conv._orig_propagate is not None assert conv.__class__._orig_propagate is not None def test_my_conv_edge_index(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_index = EdgeIndex(edge_index, sparse_size=(4, 4), sort_order='col') conv = MyConv(8, 32) out = conv(x, edge_index) assert out.size() == (4, 32) class MyCommentedConv(MessagePassing): r"""This layer calls `self.propagate()` internally.""" def __init__(self) -> None: super().__init__() def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: # `self.propagate()` is used here to propagate messages. return self.propagate(edge_index, x=x) def test_my_commented_conv(): # Check that `self.propagate` occurrences in comments are correctly # ignored. x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = MyCommentedConv() conv(x, edge_index) jit = torch.jit.script(conv) jit(x, edge_index) class MyKwargsConv(MessagePassing): def __init__(self) -> None: super().__init__() def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: return self.propagate(x=x, edge_index=edge_index) def test_my_kwargs_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = MyKwargsConv() conv(x, edge_index) jit = torch.jit.script(conv) jit(x, edge_index) def test_my_conv_out_of_bounds(): x = torch.randn(3, 8) value = torch.randn(4) conv = MyConv(8, 32) with pytest.raises(IndexError, match="valid indices"): edge_index = torch.tensor([[-1, 1, 2, 2], [0, 0, 1, 1]]) conv(x, edge_index, value) with pytest.raises(IndexError, match="valid indices"): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv(x, edge_index, value) def test_my_conv_jit(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.randn(edge_index.size(1)) conv = MyConv(8, 32) out = conv(x1, edge_index, value) jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, value), out, atol=1e-6) assert torch.allclose(jit(x1, edge_index, value, (4, 4)), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) jit.fuse = False assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) jit.fuse = True conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, value, (4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6) jit.fuse = False assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6) jit.fuse = True def test_my_conv_jit_save(tmp_path): path = osp.join(tmp_path, 'model.pt') conv = MyConv(8, 32) conv = torch.jit.script(conv) torch.jit.save(conv, path) conv = torch.jit.load(path) @pytest.mark.parametrize('aggr', ['add', 'sum', 'mean', 'min', 'max', 'mul']) def test_my_conv_aggr(aggr): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_weight = torch.randn(edge_index.size(1)) conv = MyConv(8, 32, aggr=aggr) out = conv(x, edge_index, edge_weight) assert out.size() == (4, 32) def test_my_static_graph_conv(): x1 = torch.randn(3, 4, 8) x2 = torch.randn(3, 2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.randn(edge_index.size(1)) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) conv = MyConv(8, 32) out = conv(x1, edge_index, value) assert out.size() == (3, 4, 32) assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x1, adj.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (3, 2, 32) assert out2.size() == (3, 2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv((x1, x2), adj.t()), out1) assert torch.allclose(conv((x1, None), adj.t()), out2) class MyMultipleAggrConv(MessagePassing): def __init__(self, **kwargs): super().__init__(aggr=['add', 'mean', 'max'], **kwargs) def forward(self, x: Tensor, edge_index: Adj) -> Tensor: # propagate_type: (x: Tensor) return self.propagate(edge_index, x=x) @pytest.mark.parametrize('multi_aggr_tuple', [ (dict(mode='cat'), 3), (dict(mode='proj', mode_kwargs=dict(in_channels=16, out_channels=16)), 1) ]) def test_my_multiple_aggr_conv(multi_aggr_tuple): # The 'cat' combine mode will expand the output dimensions by # the number of aggregators which is 3 here, while the 'proj' # mode keeps output dimensions unchanged. aggr_kwargs, expand = multi_aggr_tuple x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs) out = conv(x, edge_index) assert out.size() == (4, 16 * expand) assert torch.allclose(conv(x, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x, adj2.t()), out) def test_my_multiple_aggr_conv_jit(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = MyMultipleAggrConv() out = conv(x, edge_index) jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(jit(x, adj.t()), out) def test_copy(): conv = MyConv(8, 32) conv2 = copy.copy(conv) assert conv != conv2 assert torch.equal(conv.lin_l.weight, conv2.lin_l.weight) assert torch.equal(conv.lin_r.weight, conv2.lin_r.weight) assert conv.lin_l.weight.data_ptr == conv2.lin_l.weight.data_ptr assert conv.lin_r.weight.data_ptr == conv2.lin_r.weight.data_ptr conv = copy.deepcopy(conv) assert conv != conv2 assert torch.equal(conv.lin_l.weight, conv2.lin_l.weight) assert torch.equal(conv.lin_r.weight, conv2.lin_r.weight) assert conv.lin_l.weight.data_ptr != conv2.lin_l.weight.data_ptr assert conv.lin_r.weight.data_ptr != conv2.lin_r.weight.data_ptr class MyEdgeConv(MessagePassing): def __init__(self): super().__init__(aggr='add') def forward(self, x: Tensor, edge_index: Adj) -> Tensor: # edge_updater_type: (x: Tensor) edge_attr = self.edge_updater(edge_index, x=x) # propagate_type: (edge_attr: Tensor) return self.propagate(edge_index, edge_attr=edge_attr, size=(x.size(0), x.size(0))) def edge_update(self, x_j: Tensor, x_i: Tensor) -> Tensor: return x_j - x_i def message(self, edge_attr: Tensor) -> Tensor: return edge_attr def test_my_edge_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) row, col = edge_index expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='sum') conv = MyEdgeConv() out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(out, expected) assert torch.allclose(conv(x, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out) def test_my_edge_conv_jit(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = MyEdgeConv() out = conv(x, edge_index) jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(jit(x, adj.t()), out) num_pre_hook_calls = 0 num_hook_calls = 0 def test_message_passing_hooks(): conv = MyConv(8, 32) x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.randn(edge_index.size(1)) adj = to_torch_csc_tensor(edge_index, value, size=(4, 4)) def pre_hook(module, inputs): assert module == conv global num_pre_hook_calls num_pre_hook_calls += 1 return inputs def hook(module, inputs, output): assert module == conv global num_hook_calls num_hook_calls += 1 return output handle1 = conv.register_propagate_forward_pre_hook(pre_hook) assert len(conv._propagate_forward_pre_hooks) == 1 handle2 = conv.register_propagate_forward_hook(hook) assert len(conv._propagate_forward_hooks) == 1 handle3 = conv.register_message_forward_pre_hook(pre_hook) assert len(conv._message_forward_pre_hooks) == 1 handle4 = conv.register_message_forward_hook(hook) assert len(conv._message_forward_hooks) == 1 handle5 = conv.register_aggregate_forward_pre_hook(pre_hook) assert len(conv._aggregate_forward_pre_hooks) == 1 handle6 = conv.register_aggregate_forward_hook(hook) assert len(conv._aggregate_forward_hooks) == 1 handle7 = conv.register_message_and_aggregate_forward_pre_hook(pre_hook) assert len(conv._message_and_aggregate_forward_pre_hooks) == 1 handle8 = conv.register_message_and_aggregate_forward_hook(hook) assert len(conv._message_and_aggregate_forward_hooks) == 1 out1 = conv(x, edge_index, value) assert num_pre_hook_calls == 3 assert num_hook_calls == 3 out2 = conv(x, adj.t()) assert num_pre_hook_calls == 5 assert num_hook_calls == 5 assert torch.allclose(out1, out2, atol=1e-6) handle1.remove() assert len(conv._propagate_forward_pre_hooks) == 0 handle2.remove() assert len(conv._propagate_forward_hooks) == 0 handle3.remove() assert len(conv._message_forward_pre_hooks) == 0 handle4.remove() assert len(conv._message_forward_hooks) == 0 handle5.remove() assert len(conv._aggregate_forward_pre_hooks) == 0 handle6.remove() assert len(conv._aggregate_forward_hooks) == 0 handle7.remove() assert len(conv._message_and_aggregate_forward_pre_hooks) == 0 handle8.remove() assert len(conv._message_and_aggregate_forward_hooks) == 0 conv = MyEdgeConv() handle1 = conv.register_edge_update_forward_pre_hook(pre_hook) assert len(conv._edge_update_forward_pre_hooks) == 1 handle2 = conv.register_edge_update_forward_hook(hook) assert len(conv._edge_update_forward_hooks) == 1 out1 = conv(x, edge_index) assert num_pre_hook_calls == 6 assert num_hook_calls == 6 out2 = conv(x, adj.t()) assert num_pre_hook_calls == 7 assert num_hook_calls == 7 assert torch.allclose(out1, out2, atol=1e-6) handle1.remove() assert len(conv._propagate_forward_pre_hooks) == 0 handle2.remove() assert len(conv._propagate_forward_hooks) == 0 def test_modified_message_passing_hook(): conv = MyConv(8, 32) x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_weight = torch.randn(edge_index.size(1)) out1 = conv(x, edge_index, edge_weight) def hook(module, inputs, output): assert len(inputs) == 1 assert len(inputs[-1]) == 2 assert 'x_j' in inputs[-1] assert 'edge_weight' in inputs[-1] return output + 1. conv.register_message_forward_hook(hook) out2 = conv(x, edge_index, edge_weight) assert not torch.allclose(out1, out2, atol=1e-6) class MyDefaultArgConv(MessagePassing): def __init__(self): super().__init__(aggr='mean') # propagate_type: (x: Tensor) def forward(self, x: Tensor, edge_index: Adj) -> Tensor: return self.propagate(edge_index, x=x) def message(self, x_j, zeros: bool = True): return x_j * 0 if zeros else x_j def test_my_default_arg_conv(): x = torch.randn(4, 1) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = MyDefaultArgConv() assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] assert conv(x, adj1.t()).view(-1).tolist() == [0, 0, 0, 0] if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert conv(x, adj2.t()).view(-1).tolist() == [0, 0, 0, 0] jit = torch.jit.script(conv) assert jit(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] assert jit(x, adj1.t()).view(-1).tolist() == [0, 0, 0, 0] class MyMultipleOutputConv(MessagePassing): def __init__(self): super().__init__() def forward(self, x: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]: # propagate_type: (x: Tensor) return self.propagate(edge_index, x=x) def message(self, x_j: Tensor) -> Tuple[Tensor, Tensor]: return x_j, x_j def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor) -> Tuple[Tensor, Tensor]: return (scatter(inputs[0], index, dim=0, reduce='sum'), scatter(inputs[0], index, dim=0, reduce='mean')) def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: return inputs def test_tuple_output(): conv = MyMultipleOutputConv() x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) out1 = conv(x, edge_index) assert isinstance(out1, tuple) and len(out1) == 2 def test_tuple_output_jit(): conv = MyMultipleOutputConv() x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) out1 = conv(x, edge_index) assert isinstance(out1, tuple) and len(out1) == 2 jit = torch.jit.script(conv) out2 = jit(x, edge_index) assert isinstance(out2, tuple) and len(out2) == 2 assert torch.allclose(out1[0], out2[0]) assert torch.allclose(out1[1], out2[1]) class MyExplainConv(MessagePassing): def __init__(self): super().__init__(aggr='add') def forward(self, x: Tensor, edge_index: Adj) -> Tensor: return self.propagate(edge_index, x=x) def test_explain_message(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = MyExplainConv() conv.explain = True assert conv.propagate.__module__.endswith('message_passing') with pytest.raises(ValueError, match="pre-defined 'edge_mask'"): conv(x, edge_index) conv._edge_mask = torch.tensor([0.0, 0.0, 0.0, 0.0]) conv._apply_sigmoid = False assert conv(x, edge_index).abs().sum() == 0. conv._edge_mask = torch.tensor([1.0, 1.0, 1.0, 1.0]) conv._apply_sigmoid = False out1 = conv(x, edge_index) # TorchScript should still work since it relies on class methods # (but without explainability). torch.jit.script(conv) conv.explain = False assert conv.propagate.__module__.endswith('MyExplainConv_propagate') out2 = conv(x, edge_index) assert torch.allclose(out1, out2) class MyAggregatorConv(MessagePassing): def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, x: Tensor, edge_index: Adj) -> Tensor: # propagate_type: (x: Tensor) return self.propagate(edge_index, x=x) @pytest.mark.parametrize('aggr_module', [ aggr.MeanAggregation(), aggr.SumAggregation(), aggr.MaxAggregation(), aggr.SoftmaxAggregation(), aggr.PowerMeanAggregation(), aggr.MultiAggregation(['mean', 'max']) ]) def test_message_passing_with_aggr_module(aggr_module): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = MyAggregatorConv(aggr=aggr_module) assert isinstance(conv.aggr_module, aggr.Aggregation) out = conv(x, edge_index) assert out.size(0) == 4 and out.size(1) in {8, 16} assert torch.allclose(conv(x, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj2.t()), out) def test_message_passing_int32_edge_index(): # Check that we can dispatch an int32 edge_index up to aggregation x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.int32) edge_weight = torch.randn(edge_index.shape[1]) # Use a hook to promote the edge_index to long to workaround PyTorch CPU # backend restriction to int64 for the index. def cast_index_hook(module, inputs): input_dict = inputs[-1] input_dict['index'] = input_dict['index'].long() return (input_dict, ) conv = MyConv(8, 32) conv.register_aggregate_forward_pre_hook(cast_index_hook) assert conv(x, edge_index, edge_weight).size() == (4, 32) @pytest.mark.parametrize('num_nodes', [4, 8, 2, 0]) def test_traceable_my_conv_with_self_loops(num_nodes): # `torch.jit.trace` a `MessagePassing` layer that adds self loops and test # it across different input sizes. x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) conv = MyConvWithSelfLoops() traced_conv = torch.jit.trace(conv, ((x, edge_index))) scripted_conv = torch.jit.script(conv) x = torch.randn(num_nodes, 16) if num_nodes > 0: edge_index = torch.stack([ torch.arange(0, num_nodes - 1), torch.arange(1, num_nodes), ], dim=0) else: edge_index = torch.empty((2, 0), dtype=torch.long) out = conv(x, edge_index) traced_out = traced_conv(x, edge_index) scripted_out = scripted_conv(x, edge_index) assert torch.allclose(out, traced_out) assert torch.allclose(out, scripted_out) def test_pickle(tmp_path): path = osp.join(tmp_path, 'model.pt') model = MyConv(16, 32) torch.save(model, path) MyConv.propagate = MyConv._orig_propagate model = torch.load(path, weights_only=False) torch.jit.script(model) class MyOptionalEdgeAttrConv(MessagePassing): def __init__(self): super().__init__() def forward(self, x, edge_index, edge_attr=None): return self.propagate(edge_index, x=x, edge_attr=edge_attr) def message(self, x_j, edge_attr=None): return x_j if edge_attr is None else x_j * edge_attr.view(-1, 1) def test_my_optional_edge_attr_conv(): conv = MyOptionalEdgeAttrConv() x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) out = conv(x, edge_index) assert out.size() == (4, 8) ================================================ FILE: test/nn/conv/test_mf_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import MFConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor def test_mf_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = MFConv(8, 32) assert str(conv) == 'MFConv(8, 32)' out = conv(x1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out) assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj.t()), out) # Test bipartite message passing: conv = MFConv((8, 16), 32) assert str(conv) == 'MFConv((8, 16), 32)' out1 = conv((x1, x2), edge_index) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1) out2 = conv((x1, None), edge_index, (4, 2)) assert out2.size() == (2, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj.t()), out1) assert torch.allclose(conv((x1, None), adj.t()), out2) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index), out1) assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj.t()), out1) assert torch.allclose(jit((x1, None), adj.t()), out2) ================================================ FILE: test/nn/conv/test_mixhop_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import MixHopConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_mixhop_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = MixHopConv(16, 32, powers=[0, 1, 2, 4]) assert str(conv) == 'MixHopConv(16, 32, powers=[0, 1, 2, 4])' out1 = conv(x, edge_index) assert out1.size() == (4, 128) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 128) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) ================================================ FILE: test/nn/conv/test_nn_conv.py ================================================ import torch from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq import torch_geometric.typing from torch_geometric.nn import NNConv from torch_geometric.testing import is_full_test, withCUDA from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_coo_tensor @withCUDA def test_nn_conv(device): x1 = torch.randn(4, 8, device=device) x2 = torch.randn(2, 16, device=device) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device) value = torch.rand(edge_index.size(1), 3, device=device) adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 4)) nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32)) conv = NNConv(8, 32, nn=nn).to(device) assert str(conv) == ( 'NNConv(8, 32, aggr=add, nn=Sequential(\n' ' (0): Linear(in_features=3, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=256, bias=True)\n' '))') out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out) assert torch.allclose(conv(x1, adj1.transpose(0, 1).coalesce()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x1, adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, value), out) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out) # Test bipartite message passing: adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 2)) conv = NNConv((8, 16), 32, nn=nn).to(device) assert str(conv) == ( 'NNConv((8, 16), 32, aggr=add, nn=Sequential(\n' ' (0): Linear(in_features=3, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=256, bias=True)\n' '))') out1 = conv((x1, x2), edge_index, value) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(conv((x1, x2), adj1.transpose(0, 1).coalesce()), out1) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, None), adj1.transpose(0, 1).coalesce()), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out1) assert torch.allclose(conv((x1, None), adj2.t()), out2) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out1) assert torch.allclose(jit((x1, None), adj2.t()), out2) ================================================ FILE: test/nn/conv/test_pan_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import PANConv from torch_geometric.testing import withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor @withPackage('torch_sparse') # TODO `PANConv` returns a `SparseTensor`. def test_pan_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) conv = PANConv(16, 32, filter_size=2) assert str(conv) == 'PANConv(16, 32, filter_size=2)' out1, M1 = conv(x, edge_index) assert out1.size() == (4, 32) out2, M2 = conv(x, adj1.t()) assert torch.allclose(out1, out2, atol=1e-6) assert torch.allclose(M1.to_dense(), M2.to_dense()) if torch_geometric.typing.WITH_TORCH_SPARSE: out3, M3 = conv(x, adj2.t()) assert torch.allclose(out1, out3, atol=1e-6) assert torch.allclose(M1.to_dense(), M3.to_dense()) ================================================ FILE: test/nn/conv/test_pdn_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import PDNConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor def test_pdn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.randn(edge_index.size(1), 8) conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128) assert str(conv) == "PDNConv(16, 32)" out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index, edge_attr), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj.t()), out, atol=1e-6) def test_pdn_conv_with_sparse_node_input_feature(): x = torch.sparse_coo_tensor( indices=torch.tensor([[0, 0], [0, 1]]), values=torch.tensor([1.0, 1.0]), size=torch.Size([4, 16]), ) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.randn(edge_index.size(1), 8) conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128) out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) assert torch.allclose(conv(x, adj.t(), edge_attr), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index, edge_attr), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj.t(), edge_attr), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_pna_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.loader import DataLoader, NeighborLoader from torch_geometric.nn import PNAConv from torch_geometric.testing import is_full_test, onlyNeighborSampler from torch_geometric.typing import SparseTensor aggregators = ['sum', 'mean', 'min', 'max', 'var', 'std'] scalers = [ 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' ] @pytest.mark.parametrize('divide_input', [True, False]) def test_pna_conv(divide_input): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) deg = torch.tensor([0, 3, 0, 1]) value = torch.rand(edge_index.size(1), 3) conv = PNAConv(16, 32, aggregators, scalers, deg=deg, edge_dim=3, towers=4, pre_layers=2, post_layers=2, divide_input=divide_input) assert str(conv) == 'PNAConv(16, 32, towers=4, edge_dim=3)' out = conv(x, edge_index, value) assert out.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index, value), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj.t()), out, atol=1e-6) @onlyNeighborSampler def test_pna_conv_get_degree_histogram_neighbor_loader(): edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]]) data = Data(num_nodes=5, edge_index=edge_index) loader = NeighborLoader( data, num_neighbors=[-1], input_nodes=None, batch_size=5, shuffle=False, ) deg_hist = PNAConv.get_degree_histogram(loader) assert torch.equal(deg_hist, torch.tensor([1, 2, 1, 1])) def test_pna_conv_get_degree_histogram_dataloader(): edge_index_1 = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]]) edge_index_2 = torch.tensor([[1, 1, 2, 2, 0, 3, 3], [2, 3, 3, 1, 1, 0, 2]]) edge_index_3 = torch.tensor([[1, 3, 2, 0, 0, 4, 2], [2, 0, 4, 1, 1, 0, 3]]) edge_index_4 = torch.tensor([[0, 1, 2, 4, 0, 1, 3], [2, 3, 3, 1, 1, 0, 2]]) data_1 = Data(num_nodes=5, edge_index=edge_index_1) # hist = [1, 2 ,1 ,1] data_2 = Data(num_nodes=5, edge_index=edge_index_2) # hist = [1, 1, 3] data_3 = Data(num_nodes=5, edge_index=edge_index_3) # hist = [0, 3, 2] data_4 = Data(num_nodes=5, edge_index=edge_index_4) # hist = [1, 1, 3] loader = DataLoader( [data_1, data_2, data_3, data_4], batch_size=1, shuffle=False, ) deg_hist = PNAConv.get_degree_histogram(loader) assert torch.equal(deg_hist, torch.tensor([3, 7, 9, 1])) ================================================ FILE: test/nn/conv/test_point_conv.py ================================================ import torch from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq import torch_geometric.typing from torch_geometric.nn import PointNetConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_point_net_conv(): x1 = torch.randn(4, 16) pos1 = torch.randn(4, 3) pos2 = torch.randn(2, 3) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32)) global_nn = Seq(Lin(32, 32)) conv = PointNetConv(local_nn, global_nn) assert str(conv) == ( 'PointNetConv(local_nn=Sequential(\n' ' (0): Linear(in_features=19, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '), global_nn=Sequential(\n' ' (0): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, pos1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, pos1, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, pos1, adj2.t()), out, atol=1e-6) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) out = conv(x1, (pos1, pos2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, None), (pos1, pos2), edge_index), out) assert torch.allclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6) assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6) assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out, atol=1e-6) if is_full_test(): assert torch.allclose(jit((x1, None), (pos1, pos2), edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, None), (pos1, pos2), adj2.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_point_gnn_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import MLP, PointGNNConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_point_gnn_conv(): x = torch.randn(6, 8) pos = torch.randn(6, 3) edge_index = torch.tensor([[0, 1, 1, 1, 2, 5], [1, 2, 3, 4, 3, 4]]) adj1 = to_torch_csc_tensor(edge_index, size=(6, 6)) conv = PointGNNConv( mlp_h=MLP([8, 16, 3]), mlp_f=MLP([3 + 8, 16, 8]), mlp_g=MLP([8, 16, 8]), ) assert str(conv) == ('PointGNNConv(\n' ' mlp_h=MLP(8, 16, 3),\n' ' mlp_f=MLP(11, 16, 8),\n' ' mlp_g=MLP(8, 16, 8),\n' ')') out = conv(x, pos, edge_index) assert out.size() == (6, 8) assert torch.allclose(conv(x, pos, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6)) assert torch.allclose(conv(x, pos, adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, pos, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, pos, adj2.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_point_transformer_conv.py ================================================ import torch from torch.nn import Linear, ReLU, Sequential import torch_geometric.typing from torch_geometric.nn import PointTransformerConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_point_transformer_conv(): x1 = torch.rand(4, 16) x2 = torch.randn(2, 8) pos1 = torch.rand(4, 3) pos2 = torch.randn(2, 3) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = PointTransformerConv(in_channels=16, out_channels=32) assert str(conv) == 'PointTransformerConv(16, 32)' out = conv(x1, pos1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, pos1, edge_index), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, pos1, adj2.t()), out, atol=1e-6) pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32)) attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32)) conv = PointTransformerConv(16, 32, pos_nn, attn_nn) out = conv(x1, pos1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6) # Test biparitite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) conv = PointTransformerConv((16, 8), 32) assert str(conv) == 'PointTransformerConv((16, 8), 32)' out = conv((x1, x2), (pos1, pos2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), (pos1, pos2), adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), (pos1, pos2), adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), (pos1, pos2), edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), (pos1, pos2), adj2.t()), out) ================================================ FILE: test/nn/conv/test_ppf_conv.py ================================================ import torch import torch.nn.functional as F from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq import torch_geometric.typing from torch_geometric.nn import PPFConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_ppf_conv(): x1 = torch.randn(4, 16) pos1 = torch.randn(4, 3) pos2 = torch.randn(2, 3) n1 = F.normalize(torch.rand(4, 3), dim=-1) n2 = F.normalize(torch.rand(2, 3), dim=-1) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32)) global_nn = Seq(Lin(32, 32)) conv = PPFConv(local_nn, global_nn) assert str(conv) == ( 'PPFConv(local_nn=Sequential(\n' ' (0): Linear(in_features=20, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '), global_nn=Sequential(\n' ' (0): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, pos1, n1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, n1, adj1.t()), out, atol=1e-3) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, pos1, n1, adj2.t()), out, atol=1e-3) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, pos1, n1, edge_index), out, atol=1e-3) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, pos1, n1, adj2.t()), out, atol=1e-3) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) out = conv(x1, (pos1, pos2), (n1, n2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), edge_index), out, atol=1e-3) assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj1.t()), out, atol=1e-3) assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj1.t()), out, atol=1e-3) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj2.t()), out, atol=1e-3) assert torch.allclose( conv((x1, None), (pos1, pos2), (n1, n2), adj2.t()), out, atol=1e-3) if is_full_test(): assert torch.allclose( jit((x1, None), (pos1, pos2), (n1, n2), edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose( jit((x1, None), (pos1, pos2), (n1, n2), adj2.t()), out, atol=1e-3) ================================================ FILE: test/nn/conv/test_res_gated_graph_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import ResGatedGraphConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('edge_dim', [None, 4]) def test_res_gated_graph_conv(edge_dim): x1 = torch.randn(4, 8) x2 = torch.randn(2, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_attr = torch.randn(edge_index.size(1), edge_dim) if edge_dim else None adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = ResGatedGraphConv(8, 32, edge_dim=edge_dim) assert str(conv) == 'ResGatedGraphConv(8, 32)' out = conv(x1, edge_index, edge_attr) assert out.size() == (4, 32) assert torch.allclose(conv(x1, adj1.t(), edge_attr), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, edge_attr), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) conv = ResGatedGraphConv((8, 32), 32, edge_dim=edge_dim) assert str(conv) == 'ResGatedGraphConv((8, 32), 32)' out = conv((x1, x2), edge_index, edge_attr) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), adj1.t(), edge_attr), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, edge_attr), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_rgat_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import RGATConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_coo_tensor @pytest.mark.parametrize('mod', [ 'additive', 'scaled', 'f-additive', 'f-scaled', ]) @pytest.mark.parametrize('attention_mechanism', [ 'within-relation', 'across-relation', ]) @pytest.mark.parametrize('attention_mode', [ 'additive-self-attention', 'multiplicative-self-attention', ]) @pytest.mark.parametrize('concat', [True, False]) @pytest.mark.parametrize('edge_dim', [8, None]) def test_rgat_conv(mod, attention_mechanism, attention_mode, concat, edge_dim): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_type = torch.tensor([0, 2, 1, 2]) edge_attr = torch.randn((4, edge_dim)) if edge_dim else None conv1 = RGATConv( # `num_bases` is not None: in_channels=8, out_channels=16, num_relations=4, num_bases=4, mod=mod, attention_mechanism=attention_mechanism, attention_mode=attention_mode, heads=2, dim=1, concat=concat, edge_dim=edge_dim, ) conv2 = RGATConv( # `num_blocks` is not `None` in_channels=8, out_channels=16, num_relations=4, num_blocks=4, mod=mod, attention_mechanism=attention_mechanism, attention_mode=attention_mode, heads=2, dim=1, concat=concat, edge_dim=edge_dim, ) conv3 = RGATConv( # Both `num_bases` and `num_blocks` are `None`: in_channels=8, out_channels=16, num_relations=4, mod=mod, attention_mechanism=attention_mechanism, attention_mode=attention_mode, heads=2, dim=1, concat=concat, edge_dim=edge_dim, ) conv4 = RGATConv( # `dropout > 0` and `mod` is `None`: in_channels=8, out_channels=16, num_relations=4, mod=None, attention_mechanism=attention_mechanism, attention_mode=attention_mode, heads=2, dim=1, concat=concat, edge_dim=edge_dim, dropout=0.5, ) for conv in [conv1, conv2, conv3, conv4]: assert str(conv) == 'RGATConv(8, 16, heads=2)' out = conv(x, edge_index, edge_type, edge_attr) assert out.size() == (4, 16 * (2 if concat else 1)) out, (adj, alpha) = conv(x, edge_index, edge_type, edge_attr, return_attention_weights=True) assert out.size() == (4, 16 * (2 if concat else 1)) assert adj.size() == edge_index.size() assert alpha.size() == (4, 2) def test_rgat_conv_jit(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_attr = torch.randn((edge_index.size(1), 8)) edge_type = torch.tensor([0, 2, 1, 2]) adj1 = to_torch_coo_tensor(edge_index, edge_attr, size=(4, 4)) conv = RGATConv(8, 20, num_relations=4, num_bases=4, mod='additive', attention_mechanism='across-relation', attention_mode='additive-self-attention', heads=2, dim=1, edge_dim=8, bias=False) out = conv(x, edge_index, edge_type, edge_attr) assert out.size() == (4, 40) # t() expects a tensor with <= 2 sparse and 0 dense dimensions adj1_t = adj1.transpose(0, 1).coalesce() assert torch.allclose(conv(x, adj1_t, edge_type), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) assert torch.allclose(conv(x, adj2.t(), edge_type), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index, edge_type), conv(x, edge_index, edge_type)) ================================================ FILE: test/nn/conv/test_rgcn_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import FastRGCNConv, RGCNConv from torch_geometric.testing import is_full_test, withCUDA, withDevice from torch_geometric.typing import SparseTensor classes = [RGCNConv, FastRGCNConv] confs = [(None, None), (2, None), (None, 2)] @withDevice @pytest.mark.parametrize('conf', confs) def test_rgcn_conv_equality(conf, device): num_bases, num_blocks = conf x1 = torch.randn(4, 4, device=device) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1], ], device=device) edge_type = torch.tensor([0, 1, 1, 0, 0, 1], device=device) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1], ], device=device) edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3], device=device) torch.manual_seed(12345) conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks, aggr='sum').to(device) torch.manual_seed(12345) conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks, aggr='sum').to(device) out1 = conv1(x1, edge_index, edge_type) out2 = conv2(x1, edge_index, edge_type) assert torch.allclose(out1, out2, atol=1e-2) if num_blocks is None: out1 = conv1(None, edge_index, edge_type) out2 = conv2(None, edge_index, edge_type) assert torch.allclose(out1, out2, atol=1e-2) @withCUDA @pytest.mark.parametrize('cls', classes) @pytest.mark.parametrize('conf', confs) def test_rgcn_conv_basic(cls, conf, device): num_bases, num_blocks = conf x1 = torch.randn(4, 4, device=device) x2 = torch.randn(2, 16, device=device) idx1 = torch.arange(4, device=device) idx2 = torch.arange(2, device=device) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1], ], device=device) edge_type = torch.tensor([0, 1, 1, 0, 0, 1], device=device) conv = cls(4, 32, 2, num_bases, num_blocks, aggr='sum').to(device) assert str(conv) == f'{cls.__name__}(4, 32, num_relations=2)' out1 = conv(x1, edge_index, edge_type) assert out1.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 4)) assert torch.allclose(conv(x1, adj.t()), out1, atol=1e-3) if num_blocks is None: out2 = conv(None, edge_index, edge_type) assert torch.allclose(conv(idx1, edge_index, edge_type), out2, 1e-3) assert out2.size() == (4, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(None, adj.t()), out2, atol=1e-3) assert torch.allclose(conv(idx1, adj.t()), out2, atol=1e-3) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, edge_type), out1, atol=1e-3) if num_blocks is None: assert torch.allclose(jit(idx1, edge_index, edge_type), out2, atol=1e-3) assert torch.allclose(jit(None, edge_index, edge_type), out2, atol=1e-3) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj.t()), out1) if num_blocks is None: assert torch.allclose(jit(idx1, adj.t()), out2, atol=1e-3) assert torch.allclose(jit(None, adj.t()), out2, atol=1e-3) # Test bipartite message passing: conv = cls((4, 16), 32, 2, num_bases, num_blocks, aggr='sum').to(device) assert str(conv) == f'{cls.__name__}((4, 16), 32, num_relations=2)' out1 = conv((x1, x2), edge_index, edge_type) assert out1.size() == (2, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 2)) assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-3) if num_blocks is None: out2 = conv((None, idx2), edge_index, edge_type) assert out2.size() == (2, 32) assert torch.allclose(conv((idx1, idx2), edge_index, edge_type), out2, atol=1e-3) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv((None, idx2), adj.t()), out2, atol=1e-3) assert torch.allclose(conv((idx1, idx2), adj.t()), out2, atol=1e-3) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, edge_type), out1, atol=1e-3) if num_blocks is None: assert torch.allclose(jit((None, idx2), edge_index, edge_type), out2, atol=1e-3) assert torch.allclose(jit((idx1, idx2), edge_index, edge_type), out2, atol=1e-3) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-3) if num_blocks is None: assert torch.allclose(jit((None, idx2), adj.t()), out2, atol=1e-3) assert torch.allclose(jit((idx1, idx2), adj.t()), out2, atol=1e-3) ================================================ FILE: test/nn/conv/test_sage_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import MLPAggregation, SAGEConv from torch_geometric.testing import ( assert_module, is_full_test, onlyLinux, withDevice, withPackage, ) from torch_geometric.typing import SparseTensor @pytest.mark.parametrize('project', [False, True]) @pytest.mark.parametrize('aggr', ['mean', 'sum']) def test_sage_conv(project, aggr): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = SAGEConv(8, 32, project=project, aggr=aggr) assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})' out = assert_module(conv, x, edge_index, expected_size=(4, 32)) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out, atol=1e-6) assert torch.allclose(jit(x, edge_index, size=(4, 4)), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6) # Test bipartite message passing: x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) conv = SAGEConv((8, 16), 32, project=project, aggr=aggr) assert str(conv) == f'SAGEConv((8, 16), 32, aggr={aggr})' out1 = assert_module(conv, (x1, x2), edge_index, expected_size=(2, 32)) out2 = assert_module(conv, (x1, None), edge_index, size=(4, 2), expected_size=(2, 32)) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-6) assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6) @pytest.mark.parametrize('project', [False, True]) def test_lazy_sage_conv(project): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) if project: with pytest.raises(ValueError, match="does not support lazy"): SAGEConv(-1, 32, project=project) else: conv = SAGEConv(-1, 32, project=project) assert str(conv) == 'SAGEConv(-1, 32, aggr=mean)' out = conv(x, edge_index) assert out.size() == (4, 32) def test_lstm_aggr_sage_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = SAGEConv(8, 32, aggr='lstm') assert str(conv) == 'SAGEConv(8, 32, aggr=lstm)' assert_module(conv, x, edge_index, expected_size=(4, 32), test_edge_permutation=False) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]]) with pytest.raises(ValueError, match="'index' tensor is not sorted"): conv(x, edge_index) def test_mlp_sage_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = SAGEConv( in_channels=8, out_channels=32, aggr=MLPAggregation( in_channels=8, out_channels=8, max_num_elements=2, num_layers=1, ), ) out = conv(x, edge_index) assert out.size() == (4, 32) @pytest.mark.parametrize('aggr_kwargs', [ dict(mode='cat'), dict(mode='proj', mode_kwargs=dict(in_channels=8, out_channels=16)), dict(mode='attn', mode_kwargs=dict(in_channels=8, out_channels=16, num_heads=4)), dict(mode='sum'), ]) def test_multi_aggr_sage_conv(aggr_kwargs): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) aggr_kwargs['aggrs_kwargs'] = [{}, {}, {}, dict(learn=True, t=1)] conv = SAGEConv(8, 32, aggr=['mean', 'max', 'sum', 'softmax'], aggr_kwargs=aggr_kwargs) assert_module(conv, x, edge_index, expected_size=(4, 32)) @withDevice @onlyLinux @withPackage('torch>=2.1.0') def test_compile_multi_aggr_sage_conv(device): import torch._dynamo as dynamo x = torch.randn(4, 8, device=device) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device) conv = SAGEConv( in_channels=8, out_channels=32, aggr=['mean', 'sum', 'min', 'max', 'std'], ).to(device) explanation = dynamo.explain(conv)(x, edge_index) assert explanation.graph_break_count == 0 compiled_conv = torch.compile(conv) expected = conv(x, edge_index) out = compiled_conv(x, edge_index) assert torch.allclose(out, expected, atol=1e-6) ================================================ FILE: test/nn/conv/test_sg_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import SGConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_sg_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = SGConv(16, 32, K=10) assert str(conv) == 'SGConv(16, 32, K=10)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv._cached_x is not None assert torch.allclose(conv(x, edge_index), out1, atol=1e-6) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) ================================================ FILE: test/nn/conv/test_signed_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import SignedConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_signed_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv1 = SignedConv(16, 32, first_aggr=True) assert str(conv1) == 'SignedConv(16, 32, first_aggr=True)' conv2 = SignedConv(32, 48, first_aggr=False) assert str(conv2) == 'SignedConv(32, 48, first_aggr=False)' out1 = conv1(x, edge_index, edge_index) assert out1.size() == (4, 64) assert torch.allclose(conv1(x, adj1.t(), adj1.t()), out1) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv1(x, adj2.t(), adj2.t()), out1) out2 = conv2(out1, edge_index, edge_index) assert out2.size() == (4, 96) assert torch.allclose(conv2(out1, adj1.t(), adj1.t()), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv2(out1, adj2.t(), adj2.t()), out2) if is_full_test(): jit1 = torch.jit.script(conv1) jit2 = torch.jit.script(conv2) assert torch.allclose(jit1(x, edge_index, edge_index), out1) assert torch.allclose(jit2(out1, edge_index, edge_index), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit1(x, adj2.t(), adj2.t()), out1) assert torch.allclose(jit2(out1, adj2.t(), adj2.t()), out2) # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2], atol=1e-6) assert torch.allclose(conv1((x, x[:2]), adj1.t(), adj1.t()), out1[:2], atol=1e-6) assert torch.allclose(conv2((out1, out1[:2]), edge_index, edge_index), out2[:2], atol=1e-6) assert torch.allclose(conv2((out1, out1[:2]), adj1.t(), adj1.t()), out2[:2], atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv1((x, x[:2]), adj2.t(), adj2.t()), out1[:2], atol=1e-6) assert torch.allclose(conv2((out1, out1[:2]), adj2.t(), adj2.t()), out2[:2], atol=1e-6) if is_full_test(): assert torch.allclose(jit1((x, x[:2]), edge_index, edge_index), out1[:2], atol=1e-6) assert torch.allclose(jit2((out1, out1[:2]), edge_index, edge_index), out2[:2], atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit1((x, x[:2]), adj2.t(), adj2.t()), out1[:2], atol=1e-6) assert torch.allclose(jit2((out1, out1[:2]), adj2.t(), adj2.t()), out2[:2], atol=1e-6) ================================================ FILE: test/nn/conv/test_simple_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import SimpleConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('aggr, combine_root', [ ('mean', None), ('sum', 'sum'), (['mean', 'max'], 'cat'), ('mean', 'self_loop'), ]) def test_simple_conv(aggr, combine_root): x1 = torch.randn(4, 8) x2 = torch.randn(2, 8) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = SimpleConv(aggr, combine_root) assert str(conv) == 'SimpleConv()' num_aggrs = 1 if isinstance(aggr, str) else len(aggr) output_size = sum([8] * num_aggrs) + (8 if combine_root == 'cat' else 0) out = conv(x1, edge_index) assert out.size() == (4, output_size) assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) assert torch.allclose(conv(x1, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index), out) assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out) # Test bipartite message passing: if combine_root != 'self_loop': adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) out = conv((x1, x2), edge_index) assert out.size() == (2, output_size) assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out) assert torch.allclose(conv((x1, x2), adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out) ================================================ FILE: test/nn/conv/test_spline_conv.py ================================================ import warnings import torch import torch_geometric.typing from torch_geometric.nn import SplineConv from torch_geometric.testing import is_full_test, withPackage from torch_geometric.typing import SparseTensor @withPackage('pyg_lib') def test_spline_conv(): warnings.filterwarnings('ignore', '.*non-optimized CPU version.*') x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.rand(edge_index.size(1), 3) conv = SplineConv(8, 32, dim=3, kernel_size=5) assert str(conv) == 'SplineConv(8, 32, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x1, edge_index, value), out, atol=1e-6) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) # Test bipartite message passing: conv = SplineConv((8, 16), 32, dim=3, kernel_size=5) assert str(conv) == 'SplineConv((8, 16), 32, dim=3)' out1 = conv((x1, x2), edge_index, value) assert out1.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out2.size() == (2, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), out1, atol=1e-6) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6) @withPackage('pyg_lib') def test_lazy_spline_conv(): warnings.filterwarnings('ignore', '.*non-optimized CPU version.*') x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) value = torch.rand(edge_index.size(1), 3) conv = SplineConv(-1, 32, dim=3, kernel_size=5) assert str(conv) == 'SplineConv(-1, 32, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 32) conv = SplineConv((-1, -1), 32, dim=3, kernel_size=5) assert str(conv) == 'SplineConv((-1, -1), 32, dim=3)' out = conv((x1, x2), edge_index, value) assert out.size() == (2, 32) ================================================ FILE: test/nn/conv/test_ssg_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import SSGConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_ssg_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = SSGConv(16, 32, alpha=0.1, K=10) assert str(conv) == 'SSGConv(16, 32, K=10, alpha=0.1)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv._cached_h is not None assert torch.allclose(conv(x, edge_index), out1, atol=1e-6) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) ================================================ FILE: test/nn/conv/test_static_graph.py ================================================ import torch from torch_geometric.data import Batch, Data from torch_geometric.nn import ChebConv, GCNConv, MessagePassing class MyConv(MessagePassing): def forward(self, x, edge_index): return self.propagate(edge_index, x=x) def test_static_graph(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) x1, x2 = torch.randn(3, 8), torch.randn(3, 8) data1 = Data(edge_index=edge_index, x=x1) data2 = Data(edge_index=edge_index, x=x2) batch = Batch.from_data_list([data1, data2]) x = torch.stack([x1, x2], dim=0) for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]: out1 = conv(batch.x, batch.edge_index) assert out1.size(0) == 6 conv.node_dim = 1 out2 = conv(x, edge_index) assert out2.size()[:2] == (2, 3) assert torch.allclose(out1, out2.view(-1, out2.size(-1))) ================================================ FILE: test/nn/conv/test_supergat_conv.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import SuperGATConv from torch_geometric.typing import SparseTensor @pytest.mark.parametrize('att_type', ['MX', 'SD']) def test_supergat_conv(att_type): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) conv = SuperGATConv(8, 32, heads=2, attention_type=att_type, neg_sample_ratio=1.0, edge_sample_ratio=1.0) assert str(conv) == f'SuperGATConv(8, 32, heads=2, type={att_type})' out = conv(x, edge_index) assert out.size() == (4, 64) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) # Negative samples are given: neg_edge_index = conv.negative_sampling(edge_index, x.size(0)) assert torch.allclose(conv(x, edge_index, neg_edge_index), out) att_loss = conv.get_attention_loss() assert isinstance(att_loss, torch.Tensor) and att_loss > 0 # Batch of graphs: x = torch.randn(8, 8) edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 1, 1, 4, 4, 5, 5]]) batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) out = conv(x, edge_index, batch=batch) assert out.size() == (8, 64) # Batch of graphs and negative samples are given: neg_edge_index = conv.negative_sampling(edge_index, x.size(0), batch) assert torch.allclose(conv(x, edge_index, neg_edge_index), out) att_loss = conv.get_attention_loss() assert isinstance(att_loss, torch.Tensor) and att_loss > 0 ================================================ FILE: test/nn/conv/test_tag_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import TAGConv from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_torch_csc_tensor def test_tag_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) value = torch.rand(edge_index.size(1)) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) conv = TAGConv(16, 32) assert str(conv) == 'TAGConv(16, 32, K=3)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) def test_static_tag_conv(): x = torch.randn(3, 4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = TAGConv(16, 32) out = conv(x, edge_index) assert out.size() == (3, 4, 32) ================================================ FILE: test/nn/conv/test_transformer_conv.py ================================================ from typing import Optional, Tuple import pytest import torch from torch import Tensor import torch_geometric.typing from torch_geometric.nn import TransformerConv from torch_geometric.testing import is_full_test from torch_geometric.typing import Adj, SparseTensor from torch_geometric.utils import to_torch_csc_tensor @pytest.mark.parametrize('edge_dim', [None, 8]) @pytest.mark.parametrize('concat', [True, False]) def test_transformer_conv(edge_dim, concat): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) out_channels = 32 heads = 2 edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_attr = torch.randn(edge_index.size(1), edge_dim) if edge_dim else None adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = TransformerConv(8, out_channels, heads, beta=True, edge_dim=edge_dim, concat=concat) assert str(conv) == f'TransformerConv(8, {out_channels}, heads={heads})' out = conv(x1, edge_index, edge_attr) assert out.size() == (4, out_channels * (heads if concat else 1)) assert torch.allclose(conv(x1, adj1.t(), edge_attr), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: Adj, edge_attr: Optional[Tensor] = None, ) -> Tensor: return self.conv(x, edge_index, edge_attr) jit = torch.jit.script(MyModule()) assert torch.allclose(jit(x1, edge_index, edge_attr), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) # Test `return_attention_weights`. result = conv(x1, edge_index, edge_attr, return_attention_weights=True) assert torch.allclose(result[0], out) assert result[1][0].size() == (2, 4) assert result[1][1].size() == (4, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 assert conv._alpha is None if torch_geometric.typing.WITH_TORCH_SPARSE: result = conv(x1, adj2.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4 assert conv._alpha is None if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor], ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: return self.conv(x, edge_index, edge_attr, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x1, edge_index, edge_attr) assert torch.allclose(result[0], out) assert result[1][0].size() == (2, 4) assert result[1][1].size() == (4, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 assert conv._alpha is None if torch_geometric.typing.WITH_TORCH_SPARSE: class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tensor, edge_index: SparseTensor, ) -> Tuple[Tensor, SparseTensor]: return self.conv(x, edge_index, return_attention_weights=True) jit = torch.jit.script(MyModule()) result = jit(x1, adj2.t()) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4 assert conv._alpha is None # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) conv = TransformerConv((8, 16), out_channels, heads=heads, beta=True, edge_dim=edge_dim, concat=concat) assert str(conv) == (f'TransformerConv((8, 16), {out_channels}, ' f'heads={heads})') out = conv((x1, x2), edge_index, edge_attr) assert out.size() == (2, out_channels * (heads if concat else 1)) assert torch.allclose(conv((x1, x2), adj1.t(), edge_attr), out, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, sparse_sizes=(4, 2)) assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6) if is_full_test(): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = conv def forward( self, x: Tuple[Tensor, Tensor], edge_index: Adj, edge_attr: Optional[Tensor] = None, ) -> Tensor: return self.conv(x, edge_index, edge_attr) jit = torch.jit.script(MyModule()) assert torch.allclose(jit((x1, x2), edge_index, edge_attr), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) ================================================ FILE: test/nn/conv/test_wl_conv.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import WLConv from torch_geometric.typing import SparseTensor from torch_geometric.utils import one_hot, to_torch_csc_tensor def test_wl_conv(): x1 = torch.tensor([1, 0, 0, 1]) x2 = one_hot(x1) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) conv = WLConv() assert str(conv) == 'WLConv()' out = conv(x1, edge_index) assert out.tolist() == [0, 1, 1, 0] assert torch.equal(conv(x2, edge_index), out) assert torch.equal(conv(x1, adj1.t()), out) assert torch.equal(conv(x2, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.equal(conv(x1, adj2.t()), out) assert torch.equal(conv(x2, adj2.t()), out) assert conv.histogram(out).tolist() == [[2, 2]] assert torch.allclose(conv.histogram(out, norm=True), torch.tensor([[0.7071, 0.7071]])) ================================================ FILE: test/nn/conv/test_wl_conv_continuous.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import WLConvContinuous from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor def test_wl_conv(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) x = torch.tensor([[-1], [0], [1]], dtype=torch.float) conv = WLConvContinuous() assert str(conv) == 'WLConvContinuous()' out = conv(x, edge_index) assert out.tolist() == [[-0.5], [0.0], [0.5]] if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(3, 3)) assert torch.allclose(conv(x, adj.t()), out) if is_full_test(): jit = torch.jit.script(conv) assert torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj.t()), out, atol=1e-6) # Test bipartite message passing: x1 = torch.randn(4, 8) x2 = torch.randn(2, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_weight = torch.randn(edge_index.size(1)) out1 = conv((x1, None), edge_index, edge_weight, size=(4, 2)) assert out1.size() == (2, 8) out2 = conv((x1, x2), edge_index, edge_weight) assert out2.size() == (2, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_weight, (4, 2)) assert torch.allclose(conv((x1, None), adj.t()), out1) assert torch.allclose(conv((x1, x2), adj.t()), out2) if is_full_test(): assert torch.allclose( jit((x1, None), edge_index, edge_weight, size=(4, 2)), out1) assert torch.allclose(jit((x1, x2), edge_index, edge_weight), out2) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, None), adj.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, x2), adj.t()), out2, atol=1e-6) ================================================ FILE: test/nn/conv/test_x_conv.py ================================================ import torch from torch_geometric.nn import XConv from torch_geometric.testing import is_full_test, withPackage @withPackage('torch_cluster') def test_x_conv(): x = torch.randn(8, 16) pos = torch.rand(8, 3) batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) conv = XConv(16, 32, dim=3, kernel_size=2, dilation=2) assert str(conv) == 'XConv(16, 32)' torch.manual_seed(12345) out1 = conv(x, pos) assert out1.size() == (8, 32) torch.manual_seed(12345) out2 = conv(x, pos, batch) assert out2.size() == (8, 32) if is_full_test(): jit = torch.jit.script(conv) torch.manual_seed(12345) assert torch.allclose(jit(x, pos), out1, atol=1e-6) torch.manual_seed(12345) assert torch.allclose(jit(x, pos, batch), out2, atol=1e-6) ================================================ FILE: test/nn/conv/utils/test_gnn_cheatsheet.py ================================================ from torch_geometric.nn.conv import utils def test_gnn_cheatsheet(): assert utils.paper_title('GCNConv') == ('Semi-supervised Classification ' 'with Graph Convolutional ' 'Networks') assert utils.paper_link('GCNConv') == 'https://arxiv.org/abs/1609.02907' assert utils.supports_sparse_tensor('GCNConv') assert not utils.supports_sparse_tensor('ChebConv') assert utils.supports_edge_weights('GraphConv') assert not utils.supports_edge_weights('SAGEConv') assert utils.supports_edge_features('GATConv') assert not utils.supports_edge_features('SimpleConv') assert utils.supports_bipartite_graphs('SAGEConv') assert not utils.supports_bipartite_graphs('GCNConv') assert utils.supports_static_graphs('GCNConv') assert not utils.supports_static_graphs('GATConv') assert utils.supports_lazy_initialization('SAGEConv') assert not utils.supports_lazy_initialization('GatedGraphConv') assert utils.processes_heterogeneous_graphs('RGCNConv') assert utils.processes_heterogeneous_graphs('HeteroConv') assert not utils.processes_heterogeneous_graphs('GCNConv') assert utils.processes_hypergraphs('HypergraphConv') assert not utils.processes_hypergraphs('SAGEConv') assert utils.processes_point_clouds('DynamicEdgeConv') assert utils.processes_point_clouds('XConv') assert not utils.processes_point_clouds('CuGraphSAGEConv') ================================================ FILE: test/nn/dense/test_dense_gat_conv.py ================================================ import pytest import torch from torch_geometric.nn import DenseGATConv, GATConv from torch_geometric.testing import is_full_test @pytest.mark.parametrize('heads', [1, 4]) @pytest.mark.parametrize('concat', [True, False]) def test_dense_gat_conv(heads, concat): channels = 16 sparse_conv = GATConv(channels, channels, heads=heads, concat=concat) dense_conv = DenseGATConv(channels, channels, heads=heads, concat=concat) assert str(dense_conv) == f'DenseGATConv(16, 16, heads={heads})' # Ensure same weights and bias: dense_conv.lin = sparse_conv.lin dense_conv.att_src = sparse_conv.att_src dense_conv.att_dst = sparse_conv.att_dst dense_conv.bias = sparse_conv.bias x = torch.randn((5, channels)) edge_index = torch.tensor([[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]]) sparse_out = sparse_conv(x, edge_index) x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels) adj = torch.tensor([ [ [0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0], ], [ [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], ]) mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool) dense_out = dense_conv(x, adj, mask) if is_full_test(): jit = torch.jit.script(dense_conv) assert torch.allclose(jit(x, adj, mask), dense_out) assert dense_out[1, 2].abs().sum() == 0 dense_out = dense_out.view(6, dense_out.size(-1))[:-1] assert torch.allclose(sparse_out, dense_out, atol=1e-4) def test_dense_gat_conv_with_broadcasting(): batch_size, num_nodes, channels = 8, 3, 16 conv = DenseGATConv(channels, channels, heads=4) x = torch.randn(batch_size, num_nodes, channels) adj = torch.tensor([ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ]) assert conv(x, adj).size() == (batch_size, num_nodes, 64) mask = torch.tensor([1, 1, 1], dtype=torch.bool) assert conv(x, adj, mask).size() == (batch_size, num_nodes, 64) ================================================ FILE: test/nn/dense/test_dense_gcn_conv.py ================================================ import torch from torch_geometric.nn import DenseGCNConv, GCNConv from torch_geometric.testing import is_full_test def test_dense_gcn_conv(): channels = 16 sparse_conv = GCNConv(channels, channels) dense_conv = DenseGCNConv(channels, channels) assert str(dense_conv) == 'DenseGCNConv(16, 16)' # Ensure same weights and bias: dense_conv.lin.weight = sparse_conv.lin.weight dense_conv.bias = sparse_conv.bias x = torch.randn((5, channels)) edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4], [1, 2, 0, 2, 0, 1, 4, 3]]) sparse_out = sparse_conv(x, edge_index) assert sparse_out.size() == (5, channels) x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels) adj = torch.tensor([ [ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ], [ [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], ]) mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool) dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, channels) if is_full_test(): jit = torch.jit.script(dense_conv) assert torch.allclose(jit(x, adj, mask), dense_out) assert dense_out[1, 2].abs().sum() == 0 dense_out = dense_out.view(6, channels)[:-1] assert torch.allclose(sparse_out, dense_out, atol=1e-4) def test_dense_gcn_conv_with_broadcasting(): batch_size, num_nodes, channels = 8, 3, 16 conv = DenseGCNConv(channels, channels) x = torch.randn(batch_size, num_nodes, channels) adj = torch.tensor([ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ]) assert conv(x, adj).size() == (batch_size, num_nodes, channels) mask = torch.tensor([1, 1, 1], dtype=torch.bool) assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels) ================================================ FILE: test/nn/dense/test_dense_gin_conv.py ================================================ import torch from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.nn import DenseGINConv, GINConv from torch_geometric.testing import is_full_test def test_dense_gin_conv(): channels = 16 nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels)) sparse_conv = GINConv(nn) dense_conv = DenseGINConv(nn) dense_conv = DenseGINConv(nn, train_eps=True) assert str(dense_conv) == ( 'DenseGINConv(nn=Sequential(\n' ' (0): Linear(in_features=16, out_features=16, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=16, out_features=16, bias=True)\n' '))') x = torch.randn((5, channels)) edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4], [1, 2, 0, 2, 0, 1, 4, 3]]) sparse_out = sparse_conv(x, edge_index) assert sparse_out.size() == (5, channels) x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels) adj = torch.tensor([ [ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ], [ [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], ]) mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool) dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, channels) if is_full_test(): jit = torch.jit.script(dense_conv) assert torch.allclose(jit(x, adj, mask), dense_out) assert dense_out[1, 2].abs().sum().item() == 0 dense_out = dense_out.view(6, channels)[:-1] assert torch.allclose(sparse_out, dense_out, atol=1e-04) def test_dense_gin_conv_with_broadcasting(): batch_size, num_nodes, channels = 8, 3, 16 nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels)) conv = DenseGINConv(nn) x = torch.randn(batch_size, num_nodes, channels) adj = torch.tensor([ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ]) assert conv(x, adj).size() == (batch_size, num_nodes, channels) mask = torch.tensor([1, 1, 1], dtype=torch.bool) assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels) ================================================ FILE: test/nn/dense/test_dense_graph_conv.py ================================================ import pytest import torch from torch_geometric.nn import DenseGraphConv, GraphConv from torch_geometric.testing import is_full_test from torch_geometric.utils import to_dense_adj @pytest.mark.parametrize('aggr', ['add', 'mean', 'max']) def test_dense_graph_conv(aggr): channels = 16 sparse_conv = GraphConv(channels, channels, aggr=aggr) dense_conv = DenseGraphConv(channels, channels, aggr=aggr) assert str(dense_conv) == 'DenseGraphConv(16, 16)' # Ensure same weights and bias. dense_conv.lin_rel = sparse_conv.lin_rel dense_conv.lin_root = sparse_conv.lin_root x = torch.randn((5, channels)) edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4], [1, 2, 0, 2, 0, 1, 4, 3]]) sparse_out = sparse_conv(x, edge_index) assert sparse_out.size() == (5, channels) adj = to_dense_adj(edge_index) mask = torch.ones(5, dtype=torch.bool) dense_out = dense_conv(x, adj, mask)[0] assert dense_out.size() == (5, channels) assert torch.allclose(sparse_out, dense_out, atol=1e-04) if is_full_test(): jit = torch.jit.script(dense_conv) assert torch.allclose(jit(x, adj, mask), dense_out) @pytest.mark.parametrize('aggr', ['add', 'mean', 'max']) def test_dense_graph_conv_batch(aggr): channels = 16 sparse_conv = GraphConv(channels, channels, aggr=aggr) dense_conv = DenseGraphConv(channels, channels, aggr=aggr) # Ensure same weights and bias. dense_conv.lin_rel = sparse_conv.lin_rel dense_conv.lin_root = sparse_conv.lin_root x = torch.randn((5, channels)) edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4], [1, 2, 0, 2, 0, 1, 4, 3]]) sparse_out = sparse_conv(x, edge_index) assert sparse_out.size() == (5, channels) x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels) adj = torch.tensor([ [ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ], [ [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], ]) mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool) dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, channels) dense_out = dense_out.view(-1, channels) assert torch.allclose(sparse_out, dense_out[:5], atol=1e-04) assert dense_out[-1].abs().sum() == 0 @pytest.mark.parametrize('aggr', ['add', 'mean', 'max']) def test_dense_graph_conv_with_broadcasting(aggr): batch_size, num_nodes, channels = 8, 3, 16 conv = DenseGraphConv(channels, channels, aggr=aggr) x = torch.randn(batch_size, num_nodes, channels) adj = torch.tensor([ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ]) assert conv(x, adj).size() == (batch_size, num_nodes, channels) mask = torch.tensor([1, 1, 1], dtype=torch.bool) assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels) ================================================ FILE: test/nn/dense/test_dense_sage_conv.py ================================================ import torch from torch_geometric.nn import DenseSAGEConv, SAGEConv from torch_geometric.testing import is_full_test def test_dense_sage_conv(): channels = 16 sparse_conv = SAGEConv(channels, channels, normalize=True) dense_conv = DenseSAGEConv(channels, channels, normalize=True) assert str(dense_conv) == 'DenseSAGEConv(16, 16)' # Ensure same weights and bias. dense_conv.lin_rel = sparse_conv.lin_l dense_conv.lin_root = sparse_conv.lin_r x = torch.randn((5, channels)) edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4], [1, 2, 0, 2, 0, 1, 4, 3]]) sparse_out = sparse_conv(x, edge_index) assert sparse_out.size() == (5, channels) x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels) adj = torch.tensor([ [ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ], [ [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], ]) mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool) dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, channels) if is_full_test(): jit = torch.jit.script(dense_conv) assert torch.allclose(jit(x, adj, mask), dense_out) assert dense_out[1, 2].abs().sum().item() == 0 dense_out = dense_out.view(6, channels)[:-1] assert torch.allclose(sparse_out, dense_out, atol=1e-04) def test_dense_sage_conv_with_broadcasting(): batch_size, num_nodes, channels = 8, 3, 16 conv = DenseSAGEConv(channels, channels) x = torch.randn(batch_size, num_nodes, channels) adj = torch.tensor([ [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], ]) assert conv(x, adj).size() == (batch_size, num_nodes, channels) mask = torch.tensor([1, 1, 1], dtype=torch.bool) assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels) ================================================ FILE: test/nn/dense/test_diff_pool.py ================================================ from itertools import product import torch from torch_geometric.nn import dense_diff_pool from torch_geometric.profile import benchmark from torch_geometric.testing import is_full_test def test_dense_diff_pool(): batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) x = torch.randn((batch_size, num_nodes, channels)) adj = torch.rand((batch_size, num_nodes, num_nodes)) s = torch.randn((batch_size, num_nodes, num_clusters)) mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool) x_out, adj_out, link_loss, ent_loss = dense_diff_pool(x, adj, s, mask) assert x_out.size() == (2, 10, 16) assert adj_out.size() == (2, 10, 10) assert link_loss.item() >= 0 assert ent_loss.item() >= 0 if is_full_test(): jit = torch.jit.script(dense_diff_pool) x_jit, adj_jit, link_loss, ent_loss = jit(x, adj, s, mask) assert torch.allclose(x_jit, x_out) assert torch.allclose(adj_jit, adj_out) assert link_loss.item() >= 0 assert ent_loss.item() >= 0 if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() BS = [2**i for i in range(4, 8)] NS = [2**i for i in range(4, 8)] FS = [2**i for i in range(5, 9)] CS = [2**i for i in range(5, 9)] funcs = [] func_names = [] args_list = [] for B, N, F, C in product(BS, NS, FS, CS): x = torch.randn(B, N, F, device=args.device) adj = torch.randint(0, 2, (B, N, N), dtype=x.dtype, device=args.device) s = torch.randn(B, N, C, device=args.device) funcs.append(dense_diff_pool) func_names.append(f'B={B}, N={N}, F={F}, C={C}') args_list.append((x, adj, s)) benchmark( funcs=funcs, func_names=func_names, args=args_list, num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, progress_bar=True, ) ================================================ FILE: test/nn/dense/test_dmon_pool.py ================================================ import math import torch from torch_geometric.nn import DMoNPooling def test_dmon_pooling(): batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) x = torch.randn((batch_size, num_nodes, channels)) adj = torch.ones((batch_size, num_nodes, num_nodes)) mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool) pool = DMoNPooling([channels, channels], num_clusters) assert str(pool) == 'DMoNPooling(16, num_clusters=10)' s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask) assert s.size() == (2, 20, 10) assert x.size() == (2, 10, 16) assert adj.size() == (2, 10, 10) assert -1 <= spectral_loss <= 0.5 assert 0 <= ortho_loss <= math.sqrt(2) assert 0 <= cluster_loss <= math.sqrt(num_clusters) - 1 ================================================ FILE: test/nn/dense/test_linear.py ================================================ import copy import warnings from typing import List import pytest import torch from torch import Tensor from torch.nn import Linear as PTLinear from torch.nn.parameter import UninitializedParameter import torch_geometric.backend from torch_geometric.nn import HeteroDictLinear, HeteroLinear, Linear from torch_geometric.profile import benchmark from torch_geometric.testing import withCUDA, withDevice, withPackage from torch_geometric.typing import pyg_lib from torch_geometric.utils import cumsum weight_inits = ['glorot', 'kaiming_uniform', None] bias_inits = ['zeros', None] @withDevice @pytest.mark.parametrize('weight', weight_inits) @pytest.mark.parametrize('bias', bias_inits) def test_linear(weight, bias, device): x = torch.randn(3, 4, 16, device=device) lin = Linear(16, 32, weight_initializer=weight, bias_initializer=bias) lin = lin.to(device) assert str(lin) == 'Linear(16, 32, bias=True)' assert lin(x).size() == (3, 4, 32) @withDevice @pytest.mark.parametrize('weight', weight_inits) @pytest.mark.parametrize('bias', bias_inits) def test_lazy_linear(weight, bias, device): x = torch.randn(3, 4, 16, device=device) lin = Linear(-1, 32, weight_initializer=weight, bias_initializer=bias) lin = lin.to(device) copied_lin = copy.deepcopy(lin) assert lin.weight.device == device assert lin.bias.device == device assert str(lin) == 'Linear(-1, 32, bias=True)' assert lin(x).size() == (3, 4, 32) assert str(lin) == 'Linear(16, 32, bias=True)' assert copied_lin.weight.device == device assert copied_lin.bias.device == device assert copied_lin(x).size() == (3, 4, 32) @withDevice @pytest.mark.parametrize('dim1', [-1, 16]) @pytest.mark.parametrize('dim2', [-1, 16]) @pytest.mark.parametrize('bias', [True, False]) def test_load_lazy_linear(dim1, dim2, bias, device): lin1 = Linear(dim1, 32, bias=bias).to(device) lin2 = Linear(dim2, 32, bias=bias).to(device) lin2.load_state_dict(lin1.state_dict()) if dim1 != -1: assert isinstance(lin1.weight, torch.nn.Parameter) assert isinstance(lin2.weight, torch.nn.Parameter) assert torch.allclose(lin1.weight, lin2.weight) assert not hasattr(lin1, '_hook') assert not hasattr(lin2, '_hook') else: assert isinstance(lin1.weight, UninitializedParameter) assert isinstance(lin2.weight, UninitializedParameter) assert hasattr(lin1, '_hook') assert hasattr(lin2, '_hook') if bias: assert isinstance(lin1.bias, torch.nn.Parameter) assert isinstance(lin2.bias, torch.nn.Parameter) if dim1 != -1: # Only check for equality on materialized bias: assert torch.allclose(lin1.bias, lin2.bias) else: assert lin1.bias is None assert lin2.bias is None with pytest.raises(RuntimeError, match="in state_dict"): lin1.load_state_dict({}, strict=True) lin1.load_state_dict({}, strict=False) @pytest.mark.parametrize('lazy', [True, False]) def test_identical_linear_default_initialization(lazy): x = torch.randn(3, 4, 16) torch.manual_seed(12345) lin1 = Linear(-1 if lazy else 16, 32) lin1(x) torch.manual_seed(12345) lin2 = PTLinear(16, 32) assert torch.equal(lin1.weight, lin2.weight) assert torch.equal(lin1.bias, lin2.bias) assert torch.allclose(lin1(x), lin2(x)) def test_copy_unintialized_parameter(): weight = UninitializedParameter() copy.deepcopy(weight) @withDevice @pytest.mark.parametrize('lazy', [True, False]) def test_copy_linear(lazy, device): lin = Linear(-1 if lazy else 16, 32).to(device) copied_lin = copy.copy(lin).to(device) assert id(copied_lin) != id(lin) assert id(copied_lin.weight) == id(lin.weight) if not isinstance(copied_lin.weight, UninitializedParameter): assert copied_lin.weight.data_ptr() == lin.weight.data_ptr() assert id(copied_lin.bias) == id(lin.bias) assert copied_lin.bias.data_ptr() == lin.bias.data_ptr() copied_lin = copy.deepcopy(lin).to(device) assert id(copied_lin) != id(lin) assert id(copied_lin.weight) != id(lin.weight) if not isinstance(copied_lin.weight, UninitializedParameter): assert copied_lin.weight.data_ptr() != lin.weight.data_ptr() assert torch.allclose(copied_lin.weight, lin.weight) assert id(copied_lin.bias) != id(lin.bias) assert copied_lin.bias.data_ptr() != lin.bias.data_ptr() if int(torch.isnan(lin.bias).sum()) == 0: assert torch.allclose(copied_lin.bias, lin.bias) @withCUDA def test_hetero_linear_basic(device): x = torch.randn(3, 16, device=device) type_vec = torch.tensor([0, 1, 2], device=device) lin = HeteroLinear(16, 32, num_types=3).to(device) assert str(lin) == 'HeteroLinear(16, 32, num_types=3, bias=True)' out = lin(x, type_vec) assert out.size() == (3, 32) jit = torch.jit.script(lin) assert torch.allclose(jit(x, type_vec), out, atol=1e-3) def test_hetero_linear_initializer(): lin = HeteroLinear( 16, 32, num_types=3, weight_initializer='glorot', bias_initializer='zeros', ) assert torch.equal(lin.bias, torch.zeros_like(lin.bias)) @withCUDA @pytest.mark.parametrize('use_segment_matmul', [None, True, False]) def test_hetero_linear_amp(device, use_segment_matmul): warnings.filterwarnings('ignore', '.*but CUDA is not available.*') old_state = torch_geometric.backend.use_segment_matmul torch_geometric.backend.use_segment_matmul = use_segment_matmul x = torch.randn(3, 16, device=device) type_vec = torch.tensor([0, 1, 2], device=device) lin = HeteroLinear(16, 32, num_types=3).to(device) with torch.amp.autocast('cuda'): assert lin(x, type_vec).size() == (3, 32) torch_geometric.backend.use_segment_matmul = old_state @withCUDA def test_lazy_hetero_linear(device): x = torch.randn(3, 16, device=device) type_vec = torch.tensor([0, 1, 2], device=device) lin = HeteroLinear(-1, 32, num_types=3).to(device) assert str(lin) == 'HeteroLinear(-1, 32, num_types=3, bias=True)' out = lin(x, type_vec) assert out.size() == (3, 32) @withDevice @pytest.mark.parametrize('bias', [True, False]) def test_hetero_dict_linear(bias, device): x_dict = { 'v': torch.randn(3, 16, device=device), 'w': torch.randn(2, 8, device=device), } lin = HeteroDictLinear({'v': 16, 'w': 8}, 32, bias=bias).to(device) assert str(lin) == (f"HeteroDictLinear({{'v': 16, 'w': 8}}, 32, " f"bias={bias})") out_dict = lin(x_dict) assert len(out_dict) == 2 assert out_dict['v'].size() == (3, 32) assert out_dict['w'].size() == (2, 32) x_dict = { 'v': torch.randn(3, 16, device=device), 'w': torch.randn(2, 16, device=device), } lin = HeteroDictLinear(16, 32, types=['v', 'w'], bias=bias).to(device) assert str(lin) == (f"HeteroDictLinear({{'v': 16, 'w': 16}}, 32, " f"bias={bias})") out_dict = lin(x_dict) assert len(out_dict) == 2 assert out_dict['v'].size() == (3, 32) assert out_dict['w'].size() == (2, 32) def test_hetero_dict_linear_jit(): x_dict = { 'v': torch.randn(3, 16), 'w': torch.randn(2, 8), } lin = HeteroDictLinear({'v': 16, 'w': 8}, 32) jit = torch.jit.script(lin) assert len(jit(x_dict)) == 2 @withDevice def test_lazy_hetero_dict_linear(device): x_dict = { 'v': torch.randn(3, 16, device=device), 'w': torch.randn(2, 8, device=device), } lin = HeteroDictLinear(-1, 32, types=['v', 'w']).to(device) assert str(lin) == "HeteroDictLinear({'v': -1, 'w': -1}, 32, bias=True)" out_dict = lin(x_dict) assert len(out_dict) == 2 assert out_dict['v'].size() == (3, 32) assert out_dict['w'].size() == (2, 32) @withCUDA @withPackage('pyg_lib') @pytest.mark.parametrize('type_vec', [ torch.tensor([0, 0, 1, 1, 2, 2]), torch.tensor([0, 1, 2, 0, 1, 2]), ]) def test_hetero_linear_sort(type_vec, device): x = torch.randn(type_vec.numel(), 16, device=device) lin = HeteroLinear(16, 32, num_types=3).to(device) out = lin(x, type_vec) for i in range(type_vec.numel()): node_type = int(type_vec[i]) expected = x[i] @ lin.weight[node_type] + lin.bias[node_type] assert torch.allclose(out[i], expected, atol=1e-3) if __name__ == '__main__': import argparse try: import dgl WITH_DLG = True except Exception: WITH_DGL = False warnings.filterwarnings('ignore', '.*API of nested tensors.*') warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*') parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() torch.manual_seed(12345) def get_xs(mean: float, std: float, num_types: int, channels: int) -> List[Tensor]: num_nodes_list = torch.normal( mean=torch.tensor([mean] * num_types, dtype=torch.float), std=torch.tensor([std] * num_types, dtype=torch.float), ).round().to(torch.long).tolist() return [ torch.randn(num_nodes, channels, device=args.device) for num_nodes in num_nodes_list ] def sequential(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]: return [x @ weight for x, weight in zip(xs, weights)] def nested(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]: x = torch.nested.nested_tensor(xs) weight = torch.nested.nested_tensor(weights) return list(torch.matmul(x, weight).unbind(0)) def grouped(x: Tensor, ptr: Tensor, weight: Tensor) -> Tensor: return pyg_lib.ops.segment_matmul(x, ptr, weight) def padded(x: Tensor, weight: Tensor) -> Tensor: return torch.matmul(x, weight) def dgl_mm(x: Tensor, count: Tensor, weight: Tensor) -> Tensor: return dgl.ops.segment_mm(x, weight, count) num_nodes, channels = 1_000_000, 64 for num_types in [3, 5, 10, 50, 100, 200, 500, 1000]: print(f'Number of types: {num_types}') mean = num_nodes // num_types std = mean // 4 xs = get_xs(mean, std, num_types, channels) count = torch.tensor([x.size(0) for x in xs]) ptr = cumsum(torch.tensor([x.size(0) for x in xs])) x = torch.cat(xs, dim=0) padded_x = torch.nested.nested_tensor(xs).to_padded_tensor(padding=0.0) weight = torch.randn(num_types, channels, channels, device=args.device) weights = list(weight.unbind(0)) funcs = [sequential, grouped, padded] func_names = ['Sequential', 'Grouped', 'Padded'] args_list = [(xs, weights), (x, ptr, weight), (padded_x, weight)] if WITH_DGL: funcs.append(dgl_mm) func_names.append('DGL') args_list.append((x, count, weight)) benchmark( funcs=funcs, func_names=func_names, args=args_list, num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) ================================================ FILE: test/nn/dense/test_mincut_pool.py ================================================ import math import torch from torch_geometric.nn import dense_mincut_pool from torch_geometric.testing import is_full_test def test_dense_mincut_pool(): batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) x = torch.randn((batch_size, num_nodes, channels)) adj = torch.ones((batch_size, num_nodes, num_nodes)) s = torch.randn((batch_size, num_nodes, num_clusters)) mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool) x_out, adj_out, mincut_loss, ortho_loss = dense_mincut_pool( x, adj, s, mask) assert x_out.size() == (2, 10, 16) assert adj_out.size() == (2, 10, 10) assert -1 <= mincut_loss <= 0 assert 0 <= ortho_loss <= 2 if is_full_test(): jit = torch.jit.script(dense_mincut_pool) x_jit, adj_jit, mincut_loss, ortho_loss = jit(x, adj, s, mask) assert x_jit.size() == (2, 10, 16) assert adj_jit.size() == (2, 10, 10) assert -1 <= mincut_loss <= 0 assert 0 <= ortho_loss <= math.sqrt(2) ================================================ FILE: test/nn/functional/test_bro.py ================================================ import torch from torch_geometric.nn.functional import bro def test_bro(): batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2]) g1 = torch.tensor([ [0.2, 0.2, 0.2, 0.2], [0.0, 0.2, 0.2, 0.2], [0.2, 0.0, 0.2, 0.2], [0.2, 0.2, 0.0, 0.2], ]) g2 = torch.tensor([ [0.2, 0.2, 0.2, 0.2], [0.0, 0.2, 0.2, 0.2], [0.2, 0.0, 0.2, 0.2], ]) g3 = torch.tensor([ [0.2, 0.2, 0.2, 0.2], [0.2, 0.0, 0.2, 0.2], ]) s = 0. for g in [g1, g2, g3]: s += torch.norm(g @ g.t() - torch.eye(g.shape[0]), p=2) assert torch.isclose(s / 3., bro(torch.cat([g1, g2, g3], dim=0), batch)) ================================================ FILE: test/nn/functional/test_gini.py ================================================ import torch from torch_geometric.nn.functional import gini def test_gini(): w = torch.tensor([[0., 0., 0., 0.], [0., 0., 0., 1000.0]]) assert torch.isclose(gini(w), torch.tensor(0.5)) ================================================ FILE: test/nn/kge/test_complex.py ================================================ import torch from torch_geometric.nn import ComplEx def test_complex_scoring(): model = ComplEx(num_nodes=5, num_relations=2, hidden_channels=1) model.node_emb.weight.data = torch.tensor([ [2.], [3.], [5.], [1.], [2.], ]) model.node_emb_im.weight.data = torch.tensor([ [4.], [1.], [3.], [1.], [2.], ]) model.rel_emb.weight.data = torch.tensor([ [2.], [3.], ]) model.rel_emb_im.weight.data = torch.tensor([ [3.], [1.], ]) score = model( head_index=torch.tensor([1, 3]), rel_type=torch.tensor([1, 0]), tail_index=torch.tensor([2, 4]), ) assert score.tolist() == [58., 8.] def test_complex(): model = ComplEx(num_nodes=10, num_relations=5, hidden_channels=32) assert str(model) == 'ComplEx(10, num_relations=5, hidden_channels=32)' head_index = torch.tensor([0, 2, 4, 6, 8]) rel_type = torch.tensor([0, 1, 2, 3, 4]) tail_index = torch.tensor([1, 3, 5, 7, 9]) loader = model.loader(head_index, rel_type, tail_index, batch_size=5) for h, r, t in loader: out = model(h, r, t) assert out.size() == (5, ) loss = model.loss(h, r, t) assert loss >= 0. mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) assert 0 <= mean_rank <= 10 assert 0 < mrr <= 1 assert hits == 1.0 ================================================ FILE: test/nn/kge/test_distmult.py ================================================ import torch from torch_geometric.nn import DistMult def test_distmult(): model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32) assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)' head_index = torch.tensor([0, 2, 4, 6, 8]) rel_type = torch.tensor([0, 1, 2, 3, 4]) tail_index = torch.tensor([1, 3, 5, 7, 9]) loader = model.loader(head_index, rel_type, tail_index, batch_size=5) for h, r, t in loader: out = model(h, r, t) assert out.size() == (5, ) loss = model.loss(h, r, t) assert loss >= 0. mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) assert 0 <= mean_rank <= 10 assert 0 < mrr <= 1 assert hits == 1.0 ================================================ FILE: test/nn/kge/test_rotate.py ================================================ import torch from torch_geometric.nn import RotatE def test_rotate(): model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32) assert str(model) == 'RotatE(10, num_relations=5, hidden_channels=32)' head_index = torch.tensor([0, 2, 4, 6, 8]) rel_type = torch.tensor([0, 1, 2, 3, 4]) tail_index = torch.tensor([1, 3, 5, 7, 9]) loader = model.loader(head_index, rel_type, tail_index, batch_size=5) for h, r, t in loader: out = model(h, r, t) assert out.size() == (5, ) loss = model.loss(h, r, t) assert loss >= 0. mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) assert 0 <= mean_rank <= 10 assert 0 < mrr <= 1 assert hits == 1.0 ================================================ FILE: test/nn/kge/test_transe.py ================================================ import torch from torch_geometric.nn import TransE def test_transe(): model = TransE(num_nodes=10, num_relations=5, hidden_channels=32) assert str(model) == 'TransE(10, num_relations=5, hidden_channels=32)' head_index = torch.tensor([0, 2, 4, 6, 8]) rel_type = torch.tensor([0, 1, 2, 3, 4]) tail_index = torch.tensor([1, 3, 5, 7, 9]) loader = model.loader(head_index, rel_type, tail_index, batch_size=5) for h, r, t in loader: out = model(h, r, t) assert out.size() == (5, ) loss = model.loss(h, r, t) assert loss >= 0. mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) assert 0 <= mean_rank <= 10 assert 0 < mrr <= 1 assert hits == 1.0 ================================================ FILE: test/nn/models/test_attentive_fp.py ================================================ import torch from torch_geometric.nn import AttentiveFP from torch_geometric.testing import is_full_test def test_attentive_fp(): model = AttentiveFP(8, 16, 32, edge_dim=3, num_layers=2, num_timesteps=2) assert str(model) == ('AttentiveFP(in_channels=8, hidden_channels=16, ' 'out_channels=32, edge_dim=3, num_layers=2, ' 'num_timesteps=2)') x = torch.randn(4, 8) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.randn(edge_index.size(1), 3) batch = torch.tensor([0, 0, 0, 0]) out = model(x, edge_index, edge_attr, batch) assert out.size() == (1, 32) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(x, edge_index, edge_attr, batch), out) ================================================ FILE: test/nn/models/test_attract_repel.py ================================================ import torch from torch_geometric.nn.models import ARLinkPredictor def test_ar_link_predictor(): model = ARLinkPredictor(in_channels=16, hidden_channels=32, num_layers=2) x = torch.randn(4, 16) # 4 nodes with 16 features each edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) # 3 edges # Test forward pass pred = model(x, edge_index) assert pred.size(0) == edge_index.size(1) assert torch.all(pred >= 0) and torch.all(pred <= 1) # Test encode function attract_z, repel_z = model.encode(x) assert attract_z.size() == ( 4, 16) # Default attract_ratio=0.5, so half of hidden_channels assert repel_z.size() == (4, 16) # Test decode function raw_scores = model.decode(attract_z, repel_z, edge_index) assert raw_scores.size(0) == edge_index.size(1) # Test R-fraction calculation r_fraction = model.calculate_r_fraction(attract_z, repel_z) assert 0 <= r_fraction <= 1 def test_ar_link_predictor_with_custom_ratio(): # Test with custom attract_ratio model = ARLinkPredictor(in_channels=8, hidden_channels=20, attract_ratio=0.7) x = torch.randn(5, 8) # Check dimensions attract_z, repel_z = model.encode(x) assert attract_z.size() == (5, 14) # 70% of 20 = 14 assert repel_z.size() == (5, 6) # 30% of 20 = 6 ================================================ FILE: test/nn/models/test_autoencoder.py ================================================ import torch from torch import Tensor as T from torch_geometric.data import Data from torch_geometric.nn import ARGA, ARGVA, GAE, VGAE from torch_geometric.testing import has_package, is_full_test from torch_geometric.transforms import RandomLinkSplit def test_gae(): model = GAE(encoder=lambda x: x) model.reset_parameters() x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) z = model.encode(x) assert torch.allclose(z, x) adj = model.decoder.forward_all(z) expected = torch.tensor([ [2.0, -1.0, 1.0], [-1.0, 5.0, 4.0], [1.0, 4.0, 5.0], ]).sigmoid() assert torch.allclose(adj, expected) edge_index = torch.tensor([[0, 1], [1, 2]]) value = model.decode(z, edge_index) assert torch.allclose(value, torch.tensor([-1.0, 4.0]).sigmoid()) if is_full_test(): jit = torch.jit.export(model) assert torch.allclose(jit.encode(x), z) assert torch.allclose(jit.decode(z, edge_index), value) edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) data = Data(edge_index=edge_index, num_nodes=11) transform = RandomLinkSplit(split_labels=True, add_negative_train_samples=False) train_data, val_data, test_data = transform(data) z = torch.randn(11, 16) loss = model.recon_loss(z, train_data.pos_edge_label_index) assert float(loss) > 0 if has_package('sklearn'): auc, ap = model.test(z, val_data.pos_edge_label_index, val_data.neg_edge_label_index) assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1 def test_vgae(): model = VGAE(encoder=lambda x: (x, x)) x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) model.encode(x) assert float(model.kl_loss()) > 0 model.eval() model.encode(x) if is_full_test(): jit = torch.jit.export(model) jit.encode(x) assert float(jit.kl_loss()) > 0 def test_arga(): model = ARGA(encoder=lambda x: x, discriminator=lambda x: T([0.5])) model.reset_parameters() x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) z = model.encode(x) assert float(model.reg_loss(z)) > 0 assert float(model.discriminator_loss(z)) > 0 if is_full_test(): jit = torch.jit.export(model) assert torch.allclose(jit.encode(x), z) assert float(jit.reg_loss(z)) > 0 assert float(jit.discriminator_loss(z)) > 0 def test_argva(): model = ARGVA(encoder=lambda x: (x, x), discriminator=lambda x: T([0.5])) x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) model.encode(x) model.reparametrize(model.__mu__, model.__logstd__) assert float(model.kl_loss()) > 0 if is_full_test(): jit = torch.jit.export(model) jit.encode(x) jit.reparametrize(jit.__mu__, jit.__logstd__) assert float(jit.kl_loss()) > 0 def test_init(): encoder = torch.nn.Linear(16, 32) decoder = torch.nn.Linear(32, 16) discriminator = torch.nn.Linear(32, 1) GAE(encoder, decoder) VGAE(encoder, decoder) ARGA(encoder, discriminator, decoder) ARGVA(encoder, discriminator, decoder) ================================================ FILE: test/nn/models/test_basic_gnn.py ================================================ import os import os.path as osp import random import warnings import pytest import torch import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.loader import NeighborLoader from torch_geometric.nn import SAGEConv from torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE from torch_geometric.profile import benchmark from torch_geometric.testing import ( onlyFullTest, onlyLinux, onlyNeighborSampler, onlyOnline, withDevice, withPackage, ) out_dims = [None, 8] dropouts = [0.0, 0.5] acts = [None, 'leaky_relu', torch.relu_, F.elu, torch.nn.ReLU()] norms = [None, 'batch_norm', 'layer_norm'] jks = [None, 'last', 'cat', 'max', 'lstm'] @pytest.mark.parametrize('out_dim', out_dims) @pytest.mark.parametrize('dropout', dropouts) @pytest.mark.parametrize('act', acts) @pytest.mark.parametrize('norm', norms) @pytest.mark.parametrize('jk', jks) def test_gcn(out_dim, dropout, act, norm, jk): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) out_channels = 16 if out_dim is None else out_dim model = GCN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, act=act, norm=norm, jk=jk) assert str(model) == f'GCN(8, {out_channels}, num_layers=3)' assert model(x, edge_index).size() == (3, out_channels) @pytest.mark.parametrize('out_dim', out_dims) @pytest.mark.parametrize('dropout', dropouts) @pytest.mark.parametrize('act', acts) @pytest.mark.parametrize('norm', norms) @pytest.mark.parametrize('jk', jks) def test_graph_sage(out_dim, dropout, act, norm, jk): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) out_channels = 16 if out_dim is None else out_dim model = GraphSAGE(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, act=act, norm=norm, jk=jk) assert str(model) == f'GraphSAGE(8, {out_channels}, num_layers=3)' assert model(x, edge_index).size() == (3, out_channels) @pytest.mark.parametrize('out_dim', out_dims) @pytest.mark.parametrize('dropout', dropouts) @pytest.mark.parametrize('act', acts) @pytest.mark.parametrize('norm', norms) @pytest.mark.parametrize('jk', jks) def test_gin(out_dim, dropout, act, norm, jk): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) out_channels = 16 if out_dim is None else out_dim model = GIN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, act=act, norm=norm, jk=jk) assert str(model) == f'GIN(8, {out_channels}, num_layers=3)' assert model(x, edge_index).size() == (3, out_channels) @pytest.mark.parametrize('out_dim', out_dims) @pytest.mark.parametrize('dropout', dropouts) @pytest.mark.parametrize('act', acts) @pytest.mark.parametrize('norm', norms) @pytest.mark.parametrize('jk', jks) def test_gat(out_dim, dropout, act, norm, jk): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) out_channels = 16 if out_dim is None else out_dim for v2 in [False, True]: model = GAT(8, 16, num_layers=3, out_channels=out_dim, v2=v2, dropout=dropout, act=act, norm=norm, jk=jk) assert str(model) == f'GAT(8, {out_channels}, num_layers=3)' assert model(x, edge_index).size() == (3, out_channels) model = GAT(8, 16, num_layers=3, out_channels=out_dim, v2=v2, dropout=dropout, act=act, norm=norm, jk=jk, heads=4) assert str(model) == f'GAT(8, {out_channels}, num_layers=3)' assert model(x, edge_index).size() == (3, out_channels) @pytest.mark.parametrize('out_dim', out_dims) @pytest.mark.parametrize('dropout', dropouts) @pytest.mark.parametrize('act', acts) @pytest.mark.parametrize('norm', norms) @pytest.mark.parametrize('jk', jks) def test_pna(out_dim, dropout, act, norm, jk): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) deg = torch.tensor([0, 2, 1]) out_channels = 16 if out_dim is None else out_dim aggregators = ['mean', 'min', 'max', 'std', 'var', 'sum'] scalers = [ 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' ] model = PNA(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, act=act, norm=norm, jk=jk, aggregators=aggregators, scalers=scalers, deg=deg) assert str(model) == f'PNA(8, {out_channels}, num_layers=3)' assert model(x, edge_index).size() == (3, out_channels) @pytest.mark.parametrize('out_dim', out_dims) @pytest.mark.parametrize('dropout', dropouts) @pytest.mark.parametrize('act', acts) @pytest.mark.parametrize('norm', norms) @pytest.mark.parametrize('jk', jks) def test_edge_cnn(out_dim, dropout, act, norm, jk): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) out_channels = 16 if out_dim is None else out_dim model = EdgeCNN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, act=act, norm=norm, jk=jk) assert str(model) == f'EdgeCNN(8, {out_channels}, num_layers=3)' assert model(x, edge_index).size() == (3, out_channels) def test_jit(): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) model = GCN(8, 16, num_layers=2) model = torch.jit.script(model) assert model(x, edge_index).size() == (3, 16) @pytest.mark.parametrize('out_dim', out_dims) @pytest.mark.parametrize('jk', jks) def test_one_layer_gnn(out_dim, jk): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) out_channels = 16 if out_dim is None else out_dim model = GraphSAGE(8, 16, num_layers=1, out_channels=out_dim, jk=jk) assert model(x, edge_index).size() == (3, out_channels) @pytest.mark.parametrize('norm', [ 'BatchNorm', 'GraphNorm', 'InstanceNorm', 'LayerNorm', ]) def test_batch(norm): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) batch = torch.tensor([0, 0, 1]) model = GraphSAGE(8, 16, num_layers=2, norm=norm) assert model.supports_norm_batch == (norm != 'BatchNorm') out = model(x, edge_index, batch=batch) assert out.size() == (3, 16) if model.supports_norm_batch: with pytest.raises(RuntimeError, match="out of bounds"): model(x, edge_index, batch=batch, batch_size=1) @onlyOnline @onlyNeighborSampler @pytest.mark.parametrize('jk', [None, 'last']) def test_basic_gnn_inference(get_dataset, jk): dataset = get_dataset(name='karate') data = dataset[0] model = GraphSAGE(dataset.num_features, hidden_channels=16, num_layers=2, out_channels=dataset.num_classes, jk=jk) model.eval() out1 = model(data.x, data.edge_index) assert out1.size() == (data.num_nodes, dataset.num_classes) loader = NeighborLoader(data, num_neighbors=[-1], batch_size=128) out2 = model.inference(loader) assert out1.size() == out2.size() assert torch.allclose(out1, out2, atol=1e-4) assert 'n_id' not in data @withDevice @onlyLinux @onlyFullTest @withPackage('torch>=2.0.0') def test_compile_basic(device): x = torch.randn(3, 8, device=device) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device) model = GCN(8, 16, num_layers=3).to(device) compiled_model = torch.compile(model) expected = model(x, edge_index) out = compiled_model(x, edge_index) assert torch.allclose(out, expected, atol=1e-6) def test_packaging(): warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*') os.makedirs(torch.hub._get_torch_home(), exist_ok=True) x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) model = GraphSAGE(8, 16, num_layers=3) path = osp.join(torch.hub._get_torch_home(), 'pyg_test_model.pt') torch.save(model, path) model = torch.load(path, weights_only=False) with torch.no_grad(): assert model(x, edge_index).size() == (3, 16) model = GraphSAGE(8, 16, num_layers=3) path = osp.join(torch.hub._get_torch_home(), 'pyg_test_package.pt') with torch.package.PackageExporter(path) as pe: pe.extern('torch_geometric.nn.**') pe.extern('torch_geometric.inspector') pe.extern('torch_geometric.utils._trim_to_layer') pe.extern('_operator') pe.save_pickle('models', 'model.pkl', model) pi = torch.package.PackageImporter(path) model = pi.load_pickle('models', 'model.pkl') with torch.no_grad(): assert model(x, edge_index).size() == (3, 16) @onlyLinux @withPackage('torch>=2.6.0') @withPackage('onnx', 'onnxruntime', 'onnxscript') def test_onnx(tmp_path: str) -> None: import onnx import onnxruntime as ort from torch_geometric import safe_onnx_export warnings.filterwarnings('ignore', '.*tensor to a Python boolean.*') warnings.filterwarnings('ignore', '.*shape inference of prim::Constant.*') class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv(8, 16) self.conv2 = SAGEConv(16, 16) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x model = MyModel() x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]]) expected = model(x, edge_index) assert expected.size() == (3, 16) path = osp.join(tmp_path, 'model.onnx') success = safe_onnx_export( model, (x, edge_index), path, input_names=('x', 'edge_index'), opset_version=18, dynamo=True, # False is deprecated by PyTorch skip_on_error=True, # Skip gracefully in CI if upstream issue occurs ) if not success: # ONNX export was skipped due to known upstream issue # This allows CI to pass while the upstream bug exists warnings.warn( "ONNX export test skipped due to known upstream onnx_ir issue. " "This is expected and does not indicate a problem with PyTorch " "Geometric.", UserWarning, stacklevel=2) return onnx_model = onnx.load(path) onnx.checker.check_model(onnx_model) providers = ['CPUExecutionProvider'] ort_session = ort.InferenceSession(path, providers=providers) out = ort_session.run(None, { 'x': x.numpy(), 'edge_index': edge_index.numpy() })[0] out = torch.from_numpy(out) assert torch.allclose(out, expected, atol=1e-6) @withPackage('pyg_lib') def test_trim_to_layer(): x = torch.randn(14, 16) edge_index = torch.tensor([ [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], ]) data = Data(x=x, edge_index=edge_index) loader = NeighborLoader( data, num_neighbors=[1, 2, 4], batch_size=2, shuffle=False, ) batch = next(iter(loader)) model = GraphSAGE(in_channels=16, hidden_channels=16, num_layers=3) out1 = model(batch.x, batch.edge_index)[:2] assert out1.size() == (2, 16) out2 = model( batch.x, batch.edge_index, num_sampled_nodes_per_hop=batch.num_sampled_nodes, num_sampled_edges_per_hop=batch.num_sampled_edges, )[:2] assert out2.size() == (2, 16) assert torch.allclose(out1, out2, atol=1e-6) @withDevice @onlyLinux @withPackage('torch>=2.1.0') @pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA]) def test_compile_graph_breaks(Model, device): import torch._dynamo as dynamo x = torch.randn(3, 8, device=device) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device) kwargs = {} if Model in {GCN, GAT}: # Adding self-loops inside the model leads to graph breaks :( kwargs['add_self_loops'] = False if Model in {PNA}: # `PNA` requires additional arguments: kwargs['aggregators'] = ['sum', 'mean', 'min', 'max', 'var', 'std'] kwargs['scalers'] = ['identity', 'amplification', 'attenuation'] kwargs['deg'] = torch.tensor([1, 2, 1]) model = Model( in_channels=8, hidden_channels=16, num_layers=2, **kwargs, ).to(device) explanation = dynamo.explain(model)(x, edge_index) assert explanation.graph_break_count == 0 @withPackage('pyg_lib') def test_basic_gnn_cache(): x = torch.randn(14, 16) edge_index = torch.tensor([ [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], ]) loader = NeighborLoader( Data(x=x, edge_index=edge_index), num_neighbors=[-1], batch_size=2, ) model = GCN(in_channels=16, hidden_channels=16, num_layers=2) model.eval() out1 = model.inference(loader, cache=False) out2 = model.inference(loader, cache=True) assert torch.allclose(out1, out2) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') parser.add_argument('--dynamic', action='store_true') args = parser.parse_args() if args.dynamic: min_num_nodes, max_num_nodes = 10_000, 15_000 min_num_edges, max_num_edges = 200_000, 300_000 else: min_num_nodes, max_num_nodes = 10_000, 10_000 min_num_edges, max_num_edges = 200_000, 200_000 def gen_args(): N = random.randint(min_num_nodes, max_num_nodes) E = random.randint(min_num_edges, max_num_edges) x = torch.randn(N, 64, device=args.device) edge_index = torch.randint(N, (2, E), device=args.device) return x, edge_index for Model in [GCN, GraphSAGE, GIN, EdgeCNN]: print(f'Model: {Model.__name__}') model = Model(64, 64, num_layers=3).to(args.device) compiled_model = torch.compile(model) benchmark( funcs=[model, compiled_model], func_names=['Vanilla', 'Compiled'], args=gen_args, num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) ================================================ FILE: test/nn/models/test_correct_and_smooth.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn.models import CorrectAndSmooth from torch_geometric.testing import noWindows from torch_geometric.typing import SparseTensor @noWindows def test_correct_and_smooth(): y_soft = torch.tensor([0.1, 0.5, 0.4]).repeat(6, 1) y_true = torch.tensor([1, 0, 0, 2, 1, 1]) edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]]) mask = torch.randint(0, 2, (6, ), dtype=torch.bool) model = CorrectAndSmooth( num_correction_layers=2, correction_alpha=0.5, num_smoothing_layers=2, smoothing_alpha=0.5, ) assert str(model) == ('CorrectAndSmooth(\n' ' correct: num_layers=2, alpha=0.5\n' ' smooth: num_layers=2, alpha=0.5\n' ' autoscale=True, scale=1.0\n' ')') out = model.correct(y_soft, y_true[mask], mask, edge_index) assert out.size() == (6, 3) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6)) assert torch.allclose( out, model.correct(y_soft, y_true[mask], mask, adj.t())) out = model.smooth(y_soft, y_true[mask], mask, edge_index) assert out.size() == (6, 3) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose( out, model.smooth(y_soft, y_true[mask], mask, adj.t())) # Test without autoscale: model = CorrectAndSmooth( num_correction_layers=2, correction_alpha=0.5, num_smoothing_layers=2, smoothing_alpha=0.5, autoscale=False, ) out = model.correct(y_soft, y_true[mask], mask, edge_index) assert out.size() == (6, 3) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose( out, model.correct(y_soft, y_true[mask], mask, adj.t())) ================================================ FILE: test/nn/models/test_deep_graph_infomax.py ================================================ import torch from torch_geometric.nn import GCN, DeepGraphInfomax from torch_geometric.testing import has_package, is_full_test, withDevice @withDevice def test_infomax(device): def corruption(z): return z + 1 model = DeepGraphInfomax( hidden_channels=16, encoder=lambda x: x, summary=lambda z, *args: z.mean(dim=0), corruption=lambda x: x + 1, ).to(device) assert str(model) == 'DeepGraphInfomax(16)' x = torch.ones(20, 16, device=device) pos_z, neg_z, summary = model(x) assert pos_z.size() == (20, 16) assert neg_z.size() == (20, 16) assert summary.size() == (16, ) loss = model.loss(pos_z, neg_z, summary) assert float(loss) >= 0 if is_full_test(): jit = torch.jit.export(model) pos_z, neg_z, summary = jit(x) assert pos_z.size() == (20, 16) and neg_z.size() == (20, 16) assert summary.size() == (16, ) if has_package('sklearn'): acc = model.test( train_z=torch.ones(20, 16), train_y=torch.randint(10, (20, )), test_z=torch.ones(20, 16), test_y=torch.randint(10, (20, )), ) assert 0 <= acc <= 1 @withDevice def test_infomax_predefined_model(device): def corruption(x, edge_index, edge_weight): return ( x[torch.randperm(x.size(0), device=x.device)], edge_index, edge_weight, ) model = DeepGraphInfomax( hidden_channels=16, encoder=GCN(16, 16, num_layers=2), summary=lambda z, *args, **kwargs: z.mean(dim=0).sigmoid(), corruption=corruption, ).to(device) x = torch.randn(4, 16, device=device) edge_index = torch.tensor( [[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]], device=device, ) edge_weight = torch.rand(edge_index.size(1), device=device) pos_z, neg_z, summary = model(x, edge_index, edge_weight=edge_weight) assert pos_z.size() == (4, 16) assert neg_z.size() == (4, 16) assert summary.size() == (16, ) loss = model.loss(pos_z, neg_z, summary) assert float(loss) >= 0 ================================================ FILE: test/nn/models/test_deepgcn.py ================================================ import pytest import torch from torch.nn import ReLU from torch_geometric.nn import DeepGCNLayer, GENConv, LayerNorm @pytest.mark.parametrize( 'block_tuple', [('res+', 1), ('res', 1), ('dense', 2), ('plain', 1)], ) @pytest.mark.parametrize('ckpt_grad', [True, False]) def test_deepgcn(block_tuple, ckpt_grad): block, expansion = block_tuple x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) conv = GENConv(8, 8) norm = LayerNorm(8) act = ReLU() layer = DeepGCNLayer(conv, norm, act, block=block, ckpt_grad=ckpt_grad) assert str(layer) == f'DeepGCNLayer(block={block})' out = layer(x, edge_index) assert out.size() == (3, 8 * expansion) ================================================ FILE: test/nn/models/test_dimenet.py ================================================ import pytest import torch import torch.nn.functional as F from torch_geometric.nn import DimeNet, DimeNetPlusPlus from torch_geometric.nn.models.dimenet import ( BesselBasisLayer, Envelope, ResidualLayer, ) from torch_geometric.testing import is_full_test, withPackage def test_dimenet_modules(): env = Envelope(exponent=5) x = torch.randn(10, 3) assert env(x).size() == (10, 3) # Isotonic layer. bbl = BesselBasisLayer(5) x = torch.randn(10, 3) assert bbl(x).size() == (10, 3, 5) # Non-isotonic layer. rl = ResidualLayer(128, torch.nn.functional.relu) x = torch.randn(128, 128) assert rl(x).size() == (128, 128) # Isotonic layer. @withPackage('sympy') @withPackage('torch_sparse') # TODO `triplet` requires `SparseTensor` for now. @withPackage('torch-cluster') @pytest.mark.parametrize('Model', [DimeNet, DimeNetPlusPlus]) def test_dimenet(Model): z = torch.randint(1, 10, (20, )) pos = torch.randn(20, 3) if Model == DimeNet: kwargs = dict(num_bilinear=3) else: kwargs = dict(out_emb_channels=3, int_emb_size=5, basis_emb_size=5) model = Model( hidden_channels=5, out_channels=1, num_blocks=5, num_spherical=5, num_radial=5, **kwargs, ) model.reset_parameters() with torch.no_grad(): out = model(z, pos) assert out.size() == (1, ) jit = torch.jit.export(model) assert torch.allclose(jit(z, pos), out) if is_full_test(): optimizer = torch.optim.Adam(model.parameters(), lr=0.1) min_loss = float('inf') for _ in range(100): optimizer.zero_grad() out = model(z, pos) loss = F.l1_loss(out, torch.tensor([1.0])) loss.backward() optimizer.step() min_loss = min(float(loss), min_loss) assert min_loss < 2 ================================================ FILE: test/nn/models/test_gnnff.py ================================================ import torch from torch_geometric.nn import GNNFF from torch_geometric.testing import is_full_test, withPackage @withPackage('torch_sparse') # TODO `triplet` requires `SparseTensor` for now. @withPackage('torch-cluster') def test_gnnff(): z = torch.randint(1, 10, (20, )) pos = torch.randn(20, 3) model = GNNFF( hidden_node_channels=5, hidden_edge_channels=5, num_layers=5, ) model.reset_parameters() out = model(z, pos) assert out.size() == (20, 3) if is_full_test(): jit = torch.jit.export(model) assert torch.allclose(jit(z, pos), out) ================================================ FILE: test/nn/models/test_gpse.py ================================================ import pytest import torch from torch_geometric.data import Batch, Data from torch_geometric.nn import GPSE, GPSENodeEncoder from torch_geometric.nn.models.gpse import ( IdentityHead, gpse_loss, gpse_process, process_batch_idx, ) from torch_geometric.testing import is_full_test from torch_geometric.transforms import VirtualNode def test_gpse_training(): x = torch.randn(6, 20) y = torch.randn(6, 51) edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) data = Data(x=x, y=y, edge_index=edge_index) data = VirtualNode()(data) data.y_graph = torch.randn(11) batch = Batch.from_data_list([data]) model = GPSE() with torch.no_grad(): out = model(batch) assert out[0].size() == out[1].size() assert out[0].size() == (7, 62) if is_full_test(): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) min_loss = float('inf') for _ in range(100): optimizer.zero_grad() pred, true = model(batch) batch_idx = process_batch_idx(batch.batch, true) loss, _ = gpse_loss(pred, true, batch_idx) loss.backward() optimizer.step() min_loss = min(float(loss), min_loss) assert min_loss < 2 def test_gpse_from_pretrained(): x = torch.randn(6, 4) edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) data = Data(x=x, edge_index=edge_index) data = VirtualNode()(data) model = GPSE() model.post_mp = IdentityHead() with torch.no_grad(): out = gpse_process(model, data, 'NormalSE') assert out.size() == (7, 512) @pytest.mark.parametrize('expand_x', [False, True]) def test_gpse_node_encoder(expand_x): x = torch.randn(6, 4) pestat_GPSE = torch.randn(6, 512) encoder = GPSENodeEncoder( dim_emb=128, dim_pe_in=512, dim_pe_out=64, dim_in=4, expand_x=expand_x, ) out = encoder(x, pestat_GPSE) assert out.size() == (6, 128) if expand_x else (6, 64) ================================================ FILE: test/nn/models/test_graph_mixer.py ================================================ import torch from torch_geometric.nn.models.graph_mixer import ( LinkEncoder, NodeEncoder, get_latest_k_edge_attr, ) def test_node_encoder(): x = torch.arange(4, dtype=torch.float).view(-1, 1) edge_index = torch.tensor([[1, 2, 0, 0, 1, 3], [0, 0, 1, 2, 2, 2]]) edge_time = torch.tensor([0, 1, 1, 1, 2, 3]) seed_time = torch.tensor([2, 2, 2, 2]) encoder = NodeEncoder(time_window=2) encoder.reset_parameters() assert str(encoder) == 'NodeEncoder(time_window=2)' out = encoder(x, edge_index, edge_time, seed_time) # Node 0 aggregates information from node 2 (excluding node 1). # Node 1 aggregates information from node 0. # Node 2 aggregates information from node 0 and node 1 (excluding node 3). # Node 3 aggregates no information. expected = torch.tensor([ [0 + 2], [1 + 0], [2 + 0.5 * (0 + 1)], [3], ]) assert torch.allclose(out, expected) def test_link_encoder(): num_nodes = 3 num_edges = 6 edge_attr = torch.rand((num_edges, 10)) edge_index = torch.randint(low=0, high=num_nodes, size=(2, num_edges)) edge_time = torch.rand(num_edges) seed_time = torch.ones(num_nodes) encoder = LinkEncoder( k=3, in_channels=edge_attr.size(1), hidden_channels=7, out_channels=11, time_channels=13, ) encoder.reset_parameters() assert str(encoder) == ('LinkEncoder(k=3, in_channels=10, ' 'hidden_channels=7, out_channels=11, ' 'time_channels=13, dropout=0.0)') out = encoder(edge_index, edge_attr, edge_time, seed_time) assert out.size() == (num_nodes, 11) def test_latest_k_edge_attr(): edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 0], [0, 1, 0, 1, 0, 1, 2]]) edge_time = torch.tensor([3, 1, 2, 3, 1, 2, 3]) edge_attr = torch.tensor([1, -1, 3, 4, -1, 6, 7]).view(-1, 1) k = 2 out = get_latest_k_edge_attr(k, edge_index, edge_attr, edge_time, num_nodes=3) expected = torch.tensor([[[1], [3]], [[4], [6]], [[7], [0]]]) assert out.size() == (3, 2, 1) assert torch.equal(out, expected) k = 1 out = get_latest_k_edge_attr(k, edge_index, edge_attr, edge_time, num_nodes=3) expected = torch.tensor([[[1]], [[4]], [[7]]]) assert out.size() == (3, 1, 1) assert torch.equal(out, expected) ================================================ FILE: test/nn/models/test_graph_unet.py ================================================ import torch from torch_geometric.nn import GraphUNet from torch_geometric.testing import is_full_test, onlyLinux @onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows. def test_graph_unet(): model = GraphUNet(16, 32, 8, depth=3) out = 'GraphUNet(16, 32, 8, depth=3, pool_ratios=[0.5, 0.5, 0.5])' assert str(model) == out x = torch.randn(3, 16) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) out = model(x, edge_index) assert out.size() == (3, 8) if is_full_test(): jit = torch.jit.export(model) out = jit(x, edge_index) assert out.size() == (3, 8) ================================================ FILE: test/nn/models/test_jumping_knowledge.py ================================================ import torch from torch_geometric.nn import HeteroJumpingKnowledge, JumpingKnowledge from torch_geometric.testing import is_full_test def test_jumping_knowledge(): num_nodes, channels, num_layers = 100, 17, 5 xs = list([torch.randn(num_nodes, channels) for _ in range(num_layers)]) model = JumpingKnowledge('cat') assert str(model) == 'JumpingKnowledge(cat)' out = model(xs) assert out.size() == (num_nodes, channels * num_layers) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(xs), out) model = JumpingKnowledge('max') assert str(model) == 'JumpingKnowledge(max)' out = model(xs) assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(xs), out) model = JumpingKnowledge('lstm', channels, num_layers) assert str(model) == (f'JumpingKnowledge(lstm, channels=' f'{channels}, layers={num_layers})') out = model(xs) assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(xs), out) def test_hetero_jumping_knowledge(): num_nodes, channels, num_layers = 100, 17, 5 types = ["author", "paper"] xs_dict = { key: [torch.randn(num_nodes, channels) for _ in range(num_layers)] for key in types } model = HeteroJumpingKnowledge(types, mode='cat') model.reset_parameters() assert str(model) == 'HeteroJumpingKnowledge(num_types=2, mode=cat)' out_dict = model(xs_dict) for out in out_dict.values(): assert out.size() == (num_nodes, channels * num_layers) if is_full_test(): jit = torch.jit.script(model) jit_out = jit(xs_dict) for key in types: assert torch.allclose(jit_out[key], out_dict[key]) model = HeteroJumpingKnowledge(types, mode='max') assert str(model) == 'HeteroJumpingKnowledge(num_types=2, mode=max)' out_dict = model(xs_dict) for out in out_dict.values(): assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(model) jit_out = jit(xs_dict) for key in types: assert torch.allclose(jit_out[key], out_dict[key]) model = HeteroJumpingKnowledge(types, mode='lstm', channels=channels, num_layers=num_layers) assert str(model) == (f'HeteroJumpingKnowledge(num_types=2, mode=lstm, ' f'channels={channels}, layers={num_layers})') out_dict = model(xs_dict) for out in out_dict.values(): assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(model) jit_out = jit(xs_dict) for key in types: assert torch.allclose(jit_out[key], out_dict[key]) ================================================ FILE: test/nn/models/test_label_prop.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn.models import LabelPropagation from torch_geometric.typing import SparseTensor def test_label_prop(): y = torch.tensor([1, 0, 0, 2, 1, 1]) edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]]) mask = torch.randint(0, 2, (6, ), dtype=torch.bool) model = LabelPropagation(num_layers=2, alpha=0.5) assert str(model) == 'LabelPropagation(num_layers=2, alpha=0.5)' # Test without mask: out = model(y, edge_index) assert out.size() == (6, 3) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6)) assert torch.allclose(model(y, adj.t()), out) # Test with mask: out = model(y, edge_index, mask) assert out.size() == (6, 3) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(model(y, adj.t(), mask), out) # Test post step: out = model(y, edge_index, mask, post_step=lambda y: torch.zeros_like(y)) assert torch.sum(out) == 0. ================================================ FILE: test/nn/models/test_lightgcn.py ================================================ import pytest import torch from torch_geometric.nn.models import LightGCN @pytest.mark.parametrize('embedding_dim', [32, 64]) @pytest.mark.parametrize('with_edge_weight', [False, True]) @pytest.mark.parametrize('lambda_reg', [0, 1e-4]) @pytest.mark.parametrize('alpha', [0, .25, torch.tensor([0.4, 0.3, 0.2])]) def test_lightgcn_ranking(embedding_dim, with_edge_weight, lambda_reg, alpha): num_nodes = 500 num_edges = 400 edge_index = torch.randint(0, num_nodes, (2, num_edges)) edge_weight = torch.rand(num_edges) if with_edge_weight else None edge_label_index = torch.randint(0, num_nodes, (2, 100)) model = LightGCN(num_nodes, embedding_dim, num_layers=2, alpha=alpha) assert str(model) == f'LightGCN(500, {embedding_dim}, num_layers=2)' pred = model(edge_index, edge_label_index, edge_weight) assert pred.size() == (100, ) loss = model.recommendation_loss( pos_edge_rank=pred[:50], neg_edge_rank=pred[50:], node_id=edge_index.unique(), lambda_reg=lambda_reg, ) assert loss.dim() == 0 and loss > 0 out = model.recommend(edge_index, edge_weight, k=2) assert out.size() == (500, 2) assert out.min() >= 0 and out.max() < 500 src_index = torch.arange(0, 250) dst_index = torch.arange(250, 500) out = model.recommend(edge_index, edge_weight, src_index, dst_index, k=2) assert out.size() == (250, 2) assert out.min() >= 250 and out.max() < 500 @pytest.mark.parametrize('embedding_dim', [32, 64]) @pytest.mark.parametrize('with_edge_weight', [False, True]) @pytest.mark.parametrize('alpha', [0, .25, torch.tensor([0.4, 0.3, 0.2])]) def test_lightgcn_link_prediction(embedding_dim, with_edge_weight, alpha): num_nodes = 500 num_edges = 400 edge_index = torch.randint(0, num_nodes, (2, num_edges)) edge_weight = torch.rand(num_edges) if with_edge_weight else None edge_label_index = torch.randint(0, num_nodes, (2, 100)) edge_label = torch.randint(0, 2, (edge_label_index.size(1), )) model = LightGCN(num_nodes, embedding_dim, num_layers=2, alpha=alpha) assert str(model) == f'LightGCN(500, {embedding_dim}, num_layers=2)' pred = model(edge_index, edge_label_index, edge_weight) assert pred.size() == (100, ) loss = model.link_pred_loss(pred, edge_label) assert loss.dim() == 0 and loss > 0 prob = model.predict_link(edge_index, edge_label_index, edge_weight, prob=True) assert prob.size() == (100, ) assert prob.min() > 0 and prob.max() < 1 prob = model.predict_link(edge_index, edge_label_index, edge_weight, prob=False) assert prob.size() == (100, ) assert ((prob == 0) | (prob == 1)).sum() == 100 ================================================ FILE: test/nn/models/test_linkx.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import LINKX from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor @pytest.mark.parametrize('num_edge_layers', [1, 2]) def test_linkx(num_edge_layers): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) edge_weight = torch.rand(edge_index.size(1)) model = LINKX(num_nodes=4, in_channels=16, hidden_channels=32, out_channels=8, num_layers=2, num_edge_layers=num_edge_layers) assert str(model) == 'LINKX(num_nodes=4, in_channels=16, out_channels=8)' out = model(x, edge_index) assert out.size() == (4, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(out, model(x, adj.t()), atol=1e-6) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(x, edge_index), out) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj.t()), out, atol=1e-6) out = model(None, edge_index) assert out.size() == (4, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(out, model(None, adj.t()), atol=1e-6) out = model(x, edge_index, edge_weight) assert out.size() == (4, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(4, 4)) assert torch.allclose(model(x, adj.t()), out, atol=1e-6) out = model(None, edge_index, edge_weight) assert out.size() == (4, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(model(None, adj.t()), out, atol=1e-6) ================================================ FILE: test/nn/models/test_lpformer.py ================================================ import torch from torch_geometric.nn import LPFormer from torch_geometric.testing import withPackage from torch_geometric.utils import to_undirected @withPackage('numba') # For ppr calculation def test_lpformer(): model = LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1) assert str( model ) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)' num_nodes = 20 x = torch.randn(num_nodes, 16) edges = torch.randint(0, num_nodes - 1, (2, 110)) edge_index, test_edges = edges[:, :100], edges[:, 100:] edge_index = to_undirected(edge_index) ppr_matrix = model.calc_sparse_ppr(edge_index, num_nodes, eps=1e-4) assert ppr_matrix.is_sparse assert ppr_matrix.size() == (num_nodes, num_nodes) assert ppr_matrix.sum().item() > 0 # Test with dense edge_index out = model(test_edges, x, edge_index, ppr_matrix) assert out.size() == (10, ) # Test with sparse edge_index adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.size(1)), [num_nodes, num_nodes]) out2 = model(test_edges, x, adj, ppr_matrix) assert out2.size() == (10, ) ================================================ FILE: test/nn/models/test_mask_label.py ================================================ import torch from torch_geometric.nn import MaskLabel def test_mask_label(): model = MaskLabel(2, 10) assert str(model) == 'MaskLabel()' x = torch.rand(4, 10) y = torch.tensor([1, 0, 1, 0]) mask = torch.tensor([False, False, True, True]) out = model(x, y, mask) assert out.size() == (4, 10) assert torch.allclose(out[~mask], x[~mask]) model = MaskLabel(2, 10, method='concat') out = model(x, y, mask) assert out.size() == (4, 20) assert torch.allclose(out[:, :10], x) def test_ratio_mask(): mask = torch.tensor([True, True, True, True, False, False, False, False]) out = MaskLabel.ratio_mask(mask, 0.5) assert out[:4].sum() <= 4 and out[4:].sum() == 0 ================================================ FILE: test/nn/models/test_meta.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.nn import MetaLayer from torch_geometric.testing import is_full_test from torch_geometric.utils import scatter count = 0 def test_meta_layer(): assert str(MetaLayer()) == ('MetaLayer(\n' ' edge_model=None,\n' ' node_model=None,\n' ' global_model=None\n' ')') def dummy_model(*args): global count count += 1 return None x = torch.randn(20, 10) edge_index = torch.randint(0, high=10, size=(2, 20), dtype=torch.long) for edge_model in (dummy_model, None): for node_model in (dummy_model, None): for global_model in (dummy_model, None): model = MetaLayer(edge_model, node_model, global_model) out = model(x, edge_index) assert isinstance(out, tuple) and len(out) == 3 assert count == 12 def test_meta_layer_example(): class EdgeModel(torch.nn.Module): def __init__(self): super().__init__() self.edge_mlp = Seq(Lin(2 * 10 + 5 + 20, 5), ReLU(), Lin(5, 5)) def forward( self, src: Tensor, dst: Tensor, edge_attr: Optional[Tensor], u: Optional[Tensor], batch: Optional[Tensor], ) -> Tensor: assert edge_attr is not None assert u is not None assert batch is not None out = torch.cat([src, dst, edge_attr, u[batch]], 1) return self.edge_mlp(out) class NodeModel(torch.nn.Module): def __init__(self): super().__init__() self.node_mlp_1 = Seq(Lin(15, 10), ReLU(), Lin(10, 10)) self.node_mlp_2 = Seq(Lin(2 * 10 + 20, 10), ReLU(), Lin(10, 10)) def forward( self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor], u: Optional[Tensor], batch: Optional[Tensor], ) -> Tensor: assert edge_attr is not None assert u is not None assert batch is not None row = edge_index[0] col = edge_index[1] out = torch.cat([x[row], edge_attr], dim=1) out = self.node_mlp_1(out) out = scatter(out, col, dim=0, dim_size=x.size(0), reduce='mean') out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out) class GlobalModel(torch.nn.Module): def __init__(self): super().__init__() self.global_mlp = Seq(Lin(20 + 10, 20), ReLU(), Lin(20, 20)) def forward( self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor], u: Optional[Tensor], batch: Optional[Tensor], ) -> Tensor: assert u is not None assert batch is not None out = torch.cat([ u, scatter(x, batch, dim=0, reduce='mean'), ], dim=1) return self.global_mlp(out) op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) x = torch.randn(20, 10) edge_attr = torch.randn(40, 5) u = torch.randn(2, 20) batch = torch.tensor([0] * 10 + [1] * 10) edge_index = torch.randint(0, high=10, size=(2, 20), dtype=torch.long) edge_index = torch.cat([edge_index, 10 + edge_index], dim=1) x_out, edge_attr_out, u_out = op(x, edge_index, edge_attr, u, batch) assert x_out.size() == (20, 10) assert edge_attr_out.size() == (40, 5) assert u_out.size() == (2, 20) if is_full_test(): jit = torch.jit.script(op) x_out, edge_attr_out, u_out = jit(x, edge_index, edge_attr, u, batch) assert x_out.size() == (20, 10) assert edge_attr_out.size() == (40, 5) assert u_out.size() == (2, 20) ================================================ FILE: test/nn/models/test_metapath2vec.py ================================================ import torch from torch_geometric.nn import MetaPath2Vec from torch_geometric.testing import has_package, withDevice @withDevice def test_metapath2vec(device): edge_index_dict = { ('author', 'writes', 'paper'): torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]], device=device), ('paper', 'written_by', 'author'): torch.tensor([[0, 0, 1, 1], [0, 1, 1, 2]], device=device) } metapath = [ ('author', 'writes', 'paper'), ('paper', 'written_by', 'author'), ] model = MetaPath2Vec(edge_index_dict, embedding_dim=16, metapath=metapath, walk_length=2, context_size=2).to(device) assert str(model) == 'MetaPath2Vec(5, 16)' z = model('author') assert z.size() == (3, 16) z = model('paper') assert z.size() == (2, 16) z = model('author', torch.arange(2, device=device)) assert z.size() == (2, 16) pos_rw, neg_rw = model._sample(torch.arange(3)) loss = model.loss(pos_rw.to(device), neg_rw.to(device)) assert 0 <= loss.item() if has_package('sklearn'): acc = model.test(torch.ones(20, 16), torch.randint(10, (20, )), torch.ones(20, 16), torch.randint(10, (20, ))) assert 0 <= acc and acc <= 1 def test_metapath2vec_empty_edges(): num_nodes_dict = {'a': 3, 'b': 4} edge_index_dict = { ('a', 'to', 'b'): torch.empty((2, 0), dtype=torch.long), ('b', 'to', 'a'): torch.empty((2, 0), dtype=torch.long), } metapath = [('a', 'to', 'b'), ('b', 'to', 'a')] model = MetaPath2Vec( edge_index_dict, embedding_dim=16, metapath=metapath, walk_length=10, context_size=7, walks_per_node=5, num_negative_samples=5, num_nodes_dict=num_nodes_dict, ) loader = model.loader(batch_size=16, shuffle=True) next(iter(loader)) ================================================ FILE: test/nn/models/test_mlp.py ================================================ import pytest import torch from torch_geometric.nn import MLP @pytest.mark.parametrize('norm', ['batch_norm', None]) @pytest.mark.parametrize('act_first', [False, True]) @pytest.mark.parametrize('plain_last', [False, True]) def test_mlp(norm, act_first, plain_last): x = torch.randn(4, 16) torch.manual_seed(12345) mlp = MLP( [16, 32, 32, 64], norm=norm, act_first=act_first, plain_last=plain_last, ) assert str(mlp) == 'MLP(16, 32, 32, 64)' out = mlp(x) assert out.size() == (4, 64) jit = torch.jit.script(mlp) assert torch.allclose(jit(x), out) torch.manual_seed(12345) mlp = MLP( 16, hidden_channels=32, out_channels=64, num_layers=3, norm=norm, act_first=act_first, plain_last=plain_last, ) assert torch.allclose(mlp(x), out) @pytest.mark.parametrize('norm', [ 'BatchNorm', 'GraphNorm', 'InstanceNorm', 'LayerNorm', ]) def test_batch(norm): x = torch.randn(3, 8) batch = torch.tensor([0, 0, 1]) model = MLP( 8, hidden_channels=16, out_channels=32, num_layers=2, norm=norm, ) assert model.supports_norm_batch == (norm != 'BatchNorm') out = model(x, batch=batch) assert out.size() == (3, 32) if model.supports_norm_batch: with pytest.raises(RuntimeError, match="out of bounds"): model(x, batch=batch, batch_size=1) def test_mlp_return_emb(): x = torch.randn(4, 16) mlp = MLP([16, 32, 1]) out, emb = mlp(x, return_emb=True) assert out.size() == (4, 1) assert emb.size() == (4, 32) out, emb = mlp(x, return_emb=False) assert out.size() == (4, 1) assert emb is None @pytest.mark.parametrize('plain_last', [False, True]) def test_fine_grained_mlp(plain_last): mlp = MLP( [16, 32, 32, 64], dropout=[0.1, 0.2, 0.3], bias=[False, True, False], ) assert mlp(torch.randn(4, 16)).size() == (4, 64) ================================================ FILE: test/nn/models/test_neural_fingerprint.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import NeuralFingerprint from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor @pytest.mark.parametrize('batch', [None, torch.tensor([0, 1, 1])]) def test_neural_fingerprint(batch): x = torch.randn(3, 7) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) model = NeuralFingerprint(7, 16, out_channels=5, num_layers=4) assert str(model) == 'NeuralFingerprint(7, 5, num_layers=4)' model.reset_parameters() out = model(x, edge_index, batch) assert out.size() == (1, 5) if batch is None else (2, 5) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(3, 3)) assert torch.allclose(model(x, adj.t(), batch), out) if is_full_test(): jit = torch.jit.export(model) assert torch.allclose(jit(x, edge_index, batch), out) ================================================ FILE: test/nn/models/test_node2vec.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import Node2Vec from torch_geometric.testing import ( has_package, is_full_test, withDevice, withPackage, ) @withDevice @withPackage('pyg_lib|torch_cluster') @pytest.mark.parametrize('p', [1.0]) @pytest.mark.parametrize('q', [1.0, 0.5]) def test_node2vec(device, p, q): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device) kwargs = dict(embedding_dim=16, walk_length=2, context_size=2, p=p, q=q) if not torch_geometric.typing.WITH_TORCH_CLUSTER and q != 1.0: with pytest.raises(ImportError, match="requires the 'torch-cluster'"): model = Node2Vec(edge_index, **kwargs) return model = Node2Vec(edge_index, **kwargs).to(device) assert str(model) == 'Node2Vec(3, 16)' assert model(torch.arange(3, device=device)).size() == (3, 16) pos_rw, neg_rw = model.sample(torch.arange(3)) assert float(model.loss(pos_rw.to(device), neg_rw.to(device))) >= 0 if has_package('sklearn'): acc = model.test(torch.ones(20, 16), torch.randint(10, (20, )), torch.ones(20, 16), torch.randint(10, (20, ))) assert 0 <= acc and acc <= 1 if is_full_test(): jit = torch.jit.script(model) assert jit(torch.arange(3, device=device)).size() == (3, 16) pos_rw, neg_rw = jit.sample(torch.arange(3)) assert float(jit.loss(pos_rw.to(device), neg_rw.to(device))) >= 0 ================================================ FILE: test/nn/models/test_pmlp.py ================================================ import pytest import torch from torch_geometric.nn.models import PMLP def test_pmlp(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) pmlp = PMLP(in_channels=16, hidden_channels=32, out_channels=2, num_layers=4) assert str(pmlp) == 'PMLP(16, 2, num_layers=4)' pmlp.training = True assert pmlp(x).size() == (4, 2) pmlp.training = False assert pmlp(x, edge_index).size() == (4, 2) with pytest.raises(ValueError, match="'edge_index' needs to be present"): pmlp.training = False pmlp(x) ================================================ FILE: test/nn/models/test_polynormer.py ================================================ import pytest import torch from torch_geometric.nn.models import Polynormer @pytest.mark.parametrize('local_attn', [True, False]) @pytest.mark.parametrize('qk_shared', [True, False]) @pytest.mark.parametrize('pre_ln', [True, False]) @pytest.mark.parametrize('post_bn', [True, False]) def test_polynormer(local_attn, qk_shared, pre_ln, post_bn): x = torch.randn(10, 16) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], ]) batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) model = Polynormer( in_channels=16, hidden_channels=128, out_channels=40, qk_shared=qk_shared, pre_ln=pre_ln, post_bn=post_bn, local_attn=local_attn, ) out = model(x, edge_index, batch) assert out.size() == (10, 40) model._global = True out = model(x, edge_index, batch) assert out.size() == (10, 40) ================================================ FILE: test/nn/models/test_re_net.py ================================================ import torch from torch_geometric.datasets.icews import EventDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import RENet from torch_geometric.testing import is_full_test class MyTestEventDataset(EventDataset): def __init__(self, root, seq_len): super().__init__(root, pre_transform=RENet.pre_transform(seq_len)) self.load(self.processed_paths[0]) @property def num_nodes(self): return 16 @property def num_rels(self): return 8 @property def processed_file_names(self): return 'data.pt' def _download(self): pass def process_events(self): sub = torch.randint(self.num_nodes, (64, ), dtype=torch.long) rel = torch.randint(self.num_rels, (64, ), dtype=torch.long) obj = torch.randint(self.num_nodes, (64, ), dtype=torch.long) t = torch.arange(8, dtype=torch.long).view(-1, 1).repeat(1, 8).view(-1) return torch.stack([sub, rel, obj, t], dim=1) def process(self): data_list = self._process_data_list() self.save(data_list, self.processed_paths[0]) def test_re_net(tmp_path): dataset = MyTestEventDataset(tmp_path, seq_len=4) loader = DataLoader(dataset, 2, follow_batch=['h_sub', 'h_obj']) model = RENet(dataset.num_nodes, dataset.num_rels, hidden_channels=16, seq_len=4) if is_full_test(): jit = torch.jit.export(model) logits = torch.randn(6, 6) y = torch.tensor([0, 1, 2, 3, 4, 5]) mrr, hits1, hits3, hits10 = model.test(logits, y) assert 0.15 < mrr <= 1 assert hits1 <= hits3 and hits3 <= hits10 and hits10 == 1 for data in loader: log_prob_obj, log_prob_sub = model(data) if is_full_test(): log_prob_obj_jit, log_prob_sub_jit = jit(data) assert torch.allclose(log_prob_obj_jit, log_prob_obj) assert torch.allclose(log_prob_sub_jit, log_prob_sub) model.test(log_prob_obj, data.obj) model.test(log_prob_sub, data.sub) ================================================ FILE: test/nn/models/test_rect.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import RECT_L from torch_geometric.testing import is_full_test from torch_geometric.typing import SparseTensor def test_rect(): x = torch.randn(6, 8) y = torch.tensor([1, 0, 0, 2, 1, 1]) edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]]) mask = torch.randint(0, 2, (6, ), dtype=torch.bool) model = RECT_L(8, 16) assert str(model) == 'RECT_L(8, 16)' out = model(x, edge_index) assert out.size() == (6, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6)) assert torch.allclose(out, model(x, adj.t()), atol=1e-6) # Test `embed`: embed_out = model.embed(x, edge_index) assert embed_out.size() == (6, 16) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(embed_out, model.embed(x, adj.t()), atol=1e-6) # Test `get_semantic_labels`: labels_out = model.get_semantic_labels(x, y, mask) assert labels_out.size() == (int(mask.sum()), 8) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(x, edge_index), out, atol=1e-6) assert torch.allclose(embed_out, jit.embed(x, edge_index), atol=1e-6) assert torch.allclose(labels_out, jit.get_semantic_labels(x, y, mask)) if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit(x, adj.t()), out, atol=1e-6) assert torch.allclose(embed_out, jit.embed(x, adj.t()), atol=1e-6) assert torch.allclose(labels_out, jit.get_semantic_labels(x, y, mask)) ================================================ FILE: test/nn/models/test_rev_gnn.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.nn import GraphConv, GroupAddRev, SAGEConv from torch_geometric.nn.dense.linear import Linear @pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) def test_revgnn_forward_inverse(num_groups): x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) lin = Linear(32, 32) conv = SAGEConv(32 // num_groups, 32 // num_groups) conv = GroupAddRev(conv, num_groups=num_groups) assert str(conv) == (f'GroupAddRev(SAGEConv({32 // num_groups}, ' f'{32 // num_groups}, aggr=mean), ' f'num_groups={num_groups})') h = lin(x) h_o = h.clone().detach() out = conv(h, edge_index) if torch_geometric.typing.WITH_PT20: assert h.untyped_storage().size() == 0 else: assert h.storage().size() == 0 h_rev = conv.inverse(out, edge_index) assert torch.allclose(h_o, h_rev, atol=0.001) @pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) def test_revgnn_backward(num_groups): x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) lin = Linear(32, 32) conv = SAGEConv(32 // num_groups, 32 // num_groups) conv = GroupAddRev(conv, num_groups=num_groups) h = lin(x) out = conv(h, edge_index) target = out.mean() target.backward() @pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) def test_revgnn_multi_backward(num_groups): x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) lin = Linear(32, 32) conv = SAGEConv(32 // num_groups, 32 // num_groups) conv = GroupAddRev(conv, num_groups=num_groups, num_bwd_passes=4) h = lin(x) out = conv(h, edge_index) target = out.mean() target.backward(retain_graph=True) target.backward(retain_graph=True) torch.autograd.grad(outputs=target, inputs=[h] + list(conv.parameters()), retain_graph=True) torch.autograd.grad(outputs=target, inputs=[h] + list(conv.parameters())) @pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) def test_revgnn_diable(num_groups): x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) lin = Linear(32, 32) conv = SAGEConv(32 // num_groups, 32 // num_groups) conv = GroupAddRev(conv, num_groups=num_groups, disable=True) h = lin(x) out = conv(h, edge_index) target = out.mean() target.backward() # Memory will not be freed if disable: if torch_geometric.typing.WITH_PT20: assert h.untyped_storage().size() == 4 * 4 * 32 else: assert h.storage().size() == 4 * 32 @pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) def test_revgnn_with_args(num_groups): x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_weight = torch.rand(4) lin = Linear(32, 32) conv = GraphConv(32 // num_groups, 32 // num_groups) conv = GroupAddRev(conv, num_groups=num_groups) h = lin(x) out = conv(h, edge_index, edge_weight) target = out.mean() target.backward() ================================================ FILE: test/nn/models/test_schnet.py ================================================ import pytest import torch from torch_geometric.data import Batch, Data from torch_geometric.nn import SchNet from torch_geometric.nn.models.schnet import RadiusInteractionGraph from torch_geometric.testing import is_full_test, withPackage def generate_data(): return Data( z=torch.randint(1, 10, (20, )), pos=torch.randn(20, 3), ) @withPackage('torch_cluster') @withPackage('ase') @pytest.mark.parametrize('use_interaction_graph', [False, True]) @pytest.mark.parametrize('use_atomref', [False, True]) def test_schnet(use_interaction_graph, use_atomref): data = generate_data() interaction_graph = None if use_interaction_graph: interaction_graph = RadiusInteractionGraph(cutoff=6.0) model = SchNet( hidden_channels=16, num_filters=16, num_interactions=2, interaction_graph=interaction_graph, num_gaussians=10, cutoff=6.0, dipole=True, atomref=torch.randn(100, 1) if use_atomref else None, ) assert str(model) == ('SchNet(hidden_channels=16, num_filters=16, ' 'num_interactions=2, num_gaussians=10, cutoff=6.0)') with torch.no_grad(): out = model(data.z, data.pos) assert out.size() == (1, 1) if is_full_test(): jit = torch.jit.export(model) out = jit(data.z, data.pos) assert out.size() == (1, 1) @withPackage('torch_cluster') def test_schnet_batch(): num_graphs = 3 batch = [generate_data() for _ in range(num_graphs)] batch = Batch.from_data_list(batch) model = SchNet( hidden_channels=16, num_filters=16, num_interactions=2, num_gaussians=10, cutoff=6.0, ) with torch.no_grad(): out = model(batch.z, batch.pos, batch.batch) assert out.size() == (num_graphs, 1) ================================================ FILE: test/nn/models/test_sgformer.py ================================================ import torch from torch_geometric.nn.models import SGFormer def test_sgformer(): x = torch.randn(10, 16) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], ]) batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) model = SGFormer( in_channels=16, hidden_channels=128, out_channels=40, ) out = model(x, edge_index, batch) assert out.size() == (10, 40) ================================================ FILE: test/nn/models/test_signed_gcn.py ================================================ import torch from torch_geometric.nn import SignedGCN from torch_geometric.testing import has_package, is_full_test # @withPackage('sklearn') def test_signed_gcn(): model = SignedGCN(8, 16, num_layers=2, lamb=5) assert str(model) == 'SignedGCN(8, 16, num_layers=2)' pos_index = torch.randint(high=10, size=(2, 40), dtype=torch.long) neg_index = torch.randint(high=10, size=(2, 40), dtype=torch.long) train_pos_index, test_pos_index = model.split_edges(pos_index) train_neg_index, test_neg_index = model.split_edges(neg_index) assert train_pos_index.size() == (2, 32) assert test_pos_index.size() == (2, 8) assert train_neg_index.size() == (2, 32) assert test_neg_index.size() == (2, 8) if has_package('sklearn'): x = model.create_spectral_features( train_pos_index, train_neg_index, num_nodes=10, ) assert x.size() == (10, 8) else: x = torch.randn(10, 8) z = model(x, train_pos_index, train_neg_index) assert z.size() == (10, 16) loss = model.loss(z, train_pos_index, train_neg_index) assert loss.item() >= 0 if has_package('sklearn'): auc, f1 = model.test(z, test_pos_index, test_neg_index) assert auc >= 0 assert f1 >= 0 if is_full_test(): jit = torch.jit.export(model) assert torch.allclose(jit(x, train_pos_index, train_neg_index), z) ================================================ FILE: test/nn/models/test_tgn.py ================================================ import pytest import torch from torch_geometric.data import TemporalData from torch_geometric.loader import TemporalDataLoader from torch_geometric.nn import TGNMemory from torch_geometric.nn.models.tgn import ( IdentityMessage, LastAggregator, LastNeighborLoader, ) @pytest.mark.parametrize('neg_sampling_ratio', [0.0, 1.0]) def test_tgn(neg_sampling_ratio): memory_dim = 16 time_dim = 16 src = torch.tensor([0, 1, 0, 2, 0, 3, 1, 4, 2, 3]) dst = torch.tensor([1, 2, 1, 1, 3, 2, 4, 3, 3, 4]) t = torch.arange(10) msg = torch.randn(10, 16) data = TemporalData(src=src, dst=dst, t=t, msg=msg) loader = TemporalDataLoader( data, batch_size=5, neg_sampling_ratio=neg_sampling_ratio, ) neighbor_loader = LastNeighborLoader(data.num_nodes, size=3) assert neighbor_loader.cur_e_id == 0 assert neighbor_loader.e_id.size() == (data.num_nodes, 3) memory = TGNMemory( num_nodes=data.num_nodes, raw_msg_dim=msg.size(-1), memory_dim=memory_dim, time_dim=time_dim, message_module=IdentityMessage(msg.size(-1), memory_dim, time_dim), aggregator_module=LastAggregator(), ) assert memory.memory.size() == (data.num_nodes, memory_dim) assert memory.last_update.size() == (data.num_nodes, ) # Test TGNMemory training: for i, batch in enumerate(loader): n_id, edge_index, e_id = neighbor_loader(batch.n_id) z, last_update = memory(n_id) memory.update_state(batch.src, batch.dst, batch.t, batch.msg) neighbor_loader.insert(batch.src, batch.dst) if i == 0: assert n_id.size(0) >= 4 assert edge_index.numel() == 0 assert e_id.numel() == 0 assert z.size() == (n_id.size(0), memory_dim) assert torch.sum(last_update) == 0 else: assert n_id.size(0) == 5 assert edge_index.numel() == 12 assert e_id.numel() == 6 assert z.size() == (n_id.size(0), memory_dim) assert torch.equal(last_update, torch.tensor([4, 3, 3, 4, 0])) # Test TGNMemory inference: memory.eval() all_n_id = torch.arange(data.num_nodes) z, last_update = memory(all_n_id) assert z.size() == (data.num_nodes, memory_dim) assert torch.equal(last_update, torch.tensor([4, 6, 8, 9, 9])) post_src = torch.tensor([3, 4]) post_dst = torch.tensor([4, 3]) post_t = torch.tensor([10, 10]) post_msg = torch.randn(2, 16) memory.update_state(post_src, post_dst, post_t, post_msg) post_z, post_last_update = memory(all_n_id) assert torch.allclose(z[0:3], post_z[0:3]) assert torch.equal(post_last_update, torch.tensor([4, 6, 8, 10, 10])) memory.reset_state() assert memory.memory.sum() == 0 assert memory.last_update.sum() == 0 ================================================ FILE: test/nn/models/test_visnet.py ================================================ import pytest import torch from torch_geometric.nn import ViSNet from torch_geometric.testing import withPackage @withPackage('torch_cluster') @pytest.mark.parametrize('kwargs', [ dict(lmax=2, derivative=True, vecnorm_type=None, vertex=False), dict(lmax=1, derivative=False, vecnorm_type='max_min', vertex=True), ]) def test_visnet(kwargs): z = torch.randint(1, 10, (20, )) pos = torch.randn(20, 3) batch = torch.zeros(20, dtype=torch.long) model = ViSNet(**kwargs) model.reset_parameters() energy, forces = model(z, pos, batch) assert energy.size() == (1, 1) if kwargs['derivative']: assert forces.size() == (20, 3) else: assert forces is None ================================================ FILE: test/nn/norm/test_batch_norm.py ================================================ import pytest import torch from torch_geometric.nn import BatchNorm, HeteroBatchNorm from torch_geometric.testing import is_full_test, withDevice @withDevice @pytest.mark.parametrize('conf', [True, False]) def test_batch_norm(device, conf): x = torch.randn(100, 16, device=device) norm = BatchNorm(16, affine=conf, track_running_stats=conf, device=device) norm.reset_running_stats() norm.reset_parameters() assert str(norm) == (f'BatchNorm(16, eps=1e-05, momentum=0.1, ' f'affine={conf}, track_running_stats={conf})') if is_full_test(): torch.jit.script(norm) out = norm(x) assert out.size() == (100, 16) def test_batch_norm_single_element(): x = torch.randn(1, 16) norm = BatchNorm(16) with pytest.raises(ValueError, match="Expected more than 1 value"): norm(x) with pytest.raises(ValueError, match="requires 'track_running_stats'"): norm = BatchNorm(16, track_running_stats=False, allow_single_element=True) norm = BatchNorm(16, track_running_stats=True, allow_single_element=True) out = norm(x) assert torch.allclose(out, x) @withDevice @pytest.mark.parametrize('conf', [True, False]) def test_hetero_batch_norm(device, conf): x = torch.randn((100, 16), device=device) # Test single type: norm = BatchNorm(16, affine=conf, track_running_stats=conf, device=device) expected = norm(x) type_vec = torch.zeros(100, dtype=torch.long, device=device) norm = HeteroBatchNorm(16, num_types=1, affine=conf, track_running_stats=conf, device=device) norm.reset_running_stats() norm.reset_parameters() assert str(norm) == 'HeteroBatchNorm(16, num_types=1)' out = norm(x, type_vec) assert out.size() == (100, 16) assert torch.allclose(out, expected, atol=1e-3) # Test multiple types: type_vec = torch.randint(5, (100, ), device=device) norm = HeteroBatchNorm(16, num_types=5, affine=conf, track_running_stats=conf, device=device) out = norm(x, type_vec) assert out.size() == (100, 16) for i in range(5): # Check that mean=0 and std=1 across all types: mean = out[type_vec == i].mean() std = out[type_vec == i].std(unbiased=False) assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-7) assert torch.allclose(std, torch.ones_like(std), atol=1e-7) ================================================ FILE: test/nn/norm/test_diff_group_norm.py ================================================ import torch from torch_geometric.nn import DiffGroupNorm from torch_geometric.testing import is_full_test, withDevice @withDevice def test_diff_group_norm(device): x = torch.randn(6, 16, device=device) norm = DiffGroupNorm(16, groups=4, lamda=0, device=device) assert str(norm) == 'DiffGroupNorm(16, groups=4)' assert torch.allclose(norm(x), x) if is_full_test(): jit = torch.jit.script(norm) assert torch.allclose(jit(x), x) norm = DiffGroupNorm(16, groups=4, lamda=0.01, device=device) assert str(norm) == 'DiffGroupNorm(16, groups=4)' out = norm(x) assert out.size() == x.size() if is_full_test(): jit = torch.jit.script(norm) assert torch.allclose(jit(x), out) def test_group_distance_ratio(): x = torch.randn(6, 16) y = torch.tensor([0, 1, 0, 1, 1, 1]) assert DiffGroupNorm.group_distance_ratio(x, y) > 0 if is_full_test(): jit = torch.jit.script(DiffGroupNorm.group_distance_ratio) assert jit(x, y) > 0 ================================================ FILE: test/nn/norm/test_graph_norm.py ================================================ import torch from torch_geometric.nn import GraphNorm from torch_geometric.testing import is_full_test, withDevice @withDevice def test_graph_norm(device): torch.manual_seed(42) x = torch.randn(200, 16, device=device) batch = torch.arange(4, device=device).view(-1, 1).repeat(1, 50).view(-1) norm = GraphNorm(16, device=device) assert str(norm) == 'GraphNorm(16)' if is_full_test(): torch.jit.script(norm) out = norm(x) assert out.size() == (200, 16) assert torch.allclose(out.mean(dim=0), torch.zeros(16, device=device), atol=1e-6) assert torch.allclose(out.std(dim=0, unbiased=False), torch.ones(16, device=device), atol=1e-6) out = norm(x, batch) assert out.size() == (200, 16) assert torch.allclose(out[:50].mean(dim=0), torch.zeros(16, device=device), atol=1e-6) assert torch.allclose(out[:50].std(dim=0, unbiased=False), torch.ones(16, device=device), atol=1e-6) ================================================ FILE: test/nn/norm/test_graph_size_norm.py ================================================ import torch from torch_geometric.nn import GraphSizeNorm from torch_geometric.testing import is_full_test def test_graph_size_norm(): x = torch.randn(100, 16) batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long)) norm = GraphSizeNorm() assert str(norm) == 'GraphSizeNorm()' out = norm(x, batch) assert out.size() == (100, 16) if is_full_test(): jit = torch.jit.script(norm) assert torch.allclose(jit(x, batch), out) ================================================ FILE: test/nn/norm/test_instance_norm.py ================================================ import pytest import torch from torch_geometric.nn import InstanceNorm from torch_geometric.testing import is_full_test, withDevice @withDevice @pytest.mark.parametrize('conf', [True, False]) def test_instance_norm(conf, device): batch = torch.zeros(100, dtype=torch.long, device=device) x1 = torch.randn(100, 16, device=device) x2 = torch.randn(100, 16, device=device) norm1 = InstanceNorm(16, affine=conf, track_running_stats=conf, device=device) norm2 = InstanceNorm(16, affine=conf, track_running_stats=conf, device=device) assert str(norm1) == 'InstanceNorm(16)' if is_full_test(): torch.jit.script(norm1) out1 = norm1(x1) out2 = norm2(x1, batch) assert out1.size() == (100, 16) assert torch.allclose(out1, out2, atol=1e-7) if conf: assert torch.allclose(norm1.running_mean, norm2.running_mean) assert torch.allclose(norm1.running_var, norm2.running_var) out1 = norm1(x2) out2 = norm2(x2, batch) assert torch.allclose(out1, out2, atol=1e-7) if conf: assert torch.allclose(norm1.running_mean, norm2.running_mean) assert torch.allclose(norm1.running_var, norm2.running_var) norm1.eval() norm2.eval() out1 = norm1(x1) out2 = norm2(x1, batch) assert torch.allclose(out1, out2, atol=1e-7) out1 = norm1(x2) out2 = norm2(x2, batch) assert torch.allclose(out1, out2, atol=1e-7) out1 = norm2(x1) out2 = norm2(x2) out3 = norm2(torch.cat([x1, x2], dim=0), torch.cat([batch, batch + 1])) assert torch.allclose(out1, out3[:100], atol=1e-7) assert torch.allclose(out2, out3[100:], atol=1e-7) ================================================ FILE: test/nn/norm/test_layer_norm.py ================================================ import pytest import torch from torch_geometric.nn import HeteroLayerNorm, LayerNorm from torch_geometric.testing import is_full_test, withDevice @withDevice @pytest.mark.parametrize('affine', [True, False]) @pytest.mark.parametrize('mode', ['graph', 'node']) def test_layer_norm(device, affine, mode): x = torch.randn(100, 16, device=device) batch = torch.zeros(100, dtype=torch.long, device=device) norm = LayerNorm(16, affine=affine, mode=mode, device=device) assert str(norm) == f'LayerNorm(16, affine={affine}, mode={mode})' if is_full_test(): torch.jit.script(norm) out1 = norm(x) assert out1.size() == (100, 16) assert torch.allclose(norm(x, batch), out1, atol=1e-6) out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0)) assert torch.allclose(out1, out2[:100], atol=1e-6) assert torch.allclose(out1, out2[100:], atol=1e-6) @withDevice @pytest.mark.parametrize('affine', [False, True]) def test_hetero_layer_norm(device, affine): x = torch.randn((100, 16), device=device) expected = LayerNorm(16, affine=affine, mode='node', device=device)(x) # Test single type: type_vec = torch.zeros(100, dtype=torch.long, device=device) type_ptr = [0, 100] norm = HeteroLayerNorm(16, num_types=1, affine=affine, device=device) assert str(norm) == 'HeteroLayerNorm(16, num_types=1)' out = norm(x, type_vec) assert out.size() == (100, 16) assert torch.allclose(out, expected, atol=1e-3) assert torch.allclose(norm(out, type_ptr=type_ptr), expected, atol=1e-3) mean = out.mean(dim=-1) std = out.std(unbiased=False, dim=-1) assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-2) assert torch.allclose(std, torch.ones_like(std), atol=1e-2) # Test multiple types: type_vec = torch.arange(5, device=device) type_vec = type_vec.view(-1, 1).repeat(1, 20).view(-1) type_ptr = [0, 20, 40, 60, 80, 100] norm = HeteroLayerNorm(16, num_types=5, affine=affine, device=device) assert str(norm) == 'HeteroLayerNorm(16, num_types=5)' out = norm(x, type_vec) assert out.size() == (100, 16) assert torch.allclose(out, expected, atol=1e-3) assert torch.allclose(norm(out, type_ptr=type_ptr), expected, atol=1e-3) mean = out.mean(dim=-1) std = out.std(unbiased=False, dim=-1) assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-2) assert torch.allclose(std, torch.ones_like(std), atol=1e-2) ================================================ FILE: test/nn/norm/test_mean_subtraction_norm.py ================================================ import torch from torch_geometric.nn import MeanSubtractionNorm from torch_geometric.testing import is_full_test def test_mean_subtraction_norm(): x = torch.randn(6, 16) batch = torch.tensor([0, 0, 1, 1, 1, 2]) norm = MeanSubtractionNorm() assert str(norm) == 'MeanSubtractionNorm()' if is_full_test(): torch.jit.script(norm) out = norm(x) assert out.size() == (6, 16) assert torch.allclose(out.mean(), torch.tensor(0.), atol=1e-6) out = norm(x, batch) assert out.size() == (6, 16) assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6) assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6) ================================================ FILE: test/nn/norm/test_msg_norm.py ================================================ import torch from torch_geometric.nn import MessageNorm from torch_geometric.testing import is_full_test, withDevice @withDevice def test_message_norm(device): norm = MessageNorm(learn_scale=True, device=device) assert str(norm) == 'MessageNorm(learn_scale=True)' x = torch.randn(100, 16, device=device) msg = torch.randn(100, 16, device=device) out = norm(x, msg) assert out.size() == (100, 16) if is_full_test(): jit = torch.jit.script(norm) assert torch.allclose(jit(x, msg), out) norm = MessageNorm(learn_scale=False, device=device) assert str(norm) == 'MessageNorm(learn_scale=False)' out = norm(x, msg) assert out.size() == (100, 16) if is_full_test(): jit = torch.jit.script(norm) assert torch.allclose(jit(x, msg), out) ================================================ FILE: test/nn/norm/test_pair_norm.py ================================================ import pytest import torch from torch_geometric.nn import PairNorm from torch_geometric.testing import is_full_test @pytest.mark.parametrize('scale_individually', [False, True]) def test_pair_norm(scale_individually): x = torch.randn(100, 16) batch = torch.zeros(100, dtype=torch.long) norm = PairNorm(scale_individually=scale_individually) assert str(norm) == 'PairNorm()' if is_full_test(): torch.jit.script(norm) out1 = norm(x) assert out1.size() == (100, 16) out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0)) assert torch.allclose(out1, out2[:100], atol=1e-6) assert torch.allclose(out1, out2[100:], atol=1e-6) ================================================ FILE: test/nn/pool/connect/test_filter_edges.py ================================================ import torch from torch_geometric.nn.pool.connect import FilterEdges from torch_geometric.nn.pool.select import SelectOutput from torch_geometric.testing import is_full_test def test_filter_edges(): edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 1, 3, 2, 2]]) edge_attr = torch.tensor([1, 2, 3, 4, 5, 6]) batch = torch.tensor([0, 0, 1, 1]) select_output = SelectOutput( node_index=torch.tensor([1, 2]), num_nodes=4, cluster_index=torch.tensor([0, 1]), num_clusters=2, ) connect = FilterEdges() assert str(connect) == 'FilterEdges()' out1 = connect(select_output, edge_index, edge_attr, batch) assert out1.edge_index.tolist() == [[0, 1], [0, 1]] assert out1.edge_attr.tolist() == [3, 5] assert out1.batch.tolist() == [0, 1] if is_full_test(): jit = torch.jit.script(connect) out2 = jit(select_output, edge_index, edge_attr, batch) torch.equal(out1.edge_index, out2.edge_index) torch.equal(out1.edge_attr, out2.edge_attr) torch.equal(out1.batch, out2.batch) ================================================ FILE: test/nn/pool/select/test_select_topk.py ================================================ from itertools import product import pytest import torch from torch_geometric.nn.pool.select import SelectOutput, SelectTopK from torch_geometric.nn.pool.select.topk import topk from torch_geometric.profile import benchmark from torch_geometric.testing import is_full_test def test_topk_ratio(): x = torch.tensor([2.0, 4.0, 5.0, 6.0, 2.0, 9.0]) batch = torch.tensor([0, 0, 1, 1, 1, 1]) perm1 = topk(x, 0.5, batch) assert perm1.tolist() == [1, 5, 3] assert x[perm1].tolist() == [4.0, 9.0, 6.0] assert batch[perm1].tolist() == [0, 1, 1] perm2 = topk(x, 2, batch) assert perm2.tolist() == [1, 0, 5, 3] assert x[perm2].tolist() == [4.0, 2.0, 9.0, 6.0] assert batch[perm2].tolist() == [0, 0, 1, 1] perm3 = topk(x, 3, batch) assert perm3.tolist() == [1, 0, 5, 3, 2] assert x[perm3].tolist() == [4.0, 2.0, 9.0, 6.0, 5.0] assert batch[perm3].tolist() == [0, 0, 1, 1, 1] if is_full_test(): jit = torch.jit.script(topk) assert torch.equal(jit(x, 0.5, batch), perm1) assert torch.equal(jit(x, 2, batch), perm2) assert torch.equal(jit(x, 3, batch), perm3) @pytest.mark.parametrize('min_score', [None, 2.0]) def test_select_topk(min_score): x = torch.randn(6, 16) batch = torch.tensor([0, 0, 1, 1, 1, 1]) pool = SelectTopK(16, min_score=min_score) if min_score is None: assert str(pool) == 'SelectTopK(16, ratio=0.5)' else: assert str(pool) == 'SelectTopK(16, min_score=2.0)' out = pool(x, batch) assert isinstance(out, SelectOutput) assert out.num_nodes == 6 assert out.num_clusters <= out.num_nodes assert out.node_index.min() >= 0 assert out.node_index.max() < out.num_nodes assert out.cluster_index.min() == 0 assert out.cluster_index.max() == out.num_clusters - 1 if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() BS = [2**i for i in range(6, 8)] NS = [2**i for i in range(8, 16)] funcs = [] func_names = [] args_list = [] for B, N in product(BS, NS): x = torch.randn(N, device=args.device) batch = torch.randint(0, B, (N, ), device=args.device).sort()[0] funcs.append(topk) func_names.append(f'B={B}, N={N}') args_list.append((x, 0.5, batch)) benchmark( funcs=funcs, func_names=func_names, args=args_list, num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, progress_bar=True, ) ================================================ FILE: test/nn/pool/test_approx_knn.py ================================================ import warnings import torch from torch_geometric.nn import approx_knn, approx_knn_graph from torch_geometric.testing import onlyFullTest, withPackage def to_set(edge_index): return {(i, j) for i, j in edge_index.t().tolist()} @onlyFullTest # JIT compile makes this test too slow :( @withPackage('pynndescent') def test_approx_knn(): warnings.filterwarnings('ignore', '.*find n_neighbors.*') x = torch.tensor([ [-1.0, -1.0], [-1.0, +1.0], [+1.0, +1.0], [+1.0, -1.0], [-1.0, -1.0], [-1.0, +1.0], [+1.0, +1.0], [+1.0, -1.0], ]) y = torch.tensor([ [+1.0, 0.0], [-1.0, 0.0], ]) batch_x = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) batch_y = torch.tensor([0, 1]) edge_index = approx_knn(x, y, 2) assert to_set(edge_index) == {(0, 2), (0, 3), (1, 0), (1, 1)} edge_index = approx_knn(x, y, 2, batch_x, batch_y) assert to_set(edge_index) == {(0, 2), (0, 3), (1, 4), (1, 5)} @onlyFullTest # JIT compile makes this test too slow :( @withPackage('pynndescent') def test_approx_knn_graph(): warnings.filterwarnings('ignore', '.*find n_neighbors.*') x = torch.tensor([ [-1.0, -1.0], [-1.0, +1.0], [+1.0, +1.0], [+1.0, -1.0], ]) edge_index = approx_knn_graph(x, k=2, flow='target_to_source') assert to_set(edge_index) == {(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), (2, 3), (3, 0), (3, 2)} edge_index = approx_knn_graph(x, k=2, flow='source_to_target') assert to_set(edge_index) == {(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)} ================================================ FILE: test/nn/pool/test_asap.py ================================================ import io import torch from torch_geometric.nn import ASAPooling, GCNConv, GraphConv from torch_geometric.testing import is_full_test, onlyFullTest, onlyLinux @onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows. def test_asap(): in_channels = 16 edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, in_channels)) for GNN in [GraphConv, GCNConv]: pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=False) assert str(pool) == ('ASAPooling(16, ratio=0.5)') out = pool(x, edge_index) assert out[0].size() == (num_nodes // 2, in_channels) assert out[1].size() == (2, 2) if is_full_test(): torch.jit.script(pool) pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=True) assert str(pool) == ('ASAPooling(16, ratio=0.5)') out = pool(x, edge_index) assert out[0].size() == (num_nodes // 2, in_channels) assert out[1].size() == (2, 4) pool = ASAPooling(in_channels, ratio=2, GNN=GNN, add_self_loops=False) assert str(pool) == ('ASAPooling(16, ratio=2)') out = pool(x, edge_index) assert out[0].size() == (2, in_channels) assert out[1].size() == (2, 2) @onlyFullTest def test_asap_jit_save(): pool = ASAPooling(in_channels=16) torch.jit.save(torch.jit.script(pool), io.BytesIO()) ================================================ FILE: test/nn/pool/test_avg_pool.py ================================================ import torch from torch_geometric.data import Batch from torch_geometric.nn import avg_pool, avg_pool_neighbor_x, avg_pool_x from torch_geometric.testing import is_full_test def test_avg_pool_x(): cluster = torch.tensor([0, 1, 0, 1, 2, 2]) x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) batch = torch.tensor([0, 0, 0, 0, 1, 1]) out = avg_pool_x(cluster, x, batch) assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]] assert out[1].tolist() == [0, 0, 1] if is_full_test(): jit = torch.jit.script(avg_pool_x) out = jit(cluster, x, batch) assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]] assert out[1].tolist() == [0, 0, 1] out, _ = avg_pool_x(cluster, x, batch, size=2) assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]] batch_size = int(batch.max().item()) + 1 out2, _ = avg_pool_x(cluster, x, batch, batch_size=batch_size, size=2) assert torch.equal(out, out2) if is_full_test(): jit = torch.jit.script(avg_pool_x) out, _ = jit(cluster, x, batch, size=2) assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]] out2, _ = jit(cluster, x, batch, batch_size=batch_size, size=2) assert torch.equal(out, out2) def test_avg_pool(): cluster = torch.tensor([0, 1, 0, 1, 2, 2]) x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) pos = torch.tensor([ [0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0], ]) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) edge_attr = torch.ones(edge_index.size(1)) batch = torch.tensor([0, 0, 0, 0, 1, 1]) data = Batch(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, batch=batch) data = avg_pool(cluster, data, transform=lambda x: x) assert data.x.tolist() == [[3, 4], [5, 6], [10, 11]] assert data.pos.tolist() == [[1, 1], [2, 2], [4.5, 4.5]] assert data.edge_index.tolist() == [[0, 1], [1, 0]] assert data.edge_attr.tolist() == [4, 4] assert data.batch.tolist() == [0, 0, 1] def test_avg_pool_neighbor_x(): x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) batch = torch.tensor([0, 0, 0, 0, 1, 1]) data = Batch(x=x, edge_index=edge_index, batch=batch) data = avg_pool_neighbor_x(data) assert data.x.tolist() == [ [4, 5], [4, 5], [4, 5], [4, 5], [10, 11], [10, 11], ] assert torch.equal(data.edge_index, edge_index) ================================================ FILE: test/nn/pool/test_cluster_pool.py ================================================ import pytest import torch from torch_geometric.nn import ClusterPooling from torch_geometric.testing import withPackage @withPackage('scipy') @pytest.mark.parametrize('edge_score_method', [ 'tanh', 'sigmoid', 'log_softmax', ]) def test_cluster_pooling(edge_score_method): x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]]) edge_index = torch.tensor([ [0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6], [1, 2, 3, 6, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0], ]) batch = torch.tensor([0, 0, 0, 0, 1, 1, 0]) op = ClusterPooling(in_channels=1, edge_score_method=edge_score_method) assert str(op) == 'ClusterPooling(1)' op.reset_parameters() x, edge_index, batch, unpool_info = op(x, edge_index, batch) assert x.size(0) <= 7 assert edge_index.size(0) == 2 if edge_index.numel() > 0: assert edge_index.min() >= 0 assert edge_index.max() < x.size(0) assert batch.size() == (x.size(0), ) ================================================ FILE: test/nn/pool/test_consecutive.py ================================================ import torch from torch_geometric.nn.pool.consecutive import consecutive_cluster def test_consecutive_cluster(): src = torch.tensor([8, 2, 10, 15, 100, 1, 100]) out, perm = consecutive_cluster(src) assert out.tolist() == [2, 1, 3, 4, 5, 0, 5] assert perm.tolist() == [5, 1, 0, 2, 3, 6] ================================================ FILE: test/nn/pool/test_decimation.py ================================================ import torch from torch_geometric.nn.pool.decimation import decimation_indices def test_decimation_basic(): N_1, N_2 = 4, 6 decimation_factor = 2 ptr = torch.tensor([0, N_1, N_1 + N_2]) idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) expected_size = (N_1 // decimation_factor) + (N_2 // decimation_factor) assert idx_decim.size(0) == expected_size expected = torch.tensor([0, N_1 // decimation_factor, expected_size]) assert torch.equal(ptr_decim, expected) def test_decimation_single_cloud(): N_1 = 4 decimation_factor = 2 ptr = torch.tensor([0, N_1]) idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) expected_size = N_1 // decimation_factor assert idx_decim.size(0) == expected_size assert torch.equal(ptr_decim, torch.tensor([0, expected_size])) def test_decimation_almost_empty(): N_1 = 4 decimation_factor = 666 # greater than N_1 ptr = torch.tensor([0, N_1]) idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) assert idx_decim.size(0) == 1 assert torch.equal(ptr_decim, torch.tensor([0, 1])) ================================================ FILE: test/nn/pool/test_edge_pool.py ================================================ import torch from torch_geometric.nn import EdgePooling from torch_geometric.testing import is_full_test from torch_geometric.utils import scatter def test_compute_edge_score_softmax(): edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) raw = torch.randn(edge_index.size(1)) e = EdgePooling.compute_edge_score_softmax(raw, edge_index, 6) assert torch.all(e >= 0) and torch.all(e <= 1) # Test whether all incoming edge scores sum up to one. assert torch.allclose( scatter(e, edge_index[1], reduce='sum'), torch.ones(6), ) if is_full_test(): jit = torch.jit.script(EdgePooling.compute_edge_score_softmax) assert torch.allclose(jit(raw, edge_index, 6), e) def test_compute_edge_score_tanh(): edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) raw = torch.randn(edge_index.size(1)) e = EdgePooling.compute_edge_score_tanh(raw, edge_index, 6) assert torch.all(e >= -1) and torch.all(e <= 1) assert torch.all(torch.argsort(raw) == torch.argsort(e)) if is_full_test(): jit = torch.jit.script(EdgePooling.compute_edge_score_tanh) assert torch.allclose(jit(raw, edge_index, 6), e) def test_compute_edge_score_sigmoid(): edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) raw = torch.randn(edge_index.size(1)) e = EdgePooling.compute_edge_score_sigmoid(raw, edge_index, 6) assert torch.all(e >= 0) and torch.all(e <= 1) assert torch.all(torch.argsort(raw) == torch.argsort(e)) if is_full_test(): jit = torch.jit.script(EdgePooling.compute_edge_score_sigmoid) assert torch.allclose(jit(raw, edge_index, 6), e) def test_edge_pooling(): x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]]) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0]]) batch = torch.tensor([0, 0, 0, 0, 1, 1, 0]) op = EdgePooling(in_channels=1) assert str(op) == 'EdgePooling(1)' # Setting parameters fixed so we can test the expected outcome: op.lin.weight.data.fill_(1.) op.lin.bias.data.fill_(0.) # Test pooling: new_x, new_edge_index, new_batch, unpool_info = op(x, edge_index, batch) assert new_x.size(0) == new_batch.size(0) == 4 assert new_edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [0, 1, 2, 1, 2, 2]] assert new_batch.tolist() == [1, 0, 0, 0] if is_full_test(): jit = torch.jit.script(op) out = jit(x, edge_index, batch) assert torch.allclose(new_x, out[0]) assert torch.equal(new_edge_index, out[1]) assert torch.equal(new_batch, out[2]) # Test unpooling: out = op.unpool(new_x, unpool_info) assert out[0].size() == x.size() assert out[0].tolist() == [[1], [1], [5], [5], [9], [9], [-1]] assert torch.equal(out[1], edge_index) assert torch.equal(out[2], batch) if is_full_test(): jit = torch.jit.export(op) out = jit.unpool(new_x, unpool_info) assert out[0].size() == x.size() assert out[0].tolist() == [[1], [1], [5], [5], [9], [9], [-1]] assert torch.equal(out[1], edge_index) assert torch.equal(out[2], batch) # Test edge cases. x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]]) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) batch = torch.tensor([0, 0, 0, 0, 1, 1]) new_x, new_edge_index, new_batch, _ = op(x, edge_index, batch) assert new_x.size(0) == new_batch.size(0) == 3 assert new_batch.tolist() == [1, 0, 0] assert new_edge_index.tolist() == [[0, 1, 1, 2, 2], [0, 1, 2, 1, 2]] ================================================ FILE: test/nn/pool/test_glob.py ================================================ import torch from torch_geometric.nn import ( global_add_pool, global_max_pool, global_mean_pool, ) def test_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) out = global_add_pool(x, batch) assert out.size() == (2, 4) assert torch.allclose(out[0], x[:4].sum(dim=0)) assert torch.allclose(out[1], x[4:].sum(dim=0)) out = global_add_pool(x, None) assert out.size() == (1, 4) assert torch.allclose(out, x.sum(dim=0, keepdim=True)) out = global_mean_pool(x, batch) assert out.size() == (2, 4) assert torch.allclose(out[0], x[:4].mean(dim=0)) assert torch.allclose(out[1], x[4:].mean(dim=0)) out = global_mean_pool(x, None) assert out.size() == (1, 4) assert torch.allclose(out, x.mean(dim=0, keepdim=True)) out = global_max_pool(x, batch) assert out.size() == (2, 4) assert torch.allclose(out[0], x[:4].max(dim=0)[0]) assert torch.allclose(out[1], x[4:].max(dim=0)[0]) out = global_max_pool(x, None) assert out.size() == (1, 4) assert torch.allclose(out, x.max(dim=0, keepdim=True)[0]) def test_permuted_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long) perm = torch.randperm(N_1 + N_2) px = x[perm] pbatch = batch[perm] px1 = px[pbatch == 0] px2 = px[pbatch == 1] out = global_add_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.sum(dim=0)) assert torch.allclose(out[1], px2.sum(dim=0)) out = global_mean_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.mean(dim=0)) assert torch.allclose(out[1], px2.mean(dim=0)) out = global_max_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.max(dim=0)[0]) assert torch.allclose(out[1], px2.max(dim=0)[0]) def test_dense_global_pool(): x = torch.randn(3, 16, 32) assert torch.allclose(global_add_pool(x, None), x.sum(dim=1)) ================================================ FILE: test/nn/pool/test_graclus.py ================================================ import torch from torch_geometric.nn import graclus from torch_geometric.testing import withPackage @withPackage('torch_cluster') def test_graclus(): edge_index = torch.tensor([[0, 1], [1, 0]]) assert graclus(edge_index).tolist() == [0, 0] ================================================ FILE: test/nn/pool/test_knn.py ================================================ import pytest import torch from torch_geometric.nn import ( ApproxL2KNNIndex, ApproxMIPSKNNIndex, L2KNNIndex, MIPSKNNIndex, ) from torch_geometric.testing import withCUDA, withPackage @withCUDA @withPackage('faiss') @pytest.mark.parametrize('k', [2]) def test_l2(device, k): lhs = torch.randn(10, 16, device=device) rhs = torch.randn(100, 16, device=device) index = L2KNNIndex(rhs) assert index.get_emb().device == device assert torch.equal(index.get_emb(), rhs) out = index.search(lhs, k) assert out.score.device == device assert out.index.device == device assert out.score.size() == (10, k) assert out.index.size() == (10, k) mat = torch.linalg.norm(lhs.unsqueeze(1) - rhs.unsqueeze(0), dim=-1).pow(2) score, index = mat.sort(dim=-1) assert torch.allclose(out.score, score[:, :k]) assert torch.equal(out.index, index[:, :k]) @withCUDA @withPackage('faiss') @pytest.mark.parametrize('k', [2]) def test_mips(device, k): lhs = torch.randn(10, 16, device=device) rhs = torch.randn(100, 16, device=device) index = MIPSKNNIndex(rhs) assert index.get_emb().device == device assert torch.equal(index.get_emb(), rhs) out = index.search(lhs, k) assert out.score.device == device assert out.index.device == device assert out.score.size() == (10, k) assert out.index.size() == (10, k) mat = lhs @ rhs.t() score, index = mat.sort(dim=-1, descending=True) assert torch.allclose(out.score, score[:, :k]) assert torch.equal(out.index, index[:, :k]) @withCUDA @withPackage('faiss') @pytest.mark.parametrize('k', [2]) @pytest.mark.parametrize('reserve', [None, 100]) def test_approx_l2(device, k, reserve): lhs = torch.randn(10, 16, device=device) rhs = torch.randn(10_000, 16, device=device) index = ApproxL2KNNIndex( num_cells=10, num_cells_to_visit=10, bits_per_vector=8, emb=rhs, reserve=reserve, ) out = index.search(lhs, k) assert out.score.device == device assert out.index.device == device assert out.score.size() == (10, k) assert out.index.size() == (10, k) assert out.index.min() >= 0 and out.index.max() < 10_000 @withCUDA @withPackage('faiss') @pytest.mark.parametrize('k', [2]) @pytest.mark.parametrize('reserve', [None, 100]) def test_approx_mips(device, k, reserve): lhs = torch.randn(10, 16, device=device) rhs = torch.randn(10_000, 16, device=device) index = ApproxMIPSKNNIndex( num_cells=10, num_cells_to_visit=10, bits_per_vector=8, emb=rhs, reserve=reserve, ) out = index.search(lhs, k) assert out.score.device == device assert out.index.device == device assert out.score.size() == (10, k) assert out.index.size() == (10, k) assert out.index.min() >= 0 and out.index.max() < 10_000 @withCUDA @withPackage('faiss') @pytest.mark.parametrize('k', [50]) def test_mips_exclude(device, k): lhs = torch.randn(10, 16, device=device) rhs = torch.randn(100, 16, device=device) exclude_lhs = torch.randint(0, 10, (500, ), device=device) exclude_rhs = torch.randint(0, 100, (500, ), device=device) exclude_links = torch.stack([exclude_lhs, exclude_rhs], dim=0) exclude_links = exclude_links.unique(dim=1) index = MIPSKNNIndex(rhs) out = index.search(lhs, k, exclude_links) assert out.score.device == device assert out.index.device == device assert out.score.size() == (10, k) assert out.index.size() == (10, k) # Ensure that excluded links are not present in `out.index`: batch = torch.arange(lhs.size(0), device=device).repeat_interleave(k) knn_links = torch.stack([batch, out.index.view(-1)], dim=0) knn_links = knn_links[:, knn_links[1] >= 0] unique_links = torch.cat([knn_links, exclude_links], dim=1).unique(dim=1) assert unique_links.size(1) == knn_links.size(1) + exclude_links.size(1) ================================================ FILE: test/nn/pool/test_max_pool.py ================================================ import torch from torch_geometric.data import Batch from torch_geometric.nn import max_pool, max_pool_neighbor_x, max_pool_x from torch_geometric.testing import is_full_test def test_max_pool_x(): cluster = torch.tensor([0, 1, 0, 1, 2, 2]) x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) batch = torch.tensor([0, 0, 0, 0, 1, 1]) out = max_pool_x(cluster, x, batch) assert out[0].tolist() == [[5, 6], [7, 8], [11, 12]] assert out[1].tolist() == [0, 0, 1] if is_full_test(): jit = torch.jit.script(max_pool_x) out = jit(cluster, x, batch) assert out[0].tolist() == [[5, 6], [7, 8], [11, 12]] assert out[1].tolist() == [0, 0, 1] out, _ = max_pool_x(cluster, x, batch, size=2) assert out.tolist() == [[5, 6], [7, 8], [11, 12], [0, 0]] batch_size = int(batch.max().item()) + 1 out2, _ = max_pool_x(cluster, x, batch, batch_size=batch_size, size=2) assert torch.equal(out, out2) if is_full_test(): jit = torch.jit.script(max_pool_x) out, _ = jit(cluster, x, batch, size=2) assert out.tolist() == [[5, 6], [7, 8], [11, 12], [0, 0]] out2, _ = jit(cluster, x, batch, batch_size=batch_size, size=2) assert torch.equal(out, out2) def test_max_pool(): cluster = torch.tensor([0, 1, 0, 1, 2, 2]) x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) pos = torch.tensor([ [0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0], ]) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) edge_attr = torch.ones(edge_index.size(1)) batch = torch.tensor([0, 0, 0, 0, 1, 1]) data = Batch(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, batch=batch) data = max_pool(cluster, data, transform=lambda x: x) assert data.x.tolist() == [[5, 6], [7, 8], [11, 12]] assert data.pos.tolist() == [[1, 1], [2, 2], [4.5, 4.5]] assert data.edge_index.tolist() == [[0, 1], [1, 0]] assert data.edge_attr.tolist() == [4, 4] assert data.batch.tolist() == [0, 0, 1] def test_max_pool_neighbor_x(): x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) batch = torch.tensor([0, 0, 0, 0, 1, 1]) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool_neighbor_x(data) assert data.x.tolist() == [ [7, 8], [7, 8], [7, 8], [7, 8], [11, 12], [11, 12], ] assert torch.equal(data.edge_index, edge_index) ================================================ FILE: test/nn/pool/test_mem_pool.py ================================================ import torch from torch_geometric.nn import MemPooling from torch_geometric.utils import to_dense_batch def test_mem_pool(): mpool1 = MemPooling(4, 8, heads=3, num_clusters=2) assert str(mpool1) == 'MemPooling(4, 8, heads=3, num_clusters=2)' mpool2 = MemPooling(8, 4, heads=2, num_clusters=1) x = torch.randn(17, 4) batch = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4]) _, mask = to_dense_batch(x, batch) out1, S = mpool1(x, batch) loss = MemPooling.kl_loss(S) with torch.autograd.set_detect_anomaly(True): loss.backward() out2, _ = mpool2(out1) assert out1.size() == (5, 2, 8) assert out2.size() == (5, 1, 4) assert S[~mask].sum() == 0 assert round(S[mask].sum().item()) == x.size(0) assert float(loss) > 0 ================================================ FILE: test/nn/pool/test_pan_pool.py ================================================ import torch from torch_geometric.nn import PANConv, PANPooling from torch_geometric.testing import is_full_test, withPackage @withPackage('torch_sparse') def test_pan_pooling(): edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, 16)) conv = PANConv(16, 32, filter_size=2) pool = PANPooling(32, ratio=0.5) assert str(pool) == 'PANPooling(32, ratio=0.5, multiplier=1.0)' x, M = conv(x, edge_index) h, edge_index, edge_weight, batch, perm, score = pool(x, M) assert h.size() == (2, 32) assert edge_index.size() == (2, 4) assert edge_weight.size() == (4, ) assert perm.size() == (2, ) assert score.size() == (2, ) if is_full_test(): jit = torch.jit.script(pool) out = jit(x, M) assert torch.allclose(h, out[0]) assert torch.equal(edge_index, out[1]) assert torch.allclose(edge_weight, out[2]) assert torch.equal(batch, out[3]) assert torch.equal(perm, out[4]) assert torch.allclose(score, out[5]) ================================================ FILE: test/nn/pool/test_pool.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.nn import radius_graph from torch_geometric.testing import onlyFullTest, withPackage @onlyFullTest @withPackage('torch_cluster') def test_radius_graph_jit(): class Net(torch.nn.Module): def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor: return radius_graph(x, r=2.5, batch=batch, loop=False) x = torch.tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.float) batch = torch.tensor([0, 0, 0, 0]) model = Net() jit = torch.jit.script(model) assert model(x, batch).size() == jit(x, batch).size() ================================================ FILE: test/nn/pool/test_sag_pool.py ================================================ import torch from torch_geometric.nn import ( GATConv, GCNConv, GraphConv, SAGEConv, SAGPooling, ) from torch_geometric.testing import is_full_test def test_sag_pooling(): in_channels = 16 edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, in_channels)) for GNN in [GraphConv, GCNConv, GATConv, SAGEConv]: pool1 = SAGPooling(in_channels, ratio=0.5, GNN=GNN) assert str(pool1) == (f'SAGPooling({GNN.__name__}, 16, ' f'ratio=0.5, multiplier=1.0)') out1 = pool1(x, edge_index) assert out1[0].size() == (num_nodes // 2, in_channels) assert out1[1].size() == (2, 2) pool2 = SAGPooling(in_channels, ratio=None, GNN=GNN, min_score=0.1) assert str(pool2) == (f'SAGPooling({GNN.__name__}, 16, ' f'min_score=0.1, multiplier=1.0)') out2 = pool2(x, edge_index) assert out2[0].size(0) <= x.size(0) and out2[0].size(1) == (16) assert out2[1].size(0) == 2 and out2[1].size(1) <= edge_index.size(1) pool3 = SAGPooling(in_channels, ratio=2, GNN=GNN) assert str(pool3) == (f'SAGPooling({GNN.__name__}, 16, ' f'ratio=2, multiplier=1.0)') out3 = pool3(x, edge_index) assert out3[0].size() == (2, in_channels) assert out3[1].size() == (2, 2) if is_full_test(): jit1 = torch.jit.script(pool1) assert torch.allclose(jit1(x, edge_index)[0], out1[0]) jit2 = torch.jit.script(pool2) assert torch.allclose(jit2(x, edge_index)[0], out2[0]) jit3 = torch.jit.script(pool3) assert torch.allclose(jit3(x, edge_index)[0], out3[0]) ================================================ FILE: test/nn/pool/test_topk_pool.py ================================================ import torch from torch_geometric.nn.pool import TopKPooling from torch_geometric.nn.pool.connect.filter_edges import filter_adj from torch_geometric.testing import is_full_test def test_filter_adj(): edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3], [1, 3, 0, 2, 1, 3, 0, 2]]) edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) perm = torch.tensor([2, 3]) out = filter_adj(edge_index, edge_attr, perm) assert out[0].tolist() == [[0, 1], [1, 0]] assert out[1].tolist() == [6.0, 8.0] if is_full_test(): jit = torch.jit.script(filter_adj) out = jit(edge_index, edge_attr, perm) assert out[0].tolist() == [[0, 1], [1, 0]] assert out[1].tolist() == [6.0, 8.0] def test_topk_pooling(): in_channels = 16 edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, in_channels)) pool1 = TopKPooling(in_channels, ratio=0.5) assert str(pool1) == 'TopKPooling(16, ratio=0.5, multiplier=1.0)' out1 = pool1(x, edge_index) assert out1[0].size() == (num_nodes // 2, in_channels) assert out1[1].size() == (2, 2) pool2 = TopKPooling(in_channels, ratio=None, min_score=0.1) assert str(pool2) == 'TopKPooling(16, min_score=0.1, multiplier=1.0)' out2 = pool2(x, edge_index) assert out2[0].size(0) <= x.size(0) and out2[0].size(1) == (16) assert out2[1].size(0) == 2 and out2[1].size(1) <= edge_index.size(1) pool3 = TopKPooling(in_channels, ratio=2) assert str(pool3) == 'TopKPooling(16, ratio=2, multiplier=1.0)' out3 = pool3(x, edge_index) assert out3[0].size() == (2, in_channels) assert out3[1].size() == (2, 2) if is_full_test(): jit1 = torch.jit.script(pool1) assert torch.allclose(jit1(x, edge_index)[0], out1[0]) jit2 = torch.jit.script(pool2) assert torch.allclose(jit2(x, edge_index)[0], out2[0]) jit3 = torch.jit.script(pool3) assert torch.allclose(jit3(x, edge_index)[0], out3[0]) ================================================ FILE: test/nn/pool/test_voxel_grid.py ================================================ import torch from torch_geometric.data import Batch from torch_geometric.nn import avg_pool, voxel_grid from torch_geometric.testing import withPackage @withPackage('torch_cluster') def test_voxel_grid(): pos = torch.tensor([ [0.0, 0.0], [11.0, 9.0], [2.0, 8.0], [2.0, 2.0], [8.0, 3.0], ]) batch = torch.tensor([0, 0, 0, 1, 1]) assert voxel_grid(pos, size=5, batch=batch).tolist() == [0, 5, 3, 6, 7] assert voxel_grid(pos, size=5).tolist() == [0, 5, 3, 0, 1] cluster = voxel_grid(pos, size=5, batch=batch, start=-1, end=[18, 14]) assert cluster.tolist() == [0, 10, 4, 16, 17] cluster_no_batch = voxel_grid(pos, size=5, start=-1, end=[18, 14]) assert cluster_no_batch.tolist() == [0, 10, 4, 0, 1] @withPackage('torch_cluster') def test_single_voxel_grid(): pos = torch.tensor([ [0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], ]) edge_index = torch.tensor([[0, 0, 3], [1, 2, 4]]) batch = torch.tensor([0, 0, 0, 1, 1]) x = torch.randn(5, 16) cluster = voxel_grid(pos, size=5, batch=batch) assert cluster.tolist() == [0, 0, 0, 1, 1] data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch) data = avg_pool(cluster, data) cluster_no_batch = voxel_grid(pos, size=5) assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0] data_no_batch = Batch(x=x, edge_index=edge_index, pos=pos) data_no_batch = avg_pool(cluster_no_batch, data_no_batch) ================================================ FILE: test/nn/test_compile_basic.py ================================================ import torch from torch_geometric.profile import benchmark from torch_geometric.testing import ( onlyFullTest, onlyLinux, withDevice, withPackage, ) from torch_geometric.utils import scatter # Basic "Gather-Apply-Scatter" patterns commonly used in PyG: def gather_scatter(x, edge_index, reduce='sum'): row, col = edge_index x_j = x[row] return scatter(x_j, col, dim_size=x.size(0), reduce=reduce) def gather_cat_scatter(x, edge_index, reduce='sum'): row, col = edge_index x_ij = torch.cat([x[col], x[row]], dim=-1) return scatter(x_ij, col, dim_size=x.size(0), reduce=reduce) def gather_weight_scatter(x, edge_index, edge_weight, reduce='sum'): row, col = edge_index x_j = x[row] * edge_weight.view(-1, 1) return scatter(x_j, col, dim_size=x.size(0), reduce=reduce) def gather_transform_scatter(x, edge_index, matrix, reduce='sum'): row, col = edge_index x_j = x[row] @ matrix return scatter(x_j, col, dim_size=x.size(0), reduce=reduce) def fused_gather_scatter(x, edge_index, reduce=('sum', 'mean', 'max')): row, col = edge_index x_j = x[row] outs = [scatter(x_j, col, dim_size=x.size(0), reduce=r) for r in reduce] return torch.cat(outs, dim=-1) @withDevice @onlyLinux @onlyFullTest @withPackage('torch>=2.0.0') def test_torch_compile(device): x = torch.randn(10, 16, device=device) edge_index = torch.randint(0, x.size(0), (2, 40), device=device) edge_weight = torch.rand(edge_index.size(1), device=device) matrix = torch.randn(x.size(-1), x.size(-1), device=device) expected = gather_scatter(x, edge_index) compiled_op = torch.compile(gather_scatter) out = compiled_op(x, edge_index) assert torch.allclose(out, expected, atol=1e-6) expected = gather_cat_scatter(x, edge_index) compiled_op = torch.compile(gather_cat_scatter) out = compiled_op(x, edge_index) assert torch.allclose(out, expected, atol=1e-6) expected = gather_weight_scatter(x, edge_index, edge_weight) compiled_op = torch.compile(gather_weight_scatter) out = compiled_op(x, edge_index, edge_weight) assert torch.allclose(out, expected, atol=1e-6) expected = gather_transform_scatter(x, edge_index, matrix) compiled_op = torch.compile(gather_transform_scatter) out = compiled_op(x, edge_index, matrix) assert torch.allclose(out, expected, atol=1e-6) expected = fused_gather_scatter(x, edge_index) compiled_op = torch.compile(fused_gather_scatter) out = compiled_op(x, edge_index) assert torch.allclose(out, expected, atol=1e-6) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() num_nodes, num_edges = 10_000, 200_000 x = torch.randn(num_nodes, 64, device=args.device) edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device) edge_weight = torch.rand(num_edges, device=args.device) matrix = torch.randn(64, 64, device=args.device) for reduce in ['sum', 'mean', 'max']: print(f'Aggregator: {reduce}') benchmark( funcs=[ gather_scatter, torch.compile(gather_scatter), ], func_names=['Vanilla', 'Compiled'], args=(x, edge_index, reduce), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) benchmark( funcs=[ gather_cat_scatter, torch.compile(gather_cat_scatter), ], func_names=['Vanilla Cat', 'Compiled Cat'], args=(x, edge_index, reduce), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) benchmark( funcs=[ gather_weight_scatter, torch.compile(gather_weight_scatter), ], func_names=['Vanilla Weight', 'Compiled Weight'], args=(x, edge_index, edge_weight, reduce), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) benchmark( funcs=[ gather_transform_scatter, torch.compile(gather_transform_scatter), ], func_names=['Vanilla Transform', 'Compiled Transform'], args=(x, edge_index, matrix, reduce), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) benchmark( funcs=[ fused_gather_scatter, torch.compile(fused_gather_scatter), ], func_names=['Vanilla Fused', 'Compiled Fused'], args=(x, edge_index), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) ================================================ FILE: test/nn/test_compile_conv.py ================================================ import pytest import torch from torch import Tensor from torch_geometric import EdgeIndex from torch_geometric.nn import GCNConv, SAGEConv from torch_geometric.profile import benchmark from torch_geometric.testing import ( onlyFullTest, onlyLinux, withDevice, withPackage, ) from torch_geometric.utils import scatter class MySAGEConv(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.lin_src = torch.nn.Linear(in_channels, out_channels) self.lin_dst = torch.nn.Linear(in_channels, out_channels) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x_j = x[edge_index[0]] out = scatter(x_j, edge_index[1], dim_size=x.size(0), reduce='mean') return self.lin_src(out) + self.lin_dst(x) @withDevice @onlyLinux @onlyFullTest @withPackage('torch>=2.1.0') @pytest.mark.parametrize('Conv', [GCNConv, SAGEConv]) def test_compile_conv(device, Conv): import torch._dynamo as dynamo x = torch.randn(10, 16, device=device) edge_index = torch.randint(0, x.size(0), (2, 40), device=device) if Conv == GCNConv: conv = Conv(16, 32, add_self_loops=False).to(device) else: conv = Conv(16, 32).to(device) explanation = dynamo.explain(conv)(x, edge_index) assert explanation.graph_break_count == 0 out = torch.compile(conv)(x, edge_index) assert torch.allclose(conv(x, edge_index), out, atol=1e-6) @withDevice @onlyLinux @onlyFullTest @withPackage('torch==2.3') @pytest.mark.parametrize('Conv', [GCNConv, SAGEConv]) def test_compile_conv_edge_index(device, Conv): import torch._dynamo as dynamo x = torch.randn(10, 16, device=device) edge_index = torch.randint(0, x.size(0), (2, 40), device=device) edge_index = EdgeIndex(edge_index, sparse_size=(10, 10)) edge_index = edge_index.sort_by('col')[0] edge_index.fill_cache_() if Conv == GCNConv: conv = Conv(16, 32, normalize=False).to(device) else: conv = Conv(16, 32).to(device) explanation = dynamo.explain(conv)(x, edge_index) assert explanation.graph_break_count == 0 out = torch.compile(conv, fullgraph=True)(x, edge_index) assert torch.allclose(conv(x, edge_index), out, atol=1e-6) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() num_nodes, num_edges = 10_000, 200_000 x = torch.randn(num_nodes, 64, device=args.device) edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device) conv = MySAGEConv(64, 64).to(args.device) benchmark( funcs=[conv, torch.compile(conv)], func_names=['Vanilla', 'Compiled'], args=(x, edge_index), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) for Conv in [GCNConv, SAGEConv]: print(f'Conv: {Conv.__name__}') conv = Conv(64, 64).to(args.device) compiled_conv = torch.compile(conv) benchmark( funcs=[conv, compiled_conv], func_names=['Vanilla', 'Compiled'], args=(x, edge_index), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) ================================================ FILE: test/nn/test_compile_dynamic.py ================================================ import random import torch from torch import Tensor from torch_geometric.testing import ( get_random_edge_index, onlyFullTest, onlyLinux, withDevice, withPackage, ) from torch_geometric.utils import scatter class MySAGEConv(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.lin_src = torch.nn.Linear(in_channels, out_channels) self.lin_dst = torch.nn.Linear(in_channels, out_channels) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x_j = x[edge_index[0]] out = scatter(x_j, edge_index[1], dim_size=x.size(0), reduce='mean') return self.lin_src(out) + self.lin_dst(x) @withDevice @onlyLinux @onlyFullTest @withPackage('torch>2.0.0') def test_dynamic_torch_compile(device): conv = MySAGEConv(64, 64).to(device) conv = torch.compile(conv, dynamic=True) optimizer = torch.optim.Adam(conv.parameters(), lr=0.01) for _ in range(10): N = random.randrange(100, 500) E = random.randrange(200, 1000) x = torch.randn(N, 64, device=device) edge_index = get_random_edge_index(N, N, E, device=device) optimizer.zero_grad() expected = conv(x, edge_index) expected.mean().backward() optimizer.step() ================================================ FILE: test/nn/test_data_parallel.py ================================================ import pytest import torch from torch_geometric.data import Data from torch_geometric.nn import DataParallel from torch_geometric.testing import onlyCUDA @onlyCUDA def test_data_parallel_single_gpu(): with pytest.warns(UserWarning, match="much slower"): module = DataParallel(torch.nn.Identity()) data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]] batches = module.scatter(data_list, device_ids=[0]) assert len(batches) == 1 @onlyCUDA @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs') def test_data_parallel_multi_gpu(): with pytest.warns(UserWarning, match="much slower"): module = DataParallel(torch.nn.Identity()) data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]] batches = module.scatter(data_list, device_ids=[0, 1, 0, 1]) assert len(batches) == 3 ================================================ FILE: test/nn/test_encoding.py ================================================ import torch from torch_geometric.nn import PositionalEncoding, TemporalEncoding from torch_geometric.testing import withDevice @withDevice def test_positional_encoding(device): encoder = PositionalEncoding(64, device=device) assert str(encoder) == 'PositionalEncoding(64)' x = torch.tensor([1.0, 2.0, 3.0], device=device) assert encoder(x).size() == (3, 64) @withDevice def test_temporal_encoding(device): encoder = TemporalEncoding(64, device=device) assert str(encoder) == 'TemporalEncoding(64)' x = torch.tensor([1.0, 2.0, 3.0], device=device) assert encoder(x).size() == (3, 64) ================================================ FILE: test/nn/test_fvcore.py ================================================ import torch from torch_geometric.nn import GraphSAGE from torch_geometric.testing import get_random_edge_index, withPackage @withPackage('fvcore') def test_fvcore(): from fvcore.nn import FlopCountAnalysis x = torch.randn(10, 16) edge_index = get_random_edge_index(10, 10, num_edges=100) model = GraphSAGE(16, 32, num_layers=2) flops = FlopCountAnalysis(model, (x, edge_index)) # TODO (matthias) Currently, aggregations are not properly registered. assert flops.by_module()['convs.0'] == 2 * 10 * 16 * 32 assert flops.by_module()['convs.1'] == 2 * 10 * 32 * 32 assert flops.total() == (flops.by_module()['convs.0'] + flops.by_module()['convs.1']) assert flops.by_operator()['linear'] == flops.total() ================================================ FILE: test/nn/test_fx.py ================================================ import torch import torch.nn.functional as F from torch import Tensor def test_dropout(): class MyModule(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return F.dropout(x, p=1.0, training=self.training) module = MyModule() graph_module = torch.fx.symbolic_trace(module) graph_module.recompile() x = torch.randn(4) graph_module.train() assert torch.allclose(graph_module(x), torch.zeros_like(x)) # This is certainly undesired behavior due to tracing :( graph_module.eval() assert torch.allclose(graph_module(x), torch.zeros_like(x)) ================================================ FILE: test/nn/test_inits.py ================================================ import torch from torch.nn import Linear as Lin from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_geometric.nn.inits import ( glorot, glorot_orthogonal, ones, reset, uniform, zeros, ) def test_inits(): x = torch.empty(1, 4) uniform(size=4, value=x) assert x.min() >= -0.5 assert x.max() <= 0.5 glorot(x) assert x.min() >= -1.1 assert x.max() <= 1.1 glorot_orthogonal(x, scale=1.0) assert x.min() >= -2.5 assert x.max() <= 2.5 zeros(x) assert x.tolist() == [[0, 0, 0, 0]] ones(x) assert x.tolist() == [[1, 1, 1, 1]] nn = Lin(16, 16) uniform(size=4, value=nn.weight) assert nn.weight[0].min() >= -0.5 assert nn.weight[0].max() <= 0.5 glorot(nn.weight) assert nn.weight[0].min() >= -0.45 assert nn.weight[0].max() <= 0.45 glorot_orthogonal(nn.weight, scale=1.0) assert nn.weight[0].min() >= -2.5 assert nn.weight[0].max() <= 2.5 def test_reset(): nn = Lin(16, 16) w = nn.weight.clone() reset(nn) assert not torch.allclose(nn.weight, w) nn = Seq(Lin(16, 16), ReLU(), Lin(16, 16)) w_1, w_2 = nn[0].weight.clone(), nn[2].weight.clone() reset(nn) assert not torch.allclose(nn[0].weight, w_1) assert not torch.allclose(nn[2].weight, w_2) ================================================ FILE: test/nn/test_model_hub.py ================================================ import os from pathlib import Path from unittest.mock import Mock import pytest import torch from torch_geometric.nn import GCN from torch_geometric.nn.model_hub import PyGModelHubMixin from torch_geometric.testing import withPackage REPO_NAME = 'pyg_hugging_test' MODEL_NAME = 'pyg_test_model' DATASET_NAME = 'pyg_dataset' CONFIG = {'hello': 'world'} class DummyModel(GCN, PyGModelHubMixin): def __init__(self, model_name, dataset_name, model_kwargs): GCN.__init__(self, in_channels=3, hidden_channels=5, num_layers=2) PyGModelHubMixin.__init__(self, model_name, dataset_name, model_kwargs) @pytest.fixture def model(): return DummyModel(MODEL_NAME, DATASET_NAME, CONFIG) @withPackage('huggingface_hub') def test_model_init(): model = DummyModel( MODEL_NAME, DATASET_NAME, model_kwargs={ **CONFIG, 'tensor': torch.randn([1, 2, 3]) }) assert model.model_config == CONFIG @withPackage('huggingface_hub') def test_save_pretrained(model, tmp_path): save_directory = f'{str(tmp_path / REPO_NAME)}' model.save_pretrained(save_directory) files = os.listdir(save_directory) assert 'model.pth' in files assert len(files) >= 1 @withPackage('huggingface_hub') def test_save_pretrained_internal(model, tmp_path): save_directory = f'{str(tmp_path / REPO_NAME)}' model._save_pretrained = Mock() model.save_pretrained(save_directory) model._save_pretrained.assert_called_with(Path(save_directory)) @withPackage('huggingface_hub') def test_save_pretrained_with_push_to_hub(model, tmp_path): save_directory = f'{str(tmp_path / REPO_NAME)}' model.push_to_hub = Mock() model.construct_model_card = Mock() model._save_pretrained = Mock() # disable _save_pretrained to speed-up # Not pushed to hub model.save_pretrained(save_directory) model.push_to_hub.assert_not_called() model.construct_model_card.assert_called_with(MODEL_NAME, DATASET_NAME) # Push to hub with repo_id model.save_pretrained(save_directory, push_to_hub=True, repo_id='CustomID', config=CONFIG) model.push_to_hub.assert_called_with( repo_id='CustomID', model_card_kwargs={}, config=CONFIG, ) # Push to hub with default repo_id (based on dir name) model.save_pretrained(save_directory, push_to_hub=True, config=CONFIG) model.push_to_hub.assert_called_with( repo_id=REPO_NAME, model_card_kwargs={}, config=CONFIG, ) @withPackage('huggingface_hub') def test_from_pretrained(model, tmp_path): save_directory = f'{str(tmp_path / REPO_NAME)}' model.save_pretrained(save_directory) model = model.from_pretrained(save_directory) assert isinstance(model, DummyModel) @withPackage('huggingface_hub') def test_from_pretrained_internal(model, monkeypatch): hf_hub_download = Mock(side_effect='model') monkeypatch.setattr('torch_geometric.nn.model_hub.hf_hub_download', hf_hub_download) monkeypatch.setattr('torch_geometric.nn.model_hub.fs.torch_load', lambda x, **kwargs: {'state_dict': 1}) model = model._from_pretrained( model_id=MODEL_NAME, revision=None, cache_dir=None, force_download=False, local_files_only=False, token=False, dataset_name=DATASET_NAME, model_name=MODEL_NAME, map_location='cpu', strict=False, **CONFIG, ) assert hf_hub_download.call_count == 1 assert model.model_config == CONFIG ================================================ FILE: test/nn/test_model_summary.py ================================================ from typing import Optional import pytest import torch from torch import Tensor, nn import torch_geometric.typing from torch_geometric.nn import Linear, SAGEConv, summary, to_hetero from torch_geometric.nn.models import GCN from torch_geometric.testing import withPackage from torch_geometric.typing import SparseTensor class GraphSAGE(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = Linear(16, 16) self.conv1 = SAGEConv(16, 32) self.lin2 = Linear(32, 32) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.lin1(x).relu() x = self.conv1(x, edge_index).relu() x = self.lin2(x) return x class ModuleDictModel(nn.Module): def __init__(self): super().__init__() self.acts = nn.ModuleDict({ "lrelu": nn.LeakyReLU(), "prelu": nn.PReLU() }) def forward(self, x: torch.Tensor, act_type: str) -> torch.Tensor: return self.acts[act_type](x) @pytest.fixture def gcn(): torch.manual_seed(1) model = GCN(32, 16, num_layers=2, out_channels=32) x = torch.randn(100, 32) edge_index = torch.randint(100, size=(2, 20)) adj_t: Optional[SparseTensor] = None if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t = SparseTensor.from_edge_index( edge_index, sparse_sizes=(100, 100), ).t() return dict(model=model, x=x, edge_index=edge_index, adj_t=adj_t) @withPackage('tabulate') def test_summary_basic(gcn): expected = """ +---------------------+--------------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |---------------------+--------------------+----------------+----------| | GCN | [100, 32], [2, 20] | [100, 32] | 1,072 | | ├─(dropout)Dropout | [100, 16] | [100, 16] | -- | | ├─(act)ReLU | [100, 16] | [100, 16] | -- | | ├─(convs)ModuleList | -- | -- | 1,072 | | │ └─(0)GCNConv | [100, 32], [2, 20] | [100, 16] | 528 | | │ └─(1)GCNConv | [100, 16], [2, 20] | [100, 32] | 544 | | ├─(norms)ModuleList | -- | -- | -- | | │ └─(0)Identity | [100, 16] | [100, 16] | -- | | │ └─(1)Identity | -- | -- | -- | +---------------------+--------------------+----------------+----------+ """ assert summary(gcn['model'], gcn['x'], gcn['edge_index']) == expected[1:-1] @withPackage('tabulate', 'torch_sparse') def test_summary_with_sparse_tensor(gcn): expected = """ +---------------------+-----------------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |---------------------+-----------------------+----------------+----------| | GCN | [100, 32], [100, 100] | [100, 32] | 1,072 | | ├─(dropout)Dropout | [100, 16] | [100, 16] | -- | | ├─(act)ReLU | [100, 16] | [100, 16] | -- | | ├─(convs)ModuleList | -- | -- | 1,072 | | │ └─(0)GCNConv | [100, 32], [100, 100] | [100, 16] | 528 | | │ └─(1)GCNConv | [100, 16], [100, 100] | [100, 32] | 544 | | ├─(norms)ModuleList | -- | -- | -- | | │ └─(0)Identity | [100, 16] | [100, 16] | -- | | │ └─(1)Identity | -- | -- | -- | +---------------------+-----------------------+----------------+----------+ """ assert summary(gcn['model'], gcn['x'], gcn['adj_t']) == expected[1:-1] @withPackage('tabulate') def test_lazy_gcn(): expected = """ +---------------------+--------------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |---------------------+--------------------+----------------+----------| | GCN | [100, 32], [2, 20] | [100, 32] | -1 | | ├─(dropout)Dropout | [100, 16] | [100, 16] | -- | | ├─(act)ReLU | [100, 16] | [100, 16] | -- | | ├─(convs)ModuleList | -- | -- | -1 | | │ └─(0)GCNConv | [100, 32], [2, 20] | [100, 16] | -1 | | │ └─(1)GCNConv | [100, 16], [2, 20] | [100, 32] | 544 | | ├─(norms)ModuleList | -- | -- | -- | | │ └─(0)Identity | [100, 16] | [100, 16] | -- | | │ └─(1)Identity | -- | -- | -- | +---------------------+--------------------+----------------+----------+ """ model = GCN(-1, 16, num_layers=2, out_channels=32) x = torch.randn(100, 32) edge_index = torch.randint(100, size=(2, 20)) assert summary(model, x, edge_index) == expected[1:-1] @withPackage('tabulate') def test_summary_with_max_depth(gcn): expected = """ +---------------------+--------------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |---------------------+--------------------+----------------+----------| | GCN | [100, 32], [2, 20] | [100, 32] | 1,072 | | ├─(dropout)Dropout | [100, 16] | [100, 16] | -- | | ├─(act)ReLU | [100, 16] | [100, 16] | -- | | ├─(convs)ModuleList | -- | -- | 1,072 | | ├─(norms)ModuleList | -- | -- | -- | +---------------------+--------------------+----------------+----------+ """ assert summary( gcn['model'], gcn['x'], gcn['edge_index'], max_depth=1, ) == expected[1:-1] @withPackage('tabulate') def test_summary_with_leaf_module(gcn): expected = """# noqa: E501 +-----------------------------------------+--------------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |-----------------------------------------+--------------------+----------------+----------| | GCN | [100, 32], [2, 20] | [100, 32] | 1,072 | | ├─(dropout)Dropout | [100, 16] | [100, 16] | -- | | ├─(act)ReLU | [100, 16] | [100, 16] | -- | | ├─(convs)ModuleList | -- | -- | 1,072 | | │ └─(0)GCNConv | [100, 32], [2, 20] | [100, 16] | 528 | | │ │ └─(aggr_module)SumAggregation | [120, 16] | [100, 16] | -- | | │ │ └─(lin)Linear | [100, 32] | [100, 16] | 512 | | │ └─(1)GCNConv | [100, 16], [2, 20] | [100, 32] | 544 | | │ │ └─(aggr_module)SumAggregation | [120, 32] | [100, 32] | -- | | │ │ └─(lin)Linear | [100, 16] | [100, 32] | 512 | | ├─(norms)ModuleList | -- | -- | -- | | │ └─(0)Identity | [100, 16] | [100, 16] | -- | | │ └─(1)Identity | -- | -- | -- | +-----------------------------------------+--------------------+----------------+----------+ """ assert summary( gcn['model'], gcn['x'], gcn['edge_index'], leaf_module=None, ) == expected[13:-1] @withPackage('tabulate') def test_summary_with_reusing_layers(): act = nn.ReLU(inplace=True) model1 = nn.Sequential(act, nn.Identity(), act, nn.Identity(), act) model2 = nn.Sequential( nn.ReLU(inplace=True), nn.Identity(), nn.ReLU(inplace=True), nn.Identity(), nn.ReLU(inplace=True), ) x = torch.randn(10) assert summary(model1, x) == summary(model2, x) @withPackage('tabulate') def test_summary_with_to_hetero_model(): x_dict = { 'p': torch.randn(100, 16), 'a': torch.randn(100, 16), } edge_index_dict = { ('p', 'to', 'p'): torch.randint(100, (2, 200)), ('p', 'to', 'a'): torch.randint(100, (2, 200)), ('a', 'to', 'p'): torch.randint(100, (2, 200)), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = to_hetero(GraphSAGE(), metadata) expected = """ +---------------------------+---------------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |---------------------------+---------------------+----------------+----------| | GraphModule | | | 5,824 | | ├─(lin1)ModuleDict | -- | -- | 544 | | │ └─(p)Linear | [100, 16] | [100, 16] | 272 | | │ └─(a)Linear | [100, 16] | [100, 16] | 272 | | ├─(conv1)ModuleDict | -- | -- | 3,168 | | │ └─(p__to__p)SAGEConv | [100, 16], [2, 200] | [100, 32] | 1,056 | | │ └─(p__to__a)SAGEConv | [2, 200] | [100, 32] | 1,056 | | │ └─(a__to__p)SAGEConv | [2, 200] | [100, 32] | 1,056 | | ├─(lin2)ModuleDict | -- | -- | 2,112 | | │ └─(p)Linear | [100, 32] | [100, 32] | 1,056 | | │ └─(a)Linear | [100, 32] | [100, 32] | 1,056 | +---------------------------+---------------------+----------------+----------+ """ assert summary(model, x_dict, edge_index_dict) == expected[1:-1] @withPackage('tabulate') def test_summary_with_module_dict_model(): model = ModuleDictModel() x = torch.randn(100, 32) expected = """ +-------------------------+---------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |-------------------------+---------------+----------------+----------| | ModuleDictModel | [100, 32] | [100, 32] | 1 | | ├─(acts)ModuleDict | -- | -- | 1 | | │ └─(lrelu)LeakyReLU | -- | -- | -- | | │ └─(prelu)PReLU | [100, 32] | [100, 32] | 1 | +-------------------------+---------------+----------------+----------+ """ assert summary(model, x, 'prelu') == expected[1:-1] @withPackage('tabulate') def test_summary_with_jit_model(): model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8)) model = torch.jit.script(model) x = torch.randn(100, 32) expected = """ +----------------------------+---------------+----------------+----------+ | Layer | Input Shape | Output Shape | #Param | |----------------------------+---------------+----------------+----------| | RecursiveScriptModule | -- | -- | 664 | | ├─(0)RecursiveScriptModule | -- | -- | 528 | | ├─(1)RecursiveScriptModule | -- | -- | -- | | ├─(2)RecursiveScriptModule | -- | -- | 136 | +----------------------------+---------------+----------------+----------+ """ assert summary(model, x) == expected[1:-1] ================================================ FILE: test/nn/test_module_dict.py ================================================ import torch from torch_geometric.nn.module_dict import ModuleDict def test_internal_external_key_conversion(): assert ModuleDict.to_internal_key('a.b') == 'a#b' assert ModuleDict.to_internal_key('ab') == 'ab' assert ModuleDict.to_internal_key('a.b.c') == 'a#b#c' assert ModuleDict.to_internal_key(('a', 'b')) == '' assert ModuleDict.to_internal_key(('a.b', 'c')) == '' assert ModuleDict.to_internal_key('type') == '' assert ModuleDict.to_external_key('a#b') == 'a.b' assert ModuleDict.to_external_key('a#b#c') == 'a.b.c' assert ModuleDict.to_external_key('') == ('a', 'b') assert ModuleDict.to_external_key('') == ('a.b', 'c') assert ModuleDict.to_external_key('') == 'type' def test_dot_syntax_keys(): module_dict = ModuleDict({ 'lin1': torch.nn.Linear(16, 16), 'model.lin2': torch.nn.Linear(8, 8), 'model.sub_model.lin3': torch.nn.Linear(4, 4), }) expected_keys = {'lin1', 'model.lin2', 'model.sub_model.lin3'} assert set(module_dict.keys()) == expected_keys assert {key for key, _ in module_dict.items()} == expected_keys for key in expected_keys: assert key in module_dict del module_dict['model.lin2'] assert 'model.lin2' not in module_dict def test_tuple_keys(): module_dict = ModuleDict({ ('a', 'b'): torch.nn.Linear(16, 16), ('a.b', 'c'): torch.nn.Linear(8, 8), }) expected_keys = {('a', 'b'), ('a.b', 'c')} assert set(module_dict.keys()) == expected_keys assert {key for key, _ in module_dict.items()} == expected_keys for key in expected_keys: assert key in module_dict del module_dict['a', 'b'] assert ('a', 'b') not in module_dict def test_reserved_keys(): module_dict = ModuleDict({ 'type': torch.nn.Linear(16, 16), '__annotations__': torch.nn.Linear(8, 8), }) expected_keys = {'type', '__annotations__'} assert set(module_dict.keys()) == expected_keys assert {key for key, _ in module_dict.items()} == expected_keys for key in expected_keys: assert key in module_dict del module_dict['type'] assert 'type' not in module_dict ================================================ FILE: test/nn/test_parameter_dict.py ================================================ import torch from torch_geometric.nn.parameter_dict import ParameterDict def test_internal_external_key_conversion(): assert ParameterDict.to_internal_key('a.b') == 'a#b' assert ParameterDict.to_internal_key('ab') == 'ab' assert ParameterDict.to_internal_key('a.b.c') == 'a#b#c' assert ParameterDict.to_internal_key(('a', 'b')) == '' assert ParameterDict.to_internal_key(('a.b', 'c')) == '' assert ParameterDict.to_internal_key('type') == '' assert ParameterDict.to_external_key('a#b') == 'a.b' assert ParameterDict.to_external_key('a#b#c') == 'a.b.c' assert ParameterDict.to_external_key('') == ('a', 'b') assert ParameterDict.to_external_key('') == ('a.b', 'c') assert ParameterDict.to_external_key('') == 'type' def test_dot_syntax_keys(): parameter_dict = { 'param1': torch.nn.Parameter(torch.randn(16, 16)), 'model.param2': torch.nn.Parameter(torch.randn(8, 8)), 'model.sub_model.param3': torch.nn.Parameter(torch.randn(4, 4)), } parameter_dict = ParameterDict(parameter_dict) expected_keys = {'param1', 'model.param2', 'model.sub_model.param3'} assert set(parameter_dict.keys()) == expected_keys assert {key for key, _ in parameter_dict.items()} == expected_keys for key in expected_keys: assert key in parameter_dict del parameter_dict['model.param2'] assert 'model.param2' not in parameter_dict def test_tuple_keys(): parameter_dict = { ('a', 'b'): torch.nn.Parameter(torch.randn(16, 16)), ('a.b', 'c'): torch.nn.Parameter(torch.randn(8, 8)), } parameter_dict = ParameterDict(parameter_dict) expected_keys = {('a', 'b'), ('a.b', 'c')} assert set(parameter_dict.keys()) == expected_keys assert {key for key, _ in parameter_dict.items()} == expected_keys for key in expected_keys: assert key in parameter_dict del parameter_dict['a', 'b'] assert ('a', 'b') not in parameter_dict def test_reserved_keys(): parameter_dict = { 'type': torch.nn.Parameter(torch.randn(16, 16)), '__annotations__': torch.nn.Parameter(torch.randn(8, 8)), } parameter_dict = ParameterDict(parameter_dict) expected_keys = {'type', '__annotations__'} assert set(parameter_dict.keys()) == expected_keys assert {key for key, _ in parameter_dict.items()} == expected_keys for key in expected_keys: assert key in parameter_dict del parameter_dict['type'] assert 'type' not in parameter_dict ================================================ FILE: test/nn/test_reshape.py ================================================ import torch from torch_geometric.nn.reshape import Reshape def test_reshape(): x = torch.randn(10, 4) op = Reshape(5, 2, 4) assert str(op) == 'Reshape(5, 2, 4)' assert op(x).size() == (5, 2, 4) assert torch.equal(op(x).view(10, 4), x) ================================================ FILE: test/nn/test_resolver.py ================================================ import pytest import torch from torch.optim.lr_scheduler import ConstantLR, LambdaLR, ReduceLROnPlateau import torch_geometric from torch_geometric.nn.resolver import ( activation_resolver, aggregation_resolver, lr_scheduler_resolver, normalization_resolver, optimizer_resolver, ) def test_activation_resolver(): assert isinstance(activation_resolver(torch.nn.ELU()), torch.nn.ELU) assert isinstance(activation_resolver(torch.nn.ReLU()), torch.nn.ReLU) assert isinstance(activation_resolver(torch.nn.PReLU()), torch.nn.PReLU) assert isinstance(activation_resolver('elu'), torch.nn.ELU) assert isinstance(activation_resolver('relu'), torch.nn.ReLU) assert isinstance(activation_resolver('prelu'), torch.nn.PReLU) @pytest.mark.parametrize('aggr_tuple', [ (torch_geometric.nn.MeanAggregation, 'mean'), (torch_geometric.nn.SumAggregation, 'sum'), (torch_geometric.nn.SumAggregation, 'add'), (torch_geometric.nn.MaxAggregation, 'max'), (torch_geometric.nn.MinAggregation, 'min'), (torch_geometric.nn.MulAggregation, 'mul'), (torch_geometric.nn.VarAggregation, 'var'), (torch_geometric.nn.StdAggregation, 'std'), (torch_geometric.nn.SoftmaxAggregation, 'softmax'), (torch_geometric.nn.PowerMeanAggregation, 'powermean'), ]) def test_aggregation_resolver(aggr_tuple): aggr_module, aggr_repr = aggr_tuple assert isinstance(aggregation_resolver(aggr_module()), aggr_module) assert isinstance(aggregation_resolver(aggr_repr), aggr_module) def test_multi_aggregation_resolver(): aggr = aggregation_resolver(None) assert aggr is None aggr = aggregation_resolver(['sum', 'mean', None]) assert isinstance(aggr, torch_geometric.nn.MultiAggregation) assert len(aggr.aggrs) == 3 assert isinstance(aggr.aggrs[0], torch_geometric.nn.SumAggregation) assert isinstance(aggr.aggrs[1], torch_geometric.nn.MeanAggregation) assert aggr.aggrs[2] is None @pytest.mark.parametrize('norm_tuple', [ (torch_geometric.nn.BatchNorm, 'batch', (16, )), (torch_geometric.nn.BatchNorm, 'batch_norm', (16, )), (torch_geometric.nn.InstanceNorm, 'instance_norm', (16, )), (torch_geometric.nn.LayerNorm, 'layer_norm', (16, )), (torch_geometric.nn.GraphNorm, 'graph_norm', (16, )), (torch_geometric.nn.GraphSizeNorm, 'graphsize_norm', ()), (torch_geometric.nn.PairNorm, 'pair_norm', ()), (torch_geometric.nn.MessageNorm, 'message_norm', ()), (torch_geometric.nn.DiffGroupNorm, 'diffgroup_norm', (16, 4)), ]) def test_normalization_resolver(norm_tuple): norm_module, norm_repr, norm_args = norm_tuple assert isinstance(normalization_resolver(norm_module(*norm_args)), norm_module) assert isinstance(normalization_resolver(norm_repr, *norm_args), norm_module) def test_optimizer_resolver(): params = [torch.nn.Parameter(torch.randn(1))] assert isinstance(optimizer_resolver(torch.optim.SGD(params, lr=0.01)), torch.optim.SGD) assert isinstance(optimizer_resolver(torch.optim.Adam(params)), torch.optim.Adam) assert isinstance(optimizer_resolver(torch.optim.Rprop(params)), torch.optim.Rprop) assert isinstance(optimizer_resolver('sgd', params, lr=0.01), torch.optim.SGD) assert isinstance(optimizer_resolver('adam', params), torch.optim.Adam) assert isinstance(optimizer_resolver('rprop', params), torch.optim.Rprop) @pytest.mark.parametrize('scheduler_args', [ ('constant_with_warmup', LambdaLR), ('linear_with_warmup', LambdaLR), ('cosine_with_warmup', LambdaLR), ('cosine_with_warmup_restarts', LambdaLR), ('polynomial_with_warmup', LambdaLR), ('constant', ConstantLR), ('ReduceLROnPlateau', ReduceLROnPlateau), ]) def test_lr_scheduler_resolver(scheduler_args): scheduler_name, scheduler_cls = scheduler_args model = torch.nn.Linear(10, 5) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) lr_scheduler = lr_scheduler_resolver( scheduler_name, optimizer, num_training_steps=100, ) assert isinstance(lr_scheduler, scheduler_cls) ================================================ FILE: test/nn/test_sequential.py ================================================ from collections import OrderedDict import torch import torch.fx from torch.nn import Dropout, Linear, ReLU import torch_geometric.typing from torch_geometric.nn import ( GCNConv, JumpingKnowledge, MessagePassing, SAGEConv, Sequential, global_mean_pool, to_hetero, ) from torch_geometric.typing import SparseTensor def test_sequential_basic(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) batch = torch.zeros(4, dtype=torch.long) model = Sequential('x, edge_index', [ (GCNConv(16, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, 7), ]).cpu() model.reset_parameters() assert len(model) == 5 assert str(model) == ( 'Sequential(\n' ' (0) - GCNConv(16, 64): x, edge_index -> x\n' ' (1) - ReLU(inplace=True): x -> x\n' ' (2) - GCNConv(64, 64): x, edge_index -> x\n' ' (3) - ReLU(inplace=True): x -> x\n' ' (4) - Linear(in_features=64, out_features=7, bias=True): x -> x\n' ')') assert isinstance(model[0], GCNConv) assert isinstance(model[1], ReLU) assert isinstance(model[2], GCNConv) assert isinstance(model[3], ReLU) assert isinstance(model[4], Linear) out = model(x, edge_index) assert out.size() == (4, 7) model = Sequential('x, edge_index, batch', [ (Dropout(p=0.5), 'x -> x'), (GCNConv(16, 64), 'x, edge_index -> x1'), ReLU(inplace=True), (GCNConv(64, 64), 'x1, edge_index -> x2'), ReLU(inplace=True), (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'), (JumpingKnowledge('cat', 64, num_layers=2), 'xs -> x'), (global_mean_pool, 'x, batch -> x'), Linear(2 * 64, 7), ]) model.reset_parameters() out = model(x, edge_index, batch) assert out.size() == (1, 7) def test_sequential_jit(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) model = Sequential('x: Tensor, edge_index: Tensor', [ (GCNConv(16, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, 7), ]) torch.jit.script(model)(x, edge_index) if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t = SparseTensor.from_edge_index(edge_index).t() model = Sequential('x: Tensor, edge_index: SparseTensor', [ (GCNConv(16, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, 7), ]) torch.jit.script(model)(x, adj_t) def symbolic_trace(module): class Tracer(torch.fx.Tracer): def is_leaf_module(self, module, *args, **kwargs) -> bool: return (isinstance(module, MessagePassing) or super().is_leaf_module(module, *args, **kwargs)) return torch.fx.GraphModule(module, Tracer().trace(module)) def test_sequential_tracable(): model = Sequential('x, edge_index', [ (GCNConv(16, 64), 'x, edge_index -> x1'), ReLU(inplace=True), (GCNConv(64, 64), 'x1, edge_index -> x2'), ReLU(inplace=True), (lambda x1, x2: x1 + x2, 'x1, x2 -> x'), Linear(64, 7), ]) symbolic_trace(model) def test_sequential_with_multiple_return_values(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) model = Sequential('x, edge_index', [ (GCNConv(16, 32), 'x, edge_index -> x1'), (GCNConv(32, 64), 'x1, edge_index -> x2'), (lambda x1, x2: (x1, x2), 'x1, x2 -> x1, x2'), ]) x1, x2 = model(x, edge_index) assert x1.size() == (4, 32) assert x2.size() == (4, 64) def test_sequential_with_ordered_dict(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) model = Sequential( 'x, edge_index', modules=OrderedDict([ ('conv1', (GCNConv(16, 32), 'x, edge_index -> x')), ('conv2', (GCNConv(32, 64), 'x, edge_index -> x')), ])) assert isinstance(model.conv1, GCNConv) assert isinstance(model.conv2, GCNConv) x = model(x, edge_index) assert x.size() == (4, 64) def test_sequential_to_hetero(): model = Sequential('x, edge_index', [ (SAGEConv((-1, -1), 32), 'x, edge_index -> x1'), ReLU(), (SAGEConv((-1, -1), 64), 'x1, edge_index -> x2'), ReLU(), ]) x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), } edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = to_hetero(model, metadata, debug=False) out_dict = model(x_dict, edge_index_dict) assert isinstance(out_dict, dict) and len(out_dict) == 2 assert out_dict['paper'].size() == (100, 64) assert out_dict['author'].size() == (100, 64) ================================================ FILE: test/nn/test_to_fixed_size_transformer.py ================================================ import torch from torch_geometric.nn import SumAggregation from torch_geometric.nn.to_fixed_size_transformer import to_fixed_size class Model(torch.nn.Module): def __init__(self): super().__init__() self.aggr = SumAggregation() def forward(self, x, batch): return self.aggr(x, batch, dim=0) def test_to_fixed_size(): x = torch.randn(10, 16) batch = torch.zeros(10, dtype=torch.long) model = Model() assert model(x, batch).size() == (1, 16) model = to_fixed_size(model, batch_size=10) assert model(x, batch).size() == (10, 16) ================================================ FILE: test/nn/test_to_hetero_module.py ================================================ import pytest import torch from torch_geometric.nn.conv import SAGEConv from torch_geometric.nn.dense import Linear from torch_geometric.nn.to_hetero_module import ( ToHeteroLinear, ToHeteroMessagePassing, ) @pytest.mark.parametrize('LinearCls', [torch.nn.Linear, Linear]) def test_to_hetero_linear(LinearCls): x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)} x = torch.cat([x_dict['1'], x_dict['2']], dim=0) type_vec = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1]) module = ToHeteroLinear(LinearCls(16, 32), list(x_dict.keys())) out_dict = module(x_dict) assert len(out_dict) == 2 assert out_dict['1'].size() == (5, 32) assert out_dict['2'].size() == (4, 32) out = module(x, type_vec) assert out.size() == (9, 32) assert torch.allclose(out_dict['1'], out[0:5]) assert torch.allclose(out_dict['2'], out[5:9]) def test_to_hetero_message_passing(): x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)} x = torch.cat([x_dict['1'], x_dict['2']], dim=0) node_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1]) edge_index_dict = { ('1', 'to', '2'): torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 3]]), ('2', 'to', '1'): torch.tensor([[0, 0, 1, 2, 3], [0, 1, 2, 3, 4]]), } edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 5, 6, 7, 8], [5, 5, 6, 7, 8, 0, 1, 2, 3, 4], ]) edge_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) module = ToHeteroMessagePassing(SAGEConv(16, 32), list(x_dict.keys()), list(edge_index_dict.keys())) out_dict = module(x_dict, edge_index_dict) assert len(out_dict) == 2 assert out_dict['1'].size() == (5, 32) assert out_dict['2'].size() == (4, 32) out = module(x, edge_index, node_type, edge_type) assert out.size() == (9, 32) assert torch.allclose(out_dict['1'], out[0:5]) assert torch.allclose(out_dict['2'], out[5:9]) ================================================ FILE: test/nn/test_to_hetero_transformer.py ================================================ from typing import Tuple import pytest import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Linear, ReLU, Sequential import torch_geometric.typing from torch_geometric.datasets import FakeHeteroDataset from torch_geometric.nn import ( GAT, BatchNorm, GATv2Conv, GCNConv, GINEConv, GraphSAGE, ) from torch_geometric.nn import Linear as LazyLinear from torch_geometric.nn import ( MeanAggregation, MessagePassing, RGCNConv, SAGEConv, to_hetero, ) from torch_geometric.testing import onlyCUDA from torch_geometric.typing import SparseTensor from torch_geometric.utils import dropout_edge torch.fx.wrap('dropout_edge') class Net1(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = Linear(16, 32) self.lin2 = Linear(8, 16) def forward(self, x: Tensor, edge_attr: Tensor) -> Tuple[Tensor, Tensor]: x = self.lin1(x) edge_attr = self.lin2(edge_attr) return x, edge_attr class Net2(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = Linear(16, 16) self.conv1 = SAGEConv(16, 32) self.lin2 = Linear(32, 32) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.lin1(x).relu_() x = self.conv1(x, edge_index).relu_() x = self.lin2(x) return x class Net3(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = Linear(8, 16) self.conv1 = GINEConv(nn=Linear(16, 32)) def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor: x = self.conv1(x, edge_index, self.lin1(edge_attr)) return x class Net4(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv(16, 16) self.conv2 = SAGEConv(16, 16) self.lin1 = Linear(3 * 16, 32) def forward(self, x0: Tensor, edge_index: Tensor) -> Tensor: x1 = self.conv1(x0, edge_index).relu_() x2 = self.conv2(x1, edge_index).relu_() return self.lin1(torch.cat([x0, x1, x2], dim=-1)) class Net5(torch.nn.Module): def __init__(self, num_layers): super().__init__() self.lins = torch.nn.ModuleList() self.convs = torch.nn.ModuleList() for _ in range(num_layers): self.lins.append(Linear(16, 16)) self.convs.append(SAGEConv(16, 16)) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for lin, conv in zip(self.lins, self.convs): x = (conv(x, edge_index) + lin(x)) return x class Net6(torch.nn.Module): def __init__(self, num_layers): super().__init__() self.lins = torch.nn.ModuleDict() self.convs = torch.nn.ModuleDict() for i in range(num_layers): self.lins[str(i)] = Linear(16, 16) self.convs[str(i)] = SAGEConv(16, 16) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for i in range(len(self.lins)): x = (self.convs[str(i)](x, edge_index) + self.lins[str(i)](x)) return x class Net7(torch.nn.Module): def __init__(self): super().__init__() self.mlp1 = Sequential(Linear(16, 16), ReLU(), Linear(16, 16)) self.conv1 = SAGEConv(16, 32) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.mlp1(x) x = self.conv1(x, edge_index) return x class Net8(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = LazyLinear(-1, 32) def forward(self, x: Tensor) -> Tensor: x = self.lin1(x) return x class Net9(torch.nn.Module): def __init__(self): super().__init__() self.batch_norm = BatchNorm(16) def forward(self, x: Tensor) -> Tensor: return self.batch_norm(x) class Net10(torch.nn.Module): def __init__(self): super().__init__() self.conv = SAGEConv(16, 32) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = F.dropout(x, p=0.5, training=self.training) edge_index, _ = dropout_edge(edge_index, p=0.5, training=self.training) return self.conv(x, edge_index) class Net11(torch.nn.Module): def __init__(self): super().__init__() self.conv = SAGEConv(16, 16) self.num_layers = 3 def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: xs = [x] for _ in range(self.num_layers): xs.append(self.conv(xs[-1], edge_index)) return torch.cat(xs, dim=-1) class Net12(torch.nn.Module): def __init__(self): super().__init__() self.conv = Net8() def forward(self, x: Tensor) -> Tensor: return self.conv(x) def test_to_hetero_basic(): x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), } edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } edge_attr_dict = { ('paper', 'cites', 'paper'): torch.randn(200, 8), ('paper', 'written_by', 'author'): torch.randn(200, 8), ('author', 'writes', 'paper'): torch.randn(200, 8), } if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict = {} for edge_type, (row, col) in edge_index_dict.items(): adj_t_dict[edge_type] = SparseTensor( row=col, col=row, sparse_sizes=(100, 100), ) metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = Net1() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_attr_dict) assert isinstance(out, tuple) and len(out) == 2 assert isinstance(out[0], dict) and len(out[0]) == 2 assert out[0]['paper'].size() == (100, 32) assert out[0]['author'].size() == (100, 32) assert isinstance(out[1], dict) and len(out[1]) == 3 assert out[1][('paper', 'cites', 'paper')].size() == (200, 16) assert out[1][('paper', 'written_by', 'author')].size() == (200, 16) assert out[1][('author', 'writes', 'paper')].size() == (200, 16) assert sum(p.numel() for p in model.parameters()) == 1520 for aggr in ['sum', 'mean', 'min', 'max', 'mul']: model = Net2() model = to_hetero(model, metadata, aggr=aggr, debug=False) assert sum(p.numel() for p in model.parameters()) == 5824 out1 = model(x_dict, edge_index_dict) assert isinstance(out1, dict) and len(out1) == 2 assert out1['paper'].size() == (100, 32) assert out1['author'].size() == (100, 32) if torch_geometric.typing.WITH_TORCH_SPARSE: out2 = model(x_dict, adj_t_dict) assert isinstance(out2, dict) and len(out2) == 2 for key in x_dict.keys(): assert torch.allclose(out1[key], out2[key], atol=1e-6) model = Net3() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict, edge_attr_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net4() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net5(num_layers=2) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net6(num_layers=2) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net7() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net8() model = to_hetero(model, metadata, debug=False) out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)}) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (4, 32) assert out['author'].size() == (8, 32) model = Net9() model = to_hetero(model, metadata, debug=False) out = model({'paper': torch.randn(4, 16), 'author': torch.randn(8, 16)}) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (4, 16) assert out['author'].size() == (8, 16) model = Net10() with pytest.warns(UserWarning, match="with keyword argument 'training'"): model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net11() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 64) assert out['author'].size() == (100, 64) model = Net12() with pytest.warns(UserWarning, match="parameters cannot be reset"): model = to_hetero(model, metadata, debug=False) out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)}) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (4, 32) assert out['author'].size() == (8, 32) class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(16, 32) self.conv2 = GCNConv(32, 64) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return x def test_to_hetero_with_gcn(): x_dict = { 'paper': torch.randn(100, 16), } edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200)), ('paper', 'rev_cites', 'paper'): torch.randint(100, (2, 200)), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GCN() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 1 assert out['paper'].size() == (100, 64) def test_to_hetero_with_basic_model(): x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), } edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GraphSAGE((-1, -1), 32, num_layers=3) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 model = GAT((-1, -1), 32, num_layers=3, add_self_loops=False) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 class GraphConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='sum') self.lin = Linear(in_channels, out_channels, bias=False) def reset_parameters(self): self.lin.reset_parameters() def forward(self, x, edge_index): if isinstance(x, Tensor): x = (x, x) return self.propagate(edge_index, x=(self.lin(x[0]), x[1])) class RGCN(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = GraphConv(in_channels, out_channels) self.lin = Linear(in_channels, out_channels, bias=True) def forward(self, x, edge_index): return self.lin(x) + self.conv(x, edge_index) def test_to_hetero_and_rgcn_equal_output(): torch.manual_seed(1234) # Run `RGCN`: x = torch.randn(10, 16) # 6 paper nodes, 4 author nodes adj = (torch.rand(10, 10) > 0.5) adj[6:, 6:] = False edge_index = adj.nonzero(as_tuple=False).t().contiguous() row, col = edge_index # # 0 = paper<->paper, 1 = paper->author, 2 = author->paper edge_type = torch.full((edge_index.size(1), ), -1, dtype=torch.long) edge_type[(row < 6) & (col < 6)] = 0 edge_type[(row < 6) & (col >= 6)] = 1 edge_type[(row >= 6) & (col < 6)] = 2 assert edge_type.min() == 0 conv = RGCNConv(16, 32, num_relations=3, aggr='sum') out1 = conv(x, edge_index, edge_type) # Run `to_hetero`: x_dict = { 'paper': x[:6], 'author': x[6:], } edge_index_dict = { ('paper', '_', 'paper'): edge_index[:, edge_type == 0], ('paper', '_', 'author'): edge_index[:, edge_type == 1] - torch.tensor([[0], [6]]), ('author', '_', 'paper'): edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]), } if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict = { key: SparseTensor.from_edge_index(edge_index).t() for key, edge_index in edge_index_dict.items() } node_types, edge_types = list(x_dict.keys()), list(edge_index_dict.keys()) model = to_hetero(RGCN(16, 32), (node_types, edge_types)) # Set model weights: for i, edge_type in enumerate(edge_types): weight = model.conv['__'.join(edge_type)].lin.weight weight.data = conv.weight[i].data.t() for node_type in node_types: model.lin[node_type].weight.data = conv.root.data.t() model.lin[node_type].bias.data = conv.bias.data out2 = model(x_dict, edge_index_dict) out2 = torch.cat([out2['paper'], out2['author']], dim=0) assert torch.allclose(out1, out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: out3 = model(x_dict, adj_t_dict) out3 = torch.cat([out3['paper'], out3['author']], dim=0) assert torch.allclose(out1, out3, atol=1e-6) class GraphLevelGNN(torch.nn.Module): def __init__(self): super().__init__() self.conv = SAGEConv(16, 32) self.pool = MeanAggregation() self.lin = Linear(32, 64) def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor: x = self.conv(x, edge_index) x = self.pool(x, batch) x = self.lin(x) return x def test_graph_level_to_hetero(): x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), } edge_index_dict = { ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } batch_dict = { 'paper': torch.zeros(100, dtype=torch.long), 'author': torch.zeros(100, dtype=torch.long), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GraphLevelGNN() model = to_hetero(model, metadata, aggr='mean', debug=False) out = model(x_dict, edge_index_dict, batch_dict) assert out.size() == (1, 64) class MessagePassingLoops(MessagePassing): def __init__(self): super().__init__() self.add_self_loops = True def forward(self, x): return x class ModelLoops(torch.nn.Module): def __init__(self): super().__init__() self.conv = MessagePassingLoops() def forward(self, x): return self.conv(x) def test_hetero_transformer_self_loop_error(): to_hetero(ModelLoops(), metadata=(['a'], [('a', 'to', 'a')])) with pytest.raises(ValueError, match="incorrect message passing"): to_hetero(ModelLoops(), metadata=(['a', 'b'], [('a', 'to', 'b'), ('b', 'to', 'a')])) def test_to_hetero_validate(): model = Net1() metadata = (['my test'], [('my test', 'rel', 'my test')]) with pytest.warns(UserWarning, match="letters, numbers and underscores"): model = to_hetero(model, metadata, debug=False) def test_to_hetero_on_static_graphs(): x_dict = { 'paper': torch.randn(4, 100, 16), 'author': torch.randn(4, 100, 16), } edge_index_dict = { ('paper', 'written_by', 'author'): torch.randint(100, (2, 200)), ('author', 'writes', 'paper'): torch.randint(100, (2, 200)), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = to_hetero(Net4(), metadata, debug=False) out_dict = model(x_dict, edge_index_dict) assert len(out_dict) == 2 assert out_dict['paper'].size() == (4, 100, 32) assert out_dict['author'].size() == (4, 100, 32) @onlyCUDA def test_to_hetero_lazy_cuda(): class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv = GATv2Conv( (-1, -1), out_channels=2, add_self_loops=False, edge_dim=-1, heads=1, ).to('cuda') def forward(self, x, edge_index, edge_attr): return self.conv(x, edge_index, edge_attr) data = FakeHeteroDataset(edge_dim=10)[0].to('cuda') model = to_hetero(Model(), data.metadata()) out_dict = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) assert len(out_dict) == len(data.node_types) for out in out_dict.values(): assert out.is_cuda assert out.size(-1) == 2 ================================================ FILE: test/nn/test_to_hetero_with_bases_transformer.py ================================================ import os.path as osp from typing import Tuple import pytest import torch from torch import Tensor from torch.nn import Linear, ReLU, Sequential import torch_geometric.typing from torch_geometric.nn import ( GINEConv, MessagePassing, RGCNConv, SAGEConv, to_hetero_with_bases, ) from torch_geometric.typing import SparseTensor class Net1(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = Linear(16, 32) self.lin2 = Linear(8, 16) def forward(self, x: Tensor, edge_attr: Tensor) -> Tuple[Tensor, Tensor]: x = self.lin1(x) edge_attr = self.lin2(edge_attr) return x, edge_attr class Net2(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = Linear(16, 16) self.conv1 = SAGEConv(16, 32) self.lin2 = Linear(32, 32) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.lin1(x).relu_() x = self.conv1(x, edge_index).relu_() x = self.lin2(x) return x class Net3(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = Linear(8, 16) self.conv1 = GINEConv(nn=Linear(16, 32)) def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor: x = self.conv1(x, edge_index, self.lin1(edge_attr)) return x class Net4(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv(16, 16) self.conv2 = SAGEConv(16, 16) self.lin1 = Linear(3 * 16, 32) def forward(self, x0: Tensor, edge_index: Tensor) -> Tensor: x1 = self.conv1(x0, edge_index).relu_() x2 = self.conv2(x1, edge_index).relu_() return self.lin1(torch.cat([x0, x1, x2], dim=-1)) class Net5(torch.nn.Module): def __init__(self, num_layers): super().__init__() self.lins = torch.nn.ModuleList() self.convs = torch.nn.ModuleList() for _ in range(num_layers): self.lins.append(Linear(16, 16)) self.convs.append(SAGEConv(16, 16)) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for lin, conv in zip(self.lins, self.convs): x = (conv(x, edge_index) + lin(x)) return x class Net6(torch.nn.Module): def __init__(self, num_layers): super().__init__() self.lins = torch.nn.ModuleDict() self.convs = torch.nn.ModuleDict() for i in range(num_layers): self.lins[str(i)] = Linear(16, 16) self.convs[str(i)] = SAGEConv(16, 16) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: for i in range(len(self.lins)): x = (self.convs[str(i)](x, edge_index) + self.lins[str(i)](x)) return x class Net7(torch.nn.Module): def __init__(self): super().__init__() self.mlp1 = Sequential(Linear(16, 16), ReLU(), Linear(16, 16)) self.conv1 = SAGEConv(16, 32) def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.mlp1(x) x = self.conv1(x, edge_index) return x def test_to_hetero_with_bases(): metadata = (['paper', 'author'], [('paper', 'cites', 'paper'), ('paper', 'written_by', 'author'), ('author', 'writes', 'paper')]) x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 8)} edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } edge_attr_dict = { ('paper', 'cites', 'paper'): torch.randn(200, 8), ('paper', 'written_by', 'author'): torch.randn(200, 8), ('author', 'writes', 'paper'): torch.randn(200, 8), } model = Net1() in_channels = {'x': 16, 'edge_attr': 8} model = to_hetero_with_bases(model, metadata, num_bases=4, in_channels=in_channels, debug=False) out = model(x_dict, edge_attr_dict) assert isinstance(out, tuple) and len(out) == 2 assert isinstance(out[0], dict) and len(out[0]) == 2 assert out[0]['paper'].size() == (100, 32) assert out[0]['author'].size() == (100, 32) assert isinstance(out[1], dict) and len(out[1]) == 3 assert out[1][('paper', 'cites', 'paper')].size() == (200, 16) assert out[1][('paper', 'written_by', 'author')].size() == (200, 16) assert out[1][('author', 'writes', 'paper')].size() == (200, 16) assert sum(p.numel() for p in model.parameters()) == 1264 model = Net2() in_channels = {'x': 16} model = to_hetero_with_bases(model, metadata, num_bases=4, in_channels=in_channels, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) assert sum(p.numel() for p in model.parameters()) == 5948 model = Net3() in_channels = {'x': 16, 'edge_attr': 8} model = to_hetero_with_bases(model, metadata, num_bases=4, in_channels=in_channels, debug=False) out = model(x_dict, edge_index_dict, edge_attr_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net4() in_channels = {'x0': 16} model = to_hetero_with_bases(model, metadata, num_bases=4, in_channels=in_channels, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net5(num_layers=2) in_channels = {'x': 16} model = to_hetero_with_bases(model, metadata, num_bases=4, in_channels=in_channels, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net6(num_layers=2) in_channels = {'x': 16} model = to_hetero_with_bases(model, metadata, num_bases=4, in_channels=in_channels, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net7() in_channels = {'x': 16} model = to_hetero_with_bases(model, metadata, num_bases=4, in_channels=in_channels, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) class GraphConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='add') self.lin = Linear(in_channels, out_channels, bias=False) def reset_parameters(self): self.lin.reset_parameters() def forward(self, x, edge_index): if isinstance(x, Tensor): x = (x, x) return self.propagate(edge_index, x=(self.lin(x[0]), x[1])) class RGCN(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = GraphConv(in_channels, out_channels) self.lin = Linear(in_channels, out_channels, bias=True) def forward(self, x, edge_index): return self.lin(x) + self.conv(x, edge_index) def test_to_hetero_with_bases_and_rgcn_equal_output(): torch.manual_seed(1234) # Run `RGCN` with basis decomposition: x = torch.randn(10, 16) # 6 paper nodes, 4 author nodes adj = (torch.rand(10, 10) > 0.5) adj[6:, 6:] = False edge_index = adj.nonzero(as_tuple=False).t().contiguous() row, col = edge_index # # 0 = paper<->paper, 1 = author->paper, 2 = paper->author edge_type = torch.full((edge_index.size(1), ), -1, dtype=torch.long) edge_type[(row < 6) & (col < 6)] = 0 edge_type[(row < 6) & (col >= 6)] = 1 edge_type[(row >= 6) & (col < 6)] = 2 assert edge_type.min() == 0 num_bases = 4 conv = RGCNConv(16, 32, num_relations=3, num_bases=num_bases, aggr='add') out1 = conv(x, edge_index, edge_type) # Run `to_hetero_with_bases`: x_dict = { 'paper': x[:6], 'author': x[6:], } edge_index_dict = { ('paper', '_', 'paper'): edge_index[:, edge_type == 0], ('paper', '_', 'author'): edge_index[:, edge_type == 1] - torch.tensor([[0], [6]]), ('author', '_', 'paper'): edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]), } if torch_geometric.typing.WITH_TORCH_SPARSE: adj_t_dict = { key: SparseTensor.from_edge_index(edge_index).t() for key, edge_index in edge_index_dict.items() } metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) model = to_hetero_with_bases(RGCN(16, 32), metadata, num_bases=num_bases, debug=False) # Set model weights: for i in range(num_bases): model.conv.convs[i].lin.weight.data = conv.weight[i].data.t() model.conv.convs[i].edge_type_weight.data = conv.comp[:, i].data.t() model.lin.weight.data = conv.root.data.t() model.lin.bias.data = conv.bias.data out2 = model(x_dict, edge_index_dict) out2 = torch.cat([out2['paper'], out2['author']], dim=0) assert torch.allclose(out1, out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: out3 = model(x_dict, adj_t_dict) out3 = torch.cat([out3['paper'], out3['author']], dim=0) assert torch.allclose(out1, out3, atol=1e-6) def test_to_hetero_with_bases_validate(): model = Net1() metadata = (['my test'], [('my test', 'rel', 'my test')]) with pytest.warns(UserWarning, match="letters, numbers and underscores"): model = to_hetero_with_bases(model, metadata, num_bases=4, debug=False) def test_to_hetero_with_bases_on_static_graphs(): x_dict = { 'paper': torch.randn(4, 100, 16), 'author': torch.randn(4, 100, 16), } edge_index_dict = { ('paper', 'written_by', 'author'): torch.randint(100, (2, 200)), ('author', 'writes', 'paper'): torch.randint(100, (2, 200)), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = to_hetero_with_bases(Net4(), metadata, num_bases=4, in_channels={'x0': 16}, debug=False) out_dict = model(x_dict, edge_index_dict) assert len(out_dict) == 2 assert out_dict['paper'].size() == (4, 100, 32) def test_to_hetero_with_bases_save(tmp_path): x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 8)} edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } model = to_hetero_with_bases( Net2(), (list(x_dict.keys()), list(edge_index_dict.keys())), num_bases=4, in_channels={'x': 16}, debug=False, ) model(x_dict, edge_index_dict) path = osp.join(tmp_path, 'model.pt') torch.save(model, path) ================================================ FILE: test/nn/unpool/test_knn_interpolate.py ================================================ import torch from torch_geometric.nn import knn_interpolate from torch_geometric.testing import withPackage @withPackage('torch_cluster') def test_knn_interpolate(): x = torch.tensor([[1.0], [10.0], [100.0], [-1.0], [-10.0], [-100.0]]) pos_x = torch.tensor([ [-1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [-2.0, 0.0], [0.0, 0.0], [2.0, 0.0], ]) pos_y = torch.tensor([ [-1.0, -1.0], [1.0, 1.0], [-2.0, -2.0], [2.0, 2.0], ]) batch_x = torch.tensor([0, 0, 0, 1, 1, 1]) batch_y = torch.tensor([0, 0, 1, 1]) y = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k=2) assert y.tolist() == [[4.0], [70.0], [-4.0], [-70.0]] ================================================ FILE: test/profile/test_benchmark.py ================================================ import torch from torch_geometric.profile import benchmark from torch_geometric.testing import withPackage @withPackage('tabulate') def test_benchmark(capfd): def add(x, y): return x + y benchmark( funcs=[add], args=(torch.randn(10), torch.randn(10)), num_steps=1, num_warmups=1, backward=True, ) out, _ = capfd.readouterr() assert '| Name | Forward | Backward | Total |' in out assert '| add |' in out ================================================ FILE: test/profile/test_nvtx.py ================================================ from unittest.mock import call, patch from torch_geometric.profile import nvtxit def _setup_mock(torch_cuda_mock): torch_cuda_mock.is_available.return_value = True torch_cuda_mock.cudart.return_value.cudaProfilerStart.return_value = None torch_cuda_mock.cudart.return_value.cudaProfilerStop.return_value = None return torch_cuda_mock @patch('torch_geometric.profile.nvtx.torch.cuda') def test_nvtxit_base(torch_cuda_mock): torch_cuda_mock = _setup_mock(torch_cuda_mock) # dummy func calls a calls b @nvtxit() def call_b(): assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 return 42 @nvtxit() def call_a(): assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 return call_b() def dummy_func(): assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 return call_a() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 dummy_func() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 assert torch_cuda_mock.nvtx.range_push.call_args_list == [ call('call_a_0'), call('call_b_0') ] @patch('torch_geometric.profile.nvtx.torch.cuda') def test_nvtxit_rename(torch_cuda_mock): torch_cuda_mock = _setup_mock(torch_cuda_mock) # dummy func calls a calls b @nvtxit() def call_b(): assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 return 42 @nvtxit('a_nvtx') def call_a(): assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 return call_b() def dummy_func(): assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 return call_a() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 dummy_func() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 assert torch_cuda_mock.nvtx.range_push.call_args_list == [ call('a_nvtx_0'), call('call_b_0') ] @patch('torch_geometric.profile.nvtx.torch.cuda') def test_nvtxit_iters(torch_cuda_mock): torch_cuda_mock = _setup_mock(torch_cuda_mock) # dummy func calls a calls b @nvtxit(n_iters=1) def call_b(): return 42 @nvtxit() def call_a(): return call_b() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 call_b() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 call_a() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 2 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 2 # noqa: E501 assert torch_cuda_mock.nvtx.range_push.call_args_list == [ call('call_b_0'), call('call_a_0') ] @patch('torch_geometric.profile.nvtx.torch.cuda') def test_nvtxit_warmups(torch_cuda_mock): torch_cuda_mock = _setup_mock(torch_cuda_mock) # dummy func calls a calls b @nvtxit(n_warmups=1) def call_b(): return 42 @nvtxit() def call_a(): return call_b() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 call_b() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 call_a() assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 assert torch_cuda_mock.nvtx.range_push.call_args_list == [ call('call_a_0'), call('call_b_1') ] ================================================ FILE: test/profile/test_profile.py ================================================ import os import os.path as osp import warnings import pytest import torch import torch.nn.functional as F from torch_geometric.nn import GraphSAGE from torch_geometric.profile import ( get_stats_summary, profileit, rename_profile_file, timeit, ) from torch_geometric.profile.profile import torch_profile, xpu_profile from torch_geometric.testing import ( onlyCUDA, onlyLinux, onlyOnline, onlyXPU, withDevice, withPackage, ) @withDevice @onlyLinux def test_timeit(device): x = torch.randn(100, 16, device=device) lin = torch.nn.Linear(16, 32).to(device) with timeit(log=False) as t: assert not hasattr(t, 'duration') with torch.no_grad(): lin(x) t.reset() assert t.duration > 0 del t.duration assert not hasattr(t, 'duration') assert t.duration > 0 @onlyCUDA @onlyOnline @withPackage('pytorch_memlab') def test_profileit_cuda(get_dataset): warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*') dataset = get_dataset(name='karate') data = dataset[0].cuda() model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, out_channels=dataset.num_classes).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) @profileit('cuda') def train(model, x, edge_index, y): model.train() optimizer.zero_grad() out = model(x, edge_index) loss = F.cross_entropy(out, y) loss.backward() return float(loss.detach()) stats_list = [] for epoch in range(5): _, stats = train(model, data.x, data.edge_index, data.y) assert stats.time > 0 assert stats.max_allocated_gpu > 0 assert stats.max_reserved_gpu > 0 assert stats.max_active_gpu > 0 assert stats.nvidia_smi_free_cuda >= 0 assert stats.nvidia_smi_used_cuda >= 0 if epoch >= 2: # Warm-up stats_list.append(stats) stats_summary = get_stats_summary(stats_list) assert stats_summary.time_mean > 0 assert stats_summary.time_std > 0 assert stats_summary.max_allocated_gpu > 0 assert stats_summary.max_reserved_gpu > 0 assert stats_summary.max_active_gpu > 0 assert stats_summary.min_nvidia_smi_free_cuda >= 0 assert stats_summary.max_nvidia_smi_used_cuda >= 0 @onlyXPU def test_profileit_xpu(get_dataset): warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*') dataset = get_dataset(name='karate') device = torch.device('xpu') data = dataset[0].to(device) model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, out_channels=dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) @profileit('xpu') def train(model, x, edge_index, y): model.train() optimizer.zero_grad() out = model(x, edge_index) loss = F.cross_entropy(out, y) loss.backward() return float(loss.detach()) stats_list = [] for epoch in range(5): _, stats = train(model, data.x, data.edge_index, data.y) assert stats.time > 0 assert stats.max_allocated_gpu > 0 assert stats.max_reserved_gpu > 0 assert stats.max_active_gpu > 0 assert not hasattr(stats, 'nvidia_smi_free_cuda') assert not hasattr(stats, 'nvidia_smi_used_cuda') if epoch >= 2: # Warm-up stats_list.append(stats) stats_summary = get_stats_summary(stats_list) assert stats_summary.time_mean > 0 assert stats_summary.time_std > 0 assert stats_summary.max_allocated_gpu > 0 assert stats_summary.max_reserved_gpu > 0 assert stats_summary.max_active_gpu > 0 assert not hasattr(stats_summary, 'min_nvidia_smi_free_cuda') assert not hasattr(stats_summary, 'max_nvidia_smi_used_cuda') @withDevice @onlyOnline def test_torch_profile(capfd, get_dataset, device): dataset = get_dataset(name='karate') data = dataset[0].to(device) model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, out_channels=dataset.num_classes).to(device) with torch_profile(): model(data.x, data.edge_index) out, _ = capfd.readouterr() assert 'Self CPU time total' in out if data.x.is_cuda: assert 'Self CUDA time total' in out rename_profile_file('test_profile') assert osp.exists('profile-test_profile.json') os.remove('profile-test_profile.json') @onlyXPU @onlyOnline @pytest.mark.parametrize('export_chrome_trace', [False, True]) def test_xpu_profile(capfd, get_dataset, export_chrome_trace): dataset = get_dataset(name='karate') device = torch.device('xpu') data = dataset[0].to(device) model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, out_channels=dataset.num_classes).to(device) with xpu_profile(export_chrome_trace): model(data.x, data.edge_index) out, _ = capfd.readouterr() assert 'Self CPU' in out if data.x.is_xpu: assert 'Self XPU' in out f_name = 'timeline.json' f_exists = osp.exists(f_name) if not export_chrome_trace: assert not f_exists else: assert f_exists os.remove(f_name) ================================================ FILE: test/profile/test_profile_utils.py ================================================ import torch from torch.nn import Linear from torch_geometric.data import Data from torch_geometric.profile import ( count_parameters, get_cpu_memory_from_gc, get_data_size, get_gpu_memory_from_gc, get_gpu_memory_from_ipex, get_gpu_memory_from_nvidia_smi, get_model_size, ) from torch_geometric.profile.utils import ( byte_to_megabyte, medibyte_to_megabyte, ) from torch_geometric.testing import onlyCUDA, onlyXPU, withPackage from torch_geometric.typing import SparseTensor def test_count_parameters(): assert count_parameters(Linear(32, 128)) == 32 * 128 + 128 def test_get_model_size(): model_size = get_model_size(Linear(32, 128, bias=False)) assert model_size >= 32 * 128 * 4 and model_size < 32 * 128 * 4 + 2000 def test_get_data_size(): x = torch.randn(10, 128) data = Data(x=x, y=x) data_size = get_data_size(data) assert data_size == 10 * 128 * 4 @withPackage('torch_sparse') def test_get_data_size_with_sparse_tensor(): x = torch.randn(10, 128) row, col = torch.randint(0, 10, (2, 100), dtype=torch.long) adj_t = SparseTensor(row=row, col=col, value=None, sparse_sizes=(10, 10)) data = Data(x=x, y=x, adj_t=adj_t) data_size = get_data_size(data) assert data_size == 10 * 128 * 4 + 11 * 8 + 100 * 8 def test_get_cpu_memory_from_gc(): old_mem = get_cpu_memory_from_gc() _ = torch.randn(10, 128) new_mem = get_cpu_memory_from_gc() assert new_mem - old_mem == 10 * 128 * 4 @onlyCUDA def test_get_gpu_memory_from_gc(): old_mem = get_gpu_memory_from_gc() _ = torch.randn(10, 128, device='cuda') new_mem = get_gpu_memory_from_gc() assert new_mem - old_mem == 10 * 128 * 4 @onlyCUDA def test_get_gpu_memory_from_nvidia_smi(): free_mem, used_mem = get_gpu_memory_from_nvidia_smi(device=0, digits=2) assert free_mem >= 0 assert used_mem >= 0 @onlyXPU def test_get_gpu_memory_from_ipex(): max_allocated, max_reserved, max_active = get_gpu_memory_from_ipex() assert max_allocated >= 0 assert max_reserved >= 0 assert max_active >= 0 def test_bytes_function(): assert byte_to_megabyte(1024 * 1024) == 1.00 assert medibyte_to_megabyte(1 / 1.0485) == 1.00 ================================================ FILE: test/profile/test_profiler.py ================================================ import torch import torch_geometric.typing from torch_geometric.nn import GraphSAGE from torch_geometric.profile.profiler import Profiler from torch_geometric.testing import withDevice @withDevice def test_profiler(capfd, get_dataset, device): x = torch.randn(10, 16, device=device) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8], ], device=device) model = GraphSAGE(16, hidden_channels=32, num_layers=2).to(device) with Profiler(model, profile_memory=True, use_cuda=x.is_cuda) as prof: model(x, edge_index) _, err = capfd.readouterr() if not torch_geometric.typing.WITH_PT24: assert 'Completed Stage' in err _, heading_list, raw_results, layer_names, layer_stats = prof.get_trace() assert 'Self CPU total' in heading_list assert 'aten::relu' in raw_results assert '-act--aten::relu' in layer_names ================================================ FILE: test/sampler/test_sampler_base.py ================================================ import pytest import torch from torch_geometric.sampler.base import ( HeteroSamplerOutput, NumNeighbors, SamplerOutput, ) from torch_geometric.sampler.utils import global_to_local_node_idx from torch_geometric.testing import get_random_edge_index from torch_geometric.utils import is_undirected def test_homogeneous_num_neighbors(): with pytest.raises(ValueError, match="'default' must be set to 'None'"): num_neighbors = NumNeighbors([25, 10], default=[-1, -1]) num_neighbors = NumNeighbors([25, 10]) assert str(num_neighbors) == 'NumNeighbors(values=[25, 10], default=None)' assert num_neighbors.get_values() == [25, 10] assert num_neighbors.__dict__['_values'] == [25, 10] assert num_neighbors.get_values() == [25, 10] # Test caching. assert num_neighbors.get_mapped_values() == [25, 10] assert num_neighbors.__dict__['_mapped_values'] == [25, 10] assert num_neighbors.get_mapped_values() == [25, 10] # Test caching. assert num_neighbors.num_hops == 2 assert num_neighbors.__dict__['_num_hops'] == 2 assert num_neighbors.num_hops == 2 # Test caching. ''' Merge and collate tests use the following graph: ############# ########### # Alice (0) # -> "works with" -> # Bob (1) # ############# ########### | v "leads" | v ############# ############ # Carol (2) # -> "works with" -> # Dave (3) # ############# ############ ''' def _init_merge_sampler_outputs(hetero=False, disjoint=False): if not hetero: output1 = SamplerOutput( node=torch.tensor([0, 1, 2]), row=torch.tensor([0, 0]), col=torch.tensor([1, 2]), edge=torch.tensor([0, 1]), batch=torch.tensor([0, 0, 0]) if disjoint else None, num_sampled_nodes=list([1, 2]), num_sampled_edges=list([2]), orig_row=None, orig_col=None, metadata=(None, None), ) output2 = SamplerOutput( node=torch.tensor([0, 2, 3]), row=torch.tensor([0, 1]), col=torch.tensor([1, 2]), edge=torch.tensor([1, 2]), batch=torch.tensor([0, 0, 0]) if disjoint else None, num_sampled_nodes=list([1, 1, 1]), num_sampled_edges=list([1, 1]), orig_row=None, orig_col=None, metadata=(None, None), ) return output1, output2 else: # TODO(zaristei) raise NotImplementedError("Heterogeneous merge not implemented") @pytest.mark.parametrize("disjoint", [True, False]) @pytest.mark.parametrize("bidirectional", [True, False]) def test_homogeneous_merge(disjoint, bidirectional): """Merge an output representing 1<-0->2 with one representing 0->2->3.""" output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint) if bidirectional: output1 = output1.to_bidirectional(keep_orig_edges=True) output2 = output2.to_bidirectional(keep_orig_edges=True) expected_output = SamplerOutput( node=torch.tensor([0, 1, 2, 3]), row=torch.tensor([0, 0, 2]), col=torch.tensor([1, 2, 3]), edge=torch.tensor([0, 1, 2]), batch=torch.tensor([0, 0, 0, 0]) if disjoint else None, num_sampled_nodes=[1, 2, 0, 0, 1], num_sampled_edges=[2, 0, 1], orig_row=None, orig_col=None, metadata=[(None, None), (None, None)], ) if bidirectional: expected_output = expected_output.to_bidirectional( keep_orig_edges=True) merged_output = output1.merge_with(output2) assert str(merged_output) == str(expected_output) @pytest.mark.parametrize("disjoint", [True, False]) @pytest.mark.parametrize("bidirectional", [True, False]) def test_homogeneous_merge_no_replace(disjoint, bidirectional): """Merge an output representing 1<-0->2 with one representing 0->2->3. replace=True makes it so that merged output is a simple concatenation instead of removing already sampled nodes/edges. """ output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint) if bidirectional: output1 = output1.to_bidirectional(keep_orig_edges=True) output2 = output2.to_bidirectional(keep_orig_edges=True) expected_output = SamplerOutput( node=torch.tensor([0, 1, 2, 0, 2, 3]), row=torch.tensor([0, 0, 3, 4]), col=torch.tensor([1, 2, 4, 5]), edge=torch.tensor([0, 1, 1, 2]), batch=torch.tensor([0, 0, 0, 3, 3, 3]) if disjoint else None, num_sampled_nodes=[1, 2, 1, 1, 1], num_sampled_edges=[2, 1, 1], orig_row=None, orig_col=None, metadata=[(None, None), (None, None)], ) if bidirectional: expected_output = expected_output.to_bidirectional( keep_orig_edges=True) merged_output = output1.merge_with(output2, replace=False) assert str(merged_output) == str(expected_output) def _init_collate_sampler_outputs(disjoint=False): output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint) # new edge not present in graph above output3 = SamplerOutput( node=torch.tensor([3, 4]), row=torch.tensor([0]), col=torch.tensor([1]), edge=torch.tensor([3]), batch=torch.tensor([0, 0]) if disjoint else None, num_sampled_nodes=list([1, 1]), num_sampled_edges=list([1]), orig_row=None, orig_col=None, metadata=(None, None), ) return [output1, output2, output3] @pytest.mark.parametrize("replace", [True, False]) @pytest.mark.parametrize("disjoint", [True, False]) def test_homogeneous_collate(disjoint, replace): output1, output2, output3 = _init_collate_sampler_outputs(disjoint) collated = SamplerOutput.collate([output1, output2, output3], replace=replace) assert str(collated) == str( (output1.merge_with(output2, replace=replace)).merge_with( output3, replace=replace)) def test_homogeneous_collate_empty(): with pytest.raises(ValueError, match="Cannot collate an empty list of SamplerOutputs"): SamplerOutput.collate([]) def test_homogeneous_collate_single(): output, _ = _init_merge_sampler_outputs() collated = SamplerOutput.collate([output]) assert str(collated) == str(output) def test_homogeneous_collate_missing_fields(): output1, output2, output3 = _init_collate_sampler_outputs() output3.edge = None with pytest.raises( ValueError, match="Output 3 has a different field than the first output"): SamplerOutput.collate([output1, output2, output3]) def test_heterogeneous_num_neighbors_list(): num_neighbors = NumNeighbors([25, 10]) values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]} values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')]) assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]} assert num_neighbors.num_hops == 2 def test_heterogeneous_num_neighbors_dict_and_default(): num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1]) with pytest.raises(ValueError, match="hops must be the same across all"): values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1, -1]) with pytest.raises(ValueError, match="Not all edge types"): num_neighbors.get_values([('A', 'C'), ('B', 'A')]) values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) assert values == {('A', 'B'): [25, 10], ('B', 'A'): [-1, -1]} values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')]) assert values == {'A__to__B': [25, 10], 'B__to__A': [-1, -1]} assert num_neighbors.num_hops == 2 def test_heterogeneous_num_neighbors_empty_dict(): num_neighbors = NumNeighbors({}, default=[25, 10]) values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]} values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')]) assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]} assert num_neighbors.num_hops == 2 def test_homogeneous_to_bidirectional(): edge_index = get_random_edge_index(10, 10, num_edges=20) obj = SamplerOutput( node=torch.arange(10), row=edge_index[0], col=edge_index[0], edge=torch.arange(edge_index.size(1)), ).to_bidirectional() assert is_undirected(torch.stack([obj.row, obj.col], dim=0)) def test_heterogeneous_to_bidirectional(): edge_index1 = get_random_edge_index(10, 5, num_edges=20) edge_index2 = get_random_edge_index(5, 10, num_edges=20) edge_index3 = get_random_edge_index(10, 10, num_edges=20) obj = HeteroSamplerOutput( node={ 'v1': torch.arange(10), 'v2': torch.arange(5) }, row={ ('v1', 'to', 'v2'): edge_index1[0], ('v2', 'rev_to', 'v1'): edge_index2[0], ('v1', 'to', 'v1'): edge_index3[0], }, col={ ('v1', 'to', 'v2'): edge_index1[1], ('v2', 'rev_to', 'v1'): edge_index2[1], ('v1', 'to', 'v1'): edge_index3[1], }, edge={}, ).to_bidirectional() assert torch.equal( obj.row['v1', 'to', 'v2'].sort().values, obj.col['v2', 'rev_to', 'v1'].sort().values, ) assert torch.equal( obj.col['v1', 'to', 'v2'].sort().values, obj.row['v2', 'rev_to', 'v1'].sort().values, ) assert is_undirected( torch.stack([obj.row['v1', 'to', 'v1'], obj.col['v1', 'to', 'v1']], 0)) def test_homogeneous_sampler_output_global_fields(): output = SamplerOutput( node=torch.tensor([0, 2, 3]), row=torch.tensor([0, 1]), col=torch.tensor([1, 2]), edge=torch.tensor([1, 2]), batch=torch.tensor([0, 0, 0]), num_sampled_nodes=[1, 1, 1], num_sampled_edges=[1, 1], orig_row=None, orig_col=None, metadata=(None, None), ) local_values = [] global_values = [] global_row, global_col = output.global_row, output.global_col assert torch.equal(global_row, torch.tensor([0, 2])) assert torch.equal(global_col, torch.tensor([2, 3])) local_values.append(output.row) local_values.append(output.col) global_values.append(global_row) global_values.append(global_col) seed_node = output.seed_node assert torch.equal(seed_node, torch.tensor([0, 0, 0])) local_values.append(output.batch) global_values.append(seed_node) output_bidirectional = output.to_bidirectional(keep_orig_edges=True) global_bidir_row, global_bidir_col = \ output_bidirectional.global_row, output_bidirectional.global_col assert torch.equal(global_bidir_row, torch.tensor([2, 0, 3, 2])) assert torch.equal(global_bidir_col, torch.tensor([0, 2, 2, 3])) local_values.append(output_bidirectional.row) local_values.append(output_bidirectional.col) global_values.append(global_bidir_row) global_values.append(global_bidir_col) assert torch.equal(output.global_row, output_bidirectional.global_orig_row) assert torch.equal(output.global_col, output_bidirectional.global_orig_col) # Make sure reverse mapping is correct for local_value, global_value in zip(local_values, global_values): assert torch.equal(global_to_local_node_idx(output.node, global_value), local_value) def test_heterogeneous_sampler_output_global_fields(): def _tensor_dict_equal(dict1, dict2): is_equal = True is_equal &= dict1.keys() == dict2.keys() for key in dict1.keys(): is_equal &= torch.equal(dict1[key], dict2[key]) return is_equal output = HeteroSamplerOutput( node={"person": torch.tensor([0, 2, 3])}, row={ ("person", "works_with", "person"): torch.tensor([1]), ("person", "leads", "person"): torch.tensor([0]) }, col={ ("person", "works_with", "person"): torch.tensor([2]), ("person", "leads", "person"): torch.tensor([1]) }, edge={ ("person", "works_with", "person"): torch.tensor([1]), ("person", "leads", "person"): torch.tensor([0]) }, batch={"person": torch.tensor([0, 0, 0])}, num_sampled_nodes={"person": torch.tensor([1, 1, 1])}, num_sampled_edges={ ("person", "works_with", "person"): torch.tensor([1]), ("person", "leads", "person"): torch.tensor([1]) }, orig_row=None, orig_col=None, metadata=(None, None), ) global_row, global_col = output.global_row, output.global_col assert _tensor_dict_equal( global_row, { ("person", "works_with", "person"): torch.tensor([2]), ("person", "leads", "person"): torch.tensor([0]) }) assert _tensor_dict_equal( global_col, { ("person", "works_with", "person"): torch.tensor([3]), ("person", "leads", "person"): torch.tensor([2]) }) local_row_dict = { k: global_to_local_node_idx(output.node[k[0]], v) for k, v in global_row.items() } assert _tensor_dict_equal(local_row_dict, output.row) local_col_dict = { k: global_to_local_node_idx(output.node[k[2]], v) for k, v in global_col.items() } assert _tensor_dict_equal(local_col_dict, output.col) seed_node = output.seed_node assert _tensor_dict_equal(seed_node, {"person": torch.tensor([0, 0, 0])}) local_batch_dict = { k: global_to_local_node_idx(output.node[k], v) for k, v in seed_node.items() } assert _tensor_dict_equal(local_batch_dict, output.batch) output_bidirectional = output.to_bidirectional(keep_orig_edges=True) global_bidir_row, global_bidir_col = \ output_bidirectional.global_row, output_bidirectional.global_col assert _tensor_dict_equal( global_bidir_row, { ("person", "works_with", "person"): torch.tensor([3, 2]), ("person", "leads", "person"): torch.tensor([2, 0]) }) assert _tensor_dict_equal( global_bidir_col, { ("person", "works_with", "person"): torch.tensor([2, 3]), ("person", "leads", "person"): torch.tensor([0, 2]) }) local_bidir_row_dict = { k: global_to_local_node_idx(output_bidirectional.node[k[0]], v) for k, v in global_bidir_row.items() } assert _tensor_dict_equal(local_bidir_row_dict, output_bidirectional.row) local_bidir_col_dict = { k: global_to_local_node_idx(output_bidirectional.node[k[2]], v) for k, v in global_bidir_col.items() } assert _tensor_dict_equal(local_bidir_col_dict, output_bidirectional.col) assert _tensor_dict_equal(output.global_row, output_bidirectional.global_orig_row) assert _tensor_dict_equal(output.global_col, output_bidirectional.global_orig_col) ================================================ FILE: test/sampler/test_sampler_neighbor_sampler.py ================================================ import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.sampler.base import NodeSamplerInput, SamplerOutput from torch_geometric.sampler.neighbor_sampler import ( BidirectionalNeighborSampler, NeighborSampler, ) from torch_geometric.testing import ( MyFeatureStore, MyGraphStore, onlyNeighborSampler, ) def _init_sample_graph(hetero=False): """Initializes the following graph. ############# ########### # Alice (0) # -> "works with" -> # Bob (1) # ############# ########### | v "leads" | v ############# ############ # Carol (2) # -> "works with" -> # Dave (3) # ############# ############ """ sample_attr = None sample_edge_attr = None sample_edge_indices = None if not hetero: sample_attr = torch.tensor([[0], [1], [2], [3]]) sample_edge_attr = torch.tensor([[1], [2], [3]]) sample_edge_indices = torch.tensor([[0, 0, 2], [1, 2, 3]]) else: sample_attr = dict({ "person": dict({"x": torch.tensor([[1], [2], [3]])}), "manager": dict({"x": torch.tensor([[0]])}) }) sample_edge_attr = dict({ ('person', 'works_with', 'person'): dict({"edge_attr": torch.tensor([[3]])}), ('manager', 'leads', 'person'): dict({"edge_attr": torch.tensor([[1]])}), ('manager', 'works_with', 'person'): dict({"edge_attr": torch.tensor([[2]])}) }) sample_edge_indices = dict({ ('person', 'works_with', 'person'): dict({"edge_index": torch.tensor([[1], [2]])}), ('manager', 'leads', 'person'): dict({"edge_index": torch.tensor([[0], [1]])}), ('manager', 'works_with', 'person'): dict({"edge_index": torch.tensor([[0], [0]])}) }) return sample_attr, sample_edge_attr, sample_edge_indices def _init_graph_to_sample(graph_dtype, hetero=False, reverse=False): sample_attr, sample_edge_attr, sample_edge_indices = _init_sample_graph( hetero) if reverse: if not hetero: sample_edge_indices = sample_edge_indices.flip(0) else: reversed_edge_indices = dict() reversed_edge_attr = dict() for edge_type, edge_index in sample_edge_indices.items(): edge_index = edge_index["edge_index"] edge_attr = sample_edge_attr[edge_type]["edge_attr"] flipped_edge_index = edge_index.flip(0) flipped_edge_type = (edge_type[2], edge_type[1], edge_type[0]) reversed_edge_indices[flipped_edge_type] = dict( {"edge_index": flipped_edge_index}) reversed_edge_attr[flipped_edge_type] = dict( {"edge_attr": edge_attr}) sample_edge_indices = reversed_edge_indices sample_edge_attr = reversed_edge_attr graph_to_sample = None if graph_dtype == 'data' and not hetero: graph_to_sample = Data(edge_index=sample_edge_indices, x=sample_attr, time=sample_attr.squeeze(-1), edge_attr=sample_edge_attr.squeeze(-1)) elif graph_dtype == 'remote' and not hetero: graph_store = MyGraphStore() graph_store.put_edge_index(sample_edge_indices, edge_type=None, layout='coo', is_sorted=True, size=(4, 4)) feature_store = MyFeatureStore() feature_store.put_tensor(sample_attr, group_name='default', attr_name='x', index=None) # temporal node sampling on (fs, gs) needs 'time' attr feature_store.put_tensor(sample_attr.squeeze(-1), group_name='default', attr_name='time', index=None) feature_store.put_tensor(sample_edge_attr.squeeze(-1), group_name='default', attr_name='edge_attr', index=None) graph_to_sample = (feature_store, graph_store) elif graph_dtype == 'data' and hetero: graph_to_sample = HeteroData() for node_type, node_attr in sample_attr.items(): graph_to_sample[node_type].x = node_attr['x'] graph_to_sample[node_type].time = node_attr['x'].squeeze(-1) for edge_type in sample_edge_indices.keys(): graph_to_sample[edge_type].edge_index = sample_edge_indices[ edge_type]["edge_index"] graph_to_sample[edge_type].edge_attr = sample_edge_attr[edge_type][ "edge_attr"].squeeze(-1) elif graph_dtype == 'remote' and hetero: graph_store = MyGraphStore() for edge_type, edge_index in sample_edge_indices.items(): edge_index = edge_index["edge_index"] graph_store.put_edge_index( edge_index, edge_type=edge_type, layout='coo', is_sorted=True, size=(len(sample_attr[edge_type[0]]["x"]), len(sample_attr[edge_type[2]]["x"]))) feature_store = MyFeatureStore() for node_type, node_attr in sample_attr.items(): feature_store.put_tensor(node_attr["x"], group_name=node_type, attr_name='x', index=None) # temporal node sampling on (fs, gs) needs 'time' attr feature_store.put_tensor(node_attr["x"].squeeze(-1), group_name=node_type, attr_name='time', index=None) for edge_type, edge_attr in sample_edge_attr.items(): feature_store.put_tensor(edge_attr["edge_attr"].squeeze(-1), group_name=edge_type, attr_name='edge_attr', index=None) graph_to_sample = (feature_store, graph_store) return graph_to_sample @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) def test_homogeneous_neighbor_sampler_basic(input_type): graph_to_sample = _init_graph_to_sample(input_type, hetero=False) # NeighborSampler default parameters 1 node # disjoint = False # replace = False sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], } # Sampling from Bob should yield only Alice node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([1])) expected_output = SamplerOutput( node=torch.tensor([1, 0]), row=torch.tensor([1]), col=torch.tensor([0]), edge=torch.tensor([0]), batch=None, num_sampled_nodes=[1, 1], num_sampled_edges=[1], orig_row=None, orig_col=None, metadata=(None, None)) sampler = NeighborSampler(**sampler_kwargs) sampler_output = sampler.sample_from_nodes(node_sampler_input) assert str(sampler_output) == str(expected_output) # Sampling Alice should yield no edges node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([0])) expected_output = SamplerOutput(node=torch.tensor([0]), row=torch.empty( 0, dtype=torch.int64), col=torch.empty(0, dtype=torch.int64), edge=torch.empty(0, dtype=torch.int64), batch=None, num_sampled_nodes=[1, 0], num_sampled_edges=[0], orig_row=None, orig_col=None, metadata=(None, None)) sampler = NeighborSampler(**sampler_kwargs) sampler_output = sampler.sample_from_nodes(node_sampler_input) assert str(sampler_output) == str(expected_output) @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) def test_heterogeneous_neighbor_sampler_basic(input_type): graph_to_sample = _init_graph_to_sample(input_type, hetero=True) # NeighborSampler default parameters 1 node # disjoint = False # replace = False sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], } # Sampling from Bob should yield only Alice node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([0]), input_type="person") sampler = NeighborSampler(**sampler_kwargs) sampler_output = sampler.sample_from_nodes(node_sampler_input) assert sampler_output.node['person'].tolist() == [0] assert sampler_output.node['manager'].tolist() == [0] assert sampler_output.row[('manager', 'works_with', 'person')] == torch.tensor([0]) assert sampler_output.row[('manager', 'leads', 'person')].numel() == 0 assert sampler_output.row[('person', 'works_with', 'person')].numel() == 0 assert sampler_output.col[('manager', 'works_with', 'person')] == torch.tensor([0]) assert sampler_output.col[('manager', 'leads', 'person')].numel() == 0 assert sampler_output.col[('person', 'works_with', 'person')].numel() == 0 assert sampler_output.edge[('manager', 'works_with', 'person')] == torch.tensor([0]) assert sampler_output.edge[('manager', 'leads', 'person')].numel() == 0 assert sampler_output.edge[('person', 'works_with', 'person')].numel() == 0 # Sampling Alice should yield no edges node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([0]), input_type="manager") sampler_output = sampler.sample_from_nodes(node_sampler_input) assert sampler_output.node['manager'].tolist() == [0] assert sampler_output.node['person'].numel() == 0 assert sampler_output.row[('manager', 'works_with', 'person')].numel() == 0 assert sampler_output.row[('manager', 'leads', 'person')].numel() == 0 assert sampler_output.row[('person', 'works_with', 'person')].numel() == 0 assert sampler_output.col[('manager', 'works_with', 'person')].numel() == 0 assert sampler_output.col[('manager', 'leads', 'person')].numel() == 0 assert sampler_output.col[('person', 'works_with', 'person')].numel() == 0 assert sampler_output.edge[('manager', 'works_with', 'person')].numel() == 0 assert sampler_output.edge[('manager', 'leads', 'person')].numel() == 0 assert sampler_output.edge[('person', 'works_with', 'person')].numel() == 0 @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) def test_homogeneous_neighbor_sampler_backwards(input_type): graph_to_sample = _init_graph_to_sample(input_type, hetero=False) sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], } node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([2])) sampler = NeighborSampler(**sampler_kwargs) # This output should have Carol and Alice sampler_output = sampler.sample_from_nodes(node_sampler_input) backward_sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], 'sample_direction': 'backward', } backward_sampler = NeighborSampler(**backward_sampler_kwargs) # This output should have Carol and Dave backward_sampler_output = backward_sampler.sample_from_nodes( node_sampler_input) reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=False, reverse=True) reverse_sampler_kwargs = { 'data': reverse_graph_to_sample, 'num_neighbors': [1], } reverse_sampler = NeighborSampler(**reverse_sampler_kwargs) # This output should have Carol and Dave reverse_sampler_output = reverse_sampler.sample_from_nodes( node_sampler_input) reverse_backward_sampler_kwargs = { 'data': reverse_graph_to_sample, 'num_neighbors': [1], 'sample_direction': 'backward', } reverse_backward_sampler = NeighborSampler( **reverse_backward_sampler_kwargs) # This output should have Carol and Alice reverse_backward_sampler_output = \ reverse_backward_sampler.sample_from_nodes(node_sampler_input) assert torch.equal(sampler_output.node, reverse_backward_sampler_output.node) assert torch.equal(sampler_output.row, reverse_backward_sampler_output.col) assert torch.equal(sampler_output.col, reverse_backward_sampler_output.row) assert torch.equal(sampler_output.edge, reverse_backward_sampler_output.edge) assert torch.equal(backward_sampler_output.node, reverse_sampler_output.node) assert torch.equal(backward_sampler_output.row, reverse_sampler_output.col) assert torch.equal(backward_sampler_output.col, reverse_sampler_output.row) assert torch.equal(backward_sampler_output.edge, reverse_sampler_output.edge) @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) def test_homogeneous_neighbor_sampler_weighted_backwards(input_type): graph_to_sample = _init_graph_to_sample(input_type, hetero=False) reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=False, reverse=True) sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1, 1], 'weight_attr': 'weight', 'sample_direction': 'backward' } reverse_sampler_kwargs = { 'data': reverse_graph_to_sample, 'num_neighbors': [1, 1], 'weight_attr': 'weight', 'sample_direction': 'forward', } if input_type == 'remote': with pytest.raises(NotImplementedError): NeighborSampler(**sampler_kwargs) return graph_to_sample['weight'] = torch.tensor([1.0, 0.0, 1.0]) reverse_graph_to_sample['weight'] = torch.tensor([1.0, 0.0, 1.0]) node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([0])) # Sampling from Alice should yield Bob backward_sampler = NeighborSampler(**sampler_kwargs) backward_sampler_output = backward_sampler.sample_from_nodes( node_sampler_input) reverse_sampler = NeighborSampler(**reverse_sampler_kwargs) reverse_sampler_output = reverse_sampler.sample_from_nodes( node_sampler_input) assert torch.equal(backward_sampler_output.node, reverse_sampler_output.node) assert torch.equal(backward_sampler_output.row, reverse_sampler_output.col) assert torch.equal(backward_sampler_output.col, reverse_sampler_output.row) assert torch.equal(backward_sampler_output.edge, reverse_sampler_output.edge) graph_to_sample['weight'] = torch.tensor([0.0, 1.0, 1.0]) reverse_graph_to_sample['weight'] = torch.tensor([0.0, 1.0, 1.0]) # Sampling from Alice should yield Carol and Dave backward_sampler = NeighborSampler(**sampler_kwargs) backward_sampler_output = backward_sampler.sample_from_nodes( node_sampler_input) reverse_sampler = NeighborSampler(**reverse_sampler_kwargs) reverse_sampler_output = reverse_sampler.sample_from_nodes( node_sampler_input) assert torch.equal(backward_sampler_output.node, reverse_sampler_output.node) assert torch.equal(backward_sampler_output.row, reverse_sampler_output.col) assert torch.equal(backward_sampler_output.col, reverse_sampler_output.row) assert torch.equal(backward_sampler_output.edge, reverse_sampler_output.edge) @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) @pytest.mark.parametrize('time_attr', ['time', 'edge_attr']) def test_homogeneous_neighbor_sampler_temporal_backwards( input_type, time_attr): graph_to_sample = _init_graph_to_sample(input_type, hetero=False) reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=False, reverse=True) sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [2, 2], 'time_attr': time_attr, } reverse_sampler_kwargs = { 'data': reverse_graph_to_sample, 'num_neighbors': [2, 2], 'time_attr': time_attr, } node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([1]), time=torch.tensor([1])) reverse_node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([0]), time=torch.tensor([1])) # sampling from Dave should yield Carol, Alice sampler = NeighborSampler(**sampler_kwargs) sampler_output = sampler.sample_from_nodes(node_sampler_input) reverse_sampler = NeighborSampler(**reverse_sampler_kwargs) reverse_sampler_output = reverse_sampler.sample_from_nodes( reverse_node_sampler_input) assert torch.equal(sampler_output.node, torch.tensor([1, 0])) assert torch.equal(reverse_sampler_output.node, torch.tensor([0, 1])) """ TODO (zaristei) Negative cases for temporal sampling, then verify that the output is correct for backwards sampling. """ pytest.skip("still TODO") @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) def test_heterogeneous_neighbor_sampler_backwards(input_type): graph_to_sample = _init_graph_to_sample(input_type, hetero=True) sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], } node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([1]), input_type="person") sampler = NeighborSampler(**sampler_kwargs) # This output should have Carol and Alice sampler_output = sampler.sample_from_nodes(node_sampler_input) backward_sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], 'sample_direction': 'backward', } backward_sampler = NeighborSampler(**backward_sampler_kwargs) # This output should have Carol and Dave backward_sampler_output = backward_sampler.sample_from_nodes( node_sampler_input) reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=True, reverse=True) reverse_sampler_kwargs = { 'data': reverse_graph_to_sample, 'num_neighbors': [1], } reverse_sampler = NeighborSampler(**reverse_sampler_kwargs) # This output should have Carol and Dave reverse_sampler_output = reverse_sampler.sample_from_nodes( node_sampler_input) reverse_backward_sampler_kwargs = { 'data': reverse_graph_to_sample, 'num_neighbors': [1], 'sample_direction': 'backward', } reverse_backward_sampler = NeighborSampler( **reverse_backward_sampler_kwargs) # This output should have Carol and Alice reverse_backward_sampler_output = \ reverse_backward_sampler.sample_from_nodes(node_sampler_input) def reverse_key(key): return (key[2], key[1], key[0]) assert sampler_output.node.keys( ) == reverse_backward_sampler_output.node.keys() assert reverse_sampler_output.node.keys( ) == backward_sampler_output.node.keys() for key in sampler_output.node.keys(): assert torch.equal(sampler_output.node[key], reverse_backward_sampler_output.node[key]) for key in reverse_sampler_output.node.keys(): assert torch.equal(reverse_sampler_output.node[key], backward_sampler_output.node[key]) assert len(sampler_output.row.keys()) == len( reverse_backward_sampler_output.row.keys()) for key in sampler_output.row.keys(): assert reverse_key(key) in reverse_backward_sampler_output.row.keys() assert torch.equal( sampler_output.row[key], reverse_backward_sampler_output.col[reverse_key(key)]) assert torch.equal( sampler_output.col[key], reverse_backward_sampler_output.row[reverse_key(key)]) assert torch.equal( sampler_output.edge[key], reverse_backward_sampler_output.edge[reverse_key(key)]) assert len(reverse_sampler_output.row.keys()) == len( backward_sampler_output.row.keys()) for key in reverse_sampler_output.row.keys(): assert reverse_key(key) in backward_sampler_output.row.keys() assert torch.equal(reverse_sampler_output.row[key], backward_sampler_output.col[reverse_key(key)]) assert torch.equal(reverse_sampler_output.col[key], backward_sampler_output.row[reverse_key(key)]) assert torch.equal(reverse_sampler_output.edge[key], backward_sampler_output.edge[reverse_key(key)]) @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) def test_bidirectional_neighbor_sampler(input_type): graph_to_sample = _init_graph_to_sample(input_type, hetero=False) sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], } node_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([2])) sampler = BidirectionalNeighborSampler(**sampler_kwargs) sampler_output = sampler.sample_from_nodes(node_sampler_input) expected_output = SamplerOutput( # Union between forward and backward nodes node=torch.tensor([2, 0, 3]), # Reindexed to be relative to new nodes field row=torch.tensor([1, 0]), # Reindexed to be relative to new nodes field col=torch.tensor([0, 2]), # Union between forward and backward edges edge=torch.tensor([1, 2]), # Will be part of node uid if disjoint=True batch=None, # nodes are only counted on their first sample num_sampled_nodes=[1, 1, 0, 1], # edges are only counted on their first sample num_sampled_edges=[1, 1], # Will be used as edge uid if bidirectional=True with # keep_orig_edges=True orig_row=None, # Will be used as edge uid if bidirectional=True with # keep_orig_edges=True orig_col=None, # simple concat of forward and backward metadata metadata=(None, None)) assert str(sampler_output) == str(expected_output) adv_sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [2, 2, 2, 2], } adv_sampler_input = NodeSamplerInput(input_id=None, node=torch.tensor([1, 3])) adv_sampler = BidirectionalNeighborSampler(**adv_sampler_kwargs) adv_sampler_output = adv_sampler.sample_from_nodes(adv_sampler_input) adv_expected_output = SamplerOutput( node=torch.tensor([1, 3, 0, 2]), row=torch.tensor([2, 3, 2]), col=torch.tensor([0, 1, 3]), edge=torch.tensor([0, 2, 1]), batch=None, # 8 _sample calls total, each have 2 num_sampled_nodes slots num_sampled_nodes=[2, 2] + [0] * 14, num_sampled_edges=[2, 0, 1, 0, 0, 0, 0, 0], orig_row=None, orig_col=None, metadata=(None, None)) assert str(adv_sampler_output) == str(adv_expected_output) adv_sampler_kwargs['disjoint'] = True adv_sampler_disjoint = BidirectionalNeighborSampler(**adv_sampler_kwargs) adv_sampler_disjoint_output = adv_sampler_disjoint.sample_from_nodes( adv_sampler_input) adv_expected_disjoint_output = SamplerOutput( node=torch.tensor([1, 3, 0, 2, 0, 2, 1, 3]), row=torch.tensor([2, 3, 4, 2, 4, 5]), col=torch.tensor([0, 1, 3, 5, 6, 7]), edge=torch.tensor([0, 2, 1, 1, 0, 2]), batch=torch.tensor([0, 1, 0, 1, 1, 0, 1, 0]), num_sampled_nodes=[ # First forward iteration: # (Bob, Dave) -> (Alice, Carol) 2, 2, # First backward iteration: # (Bob (seen), Dave (seen)) -> (None, None) 0, 0, # Second forward iteration: # (Alice (seen), Carol (seen)) -> (None, Alice) 0, 1, # Second backward iteration: # (Alice (seen), Carol (seen)) -> # (Bob (seen) and Carol, Alice and Dave (seen)) 0, 1, # Third forward iteration: # (Carol (seen), Alice (seen)) -> (Alice (seen), None) 0, 0, # Third backward iteration: # (Alice (seen), Carol (seen)) -> (Bob and Carol (seen), Dave) 0, 2, # Fourth forward iteration: # (Bob (seen), Dave (seen)) -> (Alice (seen), Carol (seen)) 0, 0, # Fourth backward iteration: # (Bob (seen), Dave (seen)) -> (None, None) 0, 0 ], num_sampled_edges=[2, 0, 1, 1, 0, 2, 0, 0], orig_row=None, orig_col=None, metadata=(None, None)) assert str(adv_sampler_disjoint_output) == str( adv_expected_disjoint_output) @pytest.mark.skip( reason="BidirectionalSampler not implemented yet for heterogeneous graphs." ) @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) def test_bidirectional_neighbor_sampler_hetero(input_type): raise NotImplementedError @onlyNeighborSampler @pytest.mark.parametrize('input_type', ['data', 'remote']) @pytest.mark.parametrize('hetero', [False, True]) def test_neighbor_sampler_backwards_not_supported(input_type, hetero): graph_to_sample = _init_graph_to_sample(input_type, hetero=hetero) sampler_kwargs = { 'data': graph_to_sample, 'num_neighbors': [1], 'sample_direction': 'backward', 'time_attr': 'time' } with pytest.raises(NotImplementedError): NeighborSampler(**sampler_kwargs) ================================================ FILE: test/test_config_mixin.py ================================================ from dataclasses import asdict, dataclass, is_dataclass from typing import Sequence import pytest import torch from torch_geometric.config_mixin import ConfigMixin from torch_geometric.config_store import clear_config_store, register @pytest.fixture(scope="session", autouse=True) def teardown_once(): yield # This allows tests to run before teardown is executed clear_config_store() @dataclass class Dataclass: x: int y: int class Base(torch.nn.Module, ConfigMixin): pass @register(with_target=True) class Module(Base): def __init__(self, x: int, data: Dataclass): super().__init__() self.x = x self.data = data @register(with_target=True) class SubModule(Base): def __init__(self, p: float): super().__init__() self.p = p @register(with_target=True) class CompoundModule(torch.nn.Module, ConfigMixin): def __init__( self, z: str, module: Module, submodules: list[SubModule], key_modules: dict[str, torch.nn.Module], ): super().__init__() self.z = z self.module = module self.submodules = torch.nn.ModuleList(submodules) self.key_modules = torch.nn.ModuleDict(key_modules) def test_config_mixin() -> None: x = 0 data = Dataclass(x=1, y=2) model = Module(x, data) cfg = model.config() assert is_dataclass(cfg) assert cfg.x == 0 assert isinstance(cfg.data, Dataclass) assert cfg.data.x == 1 assert cfg.data.y == 2 assert cfg._target_ == 'test_config_mixin.Module' model = Module.from_config(cfg) assert isinstance(model, Module) assert model.x == 0 assert isinstance(model.data, Dataclass) assert model.data.x == 1 assert model.data.y == 2 model = Base.from_config(cfg) assert isinstance(model, Module) assert model.x == 0 assert isinstance(model.data, Dataclass) assert model.data.x == 1 assert model.data.y == 2 model = Base.from_config(cfg, 3) assert isinstance(model, Module) assert model.x == 3 assert isinstance(model.data, Dataclass) assert model.data.x == 1 assert model.data.y == 2 model = Base.from_config(cfg, data=Dataclass(x=2, y=3)) assert isinstance(model, Module) assert model.x == 0 assert isinstance(model.data, Dataclass) assert model.data.x == 2 assert model.data.y == 3 cfg = asdict(cfg) model = Module.from_config(cfg) assert isinstance(model, Module) assert model.x == 0 assert isinstance(model.data, dict) assert model.data['x'] == 1 assert model.data['y'] == 2 model = Base.from_config(cfg) assert isinstance(model, Module) assert model.x == 0 assert isinstance(model.data, dict) assert model.data['x'] == 1 assert model.data['y'] == 2 def test_config_mixin_compound() -> None: module = Module(x=0, data=Dataclass(x=1, y=2)) submodules = [SubModule(1.41), SubModule(3.14)] key_modules = { "key1": Module(x=10, data=Dataclass(x=11, y=12)), "key2": SubModule(2.71), } model = CompoundModule(z="foo", module=module, submodules=submodules, key_modules=key_modules) cfg = model.config() assert is_dataclass(cfg) assert cfg._target_ == 'test_config_mixin.CompoundModule' assert cfg.z == "foo" assert cfg.module._target_ == 'test_config_mixin.Module' assert cfg.module.x == 0 assert isinstance(cfg.module.data, Dataclass) assert cfg.module.data.x == 1 assert cfg.module.data.y == 2 assert len(cfg.submodules) == 2 assert isinstance(cfg.submodules, Sequence) assert cfg.submodules[0]._target_ == 'test_config_mixin.SubModule' assert cfg.submodules[0].p == 1.41 assert cfg.submodules[1]._target_ == 'test_config_mixin.SubModule' assert cfg.submodules[1].p == 3.14 assert len(cfg.key_modules) == 2 assert cfg.key_modules["key1"]._target_ == 'test_config_mixin.Module' assert cfg.key_modules["key1"].x == 10 assert isinstance(cfg.key_modules["key1"].data, Dataclass) assert cfg.key_modules["key1"].data.x == 11 assert cfg.key_modules["key1"].data.y == 12 assert cfg.key_modules["key2"]._target_ == 'test_config_mixin.SubModule' assert cfg.key_modules["key2"].p == 2.71 model = CompoundModule.from_config(cfg) assert isinstance(model, CompoundModule) assert model.z == "foo" assert isinstance(model.module, Module) assert model.module.x == 0 assert isinstance(model.module.data, Dataclass) assert model.module.data.x == 1 assert model.module.data.y == 2 assert isinstance(model.submodules, torch.nn.ModuleList) assert len(model.submodules) == 2 assert isinstance(model.submodules[0], SubModule) assert model.submodules[0].p == 1.41 assert isinstance(model.submodules[1], SubModule) assert model.submodules[1].p == 3.14 assert isinstance(model.key_modules, torch.nn.ModuleDict) assert len(model.key_modules) == 2 assert isinstance(model.key_modules["key1"], Module) assert model.key_modules["key1"].x == 10 assert isinstance(model.key_modules["key1"].data, Dataclass) assert model.key_modules["key1"].data.x == 11 assert model.key_modules["key1"].data.y == 12 assert isinstance(model.key_modules["key2"], SubModule) assert model.key_modules["key2"].p == 2.71 ================================================ FILE: test/test_config_store.py ================================================ from typing import Any, Dict, List, Tuple from torch_geometric.config_store import ( class_from_dataclass, clear_config_store, dataclass_from_class, fill_config_store, get_config_store, map_annotation, register, to_dataclass, ) from torch_geometric.testing import withPackage from torch_geometric.transforms import AddSelfLoops def teardown_function(): clear_config_store() def test_to_dataclass(): from torch_geometric.transforms import AddSelfLoops AddSelfLoopsConfig = to_dataclass(AddSelfLoops, with_target=True) assert AddSelfLoopsConfig.__name__ == 'AddSelfLoops' fields = AddSelfLoopsConfig.__dataclass_fields__ assert fields['attr'].name == 'attr' assert fields['attr'].type == str assert fields['attr'].default == 'edge_weight' assert fields['fill_value'].name == 'fill_value' assert fields['fill_value'].type == Any assert fields['fill_value'].default == 1.0 assert fields['_target_'].name == '_target_' assert fields['_target_'].type == str assert fields['_target_'].default == ( 'torch_geometric.transforms.add_self_loops.AddSelfLoops') cfg = AddSelfLoopsConfig() assert str(cfg) == ("AddSelfLoops(attr='edge_weight', fill_value=1.0, " "_target_='torch_geometric.transforms.add_self_loops." "AddSelfLoops')") def test_map_annotation(): mapping = {int: Any} assert map_annotation(dict[str, int], mapping) == dict[str, Any] assert map_annotation(Dict[str, float], mapping) == Dict[str, float] assert map_annotation(List[str], mapping) == List[str] assert map_annotation(List[int], mapping) == List[Any] assert map_annotation(Tuple[int], mapping) == Tuple[Any] assert map_annotation(dict[str, int], mapping) == dict[str, Any] assert map_annotation(dict[str, float], mapping) == dict[str, float] assert map_annotation(list[str], mapping) == list[str] assert map_annotation(list[int], mapping) == list[Any] assert map_annotation(tuple[int], mapping) == tuple[Any] def test_register(): register(AddSelfLoops, group='transform') assert 'transform' in get_config_store().repo AddSelfLoopsConfig = dataclass_from_class('AddSelfLoops') Cls = class_from_dataclass('AddSelfLoops') assert Cls == AddSelfLoops Cls = class_from_dataclass(AddSelfLoopsConfig) assert Cls == AddSelfLoops ConfigCls = dataclass_from_class('AddSelfLoops') assert ConfigCls == AddSelfLoopsConfig ConfigCls = dataclass_from_class(ConfigCls) assert ConfigCls == AddSelfLoopsConfig def test_fill_config_store(): fill_config_store() assert { 'transform', 'dataset', 'model', 'optimizer', 'lr_scheduler', }.issubset(get_config_store().repo.keys()) @withPackage('hydra') def test_hydra_config_store(): import hydra from omegaconf import DictConfig fill_config_store() with hydra.initialize(config_path='.', version_base='1.1'): cfg = hydra.compose(config_name='my_config') assert len(cfg) == 4 assert 'dataset' in cfg assert 'model' in cfg assert 'optimizer' in cfg assert 'lr_scheduler' in cfg # Check `cfg.dataset`: assert len(cfg.dataset) == 2 assert cfg.dataset._target_.split('.')[-1] == 'KarateClub' # Check `cfg.dataset.transform`: assert isinstance(cfg.dataset.transform, DictConfig) assert len(cfg.dataset.transform) == 2 assert 'NormalizeFeatures' in cfg.dataset.transform assert 'AddSelfLoops' in cfg.dataset.transform assert isinstance(cfg.dataset.transform.NormalizeFeatures, DictConfig) assert (cfg.dataset.transform.NormalizeFeatures._target_.split('.')[-1] == 'NormalizeFeatures') assert cfg.dataset.transform.NormalizeFeatures.attrs == ['x'] assert isinstance(cfg.dataset.transform.AddSelfLoops, DictConfig) assert (cfg.dataset.transform.AddSelfLoops._target_.split('.')[-1] == 'AddSelfLoops') assert cfg.dataset.transform.AddSelfLoops.attr == 'edge_weight' assert cfg.dataset.transform.AddSelfLoops.fill_value == 1.0 # Check `cfg.model`: assert len(cfg.model) == 12 assert cfg.model._target_.split('.')[-1] == 'GCN' assert cfg.model.in_channels == 34 assert cfg.model.out_channels == 4 assert cfg.model.hidden_channels == 16 assert cfg.model.num_layers == 2 assert cfg.model.dropout == 0.0 assert cfg.model.act == 'relu' assert cfg.model.norm is None assert cfg.model.norm_kwargs is None assert cfg.model.jk is None assert not cfg.model.act_first assert cfg.model.act_kwargs is None # Check `cfg.optimizer`: assert cfg.optimizer._target_.split('.')[-1] == 'Adam' assert cfg.optimizer.lr == 0.001 assert cfg.optimizer.betas == [0.9, 0.999] assert cfg.optimizer.eps == 1e-08 assert cfg.optimizer.weight_decay == 0 assert not cfg.optimizer.amsgrad if hasattr(cfg.optimizer, 'maximize'): assert not cfg.optimizer.maximize # Check `cfg.lr_scheduler`: assert cfg.lr_scheduler._target_.split('.')[-1] == 'ReduceLROnPlateau' assert cfg.lr_scheduler.mode == 'min' assert cfg.lr_scheduler.factor == 0.1 assert cfg.lr_scheduler.patience == 10 assert cfg.lr_scheduler.threshold == 0.0001 assert cfg.lr_scheduler.threshold_mode == 'rel' assert cfg.lr_scheduler.cooldown == 0 assert cfg.lr_scheduler.min_lr == 0 assert cfg.lr_scheduler.eps == 1e-08 ================================================ FILE: test/test_debug.py ================================================ from torch_geometric import debug, is_debug_enabled, set_debug def test_debug(): assert is_debug_enabled() is False set_debug(True) assert is_debug_enabled() is True set_debug(False) assert is_debug_enabled() is False assert is_debug_enabled() is False with set_debug(True): assert is_debug_enabled() is True assert is_debug_enabled() is False assert is_debug_enabled() is False set_debug(True) assert is_debug_enabled() is True with set_debug(False): assert is_debug_enabled() is False assert is_debug_enabled() is True set_debug(False) assert is_debug_enabled() is False assert is_debug_enabled() is False with debug(): assert is_debug_enabled() is True assert is_debug_enabled() is False ================================================ FILE: test/test_edge_index.py ================================================ import os.path as osp import warnings from typing import List, Optional import numpy as np import pytest import torch from torch import Tensor, tensor import torch_geometric from torch_geometric import EdgeIndex, Index from torch_geometric.edge_index import ( ReduceType, SortReturnType, _scatter_spmm, _torch_sparse_spmm, _TorchSPMM, set_tuple_item, ) from torch_geometric.io import fs from torch_geometric.profile import benchmark from torch_geometric.testing import ( onlyCUDA, onlyLinux, withCUDA, withoutExtensions, withPackage, ) from torch_geometric.typing import INDEX_DTYPES, SparseTensor from torch_geometric.utils import scatter DTYPES = [pytest.param(dtype, id=str(dtype)[6:]) for dtype in INDEX_DTYPES] IS_UNDIRECTED = [ pytest.param(False, id='directed'), pytest.param(True, id='undirected'), ] TRANSPOSE = [ pytest.param(False, id=''), pytest.param(True, id='transpose'), ] @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_basic(dtype, device): kwargs = dict(dtype=dtype, device=device, sparse_size=(3, 3)) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) adj.validate() assert isinstance(adj, EdgeIndex) assert str(adj).startswith('EdgeIndex([[0, 1, 1, 2],\n' ' [1, 0, 2, 1]], ') assert 'sparse_size=(3, 3), nnz=4' in str(adj) assert (f"device='{device}'" in str(adj)) == adj.is_cuda assert (f'dtype={dtype}' in str(adj)) == (dtype != torch.long) assert adj.dtype == dtype assert adj.device == device assert adj.sparse_size() == (3, 3) assert adj.sparse_size(0) == 3 assert adj.sparse_size(-1) == 3 assert adj.sort_order is None assert not adj.is_sorted assert not adj.is_sorted_by_row assert not adj.is_sorted_by_col assert not adj.is_undirected out = adj.as_tensor() assert not isinstance(out, EdgeIndex) assert out.dtype == dtype assert out.device == device out = adj * 1 assert not isinstance(out, EdgeIndex) assert out.dtype == dtype assert out.device == device @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_identity(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs) out = EdgeIndex(adj) assert not isinstance(out.as_tensor(), EdgeIndex) assert out.data_ptr() == adj.data_ptr() assert out.dtype == adj.dtype assert out.device == adj.device assert out.sparse_size() == adj.sparse_size() assert out.sort_order == adj.sort_order assert out.is_undirected == adj.is_undirected out = EdgeIndex(adj, sparse_size=(4, 4), sort_order='row') assert out.sparse_size() == (4, 4) assert out.sort_order == 'row' @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_sparse_tensor(dtype, device): kwargs = dict(dtype=dtype, device=device, is_undirected=True) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = EdgeIndex(adj.to_sparse_coo()) assert out.equal(adj) assert out.sort_order == 'row' assert out.sparse_size() == (3, 3) assert out._indptr is None out = EdgeIndex(adj.to_sparse_csr()) assert out.equal(adj) assert out.sort_order == 'row' assert out.sparse_size() == (3, 3) assert out._indptr.equal(tensor([0, 1, 3, 4], device=device)) out = EdgeIndex(adj.to_sparse_csc()) assert out.equal(adj.sort_by('col')[0]) assert out.sort_order == 'col' assert out.sparse_size() == (3, 3) assert out._indptr.equal(tensor([0, 1, 3, 4], device=device)) def test_set_tuple_item(): tmp = (0, 1, 2) assert set_tuple_item(tmp, 0, 3) == (3, 1, 2) assert set_tuple_item(tmp, 1, 3) == (0, 3, 2) assert set_tuple_item(tmp, 2, 3) == (0, 1, 3) with pytest.raises(IndexError, match="tuple index out of range"): set_tuple_item(tmp, 3, 3) assert set_tuple_item(tmp, -1, 3) == (0, 1, 3) assert set_tuple_item(tmp, -2, 3) == (0, 3, 2) assert set_tuple_item(tmp, -3, 3) == (3, 1, 2) with pytest.raises(IndexError, match="tuple index out of range"): set_tuple_item(tmp, -4, 3) def test_validate(): with pytest.raises(TypeError, match="tensors of a single element"): EdgeIndex([torch.tensor([0, 1]), torch.tensor([1, 0])]) with pytest.raises(ValueError, match="unsupported data type"): EdgeIndex([[0.0, 1.0], [1.0, 0.0]]) with pytest.raises(ValueError, match="needs to be two-dimensional"): EdgeIndex([[[0], [1]], [[1], [0]]]) with pytest.raises(ValueError, match="needs to have a shape of"): EdgeIndex([[0, 1], [1, 0], [1, 1]]) with pytest.raises(ValueError, match="received a non-symmetric size"): EdgeIndex([[0, 1], [1, 0]], is_undirected=True, sparse_size=(2, 3)) with pytest.raises(TypeError, match="invalid combination of arguments"): EdgeIndex(tensor([[0, 1], [1, 0]]), torch.long) with pytest.raises(TypeError, match="invalid keyword arguments"): EdgeIndex(tensor([[0, 1], [1, 0]]), dtype=torch.long) with pytest.raises(ValueError, match="contains negative indices"): EdgeIndex([[-1, 0], [0, 1]]).validate() with pytest.raises(ValueError, match="than its number of rows"): EdgeIndex([[0, 10], [1, 0]], sparse_size=(2, 2)).validate() with pytest.raises(ValueError, match="than its number of columns"): EdgeIndex([[0, 1], [10, 0]], sparse_size=(2, 2)).validate() with pytest.raises(ValueError, match="not sorted by row indices"): EdgeIndex([[1, 0], [0, 1]], sort_order='row').validate() with pytest.raises(ValueError, match="not sorted by column indices"): EdgeIndex([[0, 1], [1, 0]], sort_order='col').validate() @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_undirected(dtype, device): kwargs = dict(dtype=dtype, device=device, is_undirected=True) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) assert isinstance(adj, EdgeIndex) assert adj.is_undirected assert adj.sparse_size() == (None, None) adj.get_num_rows() assert adj.sparse_size() == (3, 3) adj.validate() adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(3, None), **kwargs) assert adj.sparse_size() == (3, 3) adj.validate() adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(None, 3), **kwargs) assert adj.sparse_size() == (3, 3) adj.validate() with pytest.raises(ValueError, match="'EdgeIndex' is not undirected"): EdgeIndex([[0, 1, 1, 2], [0, 0, 1, 1]], **kwargs).validate() @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_fill_cache_(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.validate().fill_cache_() assert adj.sparse_size() == (3, 3) assert adj._indptr.dtype == dtype assert adj._indptr.equal(tensor([0, 1, 3, 4], device=device)) assert adj._T_perm.dtype == dtype assert (adj._T_perm.equal(tensor([1, 0, 3, 2], device=device)) or adj._T_perm.equal(tensor([1, 3, 0, 2], device=device))) assert adj._T_index[0].dtype == dtype assert (adj._T_index[0].equal(tensor([1, 0, 2, 1], device=device)) or adj._T_index[0].equal(tensor([1, 2, 0, 1], device=device))) assert adj._T_index[1].dtype == dtype assert adj._T_index[1].equal(tensor([0, 1, 1, 2], device=device)) if is_undirected: assert adj._T_indptr is None else: assert adj._T_indptr.dtype == dtype assert adj._T_indptr.equal(tensor([0, 1, 3, 4], device=device)) adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) adj.validate().fill_cache_() assert adj.sparse_size() == (3, 3) assert adj._indptr.dtype == dtype assert adj._indptr.equal(tensor([0, 1, 3, 4], device=device)) assert (adj._T_perm.equal(tensor([1, 0, 3, 2], device=device)) or adj._T_perm.equal(tensor([1, 3, 0, 2], device=device))) assert adj._T_index[0].dtype == dtype assert adj._T_index[0].equal(tensor([0, 1, 1, 2], device=device)) assert adj._T_index[1].dtype == dtype assert (adj._T_index[1].equal(tensor([1, 0, 2, 1], device=device)) or adj._T_index[1].equal(tensor([1, 2, 0, 1], device=device))) if is_undirected: assert adj._T_indptr is None else: assert adj._T_indptr.dtype == dtype assert adj._T_indptr.equal(tensor([0, 1, 3, 4], device=device)) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_clone(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.clone() assert isinstance(out, EdgeIndex) assert out.dtype == dtype assert out.device == device assert out.is_sorted_by_row assert out.is_undirected == is_undirected out = torch.clone(adj) assert isinstance(out, EdgeIndex) assert out.dtype == dtype assert out.device == device assert out.is_sorted_by_row assert out.is_undirected == is_undirected @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_to_function(dtype, device, is_undirected): kwargs = dict(dtype=dtype, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.fill_cache_() adj = adj.to(device) assert isinstance(adj, EdgeIndex) assert adj.device == device assert adj._indptr.dtype == dtype assert adj._indptr.device == device assert adj._T_perm.dtype == dtype assert adj._T_perm.device == device out = adj.cpu() assert isinstance(out, EdgeIndex) assert out.device == torch.device('cpu') out = adj.to(torch.int) assert out.dtype == torch.int if torch_geometric.typing.WITH_PT20: assert isinstance(out, EdgeIndex) assert out._indptr.dtype == torch.int assert out._T_perm.dtype == torch.int else: assert not isinstance(out, EdgeIndex) out = adj.to(torch.float) assert not isinstance(out, EdgeIndex) assert out.dtype == torch.float out = adj.long() assert isinstance(out, EdgeIndex) assert out.dtype == torch.int64 out = adj.int() assert out.dtype == torch.int if torch_geometric.typing.WITH_PT20: assert isinstance(out, EdgeIndex) else: assert not isinstance(out, EdgeIndex) @onlyCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_cpu_cuda(dtype): kwargs = dict(dtype=dtype) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) assert adj.is_cpu out = adj.cuda() assert isinstance(out, EdgeIndex) assert out.is_cuda out = out.cpu() assert isinstance(out, EdgeIndex) assert out.is_cpu @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_share_memory(dtype, device): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.fill_cache_() out = adj.share_memory_() assert isinstance(out, EdgeIndex) assert out.is_shared() assert out._data.is_shared() assert out._indptr.is_shared() assert out.data_ptr() == adj.data_ptr() @onlyCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_pin_memory(dtype): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=dtype) assert not adj.is_pinned() out = adj.pin_memory() assert out.is_pinned() @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_contiguous(dtype, device): kwargs = dict(dtype=dtype, device=device) data = tensor([[0, 1], [1, 0], [1, 2], [2, 1]], **kwargs).t() with pytest.raises(ValueError, match="needs to be contiguous"): EdgeIndex(data) adj = EdgeIndex(data.contiguous()).contiguous() assert isinstance(adj, EdgeIndex) assert adj.is_contiguous() @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_sort_by(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.sort_by('row') assert isinstance(out, SortReturnType) assert isinstance(out.values, EdgeIndex) assert not isinstance(out.indices, EdgeIndex) assert out.values.equal(adj) assert out.indices is None adj = EdgeIndex([[0, 1, 2, 1], [1, 0, 1, 2]], **kwargs) out = adj.sort_by('row') assert isinstance(out, SortReturnType) assert isinstance(out.values, EdgeIndex) assert not isinstance(out.indices, EdgeIndex) assert out.values[0].equal(tensor([0, 1, 1, 2], device=device)) assert (out.values[1].equal(tensor([1, 0, 2, 1], device=device)) or out.values[1].equal(tensor([1, 2, 0, 1], device=device))) assert (out.indices.equal(tensor([0, 1, 3, 2], device=device)) or out.indices.equal(tensor([0, 3, 1, 2], device=device))) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out, perm = adj.sort_by('col') assert adj._T_perm is not None # Check caches. assert adj._T_index[0] is not None and adj._T_index[1] is not None assert (out[0].equal(tensor([1, 0, 2, 1], device=device)) or out[0].equal(tensor([1, 2, 0, 1], device=device))) assert out[1].equal(tensor([0, 1, 1, 2], device=device)) assert (perm.equal(tensor([1, 0, 3, 2], device=device)) or perm.equal(tensor([1, 3, 0, 2], device=device))) assert out._T_perm is None assert out._T_index[0] is None and out._T_index[1] is None out, perm = out.sort_by('row') assert out[0].equal(tensor([0, 1, 1, 2], device=device)) assert (out[1].equal(tensor([1, 0, 2, 1], device=device)) or out[1].equal(tensor([1, 2, 0, 1], device=device))) assert (perm.equal(tensor([1, 0, 3, 2], device=device)) or perm.equal(tensor([2, 3, 0, 1], device=device))) assert out._T_perm is None assert out._T_index[0] is None and out._T_index[1] is None @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_cat(dtype, device, is_undirected): args = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **args) adj2 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], sparse_size=(4, 4), **args) adj3 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], dtype=dtype, device=device) out = torch.cat([adj1, adj2], dim=1) assert out.size() == (2, 8) assert isinstance(out, EdgeIndex) assert out.sparse_size() == (4, 4) assert not out.is_sorted assert out.is_undirected == is_undirected assert out._cat_metadata.nnz == [4, 4] assert out._cat_metadata.sparse_size == [(3, 3), (4, 4)] assert out._cat_metadata.sort_order == [None, None] assert out._cat_metadata.is_undirected == [is_undirected, is_undirected] out = torch.cat([adj1, adj2, adj3], dim=1) assert out.size() == (2, 12) assert isinstance(out, EdgeIndex) assert out.sparse_size() == (None, None) assert not out.is_sorted assert not out.is_undirected out = torch.cat([adj1, adj2], dim=0) assert out.size() == (4, 4) assert not isinstance(out, EdgeIndex) inplace = torch.empty(2, 8, dtype=dtype, device=device) out = torch.cat([adj1, adj2], dim=1, out=inplace) assert out.data_ptr() == inplace.data_ptr() assert not isinstance(out, EdgeIndex) assert not isinstance(inplace, EdgeIndex) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_flip(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.fill_cache_() out = adj.flip(0) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[1, 0, 2, 1], [0, 1, 1, 2]], device=device)) assert out.is_sorted_by_col assert out.is_undirected == is_undirected assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device)) out = adj.flip([0, 1]) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[1, 2, 0, 1], [2, 1, 1, 0]], device=device)) assert not out.is_sorted assert out.is_undirected == is_undirected assert out._T_indptr is None adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) out = adj.flip(0) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)) assert out.is_sorted_by_row assert out.is_undirected == is_undirected @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_index_select(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) index = tensor([1, 3], device=device) out = adj.index_select(1, index) assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert isinstance(out, EdgeIndex) assert not out.is_sorted assert not out.is_undirected index = tensor([0], device=device) out = adj.index_select(0, index) assert out.equal(tensor([[0, 1, 1, 2]], device=device)) assert not isinstance(out, EdgeIndex) index = tensor([1, 3], device=device) inplace = torch.empty(2, 2, dtype=dtype, device=device) out = torch.index_select(adj, 1, index, out=inplace) assert out.data_ptr() == inplace.data_ptr() assert not isinstance(out, EdgeIndex) assert not isinstance(inplace, EdgeIndex) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_narrow(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.narrow(dim=1, start=1, length=2) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[1, 1], [0, 2]], device=device)) assert out.is_sorted_by_row assert not out.is_undirected out = adj.narrow(dim=0, start=0, length=1) assert not isinstance(out, EdgeIndex) assert out.equal(tensor([[0, 1, 1, 2]], device=device)) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_getitem(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj[:, tensor([False, True, False, True], device=device)] assert isinstance(out, EdgeIndex) assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert out.is_sorted_by_row assert not out.is_undirected out = adj[..., tensor([1, 3], device=device)] assert isinstance(out, EdgeIndex) assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert not out.is_sorted assert not out.is_undirected out = adj[..., 1::2] assert isinstance(out, EdgeIndex) assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert out.is_sorted_by_row assert not out.is_undirected out = adj[...] assert isinstance(out, EdgeIndex) assert out.equal(tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)) assert out.is_sorted_by_row assert out.is_undirected == is_undirected out = adj[None] assert not isinstance(out, EdgeIndex) assert out.equal(tensor([[[0, 1, 1, 2], [1, 0, 2, 1]]], device=device)) out = adj[0, 0] assert not isinstance(out, EdgeIndex) assert out.equal(tensor(0, device=device)) out = adj[:, 0] assert not isinstance(out, EdgeIndex) out = adj[tensor([0], device=device)] assert not isinstance(out, EdgeIndex) out = adj[tensor([0], device=device), tensor([0], device=device)] assert not isinstance(out, EdgeIndex) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_select(dtype, device): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', sparse_size=(4, 5), **kwargs, ).fill_cache_() out = adj[0] assert isinstance(out, Index) assert out.equal(tensor([0, 1, 1, 2], device=device)) assert out.dim_size == 4 assert out.is_sorted assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device)) out = adj[-1] assert isinstance(out, Index) assert out.equal(tensor([1, 0, 2, 1], device=device)) assert out.dim_size == 5 assert not out.is_sorted assert out._indptr is None out = adj[-2, 2:4] assert isinstance(out, Index) assert out.equal(tensor([1, 2], device=device)) assert out.dim_size == 4 assert out.is_sorted assert out._indptr is None adj = EdgeIndex( [[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', sparse_size=(5, 4), **kwargs, ).fill_cache_() out = adj[1] assert isinstance(out, Index) assert out.equal(tensor([0, 1, 1, 2], device=device)) assert out.dim_size == 4 assert out.is_sorted assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device)) out = adj[-2] assert isinstance(out, Index) assert out.equal(tensor([1, 0, 2, 1], device=device)) assert out.dim_size == 5 assert not out.is_sorted assert out._indptr is None out = adj[-1, 2:4] assert isinstance(out, Index) assert out.equal(tensor([1, 2], device=device)) assert out.dim_size == 4 assert out.is_sorted assert out._indptr is None @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_unbind(dtype, device): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', sparse_size=(4, 5), **kwargs, ).fill_cache_() row, col = adj assert isinstance(row, Index) assert row.equal(tensor([0, 1, 1, 2], device=device)) assert row.dim_size == 4 assert row.is_sorted assert row._indptr.equal(tensor([0, 1, 3, 4, 4], device=device)) assert isinstance(col, Index) assert col.equal(tensor([1, 0, 2, 1], device=device)) assert col.dim_size == 5 assert not col.is_sorted assert col._indptr is None @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('value_dtype', [None, torch.double]) def test_to_dense(dtype, device, value_dtype): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) out = adj.to_dense(dtype=value_dtype) assert isinstance(out, Tensor) assert out.size() == (3, 3) expected = [[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]] assert out.equal(tensor(expected, dtype=value_dtype, device=device)) value = torch.arange(1, 5, dtype=value_dtype or torch.float, device=device) out = adj.to_dense(value) assert isinstance(out, Tensor) assert out.size() == (3, 3) expected = [[0.0, 2.0, 0.0], [1.0, 0.0, 4.0], [0.0, 3.0, 0.0]] assert out.equal(tensor(expected, dtype=value_dtype, device=device)) value = torch.arange(1, 5, dtype=value_dtype or torch.float, device=device) out = adj.to_dense(value.view(-1, 1)) assert isinstance(out, Tensor) assert out.size() == (3, 3, 1) expected = [ [[0.0], [2.0], [0.0]], [[1.0], [0.0], [4.0]], [[0.0], [3.0], [0.0]], ] assert out.equal(tensor(expected, dtype=value_dtype, device=device)) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_to_sparse_coo(dtype, device): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) if torch_geometric.typing.WITH_PT20: with pytest.raises(ValueError, match="Unexpected tensor layout"): adj.to_sparse(layout='int64') if torch_geometric.typing.WITH_PT20: out = adj.to_sparse(layout=torch.sparse_coo) else: out = adj.to_sparse() assert isinstance(out, Tensor) assert out.dtype == torch.float assert out.device == device assert out.layout == torch.sparse_coo assert out.size() == (3, 3) assert adj.equal(out._indices()) assert not out.is_coalesced() adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) out = adj.to_sparse_coo() assert isinstance(out, Tensor) assert out.dtype == torch.float assert out.device == device assert out.layout == torch.sparse_coo assert out.size() == (3, 3) assert adj.equal(out._indices()) assert not out.is_coalesced() adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.to_sparse_coo() assert isinstance(out, Tensor) assert out.dtype == torch.float assert out.device == device assert out.layout == torch.sparse_coo assert out.size() == (3, 3) assert adj.equal(out._indices()) assert out.is_coalesced() @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_to_sparse_csr(dtype, device): kwargs = dict(dtype=dtype, device=device) with pytest.raises(ValueError, match="not sorted"): EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs).to_sparse_csr() adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) if torch_geometric.typing.WITH_PT20: out = adj.to_sparse(layout=torch.sparse_csr) else: out = adj.to_sparse_csr() assert isinstance(out, Tensor) assert out.dtype == torch.float assert out.device == device assert out.layout == torch.sparse_csr assert out.size() == (3, 3) assert adj._indptr.equal(out.crow_indices()) assert adj[1].equal(out.col_indices()) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_to_sparse_csc(dtype, device): kwargs = dict(dtype=dtype, device=device) with pytest.raises(ValueError, match="not sorted"): EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs).to_sparse_csc() adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) if torch_geometric.typing.WITH_PT20: out = adj.to_sparse(layout=torch.sparse_csc) else: out = adj.to_sparse_csc() assert isinstance(out, Tensor) assert out.dtype == torch.float assert out.device == device assert out.layout == torch.sparse_csc assert out.size() == (3, 3) assert adj._indptr.equal(out.ccol_indices()) assert adj[0].equal(out.row_indices()) @withCUDA @withPackage('torch_sparse') def test_to_sparse_tensor(device): kwargs = dict(device=device) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) out = adj.to_sparse_tensor() assert isinstance(out, SparseTensor) assert out.sizes() == [3, 3] row, col, _ = out.coo() assert row.equal(adj[0]) assert col.equal(adj[1]) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_add(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs) out = torch.add(adj, 2, alpha=2) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[4, 5, 5, 6], [5, 4, 6, 5]], device=device)) assert out.is_undirected == is_undirected assert out.sparse_size() == (7, 7) out = adj + tensor([2], dtype=dtype, device=device) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device)) assert out.is_undirected == is_undirected assert out.sparse_size() == (5, 5) out = adj + tensor([[2], [1]], dtype=dtype, device=device) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[2, 3, 3, 4], [2, 1, 3, 2]], device=device)) assert not out.is_undirected assert out.sparse_size() == (5, 4) out = adj + tensor([[2], [2]], dtype=dtype, device=device) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device)) assert out.is_undirected == is_undirected assert out.sparse_size() == (5, 5) out = adj.add(adj) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[0, 2, 2, 4], [2, 0, 4, 2]], device=device)) assert not out.is_undirected assert out.sparse_size() == (6, 6) adj += 2 assert isinstance(adj, EdgeIndex) assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device)) assert adj.is_undirected == is_undirected assert adj.sparse_size() == (5, 5) with pytest.raises(RuntimeError, match="can't be cast"): adj += 2.5 @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_sub(dtype, device, is_undirected): kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj = EdgeIndex([[4, 5, 5, 6], [5, 4, 6, 5]], sparse_size=(7, 7), **kwargs) out = torch.sub(adj, 2, alpha=2) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)) assert out.is_undirected == is_undirected assert out.sparse_size() == (3, 3) out = adj - tensor([2], dtype=dtype, device=device) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device)) assert out.is_undirected == is_undirected assert out.sparse_size() == (5, 5) out = adj - tensor([[2], [1]], dtype=dtype, device=device) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[2, 3, 3, 4], [4, 3, 5, 4]], device=device)) assert not out.is_undirected assert out.sparse_size() == (5, 6) out = adj - tensor([[2], [2]], dtype=dtype, device=device) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device)) assert out.is_undirected == is_undirected assert out.sparse_size() == (5, 5) out = adj.sub(adj) assert isinstance(out, EdgeIndex) assert out.equal(tensor([[0, 0, 0, 0], [0, 0, 0, 0]], device=device)) assert not out.is_undirected assert out.sparse_size() == (None, None) adj -= 2 assert isinstance(adj, EdgeIndex) assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device)) assert adj.is_undirected == is_undirected assert adj.sparse_size() == (5, 5) with pytest.raises(RuntimeError, match="can't be cast"): adj -= 2.5 @withCUDA @withPackage('torch_sparse') @pytest.mark.parametrize('reduce', ReduceType.__args__) @pytest.mark.parametrize('transpose', TRANSPOSE) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_torch_sparse_spmm(device, reduce, transpose, is_undirected): if is_undirected: kwargs = dict(is_undirected=True) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs) else: adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device) adj = adj.sort_by('col' if transpose else 'row').values # Basic: x = torch.randn(3, 1, device=device) out = _torch_sparse_spmm(adj, x, None, reduce, transpose) exp = _scatter_spmm(adj, x, None, reduce, transpose) assert out.allclose(exp, atol=1e-6) # With non-zero values: x = torch.randn(3, 1, device=device) value = torch.rand(adj.size(1), device=device) out = _torch_sparse_spmm(adj, x, value, reduce, transpose) exp = _scatter_spmm(adj, x, value, reduce, transpose) assert out.allclose(exp, atol=1e-6) # Gradients w.r.t. other: x1 = torch.randn(3, 1, device=device, requires_grad=True) x2 = x1.detach().requires_grad_() grad = torch.randn_like(x1) out = _torch_sparse_spmm(adj, x1, None, reduce, transpose) out.backward(grad) exp = _scatter_spmm(adj, x2, None, reduce, transpose) exp.backward(grad) assert x1.grad.allclose(x2.grad, atol=1e-6) # Gradients w.r.t. value: x = torch.randn(3, 1, device=device) value1 = torch.rand(adj.size(1), device=device, requires_grad=True) value2 = value1.detach().requires_grad_() grad = torch.randn_like(x) out = _torch_sparse_spmm(adj, x, value1, reduce, transpose) out.backward(grad) exp = _scatter_spmm(adj, x, value2, reduce, transpose) exp.backward(grad) assert value1.grad.allclose(value2.grad, atol=1e-6) @withCUDA @pytest.mark.parametrize('reduce', ReduceType.__args__) @pytest.mark.parametrize('transpose', TRANSPOSE) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_torch_spmm(device, reduce, transpose, is_undirected): if is_undirected: kwargs = dict(is_undirected=True) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs) else: adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device) adj, perm = adj.sort_by('col' if transpose else 'row') # Basic: x = torch.randn(3, 2, device=device) if ((not x.is_cuda and torch_geometric.typing.WITH_PT20) or reduce in ['sum', 'add']): out = _TorchSPMM.apply(adj, x, None, reduce, transpose) exp = _scatter_spmm(adj, x, None, reduce, transpose) assert out.allclose(exp) else: with pytest.raises(AssertionError): _TorchSPMM.apply(adj, x, None, reduce, transpose) # With non-zero values: x = torch.randn(3, 1, device=device) value = torch.rand(adj.size(1), device=device) if ((not x.is_cuda and torch_geometric.typing.WITH_PT20) or reduce in ['sum', 'add']): out = _TorchSPMM.apply(adj, x, value, reduce, transpose) exp = _scatter_spmm(adj, x, value, reduce, transpose) assert out.allclose(exp) else: with pytest.raises(AssertionError): _TorchSPMM.apply(adj, x, value, reduce, transpose) # Gradients w.r.t. other: x1 = torch.randn(3, 1, device=device, requires_grad=True) x2 = x1.detach().requires_grad_() grad = torch.randn_like(x1) if reduce in ['sum', 'add']: out = _TorchSPMM.apply(adj, x1, None, reduce, transpose) out.backward(grad) exp = _scatter_spmm(adj, x2, None, reduce, transpose) exp.backward(grad) assert x1.grad.allclose(x2.grad) else: with pytest.raises(AssertionError): out = _TorchSPMM.apply(adj, x1, None, reduce, transpose) out.backward(grad) # Gradients w.r.t. value: x = torch.randn(3, 1, device=device) value1 = torch.rand(adj.size(1), device=device, requires_grad=True) grad = torch.randn_like(x) with pytest.raises((AssertionError, NotImplementedError)): out = _TorchSPMM.apply(adj, x, value1, reduce, transpose) out.backward(grad) @withCUDA @withoutExtensions @pytest.mark.parametrize('reduce', ReduceType.__args__) @pytest.mark.parametrize('transpose', TRANSPOSE) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_spmm(without_extensions, device, reduce, transpose, is_undirected): warnings.filterwarnings('ignore', '.*can be accelerated via.*') if is_undirected: kwargs = dict(is_undirected=True) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs) else: adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device) adj = adj.sort_by('col' if transpose else 'row').values # Basic: x = torch.randn(3, 1, device=device) with pytest.raises(ValueError, match="to be sorted by"): adj.matmul(x, reduce=reduce, transpose=not transpose) out = adj.matmul(x, reduce=reduce, transpose=transpose) exp = _scatter_spmm(adj, x, None, reduce, transpose) assert out.allclose(exp) # With non-zero values: x = torch.randn(3, 1, device=device) value = torch.rand(adj.size(1), device=device) with pytest.raises(ValueError, match="'other_value' not supported"): adj.matmul(x, reduce=reduce, other_value=value, transpose=transpose) out = adj.matmul(x, value, reduce=reduce, transpose=transpose) exp = _scatter_spmm(adj, x, value, reduce, transpose) assert out.allclose(exp) # Gradients w.r.t. other: x1 = torch.randn(3, 1, device=device, requires_grad=True) x2 = x1.detach().requires_grad_() grad = torch.randn_like(x1) out = adj.matmul(x1, reduce=reduce, transpose=transpose) out.backward(grad) exp = _scatter_spmm(adj, x2, None, reduce, transpose) exp.backward(grad) assert x1.grad.allclose(x2.grad) # Gradients w.r.t. value: x = torch.randn(3, 1, device=device) value1 = torch.rand(adj.size(1), device=device, requires_grad=True) value2 = value1.detach().requires_grad_() grad = torch.randn_like(x) out = adj.matmul(x, value1, reduce=reduce, transpose=transpose) out.backward(grad) exp = _scatter_spmm(adj, x, value2, reduce, transpose) exp.backward(grad) assert value1.grad.allclose(value2.grad) @withCUDA @pytest.mark.parametrize('reduce', ReduceType.__args__) @pytest.mark.parametrize('transpose', TRANSPOSE) @pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) def test_spspmm(device, reduce, transpose, is_undirected): if is_undirected: kwargs = dict(device=device, sort_order='row', is_undirected=True) adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) else: kwargs = dict(device=device, sort_order='row') adj1 = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], **kwargs) adj1_dense = adj1.to_dense().t() if transpose else adj1.to_dense() adj2 = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', device=device) adj2_dense = adj2.to_dense() if reduce in ['sum', 'add']: out, value = adj1.matmul(adj2, reduce=reduce, transpose=transpose) assert isinstance(out, EdgeIndex) assert out.is_sorted_by_row assert out._sparse_size == (3, 3) if not torch_geometric.typing.NO_MKL: assert out._indptr is not None assert torch.allclose(out.to_dense(value), adj1_dense @ adj2_dense) else: with pytest.raises(NotImplementedError, match="not yet supported"): adj1.matmul(adj2, reduce=reduce, transpose=transpose) @withCUDA @withoutExtensions def test_matmul(without_extensions, device): kwargs = dict(sort_order='row', device=device) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) x = torch.randn(3, 1, device=device) expected = adj.to_dense() @ x out = adj @ x assert torch.allclose(out, expected) out = adj.matmul(x) assert torch.allclose(out, expected) out = torch.mm(adj, x) assert torch.allclose(out, expected) out = torch.matmul(adj, x) assert torch.allclose(out, expected) if torch_geometric.typing.WITH_PT20: out = torch.sparse.mm(adj, x, reduce='sum') else: with pytest.raises(TypeError, match="got an unexpected keyword"): torch.sparse.mm(adj, x, reduce='sum') out = torch.sparse.mm(adj, x) assert torch.allclose(out, expected) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_sparse_row_narrow(dtype, device): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.sparse_narrow(dim=0, start=1, length=1) assert out.equal(tensor([[0, 0], [0, 2]], device=device)) assert out.sparse_size() == (1, None) assert out.sort_order == 'row' assert out._indptr.equal(tensor([0, 2], device=device)) out = adj.sparse_narrow(dim=0, start=2, length=0) assert out.equal(tensor([[], []], device=device)) assert out.sparse_size() == (0, None) assert out.sort_order == 'row' assert out._indptr is None out = adj.sparse_narrow(dim=1, start=1, length=1) assert out.equal(tensor([[0, 2], [0, 0]], device=device)) assert out.sparse_size() == (3, 1) assert out.sort_order == 'row' assert out._indptr is None out = adj.sparse_narrow(dim=1, start=2, length=0) assert out.equal(tensor([[], []], device=device)) assert out.sparse_size() == (3, 0) assert out.sort_order == 'row' assert out._indptr is None @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_sparse_col_narrow(dtype, device): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) out = adj.sparse_narrow(dim=1, start=1, length=1) assert out.equal(tensor([[0, 2], [0, 0]], device=device)) assert out.sparse_size() == (None, 1) assert out.sort_order == 'col' assert out._indptr.equal(tensor([0, 2], device=device)) out = adj.sparse_narrow(dim=1, start=2, length=0) assert out.equal(tensor([[], []], device=device)) assert out.sparse_size() == (None, 0) assert out.sort_order == 'col' assert out._indptr is None out = adj.sparse_narrow(dim=0, start=1, length=1) assert out.equal(tensor([[0, 0], [0, 2]], device=device)) assert out.sparse_size() == (1, 3) assert out.sort_order == 'col' assert out._indptr is None out = adj.sparse_narrow(dim=0, start=2, length=0) assert out.equal(tensor([[], []], device=device)) assert out.sparse_size() == (0, 3) assert out.sort_order == 'col' assert out._indptr is None @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_sparse_resize(dtype, device): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=dtype, device=device) out = adj.sort_by('row')[0].fill_cache_() assert out.sparse_size() == (3, 3) assert out._indptr.equal(tensor([0, 1, 3, 4], device=device)) assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device)) out = out.sparse_resize_(4, 5) assert out.sparse_size() == (4, 5) assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device)) assert out._T_indptr.equal(tensor([0, 1, 3, 4, 4, 4], device=device)) out = out.sparse_resize_(3, 3) assert out.sparse_size() == (3, 3) assert out._indptr.equal(tensor([0, 1, 3, 4], device=device)) assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device)) out = out.sparse_resize_(None, None) assert out.sparse_size() == (None, None) assert out._indptr is None assert out._T_indptr is None out = adj.sort_by('col')[0].fill_cache_() assert out.sparse_size() == (3, 3) assert out._indptr.equal(tensor([0, 1, 3, 4], device=device)) assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device)) out = out.sparse_resize_(4, 5) assert out.sparse_size() == (4, 5) assert out._indptr.equal(tensor([0, 1, 3, 4, 4, 4], device=device)) assert out._T_indptr.equal(tensor([0, 1, 3, 4, 4], device=device)) out = out.sparse_resize_(3, 3) assert out.sparse_size() == (3, 3) assert out._indptr.equal(tensor([0, 1, 3, 4], device=device)) assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device)) out = out.sparse_resize_(None, None) assert out.sparse_size() == (None, None) assert out._indptr is None assert out._T_indptr is None def test_tolist(): data = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) adj = EdgeIndex(data) assert adj.tolist() == data.tolist() def test_numpy(): data = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) adj = EdgeIndex(data) assert np.array_equal(adj.numpy(), data.numpy()) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_global_mapping(device, dtype): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, dtype=dtype) n_id = tensor([10, 20, 30], device=device, dtype=dtype) expected = tensor([[10, 20, 20, 30], [20, 10, 30, 20]], device=device) out = n_id[adj] assert not isinstance(out, EdgeIndex) assert out.equal(expected) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_to_vector(device, dtype): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, dtype=dtype) out = adj.to_vector() assert not isinstance(out, EdgeIndex) assert out.equal(tensor([1, 3, 5, 7], device=device)) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_save_and_load(dtype, device, tmp_path): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.fill_cache_() assert adj.sort_order == 'row' assert adj._indptr is not None path = osp.join(tmp_path, 'edge_index.pt') torch.save(adj, path) out = fs.torch_load(path) assert isinstance(out, EdgeIndex) assert out.equal(adj) assert out.sort_order == 'row' assert out._indptr.equal(adj._indptr) def _collate_fn(edge_indices: List[EdgeIndex]) -> List[EdgeIndex]: return edge_indices @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('num_workers', [0, 2]) @pytest.mark.parametrize('pin_memory', [False, True]) def test_data_loader(dtype, num_workers, pin_memory): kwargs = dict(dtype=dtype) adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.fill_cache_() loader = torch.utils.data.DataLoader( [adj] * 4, batch_size=2, num_workers=num_workers, collate_fn=_collate_fn, pin_memory=pin_memory, drop_last=True, ) assert len(loader) == 2 for batch in loader: assert isinstance(batch, list) assert len(batch) == 2 for adj in batch: assert isinstance(adj, EdgeIndex) assert adj.dtype == dtype assert adj.is_shared() != (num_workers == 0) or pin_memory assert adj._data.is_shared() != (num_workers == 0) or pin_memory def test_torch_script(): class Model(torch.nn.Module): def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: row, col = edge_index[0], edge_index[1] x_j = x[row] out = scatter(x_j, col, dim_size=edge_index.num_cols) return out x = torch.randn(3, 8) # Test that `num_cols` gets picked up by making last node isolated. edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3)) model = Model() expected = model(x, edge_index) assert expected.size() == (3, 8) # `torch.jit.script` does not support inheritance at the `Tensor` level :( with pytest.raises(RuntimeError, match="attribute or method 'num_cols'"): torch.jit.script(model) # A valid workaround is to treat `EdgeIndex` as a regular PyTorch tensor # whenever we are in script mode: class ScriptableModel(torch.nn.Module): def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: row, col = edge_index[0], edge_index[1] x_j = x[row] dim_size: Optional[int] = None if (not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex)): dim_size = edge_index.num_cols out = scatter(x_j, col, dim_size=dim_size) return out script_model = torch.jit.script(ScriptableModel()) out = script_model(x, edge_index) assert out.size() == (2, 8) assert torch.allclose(out, expected[:2]) @onlyLinux @withPackage('torch>=2.3') @pytest.mark.skip(reason="Does not work currently") def test_compile_basic(): import torch._dynamo as dynamo class Model(torch.nn.Module): def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: x_j = x[edge_index[0]] out = scatter(x_j, edge_index[1], dim_size=edge_index.num_cols) return out x = torch.randn(3, 8) # Test that `num_cols` gets picked up by making last node isolated. edge_index = EdgeIndex( [[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3), sort_order='row', ).fill_cache_() model = Model() expected = model(x, edge_index) assert expected.size() == (3, 8) explanation = dynamo.explain(model)(x, edge_index) assert explanation.graph_break_count == 0 compiled_model = torch.compile(model, fullgraph=True) out = compiled_model(x, edge_index) assert torch.allclose(out, expected) @onlyLinux @withPackage('torch>=2.3') @pytest.mark.skip(reason="Does not work currently") def test_compile_create_edge_index(): import torch._dynamo as dynamo class Model(torch.nn.Module): def forward(self) -> EdgeIndex: # Wait for: https://github.com/pytorch/pytorch/issues/117806 edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]]) return edge_index model = Model() explanation = dynamo.explain(model)() assert explanation.graph_break_count == 0 compiled_model = torch.compile(model, fullgraph=True) assert compiled_model() is None if __name__ == '__main__': import argparse warnings.filterwarnings('ignore', ".*Sparse CSR tensor support.*") parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() channels = 128 num_nodes = 20_000 num_edges = 200_000 x = torch.randn(num_nodes, channels, device=args.device) edge_index = EdgeIndex( torch.randint(0, num_nodes, size=(2, num_edges), device=args.device), sparse_size=(num_nodes, num_nodes), ).sort_by('row')[0] edge_index.fill_cache_() adj1 = edge_index.to_sparse_csr() adj2 = SparseTensor( row=edge_index[0], col=edge_index[1], sparse_sizes=(num_nodes, num_nodes), ) def edge_index_mm(edge_index, x, reduce): return edge_index.matmul(x, reduce=reduce) def torch_sparse_mm(adj, x): return adj @ x def sparse_tensor_mm(adj, x, reduce): return adj.matmul(x, reduce=reduce) def scatter_mm(edge_index, x, reduce): return _scatter_spmm(edge_index, x, reduce=reduce) funcs = [edge_index_mm, torch_sparse_mm, sparse_tensor_mm, scatter_mm] func_names = ['edge_index', 'torch.sparse', 'SparseTensor', 'scatter'] for reduce in ['sum', 'mean', 'amin', 'amax']: func_args = [(edge_index, x, reduce), (adj1, x), (adj2, x, reduce), (edge_index, x, reduce)] print(f"reduce='{reduce}':") benchmark( funcs=funcs, func_names=func_names, args=func_args, num_steps=100 if args.device == 'cpu' else 1000, num_warmups=50 if args.device == 'cpu' else 500, backward=args.backward, ) ================================================ FILE: test/test_experimental.py ================================================ import pytest from torch_geometric import ( experimental_mode, is_experimental_mode_enabled, set_experimental_mode, ) @pytest.mark.parametrize('options', ['disable_dynamic_shapes']) def test_experimental_mode(options): assert is_experimental_mode_enabled(options) is False with experimental_mode(options): assert is_experimental_mode_enabled(options) is True assert is_experimental_mode_enabled(options) is False with set_experimental_mode(True, options): assert is_experimental_mode_enabled(options) is True assert is_experimental_mode_enabled(options) is False with set_experimental_mode(False, options): assert is_experimental_mode_enabled(options) is False assert is_experimental_mode_enabled(options) is False set_experimental_mode(True, options) assert is_experimental_mode_enabled(options) is True set_experimental_mode(False, options) assert is_experimental_mode_enabled(options) is False ================================================ FILE: test/test_hash_tensor.py ================================================ import os.path as osp from typing import List import numpy as np import pytest import torch from torch import Tensor import torch_geometric.typing from torch_geometric import HashTensor from torch_geometric.io import fs from torch_geometric.testing import ( onlyCUDA, onlyLinux, withCUDA, withHashTensor, withPackage, ) KEY_DTYPES = [ pytest.param(torch.bool, id='bool'), pytest.param(torch.uint8, id='uint8'), pytest.param(torch.int8, id='int8'), pytest.param(torch.int16, id='int16'), pytest.param(torch.int32, id='int32'), pytest.param(torch.int64, id='int64'), pytest.param(torch.float16, id='float16'), pytest.param(torch.bfloat16, id='bfloat16'), pytest.param(torch.float32, id='float32'), pytest.param(torch.float64, id='float64'), ] @withCUDA @withHashTensor @pytest.mark.parametrize('dtype', KEY_DTYPES) def test_basic(dtype, device): if dtype != torch.bool: key = torch.tensor([2, 1, 0], dtype=dtype, device=device) else: key = torch.tensor([True, False], device=device) tensor = HashTensor(key) if tensor.is_cuda: assert str(tensor) == (f"HashTensor({tensor.as_tensor().tolist()}, " f"device='{tensor.device}')") else: assert str(tensor) == f"HashTensor({tensor.as_tensor().tolist()})" assert tensor.dtype == torch.int64 assert tensor.device == device assert tensor.size() == (key.size(0), ) value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) assert str(tensor).startswith("HashTensor([") assert tensor.dtype == torch.float assert tensor.device == device assert tensor.size() == (key.size(0), 2) @withCUDA @withHashTensor @pytest.mark.parametrize('dtype', KEY_DTYPES) def test_empty(dtype, device): key = torch.empty(0, dtype=dtype, device=device) tensor = HashTensor(key) assert tensor.dtype == torch.int64 assert tensor.device == device assert tensor.size() == (0, ) out = tensor.index_select(0, torch.empty(0, dtype=dtype, device=device)) assert not isinstance(out, HashTensor) assert out.dtype == torch.int64 assert out.device == device assert out.size() == (0, ) value = torch.empty(0, device=device) tensor = HashTensor(key, value) assert tensor.dtype == value.dtype assert tensor.device == device assert tensor.size() == (0, ) out = tensor.index_select(0, torch.empty(0, dtype=dtype, device=device)) assert not isinstance(out, HashTensor) assert out.dtype == value.dtype assert out.device == device assert out.size() == (0, ) @withCUDA @withHashTensor def test_string_key(device): tensor = HashTensor(['1', '2', '3'], device=device) out = tensor[['3', '2', '4']] assert out.equal(torch.tensor([2, 1, -1], device=device)) @withCUDA @withHashTensor def test_clone(device): key = torch.tensor([2, 1, 0], device=device) value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) out = tensor.clone() assert isinstance(out, HashTensor) assert out.dtype == tensor.dtype assert out.device == tensor.device assert out._value.data_ptr() != tensor._value.data_ptr() out = torch.clone(tensor) assert isinstance(out, HashTensor) assert out.dtype == tensor.dtype assert out.device == tensor.device assert out._value.data_ptr() != tensor._value.data_ptr() @withCUDA @withHashTensor def test_share_memory(device): key = torch.tensor([2, 1, 0], device=device) value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) out = tensor.share_memory_() assert isinstance(out, HashTensor) assert out.is_shared() assert out._value.is_shared() assert out.data_ptr() == tensor.data_ptr() @onlyCUDA @withHashTensor def test_pin_memory(): key = torch.tensor([2, 1, 0]) value = torch.randn(key.size(0), 2) tensor = HashTensor(key, value) assert not tensor.is_pinned() out = tensor.pin_memory() assert isinstance(out, HashTensor) assert out.is_pinned() @withCUDA @withHashTensor def test_detach(device): key = torch.tensor([2, 1, 0], device=device) value = torch.randn(key.size(0), 2, device=device, requires_grad=True) tensor = HashTensor(key, value) assert tensor.requires_grad out = tensor.detach() assert isinstance(out, HashTensor) assert not out.requires_grad assert not out._value.requires_grad tensor.detach_() assert not tensor.requires_grad assert not tensor._value.requires_grad @withCUDA @withHashTensor def test_contiguous(device): key = torch.tensor([2, 1, 0], device=device) value = torch.randn(2, key.size(0), device=device).t() assert not value.is_contiguous() tensor = HashTensor(key, value) assert not tensor.is_contiguous() out = tensor.contiguous() assert out.is_contiguous() assert out._value.is_contiguous() @withCUDA @withHashTensor def test_save_and_load(device, tmp_path): key = torch.tensor([2, 1, 0], device=device) value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) path = osp.join(tmp_path, 'hash_tensor.pt') torch.save(tensor, path) out = fs.torch_load(path) assert isinstance(out, HashTensor) assert out._value.equal(value) assert out._min_key.equal(key.min()) assert out._max_key.equal(key.max()) @withCUDA @withHashTensor def test_to_function(device): key = torch.tensor([2, 1, 0], device=device) value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) out = tensor.to(device) assert isinstance(out, HashTensor) assert id(out) == id(tensor) assert out.device == device assert out._value.device == device assert out._min_key.device == device assert out._max_key.device == device out = tensor.to('cpu') assert isinstance(out, HashTensor) if key.is_cuda: assert id(out) != id(tensor) else: assert id(out) == id(tensor) assert out.device == torch.device('cpu') assert out._value.device == torch.device('cpu') assert out._min_key.device == torch.device('cpu') assert out._max_key.device == torch.device('cpu') out = tensor.double() assert isinstance(out, HashTensor) assert out._value.dtype == torch.double @withCUDA @withHashTensor def test_unsqueeze(device): key = torch.tensor([2, 1, 0], device=device) tensor = HashTensor(key) with pytest.raises(IndexError, match="in the first dimension"): tensor.unsqueeze(0) with pytest.raises(IndexError, match="in the first dimension"): tensor.unsqueeze(-2) with pytest.raises(IndexError, match="out of range"): tensor.unsqueeze(2) with pytest.raises(IndexError, match="out of range"): tensor.unsqueeze(-3) out = tensor.unsqueeze(-1) assert out.size() == (3, 1) assert out._value is not None out = tensor[..., None] assert out.size() == (3, 1) assert out._value is not None out = tensor[..., None, None] assert out.size() == (3, 1, 1) assert out._value is not None value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) out = tensor.unsqueeze(-1) assert out.size() == (3, 2, 1) assert out._value is not None out = tensor[..., None] assert out.size() == (3, 2, 1) assert out._value is not None out = tensor[..., None, None] assert out.size() == (3, 2, 1, 1) assert out._value is not None out = tensor.unsqueeze(1) assert out.size() == (3, 1, 2) assert out._value is not None @withCUDA @withHashTensor @pytest.mark.parametrize('num_keys', [3, 1]) def test_squeeze(num_keys, device): key = torch.tensor([2, 1, 0][:num_keys], device=device) tensor = HashTensor(key) out = tensor.squeeze() assert isinstance(out, HashTensor) assert out.size() == (num_keys, ) out = tensor.squeeze(0) assert isinstance(out, HashTensor) assert out.size() == (num_keys, ) out = tensor.squeeze(-1) assert isinstance(out, HashTensor) assert out.size() == (num_keys, ) if torch_geometric.typing.WITH_PT20: out = tensor.squeeze([0]) assert isinstance(out, HashTensor) assert out.size() == (num_keys, ) with pytest.raises(IndexError, match="out of range"): tensor.squeeze(1) with pytest.raises(IndexError, match="out of range"): tensor.squeeze(-2) value = torch.randn(key.size(0), 1, 1, device=device) tensor = HashTensor(key, value) out = tensor.squeeze() assert isinstance(out, HashTensor) assert out.size() == (num_keys, ) out = tensor.squeeze(0) assert isinstance(out, HashTensor) assert out.size() == (num_keys, 1, 1) out = tensor.squeeze(-1) assert isinstance(out, HashTensor) assert out.size() == (num_keys, 1) if torch_geometric.typing.WITH_PT20: out = tensor.squeeze([0, 1, 2]) assert isinstance(out, HashTensor) assert out.size() == (num_keys, ) @withCUDA @withHashTensor def test_slice(device): key = torch.tensor([2, 1, 0], device=device) tensor = HashTensor(key) with pytest.raises(IndexError, match="out of range"): torch.narrow(tensor, dim=-2, start=0, length=2) out = tensor[:] assert isinstance(out, HashTensor) assert out._value is None out = tensor[-2:4] assert isinstance(out, HashTensor) assert out.as_tensor().equal(torch.tensor([1, 2], device=device)) out = tensor[..., 0:2] assert isinstance(out, HashTensor) assert out.as_tensor().equal(torch.tensor([0, 1], device=device)) out = torch.narrow(tensor, dim=0, start=2, length=1) assert isinstance(out, HashTensor) assert out.as_tensor().equal(torch.tensor([2], device=device)) out = tensor.narrow(dim=0, start=1, length=2) assert isinstance(out, HashTensor) assert out.as_tensor().equal(torch.tensor([1, 2], device=device)) value = torch.randn(key.size(0), 4, device=device) tensor = HashTensor(key, value) out = tensor[0:2] assert isinstance(out, HashTensor) assert out.as_tensor().equal(value[0:2]) out = tensor[..., 0:2] assert isinstance(out, HashTensor) assert out.as_tensor().equal(value[..., 0:2]) out = torch.narrow(tensor, dim=1, start=2, length=1) assert isinstance(out, HashTensor) assert out.as_tensor().equal(value[..., 2:3]) @withCUDA @withHashTensor @pytest.mark.parametrize('dtype', KEY_DTYPES) def test_index_select(dtype, device): if dtype != torch.bool: key = torch.tensor([2, 1, 0], dtype=dtype, device=device) query = torch.tensor([0, 3, 2], dtype=dtype, device=device) else: key = torch.tensor([True, False], device=device) query = torch.tensor([False, True], device=device) tensor = HashTensor(key) out = torch.index_select(tensor, 0, query) assert not isinstance(out, HashTensor) if dtype != torch.bool: assert out.equal(torch.tensor([2, -1, 0], device=device)) else: assert out.equal(torch.tensor([1, 0], device=device)) out = tensor.index_select(-1, query) assert not isinstance(out, HashTensor) if dtype != torch.bool: assert out.equal(torch.tensor([2, -1, 0], device=device)) else: assert out.equal(torch.tensor([1, 0], device=device)) with pytest.raises(IndexError, match="out of range"): torch.index_select(tensor, 1, query) with pytest.raises(IndexError, match="out of range"): tensor.index_select(-2, query) value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) out = torch.index_select(tensor, 0, query) assert not isinstance(out, HashTensor) if dtype != torch.bool: expected = torch.full_like(value, float('NaN')) expected[0] = value[2] expected[2] = value[0] assert out.allclose(expected, equal_nan=True) else: assert out.allclose(value.flip(dims=[0])) index = torch.tensor([1], device=device) out = tensor.index_select(1, index) assert isinstance(out, HashTensor) assert out.size() == (3 if dtype != torch.bool else 2, 1) assert out.as_tensor().allclose(value[:, 1:]) @withCUDA @withHashTensor def test_select(device): key = torch.tensor([2, 1, 0], device=device) tensor = HashTensor(key) out = tensor[0] assert not isinstance(out, HashTensor) assert out.dim() == 0 assert int(out) == 2 out = tensor.select(0, 4) assert not isinstance(out, HashTensor) assert out.dim() == 0 assert int(out) == -1 with pytest.raises(IndexError, match="out of range"): torch.select(tensor, 1, 0) with pytest.raises(IndexError, match="out of range"): tensor.select(-2, 0) value = torch.randn(key.size(0), 2, device=device) tensor = HashTensor(key, value) out = tensor[0] assert not isinstance(out, HashTensor) assert out.equal(value[2]) out = tensor.select(-1, 0) assert isinstance(out, HashTensor) assert out.as_tensor().equal(value[:, 0]) @withHashTensor def test_tolist(): key = torch.tensor([2, 1, 0]) value = torch.randn(key.size(0), 2) assert HashTensor(key, value).tolist() == value.tolist() @withHashTensor def test_numpy(): key = torch.tensor([2, 1, 0]) value = torch.randn(key.size(0), 2) assert np.allclose(HashTensor(key, value).numpy(), value.numpy()) def _collate_fn(hash_tensors: List[HashTensor]) -> List[HashTensor]: return hash_tensors @pytest.mark.parametrize('num_workers', [0, 2]) @pytest.mark.parametrize('pin_memory', [False, True]) def test_data_loader(num_workers, pin_memory): key = torch.tensor([2, 1, 0]) value = torch.randn(key.size(0), 2) tensor = HashTensor(key, value) loader = torch.utils.data.DataLoader( [tensor] * 4, batch_size=2, num_workers=num_workers, collate_fn=_collate_fn, pin_memory=pin_memory, drop_last=True, ) assert len(loader) == 2 for batch in loader: assert isinstance(batch, list) assert len(batch) == 2 for tensor in batch: assert isinstance(tensor, HashTensor) assert tensor.dtype == value.dtype assert tensor.is_shared() != (num_workers == 0) or pin_memory @withCUDA @withHashTensor @pytest.mark.parametrize('dtype', KEY_DTYPES[:1]) def test_getitem(dtype, device): if dtype != torch.bool: key = torch.tensor([20, 10, 0], dtype=dtype, device=device) else: key = torch.tensor([True, False], device=device) value = torch.randn(key.size(0), 2, 4, device=device) tensor = HashTensor(key, value) if dtype != torch.bool: out = tensor[10, :, None, torch.tensor([1, 2])] else: out = tensor[False, :, None, torch.tensor([1, 2])] assert not isinstance(out, HashTensor) assert out.allclose(value[1, :, None, torch.tensor([1, 2])]) if dtype != torch.bool: out = tensor[..., 10, :, None, torch.tensor([1, 2])] else: out = tensor[..., False, :, None, torch.tensor([1, 2])] assert not isinstance(out, HashTensor) assert out.allclose(value[1, :, None, torch.tensor([1, 2])]) if dtype != torch.bool: out = tensor[..., [10, 20], 1, None, 0:2] else: out = tensor[..., [False, True], 1, None, 0:2] assert not isinstance(out, HashTensor) assert out.allclose(value[torch.tensor([1, 0]), 1, None, 0:2]) if dtype != torch.bool: out = tensor[[10, 20], 1, None, 0:2] else: out = tensor[[False, True], 1, None, 0:2] assert not isinstance(out, HashTensor) assert out.allclose(value[torch.tensor([1, 0]), 1, None, 0:2]) out = tensor[..., None, torch.tensor([1, 2])] assert isinstance(out, HashTensor) assert out.as_tensor().allclose(value[..., None, torch.tensor([1, 2])]) out = tensor[...] assert isinstance(out, HashTensor) assert out.size() == value.size() out = tensor[:2] assert isinstance(out, HashTensor) assert out.size() == (2, ) + value.size()[1:] @onlyLinux @withHashTensor @withPackage('torch>=2.3') @pytest.mark.skip(reason="Does not work currently") def test_compile_basic(): import torch._dynamo as dynamo class Model(torch.nn.Module): def forward(self, key: Tensor, query: Tensor) -> Tensor: _map = HashTensor(key) return _map[query] key = torch.randperm(10) query = key[:5] model = Model() expected = model(key, query) assert expected.equal(torch.arange(query.numel())) explanation = dynamo.explain(model)(key, query) assert explanation.graph_break_count == 0 ================================================ FILE: test/test_home.py ================================================ import os import os.path as osp from torch_geometric import get_home_dir, set_home_dir from torch_geometric.home import DEFAULT_CACHE_DIR def test_home(): os.environ.pop('PYG_HOME', None) home_dir = osp.expanduser(DEFAULT_CACHE_DIR) assert get_home_dir() == home_dir home_dir = '/tmp/test_pyg1' os.environ['PYG_HOME'] = home_dir assert get_home_dir() == home_dir home_dir = '/tmp/test_pyg2' set_home_dir(home_dir) assert get_home_dir() == home_dir ================================================ FILE: test/test_index.py ================================================ import os.path as osp from typing import List import numpy as np import pytest import torch from torch import tensor import torch_geometric.typing from torch_geometric import Index from torch_geometric.io import fs from torch_geometric.testing import onlyCUDA, withCUDA from torch_geometric.typing import INDEX_DTYPES DTYPES = [pytest.param(dtype, id=str(dtype)[6:]) for dtype in INDEX_DTYPES] @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_basic(dtype, device): kwargs = dict(dtype=dtype, device=device, dim_size=3) index = Index([0, 1, 1, 2], **kwargs) index.validate() assert isinstance(index, Index) assert str(index).startswith('Index([0, 1, 1, 2], ') assert 'dim_size=3' in str(index) assert (f"device='{device}'" in str(index)) == index.is_cuda assert (f'dtype={dtype}' in str(index)) == (dtype != torch.long) assert index.dtype == dtype assert index.device == device assert index.dim_size == 3 assert not index.is_sorted out = index.as_tensor() assert not isinstance(out, Index) assert out.dtype == dtype assert out.device == device out = index * 1 assert not isinstance(out, Index) assert out.dtype == dtype assert out.device == device @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_identity(dtype, device): kwargs = dict(dtype=dtype, device=device, dim_size=3, is_sorted=True) index = Index([0, 1, 1, 2], **kwargs) out = Index(index) assert not isinstance(out.as_tensor(), Index) assert out.data_ptr() == index.data_ptr() assert out.dtype == index.dtype assert out.device == index.device assert out.dim_size == index.dim_size assert out.is_sorted == index.is_sorted out = Index(index, dim_size=4, is_sorted=False) assert out.dim_size == 4 assert out.is_sorted == index.is_sorted def test_validate(): with pytest.raises(ValueError, match="unsupported data type"): Index([0.0, 1.0]) with pytest.raises(ValueError, match="needs to be one-dimensional"): Index([[0], [1]]) with pytest.raises(TypeError, match="invalid combination of arguments"): Index(tensor([0, 1]), torch.long) with pytest.raises(TypeError, match="invalid keyword arguments"): Index(tensor([0, 1]), dtype=torch.long) with pytest.raises(ValueError, match="contains negative indices"): Index([-1, 0]).validate() with pytest.raises(ValueError, match="than its registered size"): Index([0, 10], dim_size=2).validate() with pytest.raises(ValueError, match="not sorted"): Index([1, 0], is_sorted=True).validate() @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_fill_cache_(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], is_sorted=True, **kwargs) index.validate().fill_cache_() assert index.dim_size == 3 assert index._indptr.dtype == dtype assert index._indptr.equal(tensor([0, 1, 3, 4], device=device)) index = Index([1, 0, 2, 1], **kwargs) index.validate().fill_cache_() assert index.dim_size == 3 assert index._indptr is None @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_dim_resize(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], is_sorted=True, **kwargs).fill_cache_() assert index.dim_size == 3 assert index._indptr.equal(tensor([0, 1, 3, 4], device=device)) out = index.dim_resize_(4) assert out.dim_size == 4 assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device)) out = index.dim_resize_(3) assert out.dim_size == 3 assert out._indptr.equal(tensor([0, 1, 3, 4], device=device)) out = index.dim_resize_(None) assert out.dim_size is None assert out._indptr is None @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_clone(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], is_sorted=True, dim_size=3, **kwargs) out = index.clone() assert isinstance(out, Index) assert out.dtype == dtype assert out.device == device assert out.dim_size == 3 assert out.is_sorted out = torch.clone(index) assert isinstance(out, Index) assert out.dtype == dtype assert out.device == device assert out.dim_size == 3 assert out.is_sorted @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_to_function(dtype, device): kwargs = dict(dtype=dtype) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) index.fill_cache_() index = index.to(device) assert isinstance(index, Index) assert index.device == device assert index._indptr.dtype == dtype assert index._indptr.device == device out = index.cpu() assert isinstance(out, Index) assert out.device == torch.device('cpu') out = index.to(torch.int) assert out.dtype == torch.int if torch_geometric.typing.WITH_PT20: assert isinstance(out, Index) assert out._indptr.dtype == torch.int else: assert not isinstance(out, Index) out = index.to(torch.float) assert not isinstance(out, Index) assert out.dtype == torch.float out = index.long() assert isinstance(out, Index) assert out.dtype == torch.int64 out = index.int() assert out.dtype == torch.int if torch_geometric.typing.WITH_PT20: assert isinstance(out, Index) else: assert not isinstance(out, Index) @onlyCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_cpu_cuda(dtype): kwargs = dict(dtype=dtype) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) assert index.is_cpu out = index.cuda() assert isinstance(out, Index) assert out.is_cuda out = out.cpu() assert isinstance(out, Index) assert out.is_cpu @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_share_memory(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) index.fill_cache_() out = index.share_memory_() assert isinstance(out, Index) assert out.is_shared() assert out._data.is_shared() assert out._indptr.is_shared() assert out.data_ptr() == index.data_ptr() @onlyCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_pin_memory(dtype): index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, dtype=dtype) assert not index.is_pinned() out = index.pin_memory() assert out.is_pinned() @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_contiguous(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) assert index.is_contiguous out = index.contiguous() assert isinstance(out, Index) assert out.is_contiguous @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_sort(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([1, 0, 2, 1], dim_size=3, **kwargs) index, _ = index.sort() assert isinstance(index, Index) assert index.equal(tensor([0, 1, 1, 2], device=device)) assert index.dim_size == 3 assert index.is_sorted out, perm = index.sort() assert isinstance(out, Index) assert out._data.data_ptr() == index._data.data_ptr() assert perm.equal(tensor([0, 1, 2, 3], device=device)) assert out.dim_size == 3 index, _ = index.sort(descending=True) assert isinstance(index, Index) assert index.equal(tensor([2, 1, 1, 0], device=device)) assert index.dim_size == 3 assert not index.is_sorted @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_sort_stable(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([1, 0, 2, 1], dim_size=3, **kwargs) index, perm = index.sort(stable=True) assert isinstance(index, Index) assert index.equal(tensor([0, 1, 1, 2], device=device)) assert perm.equal(tensor([1, 0, 3, 2], device=device)) assert index.dim_size == 3 assert index.is_sorted out, perm = index.sort(stable=True) assert isinstance(out, Index) assert out._data.data_ptr() == index._data.data_ptr() assert perm.equal(tensor([0, 1, 2, 3], device=device)) assert out.dim_size == 3 index, perm = index.sort(descending=True, stable=True) assert isinstance(index, Index) assert index.equal(tensor([2, 1, 1, 0], device=device)) assert perm.equal(tensor([3, 1, 2, 0], device=device)) assert index.dim_size == 3 assert not index.is_sorted @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_cat(dtype, device): kwargs = dict(dtype=dtype, device=device) index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) index2 = Index([1, 2, 2, 3], dim_size=4, is_sorted=True, **kwargs) index3 = Index([1, 2, 2, 3], **kwargs) out = torch.cat([index1, index2]) assert out.equal(tensor([0, 1, 1, 2, 1, 2, 2, 3], device=device)) assert out.size() == (8, ) assert isinstance(out, Index) assert out.dim_size == 4 assert not out.is_sorted assert out._cat_metadata.nnz == [4, 4] assert out._cat_metadata.dim_size == [3, 4] assert out._cat_metadata.is_sorted == [True, True] out = torch.cat([index1, index2, index3]) assert out.size() == (12, ) assert isinstance(out, Index) assert out.dim_size is None assert not out.is_sorted out = torch.cat([index1, index2.as_tensor()]) assert out.size() == (8, ) assert not isinstance(out, Index) inplace = torch.empty(8, dtype=dtype, device=device) out = torch.cat([index1, index2], out=inplace) assert out.equal(tensor([0, 1, 1, 2, 1, 2, 2, 3], device=device)) assert out.data_ptr() == inplace.data_ptr() assert not isinstance(out, Index) assert not isinstance(inplace, Index) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_flip(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) out = index.flip(0) assert isinstance(out, Index) assert out.equal(tensor([2, 1, 1, 0], device=device)) assert out.dim_size == 3 assert not out.is_sorted @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_index_select(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) i = tensor([1, 3], device=device) out = index.index_select(0, i) assert out.equal(tensor([1, 2], device=device)) assert isinstance(out, Index) assert out.dim_size == 3 assert not out.is_sorted inplace = torch.empty(2, dtype=dtype, device=device) out = torch.index_select(index, 0, i, out=inplace) assert out.equal(tensor([1, 2], device=device)) assert out.data_ptr() == inplace.data_ptr() assert not isinstance(out, Index) assert not isinstance(inplace, Index) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_narrow(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) out = index.narrow(0, start=1, length=2) assert isinstance(out, Index) assert out.equal(tensor([1, 1], device=device)) assert out.dim_size == 3 assert out.is_sorted @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_getitem(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) out = index[:] assert isinstance(out, Index) assert out._data.data_ptr() == index._data.data_ptr() assert out.equal(tensor([0, 1, 1, 2], device=device)) assert out.dim_size == 3 assert out.is_sorted out = index[tensor([False, True, False, True], device=device)] assert isinstance(out, Index) assert out.equal(tensor([1, 2], device=device)) assert out.dim_size == 3 assert out.is_sorted out = index[tensor([1, 3], device=device)] assert isinstance(out, Index) assert out.equal(tensor([1, 2], device=device)) assert out.dim_size == 3 assert not out.is_sorted out = index[1:3] assert isinstance(out, Index) assert out.equal(tensor([1, 1], device=device)) assert out.dim_size == 3 assert out.is_sorted out = index[...] assert isinstance(out, Index) assert out._data.data_ptr() == index._data.data_ptr() assert out.equal(tensor([0, 1, 1, 2], device=device)) assert out.dim_size == 3 assert out.is_sorted out = index[..., 1:3] assert isinstance(out, Index) assert out.equal(tensor([1, 1], device=device)) assert out.dim_size == 3 assert out.is_sorted out = index[None, 1:3] assert not isinstance(out, Index) assert out.equal(tensor([[1, 1]], device=device)) out = index[1:3, None] assert not isinstance(out, Index) assert out.equal(tensor([[1], [1]], device=device)) out = index[0] assert not isinstance(out, Index) assert out.equal(tensor(0, device=device)) tmp = torch.randn(3, device=device) out = tmp[index] assert not isinstance(out, Index) assert out.equal(tmp[index.as_tensor()]) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_add(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) out = torch.add(index, 2, alpha=2) assert isinstance(out, Index) assert out.equal(tensor([4, 5, 5, 6], device=device)) assert out.dim_size == 7 assert out.is_sorted out = index + tensor([2], dtype=dtype, device=device) assert isinstance(out, Index) assert out.equal(tensor([2, 3, 3, 4], device=device)) assert out.dim_size == 5 assert out.is_sorted out = tensor([2], dtype=dtype, device=device) + index assert isinstance(out, Index) assert out.equal(tensor([2, 3, 3, 4], device=device)) assert out.dim_size == 5 assert out.is_sorted out = index.add(index) assert isinstance(out, Index) assert out.equal(tensor([0, 2, 2, 4], device=device)) assert out.dim_size == 6 assert not out.is_sorted index += 2 assert isinstance(index, Index) assert index.equal(tensor([2, 3, 3, 4], device=device)) assert index.dim_size == 5 assert index.is_sorted with pytest.raises(RuntimeError, match="can't be cast"): index += 2.5 @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_sub(dtype, device): kwargs = dict(dtype=dtype, device=device) index = Index([4, 5, 5, 6], dim_size=7, is_sorted=True, **kwargs) out = torch.sub(index, 2, alpha=2) assert isinstance(out, Index) assert out.equal(tensor([0, 1, 1, 2], device=device)) assert out.dim_size == 3 assert out.is_sorted out = index - tensor([2], dtype=dtype, device=device) assert isinstance(out, Index) assert out.equal(tensor([2, 3, 3, 4], device=device)) assert out.dim_size == 5 assert out.is_sorted out = tensor([6], dtype=dtype, device=device) - index assert isinstance(out, Index) assert out.equal(tensor([2, 1, 1, 0], device=device)) assert out.dim_size is None assert not out.is_sorted out = index.sub(index) assert isinstance(out, Index) assert out.equal(tensor([0, 0, 0, 0], device=device)) assert out.dim_size is None assert not out.is_sorted index -= 2 assert isinstance(index, Index) assert index.equal(tensor([2, 3, 3, 4], device=device)) assert index.dim_size == 5 assert not out.is_sorted with pytest.raises(RuntimeError, match="can't be cast"): index -= 2.5 def test_to_list(): data = torch.tensor([0, 1, 1, 2]) index = Index(data) assert index.tolist() == data.tolist() def test_numpy(): data = torch.tensor([0, 1, 1, 2]) index = Index(data) assert np.array_equal(index.numpy(), data.numpy()) @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_save_and_load(dtype, device, tmp_path): kwargs = dict(dtype=dtype, device=device) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) index.fill_cache_() path = osp.join(tmp_path, 'edge_index.pt') torch.save(index, path) out = fs.torch_load(path) assert isinstance(out, Index) assert out.equal(index) assert out.dim_size == 3 assert out.is_sorted assert out._indptr.equal(index._indptr) def _collate_fn(indices: List[Index]) -> List[Index]: return indices @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('num_workers', [0, 2]) @pytest.mark.parametrize('pin_memory', [False, True]) def test_data_loader(dtype, num_workers, pin_memory): kwargs = dict(dtype=dtype) index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs) index.fill_cache_() loader = torch.utils.data.DataLoader( [index] * 4, batch_size=2, num_workers=num_workers, collate_fn=_collate_fn, pin_memory=pin_memory, drop_last=True, ) assert len(loader) == 2 for batch in loader: assert isinstance(batch, list) assert len(batch) == 2 for index in batch: assert isinstance(index, Index) assert index.dtype == dtype assert index.is_shared() != (num_workers == 0) or pin_memory assert index._data.is_shared() != (num_workers == 0) or pin_memory ================================================ FILE: test/test_inspector.py ================================================ import inspect from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union import torch from torch import Tensor from torch_geometric.inspector import Inspector, Parameter, Signature from torch_geometric.nn import GATConv, SAGEConv from torch_geometric.typing import OptPairTensor def test_eval_type() -> None: inspector = Inspector(SAGEConv) assert inspector.eval_type('Tensor') == Tensor assert inspector.eval_type('List[Tensor]') == List[Tensor] assert inspector.eval_type('Tuple[Tensor, int]') == Tuple[Tensor, int] assert inspector.eval_type('Tuple[int, ...]') == Tuple[int, ...] def test_type_repr() -> None: inspector = Inspector(SAGEConv) assert inspector.type_repr(Any) == 'typing.Any' assert inspector.type_repr(Final) == 'typing.Final' assert inspector.type_repr(OptPairTensor) == ( 'Tuple[Tensor, Optional[Tensor]]') assert inspector.type_repr( Final[Optional[Tensor]]) == ('typing.Final[Optional[Tensor]]') assert inspector.type_repr(Union[None, Tensor]) == 'Optional[Tensor]' assert inspector.type_repr(Optional[Tensor]) == 'Optional[Tensor]' assert inspector.type_repr(Set[Tensor]) == 'typing.Set[Tensor]' assert inspector.type_repr(List) == 'List' assert inspector.type_repr(Tuple) == 'Tuple' assert inspector.type_repr(Set) == 'typing.Set' assert inspector.type_repr(Dict) == 'typing.Dict' assert inspector.type_repr(Dict[str, Tuple[Tensor, Tensor]]) == ( # 'typing.Dict[str, Tuple[Tensor, Tensor]]') assert inspector.type_repr(Tuple[int, ...]) == 'Tuple[int, ...]' assert inspector.type_repr(Union[int, str, None]) == ( # 'Union[int, str, None]') def test_inspector_sage_conv() -> None: inspector = Inspector(SAGEConv) assert str(inspector) == 'Inspector(SAGEConv)' assert inspector.implements('message') assert inspector.implements('message_and_aggregate') out = inspector.inspect_signature(SAGEConv.message) assert isinstance(out, Signature) assert out.param_dict == { 'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty) } assert out.return_type == Tensor assert inspector.get_flat_params(['message', 'message']) == [ Parameter('x_j', Tensor, 'Tensor', inspect._empty), ] assert inspector.get_flat_param_names(['message']) == ['x_j'] kwargs = {'x_j': torch.randn(5), 'x_i': torch.randn(5)} data = inspector.collect_param_data('message', kwargs) assert len(data) == 1 assert torch.allclose(data['x_j'], kwargs['x_j']) assert inspector.get_params_from_method_call(SAGEConv.propagate) == { 'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty), } def test_inspector_gat_conv() -> None: inspector = Inspector(GATConv) assert str(inspector) == 'Inspector(GATConv)' assert inspector.implements('message') assert not inspector.implements('message_and_aggregate') out = inspector.inspect_signature(GATConv.message) assert isinstance(out, Signature) assert out.param_dict == { 'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty), 'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty), } assert out.return_type == Tensor assert inspector.get_flat_params(['message', 'message']) == [ Parameter('x_j', Tensor, 'Tensor', inspect._empty), Parameter('alpha', Tensor, 'Tensor', inspect._empty), ] assert inspector.get_flat_param_names(['message']) == ['x_j', 'alpha'] kwargs = {'x_j': torch.randn(5), 'alpha': torch.randn(5)} data = inspector.collect_param_data('message', kwargs) assert len(data) == 2 assert torch.allclose(data['x_j'], kwargs['x_j']) assert torch.allclose(data['alpha'], kwargs['alpha']) assert inspector.get_params_from_method_call(SAGEConv.propagate) == { 'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty), 'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty), } def test_get_params_from_method_call() -> None: class FromMethodCall1: propagate_type = {'x': Tensor} inspector = Inspector(FromMethodCall1) assert inspector.get_params_from_method_call('propagate') == { 'x': Parameter('x', Tensor, 'Tensor', inspect._empty), } class FromMethodCall2: # propagate_type: (x: Tensor) pass inspector = Inspector(FromMethodCall2) assert inspector.get_params_from_method_call('propagate') == { 'x': Parameter('x', Tensor, 'Tensor', inspect._empty), } class FromMethodCall3: def forward(self) -> None: self.propagate( # type: ignore torch.randn(5, 5), x=None, size=None, ) inspector = Inspector(FromMethodCall3) exclude = [0, 'size'] assert inspector.get_params_from_method_call('propagate', exclude) == { 'x': Parameter('x', Tensor, 'Tensor', inspect._empty), } class FromMethodCall4: pass inspector = Inspector(FromMethodCall4) assert inspector.get_params_from_method_call('propagate') == {} ================================================ FILE: test/test_isinstance.py ================================================ import torch from torch_geometric import is_torch_instance from torch_geometric.testing import onlyLinux, withPackage def test_basic(): assert is_torch_instance(torch.nn.Linear(1, 1), torch.nn.Linear) @onlyLinux @withPackage('torch>=2.0.0') def test_compile(): model = torch.compile(torch.nn.Linear(1, 1)) assert not isinstance(model, torch.nn.Linear) assert is_torch_instance(model, torch.nn.Linear) ================================================ FILE: test/test_onnx.py ================================================ import os import tempfile import warnings from typing import Any from unittest.mock import patch import pytest import torch from torch_geometric import is_in_onnx_export, safe_onnx_export # Global mock to prevent ANY real ONNX calls in tests # This ensures no deprecation warnings or real ONNX issues pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") class SimpleModel(torch.nn.Module): """Simple model for testing ONNX export.""" def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 2) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) def test_is_in_onnx_export() -> None: """Test is_in_onnx_export function.""" assert not is_in_onnx_export() def test_safe_onnx_export_ci_resilient() -> None: """Test safe_onnx_export handles CI environment issues gracefully.""" model = SimpleModel() x = torch.randn(3, 4) # Use mocking to prevent real ONNX calls and deprecation warnings with patch('torch.onnx.export', return_value=None) as mock_export: with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: # Test with skip_on_error=True - should never fail result = safe_onnx_export(model, (x, ), f.name, skip_on_error=True) # Should always succeed with mocking assert result is True # Verify the mock was called correctly mock_export.assert_called_once() call_args = mock_export.call_args[0] assert call_args[0] is model assert isinstance(call_args[1], tuple) assert call_args[2] == f.name finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass # Ignore file lock issues def test_safe_onnx_export_success() -> None: """Test successful ONNX export with pure mocking.""" model = SimpleModel() x = torch.randn(3, 4) # Use comprehensive mocking to avoid any real ONNX calls with patch('torch.onnx.export', return_value=None) as mock_export: with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: # Test with tuple args - should succeed with mock result = safe_onnx_export(model, (x, ), f.name) assert result is True # Verify torch.onnx.export was called with correct args mock_export.assert_called() call_args = mock_export.call_args[0] assert call_args[0] is model # model assert isinstance(call_args[1], tuple) # args as tuple assert call_args[2] == f.name # file path # Reset mock for second test mock_export.reset_mock() # Test with single tensor (should be converted to tuple) result = safe_onnx_export(model, x, f.name) assert result is True # Verify single tensor was converted to tuple call_args = mock_export.call_args[0] assert isinstance(call_args[1], tuple) finally: if os.path.exists(f.name): try: try: os.unlink(f.name) except (PermissionError, OSError): pass except (PermissionError, OSError): pass def test_safe_onnx_export_with_skip_on_error() -> None: """Test safe_onnx_export with skip_on_error=True.""" model = SimpleModel() x = torch.randn(3, 4) # Mock torch.onnx.export to raise SerdeError with patch('torch.onnx.export') as mock_export: mock_export.side_effect = Exception( "onnx_ir.serde.SerdeError: allowzero") with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: # Should return False instead of raising result = safe_onnx_export(model, (x, ), f.name, skip_on_error=True) assert result is False finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass def test_serde_error_patterns() -> None: """Test detection of various SerdeError patterns.""" model = SimpleModel() x = torch.randn(3, 4) error_patterns = [ "onnx_ir.serde.SerdeError: allowzero attribute", "ValueError: Value out of range: 1", "serialize_model_into failed", "serialize_attribute_into failed" ] for error_msg in error_patterns: # Use multiple patch targets to ensure comprehensive mocking with patch('torch.onnx.export') as mock_export, \ patch('torch_geometric._onnx.torch.onnx.export') as mock_export2: mock_export.side_effect = Exception(error_msg) mock_export2.side_effect = Exception(error_msg) with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: result = safe_onnx_export(model, (x, ), f.name, skip_on_error=True) assert result is False finally: if os.path.exists(f.name): try: try: os.unlink(f.name) except (PermissionError, OSError): pass except (PermissionError, OSError): pass # Ignore file lock issues def test_non_serde_error_reraise() -> None: """Test that non-SerdeError exceptions are re-raised.""" model = SimpleModel() x = torch.randn(3, 4) # Use comprehensive mocking to prevent real ONNX calls with patch('torch.onnx.export') as mock_export: mock_export.side_effect = ValueError("Some other error") with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: with pytest.raises(ValueError, match="Some other error"): safe_onnx_export(model, (x, ), f.name) finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass def test_dynamo_fallback() -> None: """Test dynamo=False fallback strategy.""" model = SimpleModel() x = torch.randn(3, 4) call_count = 0 def mock_export_side_effect(*_args: Any, **kwargs: Any) -> None: nonlocal call_count call_count += 1 if call_count == 1: # First call fails raise Exception("onnx_ir.serde.SerdeError: allowzero") elif call_count == 2 and not kwargs.get('dynamo', True): # Second call succeeds with dynamo=False return None else: raise Exception("Unexpected call") with patch('torch.onnx.export', side_effect=mock_export_side_effect): with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: result = safe_onnx_export(model, (x, ), f.name, dynamo=True) assert result is True assert call_count == 2 finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass def test_opset_fallback() -> None: """Test opset version fallback strategy.""" model = SimpleModel() x = torch.randn(3, 4) call_count = 0 def mock_export_side_effect(*_args: Any, **kwargs: Any) -> None: nonlocal call_count call_count += 1 # Fail until we get to opset_version=17 if kwargs.get('opset_version') == 17: # This call succeeds return None else: # All other calls fail raise Exception("onnx_ir.serde.SerdeError: allowzero") with patch('torch.onnx.export', side_effect=mock_export_side_effect): with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: result = safe_onnx_export(model, (x, ), f.name, opset_version=18) # Should succeed when opset_version=17 is tried assert result is True finally: if os.path.exists(f.name): try: try: os.unlink(f.name) except (PermissionError, OSError): pass except (PermissionError, OSError): pass def test_all_strategies_fail() -> None: """Test when all workaround strategies fail.""" model = SimpleModel() x = torch.randn(3, 4) with patch('torch.onnx.export') as mock_export: mock_export.side_effect = Exception( "onnx_ir.serde.SerdeError: allowzero") with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: # Should raise RuntimeError when skip_on_error=False with pytest.raises(RuntimeError, match="Failed to export model to ONNX"): safe_onnx_export(model, (x, ), f.name, skip_on_error=False) # Should return False when skip_on_error=True result = safe_onnx_export(model, (x, ), f.name, skip_on_error=True) assert result is False finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass def test_pytest_environment_detection() -> None: """Test pytest environment detection for better error messages.""" model = SimpleModel() x = torch.randn(3, 4) with patch('torch.onnx.export') as mock_export: mock_export.side_effect = Exception( "onnx_ir.serde.SerdeError: allowzero") # Set pytest environment variable with patch.dict(os.environ, {'PYTEST_CURRENT_TEST': 'test_something'}): with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: with pytest.raises(RuntimeError) as exc_info: safe_onnx_export(model, (x, ), f.name, skip_on_error=False) # Should contain pytest-specific guidance assert "pytest environments" in str(exc_info.value) assert "torch.jit.script()" in str(exc_info.value) finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass def test_warnings_emitted() -> None: """Test that appropriate warnings are emitted during workarounds.""" model = SimpleModel() x = torch.randn(3, 4) call_count = 0 def mock_export_side_effect(*_args: Any, **_kwargs: Any) -> None: nonlocal call_count call_count += 1 if call_count == 1: raise Exception("onnx_ir.serde.SerdeError: allowzero") elif call_count == 2: return None # Success on dynamo fallback else: raise Exception("Unexpected call") with patch('torch.onnx.export', side_effect=mock_export_side_effect): with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") result = safe_onnx_export(model, (x, ), f.name, dynamo=True) assert result is True assert len(w) >= 2 # Initial error + dynamo fallback assert any("allowzero boolean attribute bug" in str( warning.message) for warning in w) assert any( "dynamo=False as workaround" in str(warning.message) for warning in w) finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass @pytest.mark.parametrize( "args_input", [ torch.randn(3, 4), # Single tensor (torch.randn(3, 4), ), # Tuple with one tensor (torch.randn(3, 4), torch.randn(3, 2)), # Tuple with multiple tensors ]) def test_args_conversion(args_input: Any) -> None: """Test that args are properly converted to tuple format.""" model = SimpleModel() with patch('torch.onnx.export') as mock_export: mock_export.return_value = None with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: try: result = safe_onnx_export(model, args_input, f.name) assert result is True # Check that torch.onnx.export was called with tuple args mock_export.assert_called_once() call_args = mock_export.call_args[0] assert isinstance(call_args[1], tuple) # args should be tuple finally: if os.path.exists(f.name): try: os.unlink(f.name) except (PermissionError, OSError): pass ================================================ FILE: test/test_seed.py ================================================ import random import numpy as np import torch from torch_geometric import seed_everything def test_seed_everything(): seed_everything(0) assert random.randint(0, 100) == 49 assert random.randint(0, 100) == 97 assert np.random.randint(0, 100) == 44 assert np.random.randint(0, 100) == 47 assert int(torch.randint(0, 100, (1, ))) == 44 assert int(torch.randint(0, 100, (1, ))) == 39 ================================================ FILE: test/test_typing.py ================================================ import pytest from torch_geometric.typing import EdgeTypeStr def test_edge_type_str(): edge_type_str = EdgeTypeStr('a__links__b') assert isinstance(edge_type_str, str) assert edge_type_str == 'a__links__b' assert edge_type_str.to_tuple() == ('a', 'links', 'b') edge_type_str = EdgeTypeStr('a', 'b') assert isinstance(edge_type_str, str) assert edge_type_str == 'a__to__b' assert edge_type_str.to_tuple() == ('a', 'to', 'b') edge_type_str = EdgeTypeStr(('a', 'b')) assert isinstance(edge_type_str, str) assert edge_type_str == 'a__to__b' assert edge_type_str.to_tuple() == ('a', 'to', 'b') edge_type_str = EdgeTypeStr('a', 'links', 'b') assert isinstance(edge_type_str, str) assert edge_type_str == 'a__links__b' assert edge_type_str.to_tuple() == ('a', 'links', 'b') edge_type_str = EdgeTypeStr(('a', 'links', 'b')) assert isinstance(edge_type_str, str) assert edge_type_str == 'a__links__b' assert edge_type_str.to_tuple() == ('a', 'links', 'b') with pytest.raises(ValueError, match="invalid edge type"): EdgeTypeStr('a', 'b', 'c', 'd') with pytest.raises(ValueError, match="Cannot convert the edge type"): EdgeTypeStr('a__b__c__d').to_tuple() ================================================ FILE: test/test_warnings.py ================================================ import warnings from unittest.mock import patch import pytest from torch_geometric.warnings import WarningCache, warn def test_warn(): with pytest.warns(UserWarning, match='test'): warn('test') @patch('torch_geometric.is_compiling', return_value=True) def test_no_warn_if_compiling(_): """No warning should be raised to avoid graph breaks when compiling.""" with warnings.catch_warnings(): warnings.simplefilter('error') warn('test') def test_warning_cache(): cache = WarningCache() assert len(cache) == 0 cache.warn('test') assert len(cache) == 1 assert 'test' in cache cache.warn('test') assert len(cache) == 1 cache.warn('test2') assert len(cache) == 2 assert 'test2' in cache ================================================ FILE: test/testing/test_decorators.py ================================================ import torch_geometric.typing from torch_geometric.testing import disableExtensions def test_enable_extensions(): try: import pyg_lib # noqa assert torch_geometric.typing.WITH_PYG_LIB except (ImportError, OSError): assert not torch_geometric.typing.WITH_PYG_LIB try: import torch_scatter # noqa assert torch_geometric.typing.WITH_TORCH_SCATTER except (ImportError, OSError): assert not torch_geometric.typing.WITH_TORCH_SCATTER try: import torch_sparse # noqa assert torch_geometric.typing.WITH_TORCH_SPARSE except (ImportError, OSError): assert not torch_geometric.typing.WITH_TORCH_SPARSE @disableExtensions def test_disable_extensions(): assert not torch_geometric.typing.WITH_PYG_LIB assert not torch_geometric.typing.WITH_TORCH_SCATTER assert not torch_geometric.typing.WITH_TORCH_SPARSE ================================================ FILE: test/transforms/test_add_gpse.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.nn import GPSE from torch_geometric.nn.models.gpse import IdentityHead from torch_geometric.transforms import AddGPSE num_nodes = 6 gpse_inner_dim = 512 def test_gpse(): x = torch.randn(num_nodes, 4) edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) data = Data(x=x, edge_index=edge_index) model = GPSE() model.post_mp = IdentityHead() transform = AddGPSE(model) assert str(transform) == 'AddGPSE()' out = transform(data) assert out.pestat_GPSE.size() == (num_nodes, gpse_inner_dim) ================================================ FILE: test/transforms/test_add_metapaths.py ================================================ import torch from torch import tensor from torch_geometric.data import HeteroData from torch_geometric.transforms import AddMetaPaths, AddRandomMetaPaths from torch_geometric.utils import coalesce def generate_data() -> HeteroData: data = HeteroData() data['p'].x = torch.ones(5) data['a'].x = torch.ones(6) data['c'].x = torch.ones(3) data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]]) data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]]) data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0]) data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]]) data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0]) return data def test_add_metapaths() -> None: data = generate_data() # Test transform options: metapaths = [[('p', 'c'), ('c', 'p')]] transform = AddMetaPaths(metapaths) assert str(transform) == 'AddMetaPaths()' meta1 = transform(data) transform = AddMetaPaths(metapaths, drop_orig_edge_types=True) assert str(transform) == 'AddMetaPaths()' meta2 = transform(data) transform = AddMetaPaths(metapaths, drop_orig_edge_types=True, keep_same_node_type=True) assert str(transform) == 'AddMetaPaths()' meta3 = transform(data) transform = AddMetaPaths(metapaths, drop_orig_edge_types=True, keep_same_node_type=True, drop_unconnected_node_types=True) assert str(transform) == 'AddMetaPaths()' meta4 = transform(data) assert meta1['metapath_0'].edge_index.size() == (2, 9) assert meta2['metapath_0'].edge_index.size() == (2, 9) assert meta3['metapath_0'].edge_index.size() == (2, 9) assert meta4['metapath_0'].edge_index.size() == (2, 9) assert all([i in meta1.edge_types for i in data.edge_types]) assert meta2.edge_types == [('p', 'metapath_0', 'p')] assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] assert meta3.node_types == ['p', 'a', 'c'] assert meta4.node_types == ['p'] # Test 4-hop metapath: metapaths = [ [('a', 'p'), ('p', 'c')], [('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')], ] transform = AddMetaPaths(metapaths) meta = transform(data) new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')] assert meta['metapath_0'].edge_index.size() == (2, 4) assert meta['metapath_1'].edge_index.size() == (2, 4) # Test `metapath_dict` information: assert list(meta.metapath_dict.values()) == metapaths assert list(meta.metapath_dict.keys()) == new_edge_types def test_add_metapaths_max_sample() -> None: torch.manual_seed(12345) data = generate_data() metapaths = [[('p', 'c'), ('c', 'p')]] transform = AddMetaPaths(metapaths, max_sample=1) meta = transform(data) assert meta['metapath_0'].edge_index.size(1) < 9 def test_add_weighted_metapaths() -> None: torch.manual_seed(12345) data = HeteroData() data['a'].num_nodes = 2 data['b'].num_nodes = 3 data['c'].num_nodes = 2 data['d'].num_nodes = 2 data['a', 'b'].edge_index = tensor([[0, 1, 1], [0, 1, 2]]) data['b', 'a'].edge_index = data['a', 'b'].edge_index.flip([0]) data['b', 'c'].edge_index = tensor([[0, 1, 2], [0, 1, 1]]) data['c', 'b'].edge_index = data['b', 'c'].edge_index.flip([0]) data['c', 'd'].edge_index = tensor([[0, 1], [0, 0]]) data['d', 'c'].edge_index = data['c', 'd'].edge_index.flip([0]) metapaths = [ [('a', 'b'), ('b', 'c')], [('a', 'b'), ('b', 'c'), ('c', 'd')], [('a', 'b'), ('b', 'c'), ('c', 'd'), ('d', 'c'), ('c', 'b'), ('b', 'a')], ] transform = AddMetaPaths(metapaths, weighted=True) out = transform(data) # Make sure manually added metapaths compute the correct number of edges: edge_index = out['a', 'a'].edge_index edge_weight = out['a', 'a'].edge_weight edge_index, edge_weight = coalesce(edge_index, edge_weight) assert edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] assert edge_weight.tolist() == [1, 2, 2, 4] edge_index = out['a', 'c'].edge_index edge_weight = out['a', 'c'].edge_weight edge_index, edge_weight = coalesce(edge_index, edge_weight) assert edge_index.tolist() == [[0, 1], [0, 1]] assert edge_weight.tolist() == [1, 2] edge_index = out['a', 'd'].edge_index edge_weight = out['a', 'd'].edge_weight edge_index, edge_weight = coalesce(edge_index, edge_weight) assert edge_index.tolist() == [[0, 1], [0, 0]] assert edge_weight.tolist() == [1, 2] # Compute intra-table metapaths efficiently: metapaths = [[('a', 'b'), ('b', 'c'), ('c', 'd')]] out = AddMetaPaths(metapaths, weighted=True)(data) out['d', 'a'].edge_index = out['a', 'd'].edge_index.flip([0]) out['d', 'a'].edge_weight = out['a', 'd'].edge_weight metapaths = [[('a', 'd'), ('d', 'a')]] out = AddMetaPaths(metapaths, weighted=True)(out) edge_index = out['a', 'a'].edge_index edge_weight = out['a', 'a'].edge_weight edge_index, edge_weight = coalesce(edge_index, edge_weight) assert edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] assert edge_weight.tolist() == [1, 2, 2, 4] def test_add_random_metapaths() -> None: data = generate_data() # Test transform options: metapaths = [[('p', 'c'), ('c', 'p')]] torch.manual_seed(12345) transform = AddRandomMetaPaths(metapaths) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' 'walks_per_node=[1])') meta1 = transform(data) transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' 'walks_per_node=[1])') meta2 = transform(data) transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True, keep_same_node_type=True) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' 'walks_per_node=[1])') meta3 = transform(data) transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True, keep_same_node_type=True, drop_unconnected_node_types=True) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' 'walks_per_node=[1])') meta4 = transform(data) transform = AddRandomMetaPaths(metapaths, sample_ratio=0.8, drop_orig_edge_types=True, keep_same_node_type=True, drop_unconnected_node_types=True) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=0.8, ' 'walks_per_node=[1])') meta5 = transform(data) transform = AddRandomMetaPaths(metapaths, walks_per_node=5, drop_orig_edge_types=True, keep_same_node_type=True, drop_unconnected_node_types=True) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' 'walks_per_node=[5])') meta6 = transform(data) assert meta1['metapath_0'].edge_index.size() == (2, 5) assert meta2['metapath_0'].edge_index.size() == (2, 5) assert meta3['metapath_0'].edge_index.size() == (2, 5) assert meta4['metapath_0'].edge_index.size() == (2, 5) assert meta5['metapath_0'].edge_index.size() == (2, 4) assert meta6['metapath_0'].edge_index.size() == (2, 7) assert all([i in meta1.edge_types for i in data.edge_types]) assert meta2.edge_types == [('p', 'metapath_0', 'p')] assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] assert meta3.node_types == ['p', 'a', 'c'] assert meta4.node_types == ['p'] # Test 4-hop metapath: metapaths = [ [('a', 'p'), ('p', 'c')], [('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')], ] transform = AddRandomMetaPaths(metapaths) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' 'walks_per_node=[1, 1])') meta1 = transform(data) new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')] assert meta1['metapath_0'].edge_index.size() == (2, 2) assert meta1['metapath_1'].edge_index.size() == (2, 2) # Test `metapath_dict` information: assert list(meta1.metapath_dict.values()) == metapaths assert list(meta1.metapath_dict.keys()) == new_edge_types transform = AddRandomMetaPaths(metapaths, walks_per_node=[2, 5]) assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' 'walks_per_node=[2, 5])') meta2 = transform(data) new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')] assert meta2['metapath_0'].edge_index.size() == (2, 2) assert meta2['metapath_1'].edge_index.size() == (2, 3) ================================================ FILE: test/transforms/test_add_positional_encoding.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.transforms import ( AddLaplacianEigenvectorPE, AddRandomWalkPE, ) @withPackage('scipy') def test_add_laplacian_eigenvector_pe(): x = torch.randn(6, 4) edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) data = Data(x=x, edge_index=edge_index) transform = AddLaplacianEigenvectorPE(k=3) assert str(transform) == 'AddLaplacianEigenvectorPE()' out = transform(data) assert out.laplacian_eigenvector_pe.size() == (6, 3) transform = AddLaplacianEigenvectorPE(k=3, attr_name=None) out = transform(data) assert out.x.size() == (6, 4 + 3) transform = AddLaplacianEigenvectorPE(k=3, attr_name='x') out = transform(data) assert out.x.size() == (6, 3) # Output tests: edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5, 2, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3, 5, 2]]) data = Data(x=x, edge_index=edge_index) transform1 = AddLaplacianEigenvectorPE(k=1, is_undirected=True) transform2 = AddLaplacianEigenvectorPE(k=1, is_undirected=False) # Clustering test with first non-trivial eigenvector (Fiedler vector) pe = transform1(data).laplacian_eigenvector_pe pe_cluster_1 = pe[[0, 1, 4]] pe_cluster_2 = pe[[2, 3, 5]] assert not torch.allclose(pe_cluster_1, pe_cluster_2) assert torch.allclose(pe_cluster_1, pe_cluster_1.mean()) assert torch.allclose(pe_cluster_2, pe_cluster_2.mean()) pe = transform2(data).laplacian_eigenvector_pe pe_cluster_1 = pe[[0, 1, 4]] pe_cluster_2 = pe[[2, 3, 5]] assert not torch.allclose(pe_cluster_1, pe_cluster_2) assert torch.allclose(pe_cluster_1, pe_cluster_1.mean()) assert torch.allclose(pe_cluster_2, pe_cluster_2.mean()) @withPackage('scipy') def test_eigenvector_permutation_invariance(): edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) data = Data(edge_index=edge_index, num_nodes=6) perm = torch.randperm(data.num_nodes) transform = AddLaplacianEigenvectorPE( k=2, is_undirected=True, attr_name='x', ) out1 = transform(data) transform = AddLaplacianEigenvectorPE( k=2, is_undirected=True, attr_name='x', ) out2 = transform(data.subgraph(perm)) assert torch.allclose(out1.x[perm].abs(), out2.x.abs(), atol=1e-6) def test_add_random_walk_pe(): x = torch.randn(6, 4) edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) data = Data(x=x, edge_index=edge_index) transform = AddRandomWalkPE(walk_length=3) assert str(transform) == 'AddRandomWalkPE()' out = transform(data) assert out.random_walk_pe.size() == (6, 3) transform = AddRandomWalkPE(walk_length=3, attr_name=None) out = transform(data) assert out.x.size() == (6, 4 + 3) transform = AddRandomWalkPE(walk_length=3, attr_name='x') out = transform(data) assert out.x.size() == (6, 3) # Output tests: assert out.x.tolist() == [ [0.0, 0.5, 0.25], [0.0, 0.5, 0.25], [0.0, 0.5, 0.00], [0.0, 1.0, 0.00], [0.0, 0.5, 0.25], [0.0, 0.5, 0.00], ] edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) data = Data(edge_index=edge_index, num_nodes=4) out = transform(data) assert out.x.tolist() == [ [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], ] ================================================ FILE: test/transforms/test_add_remaining_self_loops.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import AddRemainingSelfLoops def test_add_remaining_self_loops(): assert str(AddRemainingSelfLoops()) == 'AddRemainingSelfLoops()' assert len(AddRemainingSelfLoops()(Data())) == 0 # No self-loops in `edge_index`. edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.tensor([1, 2, 3, 4]) edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) data = Data(edge_index=edge_index, num_nodes=3) data = AddRemainingSelfLoops()(data) assert len(data) == 2 assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] assert data.num_nodes == 3 # Single self-loop in `edge_index`. edge_index = torch.tensor([[0, 0, 1, 2], [1, 0, 2, 1]]) data = Data(edge_index=edge_index, num_nodes=3) data = AddRemainingSelfLoops()(data) assert len(data) == 2 assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]] assert data.num_nodes == 3 data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) data = AddRemainingSelfLoops(attr='edge_weight', fill_value=5)(data) assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]] assert data.num_nodes == 3 assert data.edge_weight.tolist() == [1, 3, 4, 2, 5, 5] data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) data = AddRemainingSelfLoops(attr='edge_attr', fill_value='add')(data) assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]] assert data.num_nodes == 3 assert data.edge_attr.tolist() == [[1, 2], [5, 6], [7, 8], [3, 4], [8, 10], [5, 6]] def test_add_remaining_self_loops_all_loops_exist(): # All self-loops already exist in the data object. edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) data = Data(edge_index=edge_index, num_nodes=3) data = AddRemainingSelfLoops()(data) assert data.edge_index.tolist() == edge_index.tolist() # All self-loops already exist in the data object, some of them appear # multiple times. edge_index = torch.tensor([[0, 0, 1, 1, 2], [0, 0, 1, 1, 2]]) data = Data(edge_index=edge_index, num_nodes=3) data = AddRemainingSelfLoops()(data) assert data.edge_index.tolist() == [[0, 1, 2], [0, 1, 2]] def test_hetero_add_remaining_self_loops(): edge_index = torch.tensor([[0, 0, 1, 2], [1, 0, 2, 1]]) data = HeteroData() data['v'].num_nodes = 3 data['w'].num_nodes = 3 data['v', 'v'].edge_index = edge_index data['v', 'w'].edge_index = edge_index data = AddRemainingSelfLoops()(data) assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]] assert data['v', 'w'].edge_index.tolist() == edge_index.tolist() ================================================ FILE: test/transforms/test_add_self_loops.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import AddSelfLoops def test_add_self_loops(): assert str(AddSelfLoops()) == 'AddSelfLoops()' assert len(AddSelfLoops()(Data())) == 0 edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.tensor([1, 2, 3, 4]) edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) data = Data(edge_index=edge_index, num_nodes=3) data = AddSelfLoops()(data) assert len(data) == 2 assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] assert data.num_nodes == 3 data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) data = AddSelfLoops(attr='edge_weight', fill_value=5)(data) assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] assert data.num_nodes == 3 assert data.edge_weight.tolist() == [1, 2, 3, 4, 5, 5, 5] data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) data = AddSelfLoops(attr='edge_attr', fill_value='add')(data) assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] assert data.num_nodes == 3 assert data.edge_attr.tolist() == [[1, 2], [3, 4], [5, 6], [7, 8], [3, 4], [8, 10], [5, 6]] def test_add_self_loops_with_existing_self_loops(): edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) data = Data(edge_index=edge_index, num_nodes=3) data = AddSelfLoops()(data) assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]] assert data.num_nodes == 3 def test_hetero_add_self_loops(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data = HeteroData() data['v'].num_nodes = 3 data['w'].num_nodes = 3 data['v', 'v'].edge_index = edge_index data['v', 'w'].edge_index = edge_index data = AddSelfLoops()(data) assert data['v', 'v'].edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] assert data['v', 'w'].edge_index.tolist() == edge_index.tolist() ================================================ FILE: test/transforms/test_cartesian.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import Cartesian def test_cartesian(): assert str(Cartesian()) == 'Cartesian(norm=True, max_value=None)' pos = torch.tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0]) data = Data(edge_index=edge_index, pos=pos) data = Cartesian(norm=False)(data) assert len(data) == 3 assert torch.equal(data.pos, pos) assert torch.equal(data.edge_index, edge_index) assert torch.allclose( data.edge_attr, torch.tensor([[-1.0, 0.0], [1.0, 0.0], [-2.0, 0.0], [2.0, 0.0]]), ) data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr) data = Cartesian(norm=True)(data) assert len(data) == 3 assert torch.equal(data.pos, pos) assert torch.equal(data.edge_index, edge_index) assert torch.allclose( data.edge_attr, torch.tensor([ [1, 0.25, 0.5], [2, 0.75, 0.5], [3, 0, 0.5], [4, 1, 0.5], ]), ) ================================================ FILE: test/transforms/test_center.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import Center def test_center(): transform = Center() assert str(transform) == 'Center()' pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]]) data = Data(pos=pos) data = transform(data) assert len(data) == 1 assert data.pos.tolist() == [[-2, 0], [0, 0], [2, 0]] ================================================ FILE: test/transforms/test_compose.py ================================================ import torch import torch_geometric.transforms as T from torch_geometric.data import Data def test_compose(): transform = T.Compose([T.Center(), T.AddSelfLoops()]) assert str(transform) == ('Compose([\n' ' Center(),\n' ' AddSelfLoops()\n' '])') pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data = Data(edge_index=edge_index, pos=pos) data = transform(data) assert len(data) == 2 assert data.pos.tolist() == [[-2.0, 0.0], [0.0, 0.0], [2.0, 0.0]] assert data.edge_index.size() == (2, 7) def test_compose_data_list(): transform = T.Compose([T.Center(), T.AddSelfLoops()]) pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data_list = [Data(edge_index=edge_index, pos=pos) for _ in range(3)] data_list = transform(data_list) assert len(data_list) == 3 for data in data_list: assert len(data) == 2 assert data.pos.tolist() == [[-2.0, 0.0], [0.0, 0.0], [2.0, 0.0]] assert data.edge_index.size() == (2, 7) def test_compose_filters(): filter_fn = T.ComposeFilters([ lambda d: d.num_nodes > 2, lambda d: d.num_edges > 2, ]) assert str(filter_fn)[:16] == 'ComposeFilters([' data1 = Data(x=torch.arange(3)) assert not filter_fn(data1) data2 = Data(x=torch.arange(2), edge_index=torch.tensor([ [0, 0, 1], [0, 1, 1], ])) assert not filter_fn(data2) data3 = Data(x=torch.arange(3), edge_index=torch.tensor([ [0, 0, 1], [0, 1, 1], ])) assert filter_fn(data3) # Test tuple of data objects: assert filter_fn((data1, data2, data3)) is False ================================================ FILE: test/transforms/test_constant.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import Constant def test_constant(): assert str(Constant()) == 'Constant(value=1.0)' x = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) edge_index = torch.tensor([[0, 1], [1, 2]]) data = Data(edge_index=edge_index, num_nodes=3) data = Constant()(data) assert len(data) == 3 assert data.edge_index.tolist() == edge_index.tolist() assert data.x.tolist() == [[1], [1], [1]] assert data.num_nodes == 3 data = Data(edge_index=edge_index, x=x) data = Constant()(data) assert len(data) == 2 assert data.edge_index.tolist() == edge_index.tolist() assert data.x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]] data = HeteroData() data['v'].x = x data = Constant()(data) assert len(data) == 1 assert data['v'].x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]] data = HeteroData() data['v'].x = x data['w'].x = x data = Constant(node_types='w')(data) assert len(data) == 1 assert data['v'].x.tolist() == x.tolist() assert data['w'].x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]] ================================================ FILE: test/transforms/test_delaunay.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.transforms import Delaunay def assert_one_point(transform: Delaunay) -> None: data = Data(pos=torch.rand(1, 2)) data = transform(data) assert len(data) == 2 assert data.edge_index.tolist() == [[], []] def assert_two_points(transform: Delaunay) -> None: data = Data(pos=torch.rand(2, 2)) data = transform(data) assert len(data) == 2 assert data.edge_index.tolist() == [[0, 1], [1, 0]] def assert_three_points(transform: Delaunay) -> None: data = Data(pos=torch.rand(3, 2)) data = transform(data) assert len(data) == 2 assert data.face.tolist() == [[0], [1], [2]] def assert_four_points(transform: Delaunay) -> None: pos = torch.tensor([ [-1.0, -1.0], [-1.0, 1.0], [1.0, 1.0], [0.5, -0.5], ]) data = Data(pos=pos) data = transform(data) assert len(data) == 2 # The order of the simplices does not matter, therefore assert sets. faces = set(map(tuple, data.face.tolist())) assert faces == {(3, 1), (1, 3), (0, 2)} @withPackage('scipy') def test_qhull_delaunay() -> None: transform = Delaunay() assert str(transform) == 'Delaunay()' assert_one_point(transform) assert_two_points(transform) assert_three_points(transform) assert_four_points(transform) @withPackage('torch_delaunay') def test_shull_delaunay() -> None: transform = Delaunay() assert str(transform) == 'Delaunay()' assert_one_point(transform) assert_two_points(transform) assert_three_points(transform) assert_four_points(transform) ================================================ FILE: test/transforms/test_distance.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import Distance def test_distance(): assert str(Distance()) == 'Distance(norm=True, max_value=None)' pos = torch.tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_attr = torch.tensor([1.0, 1.0, 1.0, 1.0]) data = Data(edge_index=edge_index, pos=pos) data = Distance(norm=False)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert data.edge_attr.tolist() == [[1.0], [1.0], [2.0], [2.0]] data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr) data = Distance(norm=True)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert data.edge_attr.tolist() == [ [1.0, 0.5], [1.0, 0.5], [1.0, 1.0], [1.0, 1.0], ] ================================================ FILE: test/transforms/test_face_to_edge.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import FaceToEdge def test_2d_face_to_edge() -> None: transform = FaceToEdge() assert str(transform) == 'FaceToEdge()' face = torch.tensor([[0, 0], [1, 1], [2, 3]]) data = Data(face=face, num_nodes=4) data = transform(data) assert len(data) == 2 assert data.edge_index.tolist() == [ [0, 0, 0, 1, 1, 1, 2, 2, 3, 3], [1, 2, 3, 0, 2, 3, 0, 1, 0, 1], ] assert data.num_nodes == 4 def test_3d_face_to_edge() -> None: transform = FaceToEdge() assert str(transform) == 'FaceToEdge()' face = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]).t() data = Data(face=face, num_nodes=5) data = transform(data) assert data.edge_index.tolist() == [ [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4], [1, 2, 3, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 1, 2, 3], ] assert data.num_nodes == 5 ================================================ FILE: test/transforms/test_feature_propagation.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import FeaturePropagation, ToSparseTensor def test_feature_propagation(): x = torch.randn(6, 4) x[0, 1] = float('nan') x[2, 3] = float('nan') missing_mask = torch.isnan(x) edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) transform = FeaturePropagation(missing_mask) assert str(transform) == ('FeaturePropagation(missing_features=8.3%, ' 'num_iterations=40)') data1 = Data(x=x, edge_index=edge_index) assert torch.isnan(data1.x).sum() == 2 data1 = FeaturePropagation(missing_mask)(data1) assert torch.isnan(data1.x).sum() == 0 assert data1.x.size() == x.size() data2 = Data(x=x, edge_index=edge_index) assert torch.isnan(data2.x).sum() == 2 data2 = ToSparseTensor()(data2) data2 = transform(data2) assert torch.isnan(data2.x).sum() == 0 assert torch.allclose(data1.x, data2.x) ================================================ FILE: test/transforms/test_fixed_points.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import FixedPoints def test_fixed_points(): assert str(FixedPoints(1024)) == 'FixedPoints(1024, replace=True)' data = Data( pos=torch.randn(100, 3), x=torch.randn(100, 16), y=torch.randn(1), edge_attr=torch.randn(100, 3), num_nodes=100, ) out = FixedPoints(50, replace=True)(data) assert len(out) == 5 assert out.pos.size() == (50, 3) assert out.x.size() == (50, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) assert out.num_nodes == 50 out = FixedPoints(200, replace=True)(data) assert len(out) == 5 assert out.pos.size() == (200, 3) assert out.x.size() == (200, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) assert out.num_nodes == 200 out = FixedPoints(50, replace=False, allow_duplicates=False)(data) assert len(out) == 5 assert out.pos.size() == (50, 3) assert out.x.size() == (50, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) assert out.num_nodes == 50 out = FixedPoints(200, replace=False, allow_duplicates=False)(data) assert len(out) == 5 assert out.pos.size() == (100, 3) assert out.x.size() == (100, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) assert out.num_nodes == 100 out = FixedPoints(50, replace=False, allow_duplicates=True)(data) assert len(out) == 5 assert out.pos.size() == (50, 3) assert out.x.size() == (50, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) assert out.num_nodes == 50 out = FixedPoints(200, replace=False, allow_duplicates=True)(data) assert len(out) == 5 assert out.pos.size() == (200, 3) assert out.x.size() == (200, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) assert out.num_nodes == 200 ================================================ FILE: test/transforms/test_gcn_norm.py ================================================ import torch import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.transforms import GCNNorm from torch_geometric.typing import SparseTensor def test_gcn_norm(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.ones(edge_index.size(1)) transform = GCNNorm() assert str(transform) == 'GCNNorm(add_self_loops=True)' expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] expected_edge_weight = torch.tensor( [0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000]) data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) data = transform(data) assert len(data) == 3 assert data.num_nodes == 3 assert data.edge_index.tolist() == expected_edge_index assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4) data = Data(edge_index=edge_index, num_nodes=3) data = transform(data) assert len(data) == 3 assert data.num_nodes == 3 assert data.edge_index.tolist() == expected_edge_index assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4) # For `SparseTensor`, expected outputs will be sorted: if torch_geometric.typing.WITH_TORCH_SPARSE: expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]] expected_edge_weight = torch.tensor( [0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]) adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t() data = Data(adj_t=adj_t) data = transform(data) assert len(data) == 1 row, col, value = data.adj_t.coo() assert row.tolist() == expected_edge_index[0] assert col.tolist() == expected_edge_index[1] assert torch.allclose(value, expected_edge_weight, atol=1e-4) ================================================ FILE: test/transforms/test_gdc.py ================================================ import torch from torch_geometric.datasets import KarateClub from torch_geometric.testing import withPackage from torch_geometric.transforms import GDC from torch_geometric.utils import to_dense_adj @withPackage('numba') def test_gdc(): data = KarateClub()[0] gdc = GDC( self_loop_weight=1, normalization_in='sym', normalization_out='sym', diffusion_kwargs=dict(method='ppr', alpha=0.15), sparsification_kwargs=dict(method='threshold', avg_degree=2), exact=True, ) out = gdc(data) mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() assert torch.all(mat >= -1e-8) assert torch.allclose(mat, mat.t(), atol=1e-4) gdc = GDC( self_loop_weight=1, normalization_in='sym', normalization_out='sym', diffusion_kwargs=dict(method='heat', t=10), sparsification_kwargs=dict(method='threshold', avg_degree=2), exact=True, ) out = gdc(data) mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() assert torch.all(mat >= -1e-8) assert torch.allclose(mat, mat.t(), atol=1e-4) gdc = GDC( self_loop_weight=1, normalization_in='col', normalization_out='col', diffusion_kwargs=dict(method='heat', t=10), sparsification_kwargs=dict(method='topk', k=2, dim=0), exact=True, ) out = gdc(data) mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() col_sum = mat.sum(0) assert torch.all(mat >= -1e-8) assert torch.all( torch.isclose(col_sum, torch.tensor(1.0)) | torch.isclose(col_sum, torch.tensor(0.0))) assert torch.all((~torch.isclose(mat, torch.tensor(0.0))).sum(0) == 2) gdc = GDC( self_loop_weight=1, normalization_in='row', normalization_out='row', diffusion_kwargs=dict(method='heat', t=5), sparsification_kwargs=dict(method='topk', k=2, dim=1), exact=True, ) out = gdc(data) mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() row_sum = mat.sum(1) assert torch.all(mat >= -1e-8) assert torch.all( torch.isclose(row_sum, torch.tensor(1.0)) | torch.isclose(row_sum, torch.tensor(0.0))) assert torch.all((~torch.isclose(mat, torch.tensor(0.0))).sum(1) == 2) gdc = GDC( self_loop_weight=1, normalization_in='row', normalization_out='row', diffusion_kwargs=dict(method='coeff', coeffs=[0.8, 0.3, 0.1]), sparsification_kwargs=dict(method='threshold', eps=0.1), exact=True, ) out = gdc(data) mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() row_sum = mat.sum(1) assert torch.all(mat >= -1e-8) assert torch.all( torch.isclose(row_sum, torch.tensor(1.0)) | torch.isclose(row_sum, torch.tensor(0.0))) gdc = GDC( self_loop_weight=1, normalization_in='sym', normalization_out='col', diffusion_kwargs=dict(method='ppr', alpha=0.15, eps=1e-4), sparsification_kwargs=dict(method='threshold', avg_degree=2), exact=False, ) out = gdc(data) mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() col_sum = mat.sum(0) assert torch.all(mat >= -1e-8) assert torch.all( torch.isclose(col_sum, torch.tensor(1.0)) | torch.isclose(col_sum, torch.tensor(0.0))) ================================================ FILE: test/transforms/test_generate_mesh_normals.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import GenerateMeshNormals def test_generate_mesh_normals(): transform = GenerateMeshNormals() assert str(transform) == 'GenerateMeshNormals()' pos = torch.tensor([ [0.0, 0.0, 0.0], [-2.0, 1.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0], ]) face = torch.tensor([ [0, 0, 0, 0], [1, 2, 3, 4], [2, 3, 4, 5], ]) data = transform(Data(pos=pos, face=face)) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.face.tolist() == face.tolist() assert data.norm.tolist() == [[0.0, 0.0, -1.0]] * 6 ================================================ FILE: test/transforms/test_grid_sampling.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.transforms import GridSampling @withPackage('torch_cluster') def test_grid_sampling(): assert str(GridSampling(5)) == 'GridSampling(size=5)' pos = torch.tensor([ [0.0, 2.0], [3.0, 2.0], [3.0, 2.0], [2.0, 8.0], [2.0, 6.0], ]) y = torch.tensor([0, 1, 1, 2, 2]) batch = torch.tensor([0, 0, 0, 0, 0]) data = Data(pos=pos, y=y, batch=batch) data = GridSampling(size=5, start=0)(data) assert len(data) == 3 assert data.pos.tolist() == [[2, 2], [2, 7]] assert data.y.tolist() == [1, 2] assert data.batch.tolist() == [0, 0] ================================================ FILE: test/transforms/test_half_hop.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import HalfHop def test_half_hop(): edge_index = torch.tensor([[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]) x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float) data = Data(x=x, edge_index=edge_index) transform = HalfHop() assert str(transform) == 'HalfHop(alpha=0.5, p=1.0)' data = transform(data) expected_edge_index = [[0, 1, 2, 0, 1, 1, 2, 3, 4, 5, 6, 1, 0, 2, 1], [0, 1, 2, 3, 4, 5, 6, 1, 0, 2, 1, 3, 4, 5, 6]] expected_x = torch.tensor( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [3, 4, 5, 6], [3, 4, 5, 6], [7, 8, 9, 10], [7, 8, 9, 10]], dtype=torch.float) assert len(data) == 3 assert data.num_nodes == 7 assert data.edge_index.tolist() == expected_edge_index assert torch.allclose(data.x, expected_x, atol=1e-4) assert data.slow_node_mask.tolist() == [ False, False, False, True, True, True, True ] torch.manual_seed(1) data = Data(x=x, edge_index=edge_index) transform = HalfHop(p=0.5) assert str(transform) == 'HalfHop(alpha=0.5, p=0.5)' data = transform(data) expected_edge_index = [[1, 0, 1, 2, 0, 1, 2, 3, 4, 5, 1, 2, 1], [0, 0, 1, 2, 3, 4, 5, 1, 2, 1, 3, 4, 5]] expected_x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [3, 4, 5, 6], [7, 8, 9, 10], [7, 8, 9, 10]], dtype=torch.float) assert data.num_nodes == 6 assert data.edge_index.tolist() == expected_edge_index assert torch.allclose(data.x, expected_x, atol=1e-4) assert data.slow_node_mask.tolist() == [ False, False, False, True, True, True ] ================================================ FILE: test/transforms/test_knn_graph.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.transforms import KNNGraph @withPackage('torch_cluster') def test_knn_graph(): assert str(KNNGraph()) == 'KNNGraph(k=6)' pos = torch.tensor([ [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 1.0], [-2.0, 0.0], [0.0, -2.0], ]) expected_row = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5] expected_col = [1, 2, 3, 4, 5, 0, 2, 3, 5, 0, 1, 0, 1, 4, 0, 3, 0, 1] data = Data(pos=pos) data = KNNGraph(k=2, force_undirected=True)(data) assert len(data) == 2 assert data.pos.tolist() == pos.tolist() assert data.edge_index[0].tolist() == expected_row assert data.edge_index[1].tolist() == expected_col ================================================ FILE: test/transforms/test_laplacian_lambda_max.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.transforms import LaplacianLambdaMax @withPackage('scipy') def test_laplacian_lambda_max(): out = str(LaplacianLambdaMax()) assert out == 'LaplacianLambdaMax(normalization=None)' edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) edge_attr = torch.tensor([1, 1, 2, 2], dtype=torch.float) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) out = LaplacianLambdaMax(normalization=None, is_undirected=True)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(4.732049)) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) out = LaplacianLambdaMax(normalization='sym', is_undirected=True)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0)) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) out = LaplacianLambdaMax(normalization='rw', is_undirected=True)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0)) data = Data(edge_index=edge_index, edge_attr=torch.randn(4, 2), num_nodes=3) out = LaplacianLambdaMax(normalization=None)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(3.0)) ================================================ FILE: test/transforms/test_largest_connected_components.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.transforms import LargestConnectedComponents @withPackage('scipy') def test_largest_connected_components(): assert str(LargestConnectedComponents()) == 'LargestConnectedComponents(1)' edge_index = torch.tensor([ [0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 5, 6, 8, 9], [1, 2, 0, 2, 0, 1, 3, 2, 4, 3, 6, 7, 9, 8], ]) data = Data(edge_index=edge_index, num_nodes=10) # Testing without `connection` specified: transform = LargestConnectedComponents(num_components=2) out = transform(data) assert out.num_nodes == 8 assert out.edge_index.tolist() == data.edge_index[:, :12].tolist() # Testing with `connection = strong`: transform = LargestConnectedComponents(num_components=2, connection='strong') out = transform(data) assert out.num_nodes == 7 assert out.edge_index.tolist() == [[0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 5, 6], [1, 2, 0, 2, 0, 1, 3, 2, 4, 3, 6, 5]] edge_index = torch.tensor([ [0, 1, 2, 3, 3, 4], [1, 0, 3, 2, 4, 3], ]) data = Data(edge_index=edge_index, num_nodes=5) # Testing without `num_components` and `connection` specified: transform = LargestConnectedComponents() out = transform(data) assert out.num_nodes == 3 assert out.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] # Testing with larger `num_components` than actual number of components: transform = LargestConnectedComponents(num_components=3) out = transform(data) assert out.num_nodes == 5 assert out.edge_index.tolist() == data.edge_index.tolist() ================================================ FILE: test/transforms/test_line_graph.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import LineGraph def test_line_graph(): transform = LineGraph() assert str(transform) == 'LineGraph()' # Directed. edge_index = torch.tensor([ [0, 1, 2, 2, 3], [1, 2, 0, 3, 0], ]) data = Data(edge_index=edge_index, num_nodes=4) data = transform(data) assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 2, 3, 0, 4, 0]] assert data.num_nodes == data.edge_index.max().item() + 1 # Undirected. edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4], [1, 2, 3, 0, 4, 0, 3, 0, 2, 4, 1, 3]]) edge_attr = torch.ones(edge_index.size(1)) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=5) data = transform(data) assert data.edge_index.max().item() + 1 == data.x.size(0) assert data.edge_index.tolist() == [ [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5], [1, 2, 3, 0, 2, 4, 0, 1, 4, 5, 0, 5, 1, 2, 5, 2, 3, 4], ] assert data.x.tolist() == [2, 2, 2, 2, 2, 2] assert data.num_nodes == data.edge_index.max().item() + 1 ================================================ FILE: test/transforms/test_linear_transformation.py ================================================ import pytest import torch from torch_geometric.data import Data from torch_geometric.transforms import LinearTransformation @pytest.mark.parametrize('matrix', [ [[2.0, 0.0], [0.0, 2.0]], torch.tensor([[2.0, 0.0], [0.0, 2.0]]), ]) def test_linear_transformation(matrix): pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]]) transform = LinearTransformation(matrix) assert str(transform) == ('LinearTransformation(\n' '[[2. 0.]\n' ' [0. 2.]]\n' ')') out = transform(Data(pos=pos)) assert len(out) == 1 assert torch.allclose(out.pos, 2 * pos) out = transform(Data()) assert len(out) == 0 ================================================ FILE: test/transforms/test_local_cartesian.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import LocalCartesian def test_local_cartesian(): transform = LocalCartesian() assert str(transform) == 'LocalCartesian()' pos = torch.tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]]) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0]) data = Data(edge_index=edge_index, pos=pos) data = transform(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert data.edge_attr.tolist() == [[0.25, 0.5], [1.0, 0.5], [0.0, 0.5], [1.0, 0.5]] data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr) data = transform(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert data.edge_attr.tolist() == [[1, 0.25, 0.5], [2, 1.0, 0.5], [3, 0.0, 0.5], [4, 1.0, 0.5]] ================================================ FILE: test/transforms/test_local_degree_profile.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import LocalDegreeProfile def test_target_indegree(): assert str(LocalDegreeProfile()) == 'LocalDegreeProfile()' edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) x = torch.tensor([[1.0], [1.0], [1.0], [1.0]]) # One isolated node. expected = torch.tensor([ [1, 2, 2, 2, 0], [2, 1, 1, 1, 0], [1, 2, 2, 2, 0], [0, 0, 0, 0, 0], ], dtype=torch.float) data = Data(edge_index=edge_index, num_nodes=x.size(0)) data = LocalDegreeProfile()(data) assert torch.allclose(data.x, expected, atol=1e-2) data = Data(edge_index=edge_index, x=x) data = LocalDegreeProfile()(data) assert torch.allclose(data.x[:, :1], x) assert torch.allclose(data.x[:, 1:], expected, atol=1e-2) ================================================ FILE: test/transforms/test_mask_transform.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import IndexToMask, MaskToIndex def test_index_to_mask(): assert str(IndexToMask()) == ('IndexToMask(attrs=None, sizes=None, ' 'replace=False)') edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3]]) train_index = torch.arange(0, 3) test_index = torch.arange(3, 5) data = Data(edge_index=edge_index, train_index=train_index, test_index=test_index, num_nodes=5) out = IndexToMask(replace=True)(data) assert len(out) == len(data) assert out.train_mask.tolist() == [True, True, True, False, False] assert out.test_mask.tolist() == [False, False, False, True, True] out = IndexToMask(replace=False)(data) assert len(out) == len(data) + 2 out = IndexToMask(sizes=6, replace=True)(data) assert out.train_mask.tolist() == [True, True, True, False, False, False] assert out.test_mask.tolist() == [False, False, False, True, True, False] out = IndexToMask(attrs='train_index')(data) assert len(out) == len(data) + 1 assert 'train_index' in out assert 'train_mask' in out assert 'test_index' in out assert 'test_mask' not in out def test_mask_to_index(): assert str(MaskToIndex()) == 'MaskToIndex(attrs=None, replace=False)' train_mask = torch.tensor([True, True, True, False, False]) test_mask = torch.tensor([False, False, False, True, True]) data = Data(train_mask=train_mask, test_mask=test_mask) out = MaskToIndex(replace=True)(data) assert len(out) == len(data) assert out.train_index.tolist() == [0, 1, 2] assert out.test_index.tolist() == [3, 4] out = MaskToIndex(replace=False)(data) assert len(out) == len(data) + 2 out = MaskToIndex(attrs='train_mask')(data) assert len(out) == len(data) + 1 assert 'train_mask' in out assert 'train_index' in out assert 'test_mask' in out assert 'test_index' not in out def test_hetero_index_to_mask(): data = HeteroData() data['u'].train_index = torch.arange(0, 3) data['u'].test_index = torch.arange(3, 5) data['u'].num_nodes = 5 data['v'].train_index = torch.arange(0, 3) data['v'].test_index = torch.arange(3, 5) data['v'].num_nodes = 5 out = IndexToMask()(data) assert len(out) == len(data) + 2 assert 'train_mask' in out['u'] assert 'test_mask' in out['u'] assert 'train_mask' in out['v'] assert 'test_mask' in out['v'] def test_hetero_mask_to_index(): data = HeteroData() data['u'].train_mask = torch.tensor([True, True, True, False, False]) data['u'].test_mask = torch.tensor([False, False, False, True, True]) data['v'].train_mask = torch.tensor([True, True, True, False, False]) data['v'].test_mask = torch.tensor([False, False, False, True, True]) out = MaskToIndex()(data) assert len(out) == len(data) + 2 assert 'train_index' in out['u'] assert 'test_index' in out['u'] assert 'train_index' in out['v'] assert 'test_index' in out['v'] ================================================ FILE: test/transforms/test_node_property_split.py ================================================ import pytest import torch from torch_geometric.datasets import graph_generator from torch_geometric.testing import withPackage from torch_geometric.transforms import NodePropertySplit @withPackage('networkx', 'scipy') @pytest.mark.parametrize('property_name', [ 'popularity', 'locality', 'density', ]) def test_node_property_split(property_name): ratios = [0.3, 0.1, 0.1, 0.2, 0.3] transform = NodePropertySplit(property_name, ratios) assert str(transform) == f'NodePropertySplit({property_name})' data = graph_generator.ERGraph(num_nodes=100, edge_prob=0.4)() data = transform(data) node_ids = [] for name, ratio in zip([ 'id_train_mask', 'id_val_mask', 'id_test_mask', 'ood_val_mask', 'ood_test_mask', ], ratios): assert data[name].dtype == torch.bool assert data[name].size() == (100, ) assert int(data[name].sum()) == 100 * ratio node_ids.extend(data[name].nonzero().view(-1).tolist()) # Check that masks are non-intersecting and cover all nodes: node_ids = torch.tensor(node_ids) assert node_ids.numel() == torch.unique(node_ids).numel() == 100 ================================================ FILE: test/transforms/test_normalize_features.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import NormalizeFeatures def test_normalize_scale(): transform = NormalizeFeatures() assert str(transform) == 'NormalizeFeatures()' x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float) data = Data(x=x) data = transform(data) assert len(data) == 1 assert data.x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] def test_hetero_normalize_scale(): x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float) data = HeteroData() data['v'].x = x data['w'].x = x data = NormalizeFeatures()(data) assert data['v'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] assert data['w'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] ================================================ FILE: test/transforms/test_normalize_rotation.py ================================================ from math import sqrt import torch from torch_geometric.data import Data from torch_geometric.transforms import NormalizeRotation def test_normalize_rotation(): assert str(NormalizeRotation()) == 'NormalizeRotation()' pos = torch.tensor([ [-2.0, -2.0], [-1.0, -1.0], [0.0, 0.0], [1.0, 1.0], [2.0, 2.0], ]) normal = torch.tensor([ [-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], ]) data = Data(pos=pos) data.normal = normal data = NormalizeRotation()(data) assert len(data) == 2 expected_pos = torch.tensor([ [-2 * sqrt(2), 0.0], [-sqrt(2), 0.0], [0.0, 0.0], [sqrt(2), 0.0], [2 * sqrt(2), 0.0], ]) expected_normal = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] assert torch.allclose(data.pos, expected_pos, atol=1e-04) assert data.normal.tolist() == expected_normal data = Data(pos=pos) data.normal = normal data = NormalizeRotation(max_points=3)(data) assert len(data) == 2 assert torch.allclose(data.pos, expected_pos, atol=1e-04) assert data.normal.tolist() == expected_normal data = Data(pos=pos) data.normal = normal data = NormalizeRotation(sort=True)(data) assert len(data) == 2 assert torch.allclose(data.pos, expected_pos, atol=1e-04) assert data.normal.tolist() == expected_normal ================================================ FILE: test/transforms/test_normalize_scale.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import NormalizeScale def test_normalize_scale(): transform = NormalizeScale() assert str(transform) == 'NormalizeScale()' pos = torch.randn((10, 3)) data = Data(pos=pos) data = transform(data) assert len(data) == 1 assert data.pos.min().item() > -1 assert data.pos.max().item() < 1 ================================================ FILE: test/transforms/test_one_hot_degree.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import OneHotDegree def test_one_hot_degree(): assert str(OneHotDegree(max_degree=3)) == 'OneHotDegree(3)' edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) x = torch.tensor([1.0, 1.0, 1.0, 1.0]) data = Data(edge_index=edge_index, num_nodes=4) data = OneHotDegree(max_degree=3)(data) assert len(data) == 3 assert data.edge_index.tolist() == edge_index.tolist() assert data.x.tolist() == [ [0.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], ] assert data.num_nodes == 4 data = Data(edge_index=edge_index, x=x) data = OneHotDegree(max_degree=3)(data) assert len(data) == 2 assert data.edge_index.tolist() == edge_index.tolist() assert data.x.tolist() == [ [1.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0], ] ================================================ FILE: test/transforms/test_pad.py ================================================ import numbers from typing import Dict, Generator, List, Optional, Tuple, Union import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.datasets import FakeDataset, FakeHeteroDataset from torch_geometric.transforms import Pad from torch_geometric.transforms.pad import ( AttrNamePadding, EdgeTypePadding, NodeTypePadding, Padding, UniformPadding, ) from torch_geometric.typing import EdgeType, NodeType def fake_data() -> Data: return FakeDataset(avg_num_nodes=10, avg_degree=5, edge_dim=2)[0] def fake_hetero_data(node_types=2, edge_types=5) -> HeteroData: return FakeHeteroDataset(num_node_types=node_types, num_edge_types=edge_types, avg_num_nodes=10, edge_dim=2)[0] def _generate_homodata_node_attrs(data: Data) -> Generator[str, None, None]: for attr in data.keys(): if data.is_node_attr(attr): yield attr def _generate_homodata_edge_attrs(data: Data) -> Generator[str, None, None]: for attr in data.keys(): if data.is_edge_attr(attr): yield attr def _generate_heterodata_nodes( data: HeteroData ) -> Generator[Tuple[NodeType, str, torch.Tensor], None, None]: for node_type, store in data.node_items(): for attr in store.keys(): yield node_type, attr def _generate_heterodata_edges( data: HeteroData ) -> Generator[Tuple[EdgeType, str, torch.Tensor], None, None]: for edge_type, store in data.edge_items(): for attr in store.keys(): yield edge_type, attr def _check_homo_data_nodes( original: Data, padded: Data, max_num_nodes: Union[int, Dict[NodeType, int]], node_pad_value: Optional[Padding] = None, is_mask_available: bool = False, exclude_keys: Optional[List[str]] = None, ): assert padded.num_nodes == max_num_nodes compare_pad_start_idx = original.num_nodes if is_mask_available: assert padded.pad_node_mask.numel() == padded.num_nodes assert torch.all(padded.pad_node_mask[:compare_pad_start_idx]) assert not torch.any(padded.pad_node_mask[compare_pad_start_idx:]) for attr in _generate_homodata_node_attrs(original): if attr in exclude_keys: assert attr not in padded.keys() continue assert attr in padded.keys() if not isinstance(padded[attr], torch.Tensor): continue assert padded[attr].shape[0] == max_num_nodes # Check values in padded area. pad_value = node_pad_value.get_value( None, attr) if node_pad_value is not None else 0.0 assert all( i == pad_value for i in torch.flatten(padded[attr][compare_pad_start_idx:])) # Check values in non-padded area. assert torch.equal(original[attr], padded[attr][:compare_pad_start_idx]) def _check_homo_data_edges( original: Data, padded: Data, max_num_edges: Optional[int] = None, edge_pad_value: Optional[Padding] = None, is_mask_available: bool = False, exclude_keys: Optional[List[str]] = None, ): # Check edge index attribute. if max_num_edges is None: max_num_edges = padded.num_nodes**2 assert padded.num_edges == max_num_edges assert padded.edge_index.shape[1] == max_num_edges assert padded.edge_index.shape[1] == max_num_edges compare_pad_start_idx = original.num_edges expected_node = original.num_nodes # Check values in padded area. assert all( padded.edge_index[0, i] == padded.edge_index[1, i] == expected_node for i in range(compare_pad_start_idx, max_num_edges)) # Check values in non-padded area. assert torch.equal(original.edge_index, padded.edge_index[:, :compare_pad_start_idx]) if is_mask_available: assert padded.pad_edge_mask.numel() == padded.num_edges assert torch.all(padded.pad_edge_mask[:compare_pad_start_idx]) assert not torch.any(padded.pad_edge_mask[compare_pad_start_idx:]) # Check other attributes. for attr in _generate_homodata_edge_attrs(original): if attr == 'edge_index': continue if attr in exclude_keys: assert attr not in padded.keys() continue assert attr in padded.keys() if not isinstance(padded[attr], torch.Tensor): continue assert padded[attr].shape[0] == max_num_edges # Check values in padded area. pad_value = edge_pad_value.get_value( None, attr) if edge_pad_value is not None else 0.0 assert all( i == pad_value for i in torch.flatten(padded[attr][compare_pad_start_idx:, :])) # Check values in non-padded area. assert torch.equal(original[attr], padded[attr][:compare_pad_start_idx, :]) def _check_hetero_data_nodes( original: HeteroData, padded: HeteroData, max_num_nodes: Union[int, Dict[NodeType, int]], node_pad_value: Optional[Padding] = None, is_mask_available: bool = False, exclude_keys: Optional[List[str]] = None, ): if is_mask_available: for store in padded.node_stores: assert 'pad_node_mask' in store expected_nodes = max_num_nodes for node_type, attr in _generate_heterodata_nodes(original): if attr in exclude_keys: assert attr not in padded[node_type].keys() continue assert attr in padded[node_type].keys() if not isinstance(padded[node_type][attr], torch.Tensor): continue compare_pad_start_idx = original[node_type].num_nodes padded_tensor = padded[node_type][attr] if attr == 'pad_node_mask': assert padded_tensor.numel() == padded[node_type].num_nodes assert torch.all(padded_tensor[:compare_pad_start_idx]) assert not torch.any(padded_tensor[compare_pad_start_idx:]) continue original_tensor = original[node_type][attr] # Check the number of nodes. if isinstance(max_num_nodes, dict): expected_nodes = max_num_nodes[node_type] assert padded_tensor.shape[0] == expected_nodes compare_pad_start_idx = original_tensor.shape[0] pad_value = node_pad_value.get_value( node_type, attr) if node_pad_value is not None else 0.0 assert all( i == pad_value for i in torch.flatten(padded_tensor[compare_pad_start_idx:])) # Compare non-padded area with the original. assert torch.equal(original_tensor, padded_tensor[:compare_pad_start_idx]) def _check_hetero_data_edges( original: HeteroData, padded: HeteroData, max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None, edge_pad_value: Optional[Padding] = None, is_mask_available: bool = False, exclude_keys: Optional[List[str]] = None, ): if is_mask_available: for store in padded.edge_stores: assert 'pad_edge_mask' in store for edge_type, attr in _generate_heterodata_edges(padded): if attr in exclude_keys: assert attr not in padded[edge_type].keys() continue assert attr in padded[edge_type].keys() if not isinstance(padded[edge_type][attr], torch.Tensor): continue compare_pad_start_idx = original[edge_type].num_edges padded_tensor = padded[edge_type][attr] if attr == 'pad_edge_mask': assert padded_tensor.numel() == padded[edge_type].num_edges assert torch.all(padded_tensor[:compare_pad_start_idx]) assert not torch.any(padded_tensor[compare_pad_start_idx:]) continue original_tensor = original[edge_type][attr] if isinstance(max_num_edges, numbers.Number): expected_num_edges = max_num_edges elif max_num_edges is None or edge_type not in max_num_edges.keys(): v1, _, v2 = edge_type expected_num_edges = padded[v1].num_nodes * padded[v2].num_nodes else: expected_num_edges = max_num_edges[edge_type] if attr == 'edge_index': # Check the number of edges. assert padded_tensor.shape[1] == expected_num_edges # Check padded area values. src_nodes = original[edge_type[0]].num_nodes assert all( i == src_nodes for i in torch.flatten(padded_tensor[0, compare_pad_start_idx:])) dst_nodes = original[edge_type[2]].num_nodes assert all( i == dst_nodes for i in torch.flatten(padded_tensor[1, compare_pad_start_idx:])) # Compare non-padded area with the original. assert torch.equal(original_tensor, padded_tensor[:, :compare_pad_start_idx]) else: # Check padded area size. assert padded_tensor.shape[0] == expected_num_edges # Check padded area values. pad_value = edge_pad_value.get_value( edge_type, attr) if edge_pad_value is not None else 0.0 assert all(i == pad_value for i in torch.flatten(padded_tensor[ compare_pad_start_idx:, :])) # Compare non-padded area with the original. assert torch.equal(original_tensor, padded_tensor[:compare_pad_start_idx, :]) def _check_data( original: Union[Data, HeteroData], padded: Union[Data, HeteroData], max_num_nodes: Union[int, Dict[NodeType, int]], max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None, node_pad_value: Optional[Union[Padding, int, float]] = None, edge_pad_value: Optional[Union[Padding, int, float]] = None, is_mask_available: bool = False, exclude_keys: Optional[List[str]] = None, ): if not isinstance(node_pad_value, Padding) and node_pad_value is not None: node_pad_value = UniformPadding(node_pad_value) if not isinstance(edge_pad_value, Padding) and edge_pad_value is not None: edge_pad_value = UniformPadding(edge_pad_value) if is_mask_available is None: is_mask_available = False if exclude_keys is None: exclude_keys = [] if isinstance(original, Data): _check_homo_data_nodes(original, padded, max_num_nodes, node_pad_value, is_mask_available, exclude_keys) _check_homo_data_edges(original, padded, max_num_edges, edge_pad_value, is_mask_available, exclude_keys) else: _check_hetero_data_nodes(original, padded, max_num_nodes, node_pad_value, is_mask_available, exclude_keys) _check_hetero_data_edges(original, padded, max_num_edges, edge_pad_value, is_mask_available, exclude_keys) def test_pad_repr(): pad_str = 'Pad(max_num_nodes=10, max_num_edges=15, ' \ 'node_pad_value=UniformPadding(value=3.0), ' \ 'edge_pad_value=UniformPadding(value=1.5))' assert str(eval(pad_str)) == pad_str @pytest.mark.parametrize('data', [fake_data(), fake_hetero_data()]) @pytest.mark.parametrize('num_nodes', [32, 64]) @pytest.mark.parametrize('add_pad_mask', [True, False]) def test_pad_auto_edges(data, num_nodes, add_pad_mask): transform = Pad(max_num_nodes=num_nodes, add_pad_mask=add_pad_mask) out = transform(data) _check_data(data, out, num_nodes, is_mask_available=add_pad_mask) @pytest.mark.parametrize('num_nodes', [32, 64]) @pytest.mark.parametrize('num_edges', [300, 411]) @pytest.mark.parametrize('add_pad_mask', [True, False]) def test_pad_data_explicit_edges(num_nodes, num_edges, add_pad_mask): data = fake_data() transform = Pad(max_num_nodes=num_nodes, max_num_edges=num_edges, add_pad_mask=add_pad_mask) out = transform(data) _check_data(data, out, num_nodes, num_edges, is_mask_available=add_pad_mask) @pytest.mark.parametrize('num_nodes', [32, {'v0': 64, 'v1': 36}]) @pytest.mark.parametrize('num_edges', [300, {('v0', 'e0', 'v1'): 397}]) @pytest.mark.parametrize('add_pad_mask', [True, False]) def test_pad_heterodata_explicit_edges(num_nodes, num_edges, add_pad_mask): data = fake_hetero_data() transform = Pad(max_num_nodes=num_nodes, max_num_edges=num_edges, add_pad_mask=add_pad_mask) out = transform(data) _check_data(data, out, num_nodes, num_edges, is_mask_available=add_pad_mask) @pytest.mark.parametrize('node_pad_value', [10, AttrNamePadding({'x': 3.0})]) @pytest.mark.parametrize('edge_pad_value', [11, AttrNamePadding({'edge_attr': 2.0})]) def test_pad_data_pad_values(node_pad_value, edge_pad_value): data = fake_data() num_nodes = 32 transform = Pad(max_num_nodes=num_nodes, node_pad_value=node_pad_value, edge_pad_value=edge_pad_value) out = transform(data) _check_data(data, out, num_nodes, node_pad_value=node_pad_value, edge_pad_value=edge_pad_value) @pytest.mark.parametrize('node_pad_value', [ UniformPadding(12), AttrNamePadding({'x': 0}), NodeTypePadding({ 'v0': UniformPadding(12), 'v1': AttrNamePadding({'x': 7}) }) ]) @pytest.mark.parametrize('edge_pad_value', [ UniformPadding(13), EdgeTypePadding({ ('v0', 'e0', 'v1'): UniformPadding(13), ('v1', 'e0', 'v0'): AttrNamePadding({'edge_attr': UniformPadding(-1.0)}) }) ]) def test_pad_heterodata_pad_values(node_pad_value, edge_pad_value): data = fake_hetero_data() num_nodes = 32 transform = Pad(max_num_nodes=num_nodes, node_pad_value=node_pad_value, edge_pad_value=edge_pad_value) out = transform(data) _check_data(data, out, num_nodes, node_pad_value=node_pad_value, edge_pad_value=edge_pad_value) @pytest.mark.parametrize('data', [fake_data(), fake_hetero_data()]) @pytest.mark.parametrize('add_pad_mask', [True, False]) @pytest.mark.parametrize('exclude_keys', [ ['y'], ['edge_attr'], ['y', 'edge_attr'], ]) def test_pad_data_exclude_keys(data, add_pad_mask, exclude_keys): num_nodes = 32 transform = Pad(max_num_nodes=num_nodes, add_pad_mask=add_pad_mask, exclude_keys=exclude_keys) out = transform(data) _check_data(data, out, num_nodes, is_mask_available=add_pad_mask, exclude_keys=exclude_keys) @pytest.mark.parametrize('is_hetero', [False, True]) def test_pad_invalid_max_num_nodes(is_hetero): if is_hetero: data = fake_hetero_data(node_types=1) else: data = fake_data() transform = Pad(max_num_nodes=data.num_nodes - 1) with pytest.raises(AssertionError, match="after padding"): transform(data) @pytest.mark.parametrize('is_hetero', [False, True]) def test_pad_invalid_max_num_edges(is_hetero): if is_hetero: data = fake_hetero_data(node_types=1, edge_types=1) else: data = fake_data() transform = Pad(max_num_nodes=data.num_nodes + 10, max_num_edges=data.num_edges - 1) with pytest.raises(AssertionError, match="after padding"): transform(data) def test_pad_num_nodes_not_complete(): data = fake_hetero_data(node_types=2, edge_types=1) transform = Pad(max_num_nodes={'v0': 100}) with pytest.raises(KeyError): transform(data) def test_pad_invalid_padding_type(): with pytest.raises(ValueError, match="to be an integer or float"): Pad(max_num_nodes=100, node_pad_value='somestring') with pytest.raises(ValueError, match="to be an integer or float"): Pad(max_num_nodes=100, edge_pad_value='somestring') def test_pad_data_non_tensor_attr(): data = fake_data() batch_size = 13 data.batch_size = batch_size transform = Pad(max_num_nodes=100) padded = transform(data) assert padded.batch_size == batch_size exclude_transform = Pad(max_num_nodes=101, exclude_keys=('batch_size', )) padded = exclude_transform(data) assert 'batch_size' not in padded.keys() @pytest.mark.parametrize('mask_pad_value', [True, False]) def test_pad_node_additional_attr_mask(mask_pad_value): data = fake_data() mask = torch.randn(data.num_nodes) > 0 mask_names = ['train_mask', 'test_mask', 'val_mask'] for mask_name in mask_names: setattr(data, mask_name, mask) padding_num = 20 max_num_nodes = int(data.num_nodes) + padding_num max_num_edges = data.num_edges + padding_num transform = Pad(max_num_nodes, max_num_edges, node_pad_value=0.1, mask_pad_value=mask_pad_value) padded = transform(data) padded_masks = [getattr(padded, mask_name) for mask_name in mask_names] for padded_mask in padded_masks: assert padded_mask.ndim == 1 assert padded_mask.size()[0] == max_num_nodes assert torch.all(padded_mask[-padding_num:] == mask_pad_value) def test_uniform_padding(): pad_val = 10.0 p = UniformPadding(pad_val) assert p.get_value() == pad_val assert p.get_value("v1", "x") == pad_val p = UniformPadding() assert p.get_value() == 0.0 with pytest.raises(ValueError, match="to be an integer or float"): UniformPadding('') def test_attr_name_padding(): x_val = 10.0 y_val = 15.0 default = 3.0 padding_dict = {'x': x_val, 'y': UniformPadding(y_val)} padding = AttrNamePadding(padding_dict, default=default) assert padding.get_value(attr_name='x') == x_val assert padding.get_value('v1', 'x') == x_val assert padding.get_value(attr_name='y') == y_val assert padding.get_value('v1', 'y') == y_val assert padding.get_value(attr_name='x2') == default padding = AttrNamePadding({}) assert padding.get_value(attr_name='x') == 0.0 def test_attr_name_padding_invalid(): with pytest.raises(ValueError, match="to be a dictionary"): AttrNamePadding(10.0) with pytest.raises(ValueError, match="to be a string"): AttrNamePadding({10: 10.0}) with pytest.raises(ValueError, match="to be of type"): AttrNamePadding({"x": {}}) with pytest.raises(ValueError, match="to be of type"): AttrNamePadding({"x": {}}) node_type_padding = NodeTypePadding({"x": 10.0}) with pytest.raises(ValueError, match="to be of type"): AttrNamePadding({'x': node_type_padding}) @pytest.mark.parametrize('store_type', ['node', 'edge']) def test_node_edge_type_padding(store_type): if store_type == "node": stores = ['v1', 'v2', 'v3', 'v4'] padding_cls = NodeTypePadding else: stores = [('v1', 'e1', 'v1'), ('v1', 'e2', 'v1'), ('v1', 'e1', 'v2'), ('v2', 'e1', 'v1')] padding_cls = EdgeTypePadding s0_default = 3.0 s0_padding_dict = {'x': 10.0, 'y': -12.0} s0_padding = AttrNamePadding(s0_padding_dict, s0_default) s1_default = 0.1 s1_padding_dict = {'y': 0.0, 'p': 13.0} s1_padding = AttrNamePadding(s1_padding_dict, s1_default) s2_default = 7.5 store_default = -11.0 padding_dict = { stores[0]: s0_padding, stores[1]: s1_padding, stores[2]: s2_default } padding = padding_cls(padding_dict, store_default) assert padding.get_value(stores[0], 'x') == s0_padding_dict['x'] assert padding.get_value(stores[0], 'y') == s0_padding_dict['y'] assert padding.get_value(stores[0], 'p') == s0_default assert padding.get_value(stores[0], 'z') == s0_default assert padding.get_value(stores[1], 'x') == s1_default assert padding.get_value(stores[1], 'y') == s1_padding_dict['y'] assert padding.get_value(stores[1], 'p') == s1_padding_dict['p'] assert padding.get_value(stores[1], 'z') == s1_default assert padding.get_value(stores[2], 'x') == s2_default assert padding.get_value(stores[2], 'z') == s2_default assert padding.get_value(stores[3], 'x') == store_default def test_edge_padding_invalid(): with pytest.raises(ValueError, match="to be a tuple"): EdgeTypePadding({'v1': 10.0}) with pytest.raises(ValueError, match="got 1"): EdgeTypePadding({('v1', ): 10.0}) with pytest.raises(ValueError, match="got 2"): EdgeTypePadding({('v1', 'v2'): 10.0}) with pytest.raises(ValueError, match="got 4"): EdgeTypePadding({('v1', 'e2', 'v1', 'v2'): 10.0}) ================================================ FILE: test/transforms/test_point_pair_features.py ================================================ from math import pi as PI import torch from torch_geometric.data import Data from torch_geometric.transforms import PointPairFeatures def test_point_pair_features(): transform = PointPairFeatures() assert str(transform) == 'PointPairFeatures()' pos = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) edge_index = torch.tensor([[0, 1], [1, 0]]) norm = torch.tensor([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) edge_attr = torch.tensor([1.0, 1.0]) data = Data(edge_index=edge_index, pos=pos, norm=norm) data = transform(data) assert len(data) == 4 assert data.pos.tolist() == pos.tolist() assert data.norm.tolist() == norm.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 0.0, 0.0, 0.0], [1.0, PI, PI, 0.0]]), atol=1e-4, ) data = Data(edge_index=edge_index, pos=pos, norm=norm, edge_attr=edge_attr) data = transform(data) assert len(data) == 4 assert data.pos.tolist() == pos.tolist() assert data.norm.tolist() == norm.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 1.0, 0.0, 0.0, 0.0], [1.0, 1.0, PI, PI, 0.0]]), atol=1e-4, ) ================================================ FILE: test/transforms/test_polar.py ================================================ from math import pi as PI import torch from torch_geometric.data import Data from torch_geometric.transforms import Polar def test_polar(): assert str(Polar()) == 'Polar(norm=True, max_value=None)' pos = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) edge_index = torch.tensor([[0, 1], [1, 0]]) edge_attr = torch.tensor([1.0, 1.0]) data = Data(edge_index=edge_index, pos=pos) data = Polar(norm=False)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 0.0], [1.0, PI]]), atol=1e-4, ) data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr) data = Polar(norm=True)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.5]]), atol=1e-4, ) ================================================ FILE: test/transforms/test_radius_graph.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.transforms import RadiusGraph from torch_geometric.utils import coalesce @withPackage('torch_cluster') def test_radius_graph(): assert str(RadiusGraph(r=1)) == 'RadiusGraph(r=1)' pos = torch.tensor([ [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 1.0], [-2.0, 0.0], [0.0, -2.0], ]) data = Data(pos=pos) data = RadiusGraph(r=1.5)(data) assert len(data) == 2 assert data.pos.tolist() == pos.tolist() assert coalesce(data.edge_index).tolist() == [[0, 0, 1, 1, 1, 2, 3, 3], [1, 3, 0, 2, 3, 1, 0, 1]] ================================================ FILE: test/transforms/test_random_flip.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import RandomFlip def test_random_flip(): assert str(RandomFlip(axis=0)) == 'RandomFlip(axis=0, p=0.5)' pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]]) data = Data(pos=pos) data = RandomFlip(axis=0, p=1)(data) assert len(data) == 1 assert data.pos.tolist() == [[1.0, 1.0], [3.0, 0.0], [-2.0, -1.0]] data = Data(pos=pos) data = RandomFlip(axis=1, p=1)(data) assert len(data) == 1 assert data.pos.tolist() == [[-1.0, -1.0], [-3.0, 0.0], [2.0, 1.0]] ================================================ FILE: test/transforms/test_random_jitter.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import RandomJitter def test_random_jitter(): assert str(RandomJitter(0.1)) == 'RandomJitter(0.1)' pos = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) data = Data(pos=pos) data = RandomJitter(0)(data) assert len(data) == 1 assert torch.allclose(data.pos, pos) data = Data(pos=pos) data = RandomJitter(0.1)(data) assert len(data) == 1 assert data.pos.min() >= -0.1 assert data.pos.max() <= 0.1 data = Data(pos=pos) data = RandomJitter([0.1, 1])(data) assert len(data) == 1 assert data.pos[:, 0].min() >= -0.1 assert data.pos[:, 0].max() <= 0.1 assert data.pos[:, 1].min() >= -1 assert data.pos[:, 1].max() <= 1 ================================================ FILE: test/transforms/test_random_link_split.py ================================================ import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.testing import ( get_random_edge_index, onlyFullTest, onlyOnline, ) from torch_geometric.transforms import RandomLinkSplit, ToSparseTensor from torch_geometric.utils import is_undirected, to_undirected def test_random_link_split(): assert str(RandomLinkSplit()) == ('RandomLinkSplit(' 'num_val=0.1, num_test=0.2)') edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]]) edge_attr = torch.randn(edge_index.size(1), 3) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=100) # No test split: transform = RandomLinkSplit(num_val=2, num_test=0, is_undirected=True) train_data, val_data, test_data = transform(data) assert len(train_data) == 5 assert train_data.num_nodes == 100 assert train_data.edge_index.size() == (2, 6) assert train_data.edge_attr.size() == (6, 3) assert train_data.edge_label_index.size(1) == 6 assert train_data.edge_label.size(0) == 6 assert len(val_data) == 5 assert val_data.num_nodes == 100 assert val_data.edge_index.size() == (2, 6) assert val_data.edge_attr.size() == (6, 3) assert val_data.edge_label_index.size(1) == 4 assert val_data.edge_label.size(0) == 4 assert len(test_data) == 5 assert test_data.num_nodes == 100 assert test_data.edge_index.size() == (2, 10) assert test_data.edge_attr.size() == (10, 3) assert test_data.edge_label_index.size() == (2, 0) assert test_data.edge_label.size() == (0, ) # Percentage split: transform = RandomLinkSplit(num_val=0.2, num_test=0.2, neg_sampling_ratio=2.0, is_undirected=False) train_data, val_data, test_data = transform(data) assert len(train_data) == 5 assert train_data.num_nodes == 100 assert train_data.edge_index.size() == (2, 6) assert train_data.edge_attr.size() == (6, 3) assert train_data.edge_label_index.size(1) == 18 assert train_data.edge_label.size(0) == 18 assert len(val_data) == 5 assert val_data.num_nodes == 100 assert val_data.edge_index.size() == (2, 6) assert val_data.edge_attr.size() == (6, 3) assert val_data.edge_label_index.size(1) == 6 assert val_data.edge_label.size(0) == 6 assert len(test_data) == 5 assert test_data.num_nodes == 100 assert test_data.edge_index.size() == (2, 8) assert test_data.edge_attr.size() == (8, 3) assert test_data.edge_label_index.size(1) == 6 assert test_data.edge_label.size(0) == 6 # Disjoint training split: transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=False, disjoint_train_ratio=0.5) train_data, val_data, test_data = transform(data) assert len(train_data) == 5 assert train_data.num_nodes == 100 assert train_data.edge_index.size() == (2, 3) assert train_data.edge_attr.size() == (3, 3) assert train_data.edge_label_index.size(1) == 6 assert train_data.edge_label.size(0) == 6 def test_random_link_split_with_to_sparse_tensor(): edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]]) data = Data(edge_index=edge_index, num_nodes=6) transform = RandomLinkSplit(num_val=2, num_test=2, neg_sampling_ratio=0.0) train_data1, _, _ = transform(data) assert train_data1.edge_index.size(1) == train_data1.edge_label.size(0) train_data2 = ToSparseTensor()(train_data1) assert train_data1.edge_label.equal(train_data2.edge_label) assert train_data1.edge_label_index.equal(train_data2.edge_label_index) def test_random_link_split_with_label(): edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]]) edge_label = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) data = Data(edge_index=edge_index, edge_label=edge_label, num_nodes=6) transform = RandomLinkSplit(num_val=0.2, num_test=0.2, neg_sampling_ratio=0.0) train_data, _, _ = transform(data) assert len(train_data) == 4 assert train_data.num_nodes == 6 assert train_data.edge_index.size() == (2, 6) assert train_data.edge_label_index.size() == (2, 6) assert train_data.edge_label.size() == (6, ) assert train_data.edge_label.min() == 0 assert train_data.edge_label.max() == 1 transform = RandomLinkSplit(num_val=0.2, num_test=0.2, neg_sampling_ratio=1.0) train_data, _, _ = transform(data) assert len(train_data) == 4 assert train_data.num_nodes == 6 assert train_data.edge_index.size() == (2, 6) assert train_data.edge_label_index.size() == (2, 12) assert train_data.edge_label.size() == (12, ) assert train_data.edge_label.min() == 0 assert train_data.edge_label.max() == 2 assert train_data.edge_label[6:].sum() == 0 def test_random_link_split_increment_label(): edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]]) edge_label = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) data = Data(edge_index=edge_index, edge_label=edge_label, num_nodes=6) transform = RandomLinkSplit(num_val=0, num_test=0, neg_sampling_ratio=0.0) train_data, _, _ = transform(data) assert train_data.edge_label.numel() == edge_index.size(1) assert train_data.edge_label.min() == 0 assert train_data.edge_label.max() == 1 transform = RandomLinkSplit(num_val=0, num_test=0, neg_sampling_ratio=1.0) train_data, _, _ = transform(data) assert train_data.edge_label.numel() == 2 * edge_index.size(1) assert train_data.edge_label.min() == 0 assert train_data.edge_label.max() == 2 assert train_data.edge_label[edge_index.size(1):].sum() == 0 def test_random_link_split_on_hetero_data(): data = HeteroData() data['p'].x = torch.arange(100) data['a'].x = torch.arange(100, 300) data['p', 'p'].edge_index = get_random_edge_index(100, 100, 500) data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index) data['p', 'p'].edge_attr = torch.arange(data['p', 'p'].num_edges) data['p', 'a'].edge_index = get_random_edge_index(100, 200, 1000) data['p', 'a'].edge_attr = torch.arange(500, 1500) data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0]) data['a', 'p'].edge_attr = torch.arange(1500, 2500) transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=True, edge_types=('p', 'p')) train_data, val_data, test_data = transform(data) assert len(train_data['p']) == 1 assert len(train_data['a']) == 1 assert len(train_data['p', 'p']) == 4 assert len(train_data['p', 'a']) == 2 assert len(train_data['a', 'p']) == 2 assert is_undirected(train_data['p', 'p'].edge_index, train_data['p', 'p'].edge_attr) assert is_undirected(val_data['p', 'p'].edge_index, val_data['p', 'p'].edge_attr) assert is_undirected(test_data['p', 'p'].edge_index, test_data['p', 'p'].edge_attr) transform = RandomLinkSplit(num_val=0.2, num_test=0.2, edge_types=('p', 'a'), rev_edge_types=('a', 'p')) train_data, val_data, test_data = transform(data) assert len(train_data['p']) == 1 assert len(train_data['a']) == 1 assert len(train_data['p', 'p']) == 2 assert len(train_data['p', 'a']) == 4 assert len(train_data['a', 'p']) == 2 assert train_data['p', 'a'].edge_index.size() == (2, 600) assert train_data['p', 'a'].edge_attr.size() == (600, ) assert train_data['p', 'a'].edge_attr.min() >= 500 assert train_data['p', 'a'].edge_attr.max() <= 1500 assert train_data['a', 'p'].edge_index.size() == (2, 600) assert train_data['a', 'p'].edge_attr.size() == (600, ) assert train_data['a', 'p'].edge_attr.min() >= 500 assert train_data['a', 'p'].edge_attr.max() <= 1500 assert train_data['p', 'a'].edge_label_index.size() == (2, 1200) assert train_data['p', 'a'].edge_label.size() == (1200, ) assert val_data['p', 'a'].edge_index.size() == (2, 600) assert val_data['p', 'a'].edge_attr.size() == (600, ) assert val_data['p', 'a'].edge_attr.min() >= 500 assert val_data['p', 'a'].edge_attr.max() <= 1500 assert val_data['a', 'p'].edge_index.size() == (2, 600) assert val_data['a', 'p'].edge_attr.size() == (600, ) assert val_data['a', 'p'].edge_attr.min() >= 500 assert val_data['a', 'p'].edge_attr.max() <= 1500 assert val_data['p', 'a'].edge_label_index.size() == (2, 400) assert val_data['p', 'a'].edge_label.size() == (400, ) assert test_data['p', 'a'].edge_index.size() == (2, 800) assert test_data['p', 'a'].edge_attr.size() == (800, ) assert test_data['p', 'a'].edge_attr.min() >= 500 assert test_data['p', 'a'].edge_attr.max() <= 1500 assert test_data['a', 'p'].edge_index.size() == (2, 800) assert test_data['a', 'p'].edge_attr.size() == (800, ) assert test_data['a', 'p'].edge_attr.min() >= 500 assert test_data['a', 'p'].edge_attr.max() <= 1500 assert test_data['p', 'a'].edge_label_index.size() == (2, 400) assert test_data['p', 'a'].edge_label.size() == (400, ) transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=True, edge_types=[('p', 'p'), ('p', 'a')], rev_edge_types=[None, ('a', 'p')]) train_data, val_data, test_data = transform(data) assert len(train_data['p']) == 1 assert len(train_data['a']) == 1 assert len(train_data['p', 'p']) == 4 assert len(train_data['p', 'a']) == 4 assert len(train_data['a', 'p']) == 2 assert is_undirected(train_data['p', 'p'].edge_index, train_data['p', 'p'].edge_attr) assert train_data['p', 'a'].edge_index.size() == (2, 600) assert train_data['a', 'p'].edge_index.size() == (2, 600) # No reverse edge types specified: transform = RandomLinkSplit(edge_types=[('p', 'p'), ('p', 'a')]) train_data, val_data, test_data = transform(data) assert train_data['p', 'p'].num_edges < data['p', 'p'].num_edges assert train_data['p', 'a'].num_edges < data['p', 'a'].num_edges assert train_data['a', 'p'].num_edges == data['a', 'p'].num_edges def test_random_link_split_on_undirected_hetero_data(): data = HeteroData() data['p'].x = torch.arange(100) data['p', 'p'].edge_index = get_random_edge_index(100, 100, 500) data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index) transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p')) train_data, val_data, test_data = transform(data) assert train_data['p', 'p'].is_undirected() transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'), rev_edge_types=('p', 'p')) train_data, val_data, test_data = transform(data) assert train_data['p', 'p'].is_undirected() transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'), rev_edge_types=('p', 'p')) train_data, val_data, test_data = transform(data) assert train_data['p', 'p'].is_undirected() def test_random_link_split_insufficient_negative_edges(): edge_index = torch.tensor([[0, 0, 1, 1, 2, 2], [1, 3, 0, 2, 0, 1]]) data = Data(edge_index=edge_index, num_nodes=4) transform = RandomLinkSplit(num_val=0.34, num_test=0.34, is_undirected=False, neg_sampling_ratio=2, split_labels=True) with pytest.warns(UserWarning, match="not enough negative edges"): train_data, val_data, test_data = transform(data) assert train_data.neg_edge_label_index.size() == (2, 2) assert val_data.neg_edge_label_index.size() == (2, 2) assert test_data.neg_edge_label_index.size() == (2, 2) def test_random_link_split_non_contiguous(): edge_index = get_random_edge_index(40, 40, num_edges=150) edge_index = edge_index[:, :100] assert not edge_index.is_contiguous() data = Data(edge_index=edge_index, num_nodes=40) transform = RandomLinkSplit(num_val=0.2, num_test=0.2) train_data, val_data, test_data = transform(data) assert train_data.num_edges == 60 assert train_data.edge_index.is_contiguous() data = HeteroData() data['p'].num_nodes = 40 data['p', 'p'].edge_index = edge_index transform = RandomLinkSplit(num_val=0.2, num_test=0.2, edge_types=('p', 'p')) train_data, val_data, test_data = transform(data) assert train_data['p', 'p'].num_edges == 60 assert train_data['p', 'p'].edge_index.is_contiguous() @onlyOnline @onlyFullTest def test_random_link_split_on_dataset(get_dataset): dataset = get_dataset(name='MUTAG') dataset.transform = RandomLinkSplit( num_val=0.1, num_test=0.1, disjoint_train_ratio=0.3, add_negative_train_samples=False, ) train_dataset, val_dataset, test_dataset = zip(*dataset) assert len(train_dataset) == len(dataset) assert len(val_dataset) == len(dataset) assert len(test_dataset) == len(dataset) assert isinstance(train_dataset[0], Data) assert train_dataset[0].edge_label.min() == 1.0 assert train_dataset[0].edge_label.max() == 1.0 assert isinstance(val_dataset[0], Data) assert val_dataset[0].edge_label.min() == 0.0 assert val_dataset[0].edge_label.max() == 1.0 assert isinstance(test_dataset[0], Data) assert test_dataset[0].edge_label.min() == 0.0 assert test_dataset[0].edge_label.max() == 1.0 ================================================ FILE: test/transforms/test_random_node_split.py ================================================ import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import RandomNodeSplit @pytest.mark.parametrize('num_splits', [1, 2]) def test_random_node_split(num_splits): num_nodes, num_classes = 1000, 4 x = torch.randn(num_nodes, 16) y = torch.randint(num_classes, (num_nodes, ), dtype=torch.long) data = Data(x=x, y=y) transform = RandomNodeSplit(split='train_rest', num_splits=num_splits, num_val=100, num_test=200) assert str(transform) == 'RandomNodeSplit(split=train_rest)' data = transform(data) assert len(data) == 5 train_mask = data.train_mask train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask assert train_mask.size() == (num_nodes, num_splits) val_mask = data.val_mask val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask assert val_mask.size() == (num_nodes, num_splits) test_mask = data.test_mask test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask assert test_mask.size() == (num_nodes, num_splits) for i in range(train_mask.size(-1)): assert train_mask[:, i].sum() == num_nodes - 100 - 200 assert val_mask[:, i].sum() == 100 assert test_mask[:, i].sum() == 200 assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 assert ((train_mask[:, i] | val_mask[:, i] | test_mask[:, i]).sum() == num_nodes) transform = RandomNodeSplit(split='train_rest', num_splits=num_splits, num_val=0.1, num_test=0.2) data = transform(data) train_mask = data.train_mask train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask val_mask = data.val_mask val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask test_mask = data.test_mask test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask for i in range(train_mask.size(-1)): assert train_mask[:, i].sum() == num_nodes - 100 - 200 assert val_mask[:, i].sum() == 100 assert test_mask[:, i].sum() == 200 assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 assert ((train_mask[:, i] | val_mask[:, i] | test_mask[:, i]).sum() == num_nodes) transform = RandomNodeSplit(split='test_rest', num_splits=num_splits, num_train_per_class=10, num_val=100) assert str(transform) == 'RandomNodeSplit(split=test_rest)' data = transform(data) assert len(data) == 5 train_mask = data.train_mask train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask val_mask = data.val_mask val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask test_mask = data.test_mask test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask for i in range(train_mask.size(-1)): assert train_mask[:, i].sum() == 10 * num_classes assert val_mask[:, i].sum() == 100 assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100 assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 assert ((train_mask[:, i] | val_mask[:, i] | test_mask[:, i]).sum() == num_nodes) transform = RandomNodeSplit(split='test_rest', num_splits=num_splits, num_train_per_class=10, num_val=0.1) data = transform(data) train_mask = data.train_mask train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask val_mask = data.val_mask val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask test_mask = data.test_mask test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask for i in range(train_mask.size(-1)): assert train_mask[:, i].sum() == 10 * num_classes assert val_mask[:, i].sum() == 100 assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100 assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 assert ((train_mask[:, i] | val_mask[:, i] | test_mask[:, i]).sum() == num_nodes) transform = RandomNodeSplit(split='random', num_splits=num_splits, num_train_per_class=10, num_val=100, num_test=200) assert str(transform) == 'RandomNodeSplit(split=random)' data = transform(data) assert len(data) == 5 train_mask = data.train_mask train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask val_mask = data.val_mask val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask test_mask = data.test_mask test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask for i in range(train_mask.size(-1)): assert train_mask[:, i].sum() == 10 * num_classes assert val_mask[:, i].sum() == 100 assert test_mask[:, i].sum() == 200 assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 assert ((train_mask[:, i] | val_mask[:, i] | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200) transform = RandomNodeSplit(split='random', num_splits=num_splits, num_train_per_class=10, num_val=0.1, num_test=0.2) assert str(transform) == 'RandomNodeSplit(split=random)' data = transform(data) train_mask = data.train_mask train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask val_mask = data.val_mask val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask test_mask = data.test_mask test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask for i in range(train_mask.size(-1)): assert train_mask[:, i].sum() == 10 * num_classes assert val_mask[:, i].sum() == 100 assert test_mask[:, i].sum() == 200 assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 assert ((train_mask[:, i] | val_mask[:, i] | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200) def test_random_node_split_on_hetero_data(): data = HeteroData() data['paper'].x = torch.randn(2000, 16) data['paper'].y = torch.randint(4, (2000, ), dtype=torch.long) data['author'].x = torch.randn(300, 16) transform = RandomNodeSplit() assert str(transform) == 'RandomNodeSplit(split=train_rest)' data = transform(data) assert len(data) == 5 assert len(data['author']) == 1 assert len(data['paper']) == 5 assert data['paper'].train_mask.sum() == 500 assert data['paper'].val_mask.sum() == 500 assert data['paper'].test_mask.sum() == 1000 ================================================ FILE: test/transforms/test_random_rotate.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import RandomRotate def test_random_rotate(): assert str(RandomRotate([-180, 180])) == ('RandomRotate(' '[-180, 180], axis=0)') pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) data = Data(pos=pos) data = RandomRotate(0)(data) assert len(data) == 1 assert data.pos.tolist() == pos.tolist() data = Data(pos=pos) data = RandomRotate([180, 180])(data) assert len(data) == 1 assert data.pos.tolist() == [[1, 1], [1, -1], [-1, 1], [-1, -1]] pos = torch.tensor([ [-1.0, -1.0, 1.0], [-1.0, 1.0, 1.0], [1.0, -1.0, -1.0], [1.0, 1.0, -1.0], ]) data = Data(pos=pos) data = RandomRotate([180, 180], axis=0)(data) assert len(data) == 1 assert data.pos.tolist() == [[-1, 1, -1], [-1, -1, -1], [1, 1, 1], [1, -1, 1]] data = Data(pos=pos) data = RandomRotate([180, 180], axis=1)(data) assert len(data) == 1 assert data.pos.tolist() == [[1, -1, -1], [1, 1, -1], [-1, -1, 1], [-1, 1, 1]] data = Data(pos=pos) data = RandomRotate([180, 180], axis=2)(data) assert len(data) == 1 assert data.pos.tolist() == [[1, 1, 1], [1, -1, 1], [-1, 1, -1], [-1, -1, -1]] ================================================ FILE: test/transforms/test_random_scale.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import RandomScale def test_random_scale(): assert str(RandomScale([1, 2])) == 'RandomScale([1, 2])' pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) data = Data(pos=pos) data = RandomScale([1, 1])(data) assert len(data) == 1 assert data.pos.tolist() == pos.tolist() data = Data(pos=pos) data = RandomScale([2, 2])(data) assert len(data) == 1 assert data.pos.tolist() == [[-2, -2], [-2, 2], [2, -2], [2, 2]] ================================================ FILE: test/transforms/test_random_shear.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import RandomShear def test_random_shear(): assert str(RandomShear(0.1)) == 'RandomShear(0.1)' pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) data = Data(pos=pos) data = RandomShear(0)(data) assert len(data) == 1 assert torch.allclose(data.pos, pos) data = Data(pos=pos) data = RandomShear(0.1)(data) assert len(data) == 1 assert not torch.allclose(data.pos, pos) ================================================ FILE: test/transforms/test_remove_duplicated_edges.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import RemoveDuplicatedEdges def test_remove_duplicated_edges(): edge_index = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 0, 0, 1, 1]]) edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=2) transform = RemoveDuplicatedEdges() assert str(transform) == 'RemoveDuplicatedEdges()' out = transform(data) assert len(out) == 3 assert out.num_nodes == 2 assert out.edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] assert out.edge_weight.tolist() == [3, 7, 11, 15] ================================================ FILE: test/transforms/test_remove_isolated_nodes.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import RemoveIsolatedNodes def test_remove_isolated_nodes(): assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()' data = Data() data.x = torch.arange(3) data.edge_index = torch.tensor([[0, 2], [2, 0]]) data.edge_attr = torch.arange(2) data = RemoveIsolatedNodes()(data) assert len(data) == 3 assert data.x.tolist() == [0, 2] assert data.edge_index.tolist() == [[0, 1], [1, 0]] assert data.edge_attr.tolist() == [0, 1] def test_remove_isolated_nodes_in_hetero_data(): data = HeteroData() data['p'].x = torch.arange(6) data['a'].x = torch.arange(6) data['i'].num_nodes = 4 # isolated paper nodes: {4} # isolated author nodes: {3, 4, 5} # isolated institution nodes: {0, 1, 2, 3} data['p', '1', 'p'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 3]]) data['p', '2', 'a'].edge_index = torch.tensor([[1, 3, 5], [0, 1, 2]]) data['p', '2', 'a'].edge_attr = torch.arange(3) data['p', '3', 'a'].edge_index = torch.tensor([[5], [2]]) data = RemoveIsolatedNodes()(data) assert len(data) == 4 assert data['p'].num_nodes == 5 assert data['a'].num_nodes == 3 assert data['i'].num_nodes == 0 assert data['p'].x.tolist() == [0, 1, 2, 3, 5] assert data['a'].x.tolist() == [0, 1, 2] assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]] assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]] assert data['2'].edge_attr.tolist() == [0, 1, 2] assert data['3'].edge_index.tolist() == [[4], [2]] ================================================ FILE: test/transforms/test_remove_self_loops.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import RemoveSelfLoops def test_remove_self_loops(): assert str(RemoveSelfLoops()) == 'RemoveSelfLoops()' assert len(RemoveSelfLoops()(Data())) == 0 edge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) edge_weight = torch.tensor([1, 2, 3, 4]) edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) data = Data(edge_index=edge_index, num_nodes=3) data = RemoveSelfLoops()(data) assert len(data) == 2 assert data.edge_index.tolist() == [[1, 2], [0, 1]] assert data.num_nodes == 3 data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) data = RemoveSelfLoops(attr='edge_weight')(data) assert data.edge_index.tolist() == [[1, 2], [0, 1]] assert data.num_nodes == 3 assert data.edge_weight.tolist() == [2, 4] data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) data = RemoveSelfLoops(attr='edge_attr')(data) assert data.edge_index.tolist() == [[1, 2], [0, 1]] assert data.num_nodes == 3 assert data.edge_attr.tolist() == [[3, 4], [7, 8]] def test_hetero_remove_self_loops(): edge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) data = HeteroData() data['v'].num_nodes = 3 data['w'].num_nodes = 3 data['v', 'v'].edge_index = edge_index data['v', 'w'].edge_index = edge_index data = RemoveSelfLoops()(data) assert data['v', 'v'].edge_index.tolist() == [[1, 2], [0, 1]] assert data['v', 'w'].edge_index.tolist() == edge_index.tolist() ================================================ FILE: test/transforms/test_remove_training_classes.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import RemoveTrainingClasses def test_remove_training_classes(): y = torch.tensor([1, 0, 0, 2, 1, 3]) train_mask = torch.tensor([False, False, True, True, True, True]) data = Data(y=y, train_mask=train_mask) transform = RemoveTrainingClasses(classes=[0, 1]) assert str(transform) == 'RemoveTrainingClasses([0, 1])' data = transform(data) assert len(data) == 2 assert torch.equal(data.y, y) assert data.train_mask.tolist() == [False, False, False, True, False, True] ================================================ FILE: test/transforms/test_rooted_subgraph.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.testing import withPackage from torch_geometric.transforms import RootedEgoNets, RootedRWSubgraph def test_rooted_ego_nets(): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_attr = torch.randn(4, 8) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) transform = RootedEgoNets(num_hops=1) assert str(transform) == 'RootedEgoNets(num_hops=1)' out = transform(data) assert len(out) == 8 assert torch.equal(out.x, data.x) assert torch.equal(out.edge_index, data.edge_index) assert torch.equal(out.edge_attr, data.edge_attr) assert out.sub_edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6], [1, 0, 3, 2, 4, 3, 6, 5]] assert out.n_id.tolist() == [0, 1, 0, 1, 2, 1, 2] assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2] assert out.e_id.tolist() == [0, 1, 0, 1, 2, 3, 2, 3] assert out.e_sub_batch.tolist() == [0, 0, 1, 1, 1, 1, 2, 2] out = out.map_data() assert len(out) == 4 assert torch.allclose(out.x, x[[0, 1, 0, 1, 2, 1, 2]]) assert out.edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6], [1, 0, 3, 2, 4, 3, 6, 5]] assert torch.allclose(out.edge_attr, edge_attr[[0, 1, 0, 1, 2, 3, 2, 3]]) assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2] @withPackage('torch_cluster') def test_rooted_rw_subgraph(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data = Data(edge_index=edge_index, num_nodes=3) transform = RootedRWSubgraph(walk_length=1) assert str(transform) == 'RootedRWSubgraph(walk_length=1)' out = transform(data) assert len(out) == 7 assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2] assert out.sub_edge_index.size() == (2, 6) out = out.map_data() assert len(out) == 3 assert out.edge_index.size() == (2, 6) assert out.num_nodes == 6 assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2] def test_rooted_subgraph_minibatch(): x = torch.randn(3, 8) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_attr = torch.randn(4, 8) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) transform = RootedEgoNets(num_hops=1) data = transform(data) loader = DataLoader([data, data], batch_size=2) batch = next(iter(loader)) batch = batch.map_data() assert batch.num_graphs == len(batch) == 2 assert batch.x.size() == (14, 8) assert batch.edge_index.size() == (2, 16) assert batch.edge_attr.size() == (16, 8) assert batch.n_sub_batch.size() == (14, ) assert batch.batch.size() == (14, ) assert batch.ptr.size() == (3, ) assert batch.edge_index.min() == 0 assert batch.edge_index.max() == 13 assert batch.n_sub_batch.min() == 0 assert batch.n_sub_batch.max() == 5 ================================================ FILE: test/transforms/test_sample_points.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import SamplePoints def test_sample_points(): assert str(SamplePoints(1024)) == 'SamplePoints(1024)' pos = torch.tensor([ [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0], ]) face = torch.tensor([[0, 1], [1, 2], [2, 3]]) data = Data(pos=pos) data.face = face data = SamplePoints(8)(data) assert len(data) == 1 assert pos[:, 0].min() >= 0 and pos[:, 0].max() <= 1 assert pos[:, 1].min() >= 0 and pos[:, 1].max() <= 1 assert pos[:, 2].abs().sum() == 0 data = Data(pos=pos) data.face = face data = SamplePoints(8, include_normals=True)(data) assert len(data) == 2 assert data.normal[:, :2].abs().sum() == 0 assert data.normal[:, 2].abs().sum() == 8 ================================================ FILE: test/transforms/test_sign.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import SIGN def test_sign(): x = torch.ones(5, 3) edge_index = torch.tensor([ [0, 1, 2, 3, 3, 4], [1, 0, 3, 2, 4, 3], ]) data = Data(x=x, edge_index=edge_index) transform = SIGN(K=2) assert str(transform) == 'SIGN(K=2)' expected_x1 = torch.tensor([ [1, 1, 1], [1, 1, 1], [0.7071, 0.7071, 0.7071], [1.4142, 1.4142, 1.4142], [0.7071, 0.7071, 0.7071], ]) expected_x2 = torch.ones(5, 3) out = transform(data) assert len(out) == 4 assert torch.equal(out.edge_index, edge_index) assert torch.allclose(out.x, x) assert torch.allclose(out.x1, expected_x1, atol=1e-4) assert torch.allclose(out.x2, expected_x2) ================================================ FILE: test/transforms/test_spherical.py ================================================ from math import pi as PI import torch from torch_geometric.data import Data from torch_geometric.transforms import Spherical def test_spherical(): assert str(Spherical()) == 'Spherical(norm=True, max_value=None)' pos = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) edge_index = torch.tensor([[0, 1], [1, 0]]) edge_attr = torch.tensor([1.0, 1.0]) data = Data(edge_index=edge_index, pos=pos) data = Spherical(norm=False)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 0.0, PI / 2.0], [1.0, PI, PI / 2.0]]), atol=1e-4, ) data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr) data = Spherical(norm=True)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 1.0, 0.0, 0.5], [1.0, 1.0, 0.5, 0.5]]), atol=1e-4, ) pos = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) edge_index = torch.tensor([[0, 1], [1, 0]]) data = Data(edge_index=edge_index, pos=pos) data = Spherical(norm=False)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 0.0, 0.0], [1.0, 0.0, PI]]), atol=1e-4, ) data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr) data = Spherical(norm=True)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 1.0]]), atol=1e-4, ) ================================================ FILE: test/transforms/test_svd_feature_reduction.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import SVDFeatureReduction def test_svd_feature_reduction(): assert str(SVDFeatureReduction(10)) == 'SVDFeatureReduction(10)' x = torch.randn(4, 16) U, S, _ = torch.linalg.svd(x) data = Data(x=x) data = SVDFeatureReduction(10)(data) assert torch.allclose(data.x, torch.mm(U[:, :10], torch.diag(S[:10]))) x = torch.randn(4, 8) data.x = x data = SVDFeatureReduction(10)(Data(x=x)) assert torch.allclose(data.x, x) ================================================ FILE: test/transforms/test_target_indegree.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import TargetIndegree def test_target_indegree(): assert str(TargetIndegree()) == 'TargetIndegree(norm=True, max_value=None)' edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_attr = torch.tensor([1.0, 1.0, 1.0, 1.0]) data = Data(edge_index=edge_index, num_nodes=3) data = TargetIndegree(norm=False)(data) assert len(data) == 3 assert data.edge_index.tolist() == edge_index.tolist() assert data.edge_attr.tolist() == [[2], [1], [1], [2]] assert data.num_nodes == 3 data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) data = TargetIndegree(norm=True)(data) assert len(data) == 3 assert data.edge_index.tolist() == edge_index.tolist() assert data.edge_attr.tolist() == [[1, 1], [1, 0.5], [1, 0.5], [1, 1]] assert data.num_nodes == 3 ================================================ FILE: test/transforms/test_to_dense.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import ToDense def test_to_dense(): edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, 4)) pos = torch.randn((num_nodes, 3)) y = torch.randint(0, 4, (num_nodes, ), dtype=torch.long) transform = ToDense() assert str(transform) == 'ToDense()' data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, y=y) data = transform(data) assert len(data) == 5 assert data.x.tolist() == x.tolist() assert data.pos.tolist() == pos.tolist() assert data.y.tolist() == y.tolist() assert data.adj.size() == (num_nodes, num_nodes) assert data.adj.tolist() == [ [0, 1, 2, 3], [4, 0, 0, 0], [5, 0, 0, 0], [6, 0, 0, 0], ] assert data.mask.tolist() == [1, 1, 1, 1] transform = ToDense(num_nodes=5) assert str(transform) == 'ToDense(num_nodes=5)' data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, y=y) data = transform(data) assert len(data) == 5 assert data.x.size() == (5, 4) assert data.x[:4].tolist() == x.tolist() assert data.x[4].tolist() == [0, 0, 0, 0] assert data.pos.size() == (5, 3) assert data.pos[:4].tolist() == pos.tolist() assert data.pos[4].tolist() == [0, 0, 0] assert data.y.size() == (5, ) assert data.y[:4].tolist() == y.tolist() assert data.y[4].tolist() == 0 assert data.adj.size() == (5, 5) assert data.adj.tolist() == [ [0, 1, 2, 3, 0], [4, 0, 0, 0, 0], [5, 0, 0, 0, 0], [6, 0, 0, 0, 0], [0, 0, 0, 0, 0], ] assert data.mask.tolist() == [1, 1, 1, 1, 0] ================================================ FILE: test/transforms/test_to_device.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.testing import withDevice from torch_geometric.transforms import ToDevice @withDevice def test_to_device(device): x = torch.randn(3, 4) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.randn(edge_index.size(1)) data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight) transform = ToDevice(device) assert str(transform) == f'ToDevice({device})' data = transform(data) for _, value in data: assert value.device == device ================================================ FILE: test/transforms/test_to_sparse_tensor.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import ToSparseTensor @pytest.mark.parametrize('layout', [None, torch.sparse_coo, torch.sparse_csr]) def test_to_sparse_tensor_basic(layout): transform = ToSparseTensor(layout=layout) assert str(transform) == (f'ToSparseTensor(attr=edge_weight, ' f'layout={layout})') edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.randn(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 8) perm = torch.tensor([1, 0, 3, 2]) data = Data(edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, num_nodes=3) data = transform(data) assert len(data) == 3 assert data.num_nodes == 3 assert torch.equal(data.edge_attr, edge_attr[perm]) assert 'adj_t' in data if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE: row, col, value = data.adj_t.coo() assert row.tolist() == [0, 1, 1, 2] assert col.tolist() == [1, 0, 2, 1] assert torch.equal(value, edge_weight[perm]) else: adj_t = data.adj_t assert adj_t.layout == layout or torch.sparse_csr if layout != torch.sparse_coo: adj_t = adj_t.to_sparse_coo() assert adj_t.coalesce().indices().tolist() == [ [0, 1, 1, 2], [1, 0, 2, 1], ] assert torch.equal(adj_t.coalesce().values(), edge_weight[perm]) def test_to_sparse_tensor_and_keep_edge_index(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.randn(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 8) perm = torch.tensor([1, 0, 3, 2]) data = Data(edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, num_nodes=3) data = ToSparseTensor(remove_edge_index=False)(data) assert len(data) == 5 assert torch.equal(data.edge_index, edge_index[:, perm]) assert torch.equal(data.edge_weight, edge_weight[perm]) assert torch.equal(data.edge_attr, edge_attr[perm]) @pytest.mark.parametrize('layout', [None, torch.sparse_coo, torch.sparse_csr]) def test_hetero_to_sparse_tensor(layout): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) data = HeteroData() data['v'].num_nodes = 3 data['w'].num_nodes = 3 data['v', 'v'].edge_index = edge_index data['v', 'w'].edge_index = edge_index data = ToSparseTensor(layout=layout)(data) if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE: row, col, value = data['v', 'v'].adj_t.coo() assert row.tolist() == [0, 1, 1, 2] assert col.tolist() == [1, 0, 2, 1] assert value is None row, col, value = data['v', 'w'].adj_t.coo() assert row.tolist() == [0, 1, 1, 2] assert col.tolist() == [1, 0, 2, 1] assert value is None else: adj_t = data['v', 'v'].adj_t assert adj_t.layout == layout or torch.sparse_csr if layout != torch.sparse_coo: adj_t = adj_t.to_sparse_coo() assert adj_t.coalesce().indices().tolist() == [ [0, 1, 1, 2], [1, 0, 2, 1], ] assert adj_t.coalesce().values().tolist() == [1., 1., 1., 1.] adj_t = data['v', 'w'].adj_t assert adj_t.layout == layout or torch.sparse_csr if layout != torch.sparse_coo: adj_t = adj_t.to_sparse_coo() assert adj_t.coalesce().indices().tolist() == [ [0, 1, 1, 2], [1, 0, 2, 1], ] assert adj_t.coalesce().values().tolist() == [1., 1., 1., 1.] def test_to_sparse_tensor_num_nodes_equals_num_edges(): x = torch.arange(4) y = torch.arange(4) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.randn(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 8) perm = torch.tensor([1, 0, 3, 2]) data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, y=y) data = ToSparseTensor()(data) assert len(data) == 4 assert torch.equal(data.x, x) assert torch.equal(data.y, y) assert torch.equal(data.edge_attr, edge_attr[perm]) ================================================ FILE: test/transforms/test_to_superpixels.py ================================================ import os import os.path as osp import torch from torch_geometric.data import download_url, extract_gz from torch_geometric.loader import DataLoader from torch_geometric.testing import onlyOnline, withPackage from torch_geometric.transforms import ToSLIC resources = [ 'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz', 'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz', ] @onlyOnline @withPackage('torchvision', 'skimage') def test_to_superpixels(tmp_path): import torchvision.transforms as T from torchvision.datasets.mnist import ( MNIST, read_image_file, read_label_file, ) raw_folder = osp.join(tmp_path, 'MNIST', 'raw') processed_folder = osp.join(tmp_path, 'MNIST', 'processed') os.makedirs(raw_folder, exist_ok=True) os.makedirs(processed_folder, exist_ok=True) for resource in resources: path = download_url(resource, raw_folder) extract_gz(path, osp.join(tmp_path, raw_folder)) test_set = ( read_image_file(osp.join(raw_folder, 't10k-images-idx3-ubyte')), read_label_file(osp.join(raw_folder, 't10k-labels-idx1-ubyte')), ) torch.save(test_set, osp.join(processed_folder, 'training.pt')) torch.save(test_set, osp.join(processed_folder, 'test.pt')) dataset = MNIST(tmp_path, download=False) dataset.transform = T.Compose([T.ToTensor(), ToSLIC()]) data, y = dataset[0] assert len(data) == 2 assert data.pos.dim() == 2 and data.pos.size(1) == 2 assert data.x.dim() == 2 and data.x.size(1) == 1 assert data.pos.size(0) == data.x.size(0) assert y == 7 loader = DataLoader(dataset, batch_size=2, shuffle=False) for batch, y in loader: assert batch.num_graphs == len(batch) == 2 assert batch.pos.dim() == 2 and batch.pos.size(1) == 2 assert batch.x.dim() == 2 and batch.x.size(1) == 1 assert batch.batch.dim() == 1 assert batch.ptr.dim() == 1 assert batch.pos.size(0) == batch.x.size(0) == batch.batch.size(0) assert y.tolist() == [7, 2] break dataset.transform = T.Compose( [T.ToTensor(), ToSLIC(add_seg=True, add_img=True)]) data, y = dataset[0] assert len(data) == 4 assert data.pos.dim() == 2 and data.pos.size(1) == 2 assert data.x.dim() == 2 and data.x.size(1) == 1 assert data.pos.size(0) == data.x.size(0) assert data.seg.size() == (1, 28, 28) assert data.img.size() == (1, 1, 28, 28) assert data.seg.max().item() + 1 == data.x.size(0) assert y == 7 loader = DataLoader(dataset, batch_size=2, shuffle=False) for batch, y in loader: assert batch.num_graphs == len(batch) == 2 assert batch.pos.dim() == 2 and batch.pos.size(1) == 2 assert batch.x.dim() == 2 and batch.x.size(1) == 1 assert batch.batch.dim() == 1 assert batch.ptr.dim() == 1 assert batch.pos.size(0) == batch.x.size(0) == batch.batch.size(0) assert batch.seg.size() == (2, 28, 28) assert batch.img.size() == (2, 1, 28, 28) assert y.tolist() == [7, 2] break ================================================ FILE: test/transforms/test_to_undirected.py ================================================ import torch from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import ToUndirected def test_to_undirected(): assert str(ToUndirected()) == 'ToUndirected()' edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]]) edge_weight = torch.randn(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 8) perm = torch.tensor([1, 2, 1, 2, 0, 0]) data = Data(edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, num_nodes=4) data = ToUndirected()(data) assert len(data) == 4 assert data.edge_index.tolist() == [[0, 0, 1, 2, 2, 3], [1, 2, 0, 0, 3, 2]] assert data.edge_weight.tolist() == edge_weight[perm].tolist() assert data.edge_attr.tolist() == edge_attr[perm].tolist() assert data.num_nodes == 4 def test_to_undirected_with_duplicates(): edge_index = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 2]]) edge_weight = torch.ones(4) data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) data = ToUndirected()(data) assert len(data) == 3 assert data.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 2, 1]] assert data.edge_weight.tolist() == [2, 2, 2, 1, 1] assert data.num_nodes == 3 def test_hetero_to_undirected(): edge_index = torch.tensor([[2, 0], [3, 1]]) edge_weight = torch.randn(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 8) perm = torch.tensor([1, 1, 0, 0]) data = HeteroData() data['v'].num_nodes = 4 data['w'].num_nodes = 4 data['v', 'v'].edge_index = edge_index data['v', 'v'].edge_weight = edge_weight data['v', 'v'].edge_attr = edge_attr data['v', 'w'].edge_index = edge_index data['v', 'w'].edge_weight = edge_weight data['v', 'w'].edge_attr = edge_attr from torch_geometric.transforms import ToUndirected assert not data.is_undirected() data = ToUndirected()(data) assert data.is_undirected() assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 3], [1, 0, 3, 2]] assert data['v', 'v'].edge_weight.tolist() == edge_weight[perm].tolist() assert data['v', 'v'].edge_attr.tolist() == edge_attr[perm].tolist() assert data['v', 'w'].edge_index.tolist() == edge_index.tolist() assert data['v', 'w'].edge_weight.tolist() == edge_weight.tolist() assert data['v', 'w'].edge_attr.tolist() == edge_attr.tolist() assert data['w', 'v'].edge_index.tolist() == [[3, 1], [2, 0]] assert data['w', 'v'].edge_weight.tolist() == edge_weight.tolist() assert data['w', 'v'].edge_attr.tolist() == edge_attr.tolist() ================================================ FILE: test/transforms/test_two_hop.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import TwoHop def test_two_hop(): transform = TwoHop() assert str(transform) == 'TwoHop()' edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.tensor([1, 2, 3, 1, 2, 3], dtype=torch.float) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=4) data = transform(data) assert len(data) == 3 assert data.edge_index.equal( torch.tensor([ [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2], ])) assert data.edge_attr.equal( torch.tensor([1, 2, 3, 1, 0, 0, 2, 0, 0, 3, 0, 0])) assert data.num_nodes == 4 data = Data(edge_index=edge_index, num_nodes=4) data = transform(data) assert len(data) == 2 assert data.edge_index.equal( torch.tensor([ [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2], ])) assert data.num_nodes == 4 ================================================ FILE: test/transforms/test_virtual_node.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.transforms import VirtualNode def test_virtual_node(): assert str(VirtualNode()) == 'VirtualNode()' x = torch.randn(4, 16) edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]]) edge_weight = torch.rand(edge_index.size(1)) edge_attr = torch.randn(edge_index.size(1), 8) data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, num_nodes=x.size(0)) data = VirtualNode()(data) assert len(data) == 6 assert data.x.size() == (5, 16) assert torch.allclose(data.x[:4], x) assert data.x[4:].abs().sum() == 0 assert data.edge_index.tolist() == [[2, 0, 2, 0, 1, 2, 3, 4, 4, 4, 4], [3, 1, 0, 4, 4, 4, 4, 0, 1, 2, 3]] assert data.edge_weight.size() == (11, ) assert torch.allclose(data.edge_weight[:3], edge_weight) assert data.edge_weight[3:].abs().sum() == 8 assert data.edge_attr.size() == (11, 8) assert torch.allclose(data.edge_attr[:3], edge_attr) assert data.edge_attr[3:].abs().sum() == 0 assert data.num_nodes == 5 assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] data = Data(x=x, edge_index=torch.empty(2, 0, dtype=torch.long)) data = VirtualNode()(data) assert len(data) == 3 assert data.x.size() == (5, 16) assert torch.allclose(data.x[:4], x) assert data.x[4:].abs().sum() == 0 assert data.edge_index.tolist() == [ [0, 1, 2, 3, 4, 4, 4, 4], [4, 4, 4, 4, 0, 1, 2, 3], ] assert data.edge_type.tolist() == [1, 1, 1, 1, 2, 2, 2, 2] ================================================ FILE: test/utils/conftest.py ================================================ import pytest import torch from torch_geometric.data import HeteroData from torch_geometric.explain.config import ModelMode, ModelReturnType from torch_geometric.nn import SAGEConv, to_hetero from torch_geometric.testing import get_random_edge_index class GraphSAGE(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv((-1, -1), 32) self.conv2 = SAGEConv((-1, -1), 32) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) class HeteroSAGE(torch.nn.Module): def __init__(self, metadata, model_config=None): super().__init__() self.model_config = model_config self.graph_sage = to_hetero(GraphSAGE(), metadata, debug=False) # Determine output channels based on model_config out_channels = 1 if (model_config and model_config.mode == ModelMode.multiclass_classification): out_channels = 7 self.lin = torch.nn.Linear(32, out_channels) def forward(self, x_dict, edge_index_dict, additonal_arg=None) -> torch.Tensor: x = self.lin(self.graph_sage(x_dict, edge_index_dict)['paper']) # Apply transformations based on model_config if available if hasattr(self, 'model_config') and self.model_config: if self.model_config.mode == ModelMode.binary_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.sigmoid() elif self.model_config.mode == ModelMode.multiclass_classification: if self.model_config.return_type == ModelReturnType.probs: x = x.softmax(dim=-1) elif (self.model_config.return_type == ModelReturnType.log_probs): x = x.log_softmax(dim=-1) return x @pytest.fixture() def hetero_data(): data = HeteroData() data['paper'].x = torch.randn(8, 16) data['author'].x = torch.randn(10, 8) data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10) data['paper', 'paper'].edge_attr = torch.randn(10, 16) data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10) data['paper', 'author'].edge_attr = torch.randn(10, 8) data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10) data['author', 'paper'].edge_attr = torch.randn(10, 8) return data @pytest.fixture() def hetero_model(): return HeteroSAGE ================================================ FILE: test/utils/test_assortativity.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.typing import SparseTensor from torch_geometric.utils import assortativity def test_assortativity(): # Completely assortative graph: edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) out = assortativity(edge_index) assert pytest.approx(out, abs=1e-5) == 1.0 if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[6, 6]) out = assortativity(adj) assert pytest.approx(out, abs=1e-5) == 1.0 # Completely disassortative graph: edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 5, 5, 5, 5], [5, 5, 5, 5, 5, 0, 1, 2, 3, 4]]) out = assortativity(edge_index) assert pytest.approx(out, abs=1e-5) == -1.0 if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[6, 6]) out = assortativity(adj) assert pytest.approx(out, abs=1e-5) == -1.0 ================================================ FILE: test/utils/test_augmentation.py ================================================ import pytest import torch from torch_geometric import seed_everything from torch_geometric.utils import ( add_random_edge, is_undirected, mask_feature, shuffle_node, ) def test_shuffle_node(): x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.float) out = shuffle_node(x, training=False) assert out[0].tolist() == x.tolist() assert out[1].tolist() == list(range(len(x))) torch.manual_seed(5) out = shuffle_node(x) assert out[0].tolist() == [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]] assert out[1].tolist() == [1, 0] torch.manual_seed(10) x = torch.arange(21).view(7, 3).to(torch.float) batch = torch.tensor([0, 0, 1, 1, 2, 2, 2]) out = shuffle_node(x, batch) assert out[0].tolist() == [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [9.0, 10.0, 11.0], [6.0, 7.0, 8.0], [12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [15.0, 16.0, 17.0]] assert out[1].tolist() == [1, 0, 3, 2, 4, 6, 5] def test_mask_feature(): x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float) out = mask_feature(x, training=False) assert out[0].tolist() == x.tolist() assert torch.all(out[1]) torch.manual_seed(4) out = mask_feature(x) assert out[0].tolist() == [[1.0, 2.0, 0.0, 0.0], [5.0, 6.0, 0.0, 0.0], [9.0, 10.0, 0.0, 0.0]] assert out[1].tolist() == [[True, True, False, False]] torch.manual_seed(5) out = mask_feature(x, mode='row') assert out[0].tolist() == [[1.0, 2.0, 3.0, 4.0], [0.0, 0.0, 0.0, 0.0], [9.0, 10.0, 11.0, 12.0]] assert out[1].tolist() == [[True], [False], [True]] torch.manual_seed(7) out = mask_feature(x, mode='all') assert out[0].tolist() == [[1.0, 0.0, 3.0, 4.0], [0.0, 0.0, 0.0, 8.0], [0.0, 10.0, 11.0, 12.0]] assert out[1].tolist() == [[True, False, True, True], [False, False, False, True], [False, True, True, True]] torch.manual_seed(7) out = mask_feature(x, mode='all', fill_value=-1) assert out[0].tolist() == [[1.0, -1.0, 3.0, 4.0], [-1.0, -1.0, -1.0, 8.0], [-1.0, 10.0, 11.0, 12.0]] assert out[1].tolist() == [[True, False, True, True], [False, False, False, True], [False, True, True, True]] def test_add_random_edge(): edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) out = add_random_edge(edge_index, p=0.5, training=False) assert out[0].tolist() == edge_index.tolist() assert out[1].tolist() == [[], []] seed_everything(5) out = add_random_edge(edge_index, p=0.5) assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 3, 1, 2], [1, 0, 2, 1, 3, 2, 0, 3, 0]] assert out[1].tolist() == [[3, 1, 2], [0, 3, 0]] seed_everything(6) out = add_random_edge(edge_index, p=0.5, force_undirected=True) assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 1, 3], [1, 0, 2, 1, 3, 2, 3, 1]] assert out[1].tolist() == [[1, 3], [3, 1]] assert is_undirected(out[0]) assert is_undirected(out[1]) # Test for bipartite graph: seed_everything(7) edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [2, 3, 1, 4, 2, 1]]) with pytest.raises(RuntimeError, match="not supported for bipartite"): add_random_edge(edge_index, force_undirected=True, num_nodes=(6, 5)) out = add_random_edge(edge_index, p=0.5, num_nodes=(6, 5)) assert out[0].tolist() == [[0, 1, 2, 3, 4, 5, 2, 0, 2], [2, 3, 1, 4, 2, 1, 0, 4, 2]] assert out[1].tolist() == [[2, 0, 2], [0, 4, 2]] ================================================ FILE: test/utils/test_coalesce.py ================================================ from typing import List, Optional, Tuple import torch from torch import Tensor from torch_geometric.utils import coalesce def test_coalesce(): edge_index = torch.tensor([[2, 1, 1, 0, 2], [1, 2, 0, 1, 1]]) edge_attr = torch.tensor([[1], [2], [3], [4], [5]]) out = coalesce(edge_index) assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] out = coalesce(edge_index, None) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1] is None out = coalesce(edge_index, edge_attr) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1].tolist() == [[4], [3], [2], [6]] out = coalesce(edge_index, [edge_attr, edge_attr.view(-1)]) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1][0].tolist() == [[4], [3], [2], [6]] assert out[1][1].tolist() == [4, 3, 2, 6] out = coalesce((edge_index[0], edge_index[1])) assert isinstance(out, tuple) assert out[0].tolist() == [0, 1, 1, 2] assert out[1].tolist() == [1, 0, 2, 1] def test_coalesce_without_duplicates(): edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) edge_attr = torch.tensor([[1], [2], [3], [4]]) out = coalesce(edge_index) assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] out = coalesce(edge_index, None) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1] is None out = coalesce(edge_index, edge_attr) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1].tolist() == [[4], [3], [2], [1]] out = coalesce(edge_index, [edge_attr, edge_attr.view(-1)]) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1][0].tolist() == [[4], [3], [2], [1]] assert out[1][1].tolist() == [4, 3, 2, 1] def test_coalesce_jit(): @torch.jit.script def wrapper1(edge_index: Tensor) -> Tensor: return coalesce(edge_index) @torch.jit.script def wrapper2( edge_index: Tensor, edge_attr: Optional[Tensor], ) -> Tuple[Tensor, Optional[Tensor]]: return coalesce(edge_index, edge_attr) @torch.jit.script def wrapper3( edge_index: Tensor, edge_attr: List[Tensor], ) -> Tuple[Tensor, List[Tensor]]: return coalesce(edge_index, edge_attr) edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) edge_attr = torch.tensor([[1], [2], [3], [4]]) out = wrapper1(edge_index) assert out.size() == edge_index.size() out = wrapper2(edge_index, None) assert out[0].size() == edge_index.size() assert out[1] is None out = wrapper2(edge_index, edge_attr) assert out[0].size() == edge_index.size() assert out[1].size() == edge_attr.size() out = wrapper3(edge_index, [edge_attr, edge_attr.view(-1)]) assert out[0].size() == edge_index.size() assert len(out[1]) == 2 assert out[1][0].size() == edge_attr.size() assert out[1][1].size() == edge_attr.view(-1).size() ================================================ FILE: test/utils/test_convert.py ================================================ import pytest import torch from torch_geometric.data import Data, HeteroData from torch_geometric.testing import get_random_edge_index, withPackage from torch_geometric.utils import ( from_cugraph, from_dgl, from_networkit, from_networkx, from_scipy_sparse_matrix, from_trimesh, sort_edge_index, subgraph, to_cugraph, to_dgl, to_networkit, to_networkx, to_scipy_sparse_matrix, to_trimesh, ) @withPackage('scipy') def test_to_scipy_sparse_matrix(): import scipy.sparse as sp edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) adj = to_scipy_sparse_matrix(edge_index) assert isinstance(adj, sp.coo_matrix) assert adj.shape == (2, 2) assert adj.row.tolist() == edge_index[0].tolist() assert adj.col.tolist() == edge_index[1].tolist() assert adj.data.tolist() == [1, 1, 1] edge_attr = torch.tensor([1.0, 2.0, 3.0]) adj = to_scipy_sparse_matrix(edge_index, edge_attr) assert isinstance(adj, sp.coo_matrix) assert adj.shape == (2, 2) assert adj.row.tolist() == edge_index[0].tolist() assert adj.col.tolist() == edge_index[1].tolist() assert adj.data.tolist() == edge_attr.tolist() @withPackage('scipy') def test_from_scipy_sparse_matrix(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) adj = to_scipy_sparse_matrix(edge_index) out = from_scipy_sparse_matrix(adj) assert out[0].tolist() == edge_index.tolist() assert out[1].tolist() == [1, 1, 1] @withPackage('networkx') def test_to_networkx(): import networkx as nx x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) pos = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_attr = torch.tensor([1.0, 2.0, 3.0]) data = Data(x=x, pos=pos, edge_index=edge_index, weight=edge_attr) for remove_self_loops in [True, False]: G = to_networkx(data, node_attrs=['x', 'pos'], edge_attrs=['weight'], remove_self_loops=remove_self_loops) assert G.nodes[0]['x'] == [1.0, 2.0] assert G.nodes[1]['x'] == [3.0, 4.0] assert G.nodes[0]['pos'] == [0.0, 0.0] assert G.nodes[1]['pos'] == [1.0, 1.0] if remove_self_loops: assert nx.to_numpy_array(G).tolist() == [[0.0, 1.0], [2.0, 0.0]] else: assert nx.to_numpy_array(G).tolist() == [[3.0, 1.0], [2.0, 0.0]] @withPackage('networkx') def test_from_networkx_set_node_attributes(): import networkx as nx G = nx.path_graph(3) attrs = { 0: { 'x': torch.tensor([1, 0, 0]) }, 1: { 'x': torch.tensor([0, 1, 0]) }, 2: { 'x': torch.tensor([0, 0, 1]) }, } nx.set_node_attributes(G, attrs) assert from_networkx(G).x.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]] @withPackage('networkx') def test_to_networkx_undirected(): import networkx as nx x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) pos = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_attr = torch.tensor([1.0, 2.0, 3.0]) data = Data(x=x, pos=pos, edge_index=edge_index, weight=edge_attr) for remove_self_loops in [True, False]: G = to_networkx( data, node_attrs=['x', 'pos'], edge_attrs=['weight'], remove_self_loops=remove_self_loops, to_undirected=True, ) assert G.nodes[0]['x'] == [1, 2] assert G.nodes[1]['x'] == [3, 4] assert G.nodes[0]['pos'] == [0, 0] assert G.nodes[1]['pos'] == [1, 1] if remove_self_loops: assert nx.to_numpy_array(G).tolist() == [[0, 2], [2, 0]] else: assert nx.to_numpy_array(G).tolist() == [[3, 2], [2, 0]] G = to_networkx(data, edge_attrs=['weight'], to_undirected=False) assert nx.to_numpy_array(G).tolist() == [[3, 1], [2, 0]] G = to_networkx(data, edge_attrs=['weight'], to_undirected='upper') assert nx.to_numpy_array(G).tolist() == [[3, 1], [1, 0]] G = to_networkx(data, edge_attrs=['weight'], to_undirected='lower') assert nx.to_numpy_array(G).tolist() == [[3, 2], [2, 0]] @withPackage('networkx') def test_to_networkx_undirected_options(): import networkx as nx edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 0]]) data = Data(edge_index=edge_index, num_nodes=3) G = to_networkx(data, to_undirected=True) assert nx.to_numpy_array(G).tolist() == [[0, 1, 1], [1, 0, 1], [1, 1, 0]] G = to_networkx(data, to_undirected='upper') assert nx.to_numpy_array(G).tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] G = to_networkx(data, to_undirected='lower') assert nx.to_numpy_array(G).tolist() == [[0, 1, 1], [1, 0, 0], [1, 0, 0]] @withPackage('networkx') def test_to_networkx_hetero(): edge_index = get_random_edge_index(5, 10, 20, coalesce=True) data = HeteroData() data['global_id'] = 0 data['author'].x = torch.arange(5) data['paper'].x = torch.arange(10) data['author', 'paper'].edge_index = edge_index data['author', 'paper'].edge_attr = torch.arange(edge_index.size(1)) G = to_networkx(data, node_attrs=['x'], edge_attrs=['edge_attr'], graph_attrs=['global_id']) assert G.number_of_nodes() == 15 assert G.number_of_edges() == edge_index.size(1) assert G.graph == {'global_id': 0} for i, (v, data) in enumerate(G.nodes(data=True)): assert i == v assert len(data) == 2 if i < 5: assert data['x'] == i assert data['type'] == 'author' else: assert data['x'] == i - 5 assert data['type'] == 'paper' for i, (v, w, data) in enumerate(G.edges(data=True)): assert v == int(edge_index[0, i]) assert w == int(edge_index[1, i]) + 5 assert len(data) == 2 assert data['type'] == ('author', 'to', 'paper') assert data['edge_attr'] == i @withPackage('networkx') def test_from_networkx(): x = torch.randn(2, 8) pos = torch.randn(2, 3) edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_attr = torch.randn(edge_index.size(1)) perm = torch.tensor([0, 2, 1]) data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr) G = to_networkx(data, node_attrs=['x', 'pos'], edge_attrs=['edge_attr']) data = from_networkx(G) assert len(data) == 4 assert data.x.tolist() == x.tolist() assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index[:, perm].tolist() assert data.edge_attr.tolist() == edge_attr[perm].tolist() @withPackage('networkx') def test_from_networkx_group_attrs(): x = torch.randn(2, 2) x1 = torch.randn(2, 4) x2 = torch.randn(2, 8) edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_attr1 = torch.randn(edge_index.size(1)) edge_attr2 = torch.randn(edge_index.size(1)) perm = torch.tensor([0, 2, 1]) data = Data(x=x, x1=x1, x2=x2, edge_index=edge_index, edge_attr1=edge_attr1, edge_attr2=edge_attr2) G = to_networkx(data, node_attrs=['x', 'x1', 'x2'], edge_attrs=['edge_attr1', 'edge_attr2']) data = from_networkx(G, group_node_attrs=['x', 'x2'], group_edge_attrs=all) assert len(data) == 4 assert data.x.tolist() == torch.cat([x, x2], dim=-1).tolist() assert data.x1.tolist() == x1.tolist() assert data.edge_index.tolist() == edge_index[:, perm].tolist() assert data.edge_attr.tolist() == torch.stack([edge_attr1, edge_attr2], dim=-1)[perm].tolist() @withPackage('networkx') def test_networkx_vice_versa_convert(): import networkx as nx G = nx.complete_graph(5) assert G.is_directed() is False data = from_networkx(G) assert data.is_directed() is False G = to_networkx(data) assert G.is_directed() is True G = nx.to_undirected(G) assert G.is_directed() is False @withPackage('networkx') def test_from_networkx_non_consecutive(): import networkx as nx graph = nx.Graph() graph.add_node(4) graph.add_node(2) graph.add_edge(4, 2) for node in graph.nodes(): graph.nodes[node]['x'] = node data = from_networkx(graph) assert len(data) == 2 assert data.x.tolist() == [4, 2] assert data.edge_index.tolist() == [[0, 1], [1, 0]] @withPackage('networkx') def test_from_networkx_inverse(): import networkx as nx graph = nx.Graph() graph.add_node(3) graph.add_node(2) graph.add_node(1) graph.add_node(0) graph.add_edge(3, 1) graph.add_edge(2, 1) graph.add_edge(1, 0) data = from_networkx(graph) assert len(data) == 2 assert data.edge_index.tolist() == [[0, 1, 2, 2, 2, 3], [2, 2, 0, 1, 3, 2]] assert data.num_nodes == 4 @withPackage('networkx') def test_from_networkx_non_numeric_labels(): import networkx as nx graph = nx.Graph() graph.add_node('4') graph.add_node('2') graph.add_edge('4', '2') for node in graph.nodes(): graph.nodes[node]['x'] = node data = from_networkx(graph) assert len(data) == 2 assert data.x == ['4', '2'] assert data.edge_index.tolist() == [[0, 1], [1, 0]] @withPackage('networkx') def test_from_networkx_without_edges(): import networkx as nx graph = nx.Graph() graph.add_node(1) graph.add_node(2) data = from_networkx(graph) assert len(data) == 2 assert data.edge_index.size() == (2, 0) assert data.num_nodes == 2 @withPackage('networkx') def test_from_networkx_with_same_node_and_edge_attributes(): import networkx as nx G = nx.Graph() G.add_nodes_from([(0, {'age': 1}), (1, {'age': 6}), (2, {'age': 5})]) G.add_edges_from([(0, 1, {'age': 2}), (1, 2, {'age': 7})]) data = from_networkx(G) assert len(data) == 4 assert data.age.tolist() == [1, 6, 5] assert data.num_nodes == 3 assert data.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert data.edge_age.tolist() == [2, 2, 7, 7] data = from_networkx(G, group_node_attrs=all, group_edge_attrs=all) assert len(data) == 3 assert data.x.tolist() == [[1], [6], [5]] assert data.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert data.edge_attr.tolist() == [[2], [2], [7], [7]] @withPackage('networkx') def test_from_networkx_subgraph_convert(): import networkx as nx G = nx.complete_graph(5) edge_index = from_networkx(G).edge_index sub_edge_index_1, _ = subgraph([0, 1, 3, 4], edge_index, relabel_nodes=True) sub_edge_index_2 = from_networkx(G.subgraph([0, 1, 3, 4])).edge_index assert sub_edge_index_1.tolist() == sub_edge_index_2.tolist() @withPackage('networkx') @pytest.mark.parametrize('n', [100]) @pytest.mark.parametrize('p', [0.8]) @pytest.mark.parametrize('q', [0.2]) def test_from_networkx_sbm(n, p, q): import networkx as nx G = nx.stochastic_block_model( sizes=[n // 2, n // 2], p=[[p, q], [q, p]], seed=0, directed=False, ) data = from_networkx(G) assert data.num_nodes == 100 assert torch.equal(data.block[:50], data.block.new_zeros(50)) assert torch.equal(data.block[50:], data.block.new_ones(50)) @withPackage('networkit') def test_to_networkit_vice_versa(): edge_index = torch.tensor([[0, 1], [1, 0]]) g = to_networkit(edge_index, directed=False) assert not g.isDirected() assert not g.isWeighted() edge_index, edge_weight = from_networkit(g) assert edge_index.tolist() == [[0, 1], [1, 0]] assert edge_weight is None @withPackage('networkit') @pytest.mark.parametrize('directed', [True, False]) @pytest.mark.parametrize('num_nodes', [None, 3]) @pytest.mark.parametrize('edge_weight', [None, torch.rand(3)]) def test_to_networkit(directed, edge_weight, num_nodes): import networkit edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]], dtype=torch.long) g = to_networkit(edge_index, edge_weight, num_nodes, directed) assert isinstance(g, networkit.Graph) assert g.isDirected() == directed assert g.numberOfNodes() == 3 if edge_weight is None: edge_weight = torch.tensor([1., 1., 1.]) assert g.weight(0, 1) == float(edge_weight[0]) assert g.weight(1, 2) == float(edge_weight[2]) if directed: assert g.numberOfEdges() == 3 assert g.weight(1, 0) == float(edge_weight[1]) else: assert g.numberOfEdges() == 2 @pytest.mark.parametrize('directed', [True, False]) @pytest.mark.parametrize('weighted', [True, False]) @withPackage('networkit') def test_from_networkit(directed, weighted): import networkit g = networkit.Graph(3, weighted=weighted, directed=directed) g.addEdge(0, 1) g.addEdge(1, 2) if directed: g.addEdge(1, 0) if weighted: for i, (u, v) in enumerate(g.iterEdges()): g.setWeight(u, v, i + 1) edge_index, edge_weight = from_networkit(g) if directed: assert edge_index.tolist() == [[0, 1, 1], [1, 2, 0]] if weighted: assert edge_weight.tolist() == [1, 2, 3] else: assert edge_weight is None else: assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] if weighted: assert edge_weight.tolist() == [1, 1, 2, 2] else: assert edge_weight is None @withPackage('trimesh') def test_trimesh_vice_versa(): pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], dtype=torch.float) face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t() data = Data(pos=pos, face=face) mesh = to_trimesh(data) data = from_trimesh(mesh) assert pos.tolist() == data.pos.tolist() assert face.tolist() == data.face.tolist() @withPackage('trimesh') def test_to_trimesh(): import trimesh pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]) face = torch.tensor([[0, 1, 2], [2, 1, 3]]).t() data = Data(pos=pos, face=face) obj = to_trimesh(data) assert isinstance(obj, trimesh.Trimesh) assert obj.vertices.shape == (4, 3) assert obj.faces.shape == (2, 3) assert obj.vertices.tolist() == data.pos.tolist() assert obj.faces.tolist() == data.face.t().contiguous().tolist() @withPackage('trimesh') def test_from_trimesh(): import trimesh vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0]] faces = [[0, 1, 2]] mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) data = from_trimesh(mesh) assert data.pos.tolist() == vertices assert data.face.t().contiguous().tolist() == faces @withPackage('cudf', 'cugraph') @pytest.mark.parametrize('edge_weight', [None, torch.rand(4)]) @pytest.mark.parametrize('relabel_nodes', [True, False]) @pytest.mark.parametrize('directed', [True, False]) def test_to_cugraph(edge_weight, directed, relabel_nodes): import cugraph if directed: edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) else: edge_index = torch.tensor([[0, 1], [1, 2]]) if edge_weight is not None: edge_weight = edge_weight[:edge_index.size(1)] graph = to_cugraph(edge_index, edge_weight, relabel_nodes, directed) assert isinstance(graph, cugraph.Graph) assert graph.number_of_nodes() == 3 edge_list = graph.view_edge_list() assert edge_list is not None edge_list = edge_list.sort_values( by=[graph.source_columns, graph.destination_columns]) cu_edge_index = edge_list[[ graph.source_columns, graph.destination_columns ]].to_pandas().values cu_edge_index = torch.from_numpy(cu_edge_index).t() cu_edge_weight = None if edge_weight is not None: cu_edge_weight = edge_list[graph.weight_column].to_pandas().values cu_edge_weight = torch.from_numpy(cu_edge_weight) cu_edge_index, cu_edge_weight = sort_edge_index(cu_edge_index, cu_edge_weight) assert torch.equal(edge_index, cu_edge_index.cpu()) if edge_weight is not None: assert torch.allclose(edge_weight, cu_edge_weight.cpu()) @withPackage('cudf', 'cugraph') @pytest.mark.parametrize('edge_weight', [None, torch.randn(4)]) @pytest.mark.parametrize('directed', [True, False]) @pytest.mark.parametrize('relabel_nodes', [True, False]) def test_from_cugraph(edge_weight, directed, relabel_nodes): import cudf import cugraph from torch.utils.dlpack import to_dlpack if directed: edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) else: edge_index = torch.tensor([[0, 1], [1, 2]]) if edge_weight is not None: edge_weight = edge_weight[:edge_index.size(1)] G = cugraph.Graph(directed=directed) df = cudf.DataFrame({ 'source': cudf.from_dlpack(to_dlpack(edge_index[0])), 'destination': cudf.from_dlpack(to_dlpack(edge_index[1])), }) if edge_weight is not None: df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight)) G.from_cudf_edgelist( df, source='source', destination='destination', edge_attr='weight' if edge_weight is not None else None, renumber=relabel_nodes, ) cu_edge_index, cu_edge_weight = from_cugraph(G) cu_edge_index, cu_edge_weight = sort_edge_index(cu_edge_index, cu_edge_weight) assert torch.equal(edge_index, cu_edge_index.cpu()) if edge_weight is not None: assert torch.allclose(edge_weight, cu_edge_weight.cpu()) else: assert cu_edge_weight is None @withPackage('dgl') def test_to_dgl_graph(): x = torch.randn(5, 3) edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]]) edge_attr = torch.randn(edge_index.size(1), 2) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) g = to_dgl(data) assert torch.equal(data.x, g.ndata['x']) row, col = g.edges() assert torch.equal(row, edge_index[0]) assert torch.equal(col, edge_index[1]) assert torch.equal(data.edge_attr, g.edata['edge_attr']) @withPackage('dgl') def test_to_dgl_hetero_graph(): data = HeteroData() data['v1'].x = torch.randn(4, 3) data['v2'].x = torch.randn(4, 3) data['v1', 'v2'].edge_index = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]) data['v1', 'v2'].edge_attr = torch.randn(4, 2) g = to_dgl(data) assert data['v1', 'v2'].num_edges == g.num_edges(('v1', 'to', 'v2')) assert data['v1'].num_nodes == g.num_nodes('v1') assert data['v2'].num_nodes == g.num_nodes('v2') assert torch.equal(data['v1'].x, g.nodes['v1'].data['x']) assert torch.equal(data['v2'].x, g.nodes['v2'].data['x']) row, col = g.edges() assert torch.equal(row, data['v1', 'v2'].edge_index[0]) assert torch.equal(col, data['v1', 'v2'].edge_index[1]) assert torch.equal(g.edata['edge_attr'], data['v1', 'v2'].edge_attr) @withPackage('dgl', 'torch_sparse') def test_to_dgl_sparse(): from torch_geometric.transforms import ToSparseTensor x = torch.randn(5, 3) edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]]) edge_attr = torch.randn(edge_index.size(1), 2) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) data = ToSparseTensor()(data) g = to_dgl(data) assert torch.equal(data.x, g.ndata["x"]) pyg_row, pyg_col, _ = data.adj_t.t().coo() dgl_row, dgl_col = g.edges() assert torch.equal(pyg_row, dgl_row) assert torch.equal(pyg_col, dgl_col) assert torch.equal(data.edge_attr, g.edata['edge_attr']) @withPackage('dgl') def test_from_dgl_graph(): import dgl g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) g.ndata['x'] = torch.randn(g.num_nodes(), 3) g.edata['edge_attr'] = torch.randn(g.num_edges()) data = from_dgl(g) assert torch.equal(data.x, g.ndata['x']) row, col = g.edges() assert torch.equal(data.edge_index[0], row) assert torch.equal(data.edge_index[1], col) assert torch.equal(data.edge_attr, g.edata['edge_attr']) @withPackage('dgl') def test_from_dgl_hetero_graph(): import dgl g = dgl.heterograph({ ('v1', 'to', 'v2'): ( [0, 1, 1, 2, 3, 3, 4], [0, 0, 1, 1, 1, 2, 2], ) }) g.nodes['v1'].data['x'] = torch.randn(5, 3) g.nodes['v2'].data['x'] = torch.randn(3, 3) data = from_dgl(g) assert data['v1', 'v2'].num_edges == g.num_edges(('v1', 'to', 'v2')) assert data['v1'].num_nodes == g.num_nodes('v1') assert data['v2'].num_nodes == g.num_nodes('v2') assert torch.equal(data['v1'].x, g.nodes['v1'].data['x']) assert torch.equal(data['v2'].x, g.nodes['v2'].data['x']) ================================================ FILE: test/utils/test_cross_entropy.py ================================================ import pytest import torch import torch.nn.functional as F from torch_geometric.testing import withCUDA from torch_geometric.utils.cross_entropy import sparse_cross_entropy @withCUDA @pytest.mark.parametrize('with_edge_label_weight', [False, True]) def test_sparse_cross_entropy_multiclass( with_edge_label_weight: bool, device: torch.device, ) -> None: x = torch.randn(5, 5, device=device, requires_grad=True) y = torch.eye(5, device=device) edge_label_index = y.nonzero().t() edge_label_weight = None if with_edge_label_weight: edge_label_weight = torch.rand(edge_label_index.size(1), device=device) y[y == 1.0] = edge_label_weight expected = F.cross_entropy(x, y) expected.backward() expected_grad = x.grad x.grad = None out = sparse_cross_entropy(x, edge_label_index, edge_label_weight) out.backward() assert torch.allclose(expected, out) assert torch.allclose(expected_grad, x.grad) @withCUDA @pytest.mark.parametrize('with_edge_label_weight', [False, True]) def test_sparse_cross_entropy_multilabel( with_edge_label_weight: bool, device: torch.device, ) -> None: x = torch.randn(4, 4, device=device, requires_grad=True) y = torch.randint_like(x, 0, 2) edge_label_index = y.nonzero().t() edge_label_weight = None if with_edge_label_weight: edge_label_weight = torch.rand(edge_label_index.size(1), device=device) y[y == 1.0] = edge_label_weight expected = F.cross_entropy(x, y) expected.backward() expected_grad = x.grad x.grad = None out = sparse_cross_entropy(x, edge_label_index, edge_label_weight) out.backward() assert torch.allclose(expected, out) assert torch.allclose(expected_grad, x.grad) @withCUDA @pytest.mark.parametrize('edge_label_weights', [ [2.0, -10.0, 1.0, -5.0, 4.0, 0.0, -1.0], [-2.0, -1.0, -1.0, -3.0, -4.0, -10.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]) def test_sparse_cross_entropy_negative_weight( edge_label_weights: list[float], device: torch.device, ) -> None: x = torch.randn(4, 8, device=device, requires_grad=True) edge_label_index = torch.tensor([ [0, 0, 1, 2, 2, 2, 3], [2, 4, 6, 5, 3, 1, 1], ], device=device) edge_label_weight = torch.tensor(edge_label_weights, device=device) pos_mask = edge_label_weight >= 0 y = torch.zeros_like(x) y[ edge_label_index[0, pos_mask], edge_label_index[1, pos_mask], ] = edge_label_weight[pos_mask] _x = x.clone() _x[ edge_label_index[0, ~pos_mask], edge_label_index[1, ~pos_mask], ] += edge_label_weight[~pos_mask].abs().log() expected = F.cross_entropy(_x, y) expected.backward() expected_grad = x.grad x.grad = None out = sparse_cross_entropy(x, edge_label_index, edge_label_weight) out.backward() assert torch.allclose(expected, out) assert torch.allclose(expected_grad, x.grad) ================================================ FILE: test/utils/test_degree.py ================================================ import torch from torch_geometric.utils import degree def test_degree(): row = torch.tensor([0, 1, 0, 2, 0]) deg = degree(row, dtype=torch.long) assert deg.dtype == torch.long assert deg.tolist() == [3, 1, 1] ================================================ FILE: test/utils/test_dropout.py ================================================ import pytest import torch from torch_geometric.testing import withPackage from torch_geometric.utils import ( dropout_adj, dropout_edge, dropout_node, dropout_path, ) def test_dropout_adj(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2], ]) edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) with pytest.warns(UserWarning, match="'dropout_adj' is deprecated"): out = dropout_adj(edge_index, edge_attr, training=False) assert edge_index.tolist() == out[0].tolist() assert edge_attr.tolist() == out[1].tolist() torch.manual_seed(5) with pytest.warns(UserWarning, match="'dropout_adj' is deprecated"): out = dropout_adj(edge_index, edge_attr) assert out[0].tolist() == [[0, 1, 2, 2], [1, 2, 1, 3]] assert out[1].tolist() == [1, 3, 4, 5] torch.manual_seed(6) with pytest.warns(UserWarning, match="'dropout_adj' is deprecated"): out = dropout_adj(edge_index, edge_attr, force_undirected=True) assert out[0].tolist() == [[0, 1, 1, 2], [1, 2, 0, 1]] assert out[1].tolist() == [1, 3, 1, 3] def test_dropout_node(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2], ]) out = dropout_node(edge_index, training=False) assert edge_index.tolist() == out[0].tolist() assert out[1].tolist() == [True, True, True, True, True, True] assert out[2].tolist() == [True, True, True, True] torch.manual_seed(5) out = dropout_node(edge_index) assert out[0].tolist() == [[2, 3], [3, 2]] assert out[1].tolist() == [False, False, False, False, True, True] assert out[2].tolist() == [True, False, True, True] def test_dropout_edge(): edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) out = dropout_edge(edge_index, training=False) assert edge_index.tolist() == out[0].tolist() assert out[1].tolist() == [True, True, True, True, True, True] torch.manual_seed(5) out = dropout_edge(edge_index) assert out[0].tolist() == [[0, 1, 2, 2], [1, 2, 1, 3]] assert out[1].tolist() == [True, False, True, True, True, False] torch.manual_seed(6) out = dropout_edge(edge_index, force_undirected=True) assert out[0].tolist() == [[0, 1, 1, 2], [1, 2, 0, 1]] assert out[1].tolist() == [0, 2, 0, 2] @withPackage('torch_cluster') def test_dropout_path(): edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) out = dropout_path(edge_index, training=False) assert edge_index.tolist() == out[0].tolist() assert out[1].tolist() == [True, True, True, True, True, True] torch.manual_seed(4) out = dropout_path(edge_index, p=0.2) assert out[0].tolist() == [[0, 1], [1, 0]] assert out[1].tolist() == [True, True, False, False, False, False] assert edge_index[:, out[1]].tolist() == out[0].tolist() # test with unsorted edges torch.manual_seed(5) edge_index = torch.tensor([[3, 5, 2, 2, 2, 1], [1, 0, 0, 1, 3, 2]]) out = dropout_path(edge_index, p=0.2) assert out[0].tolist() == [[3, 2, 2, 1], [1, 1, 3, 2]] assert out[1].tolist() == [True, False, False, True, True, True] assert edge_index[:, out[1]].tolist() == out[0].tolist() # test with isolated nodes torch.manual_seed(7) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 2, 4]]) out = dropout_path(edge_index, p=0.2) assert out[0].tolist() == [[2, 3], [2, 4]] assert out[1].tolist() == [False, False, True, True] assert edge_index[:, out[1]].tolist() == out[0].tolist() ================================================ FILE: test/utils/test_embedding.py ================================================ import pytest import torch from torch_geometric.nn import GCNConv, Linear from torch_geometric.utils import get_embeddings from torch_geometric.utils.embedding import get_embeddings_hetero class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(5, 6) self.conv2 = GCNConv(6, 7) def forward(self, x0, edge_index): x1 = self.conv1(x0, edge_index) x2 = self.conv2(x1, edge_index) return [x1, x2] def test_get_embeddings(): x = torch.randn(6, 5) edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]]) with pytest.warns(UserWarning, match="any 'MessagePassing' layers"): intermediate_outs = get_embeddings(Linear(5, 5), x) assert len(intermediate_outs) == 0 model = GNN() expected_embeddings = model(x, edge_index) embeddings = get_embeddings(model, x, edge_index) assert len(embeddings) == 2 for expected, out in zip(expected_embeddings, embeddings): assert torch.allclose(expected, out) def test_get_embeddings_hetero(hetero_data, hetero_model): # Create model using the metadata from hetero_data metadata = hetero_data.metadata() model = hetero_model(metadata) # Get heterogeneous embeddings embeddings_dict = get_embeddings_hetero(model, None, hetero_data.x_dict, hetero_data.edge_index_dict) # Verify the structure of the returned embeddings assert isinstance(embeddings_dict, dict) assert 'paper' in embeddings_dict assert 'author' in embeddings_dict # Verify that we have embeddings for both node types assert len(embeddings_dict['paper']) > 0 assert len(embeddings_dict['author']) > 0 # Check that the embeddings have the right shape num_paper_nodes = hetero_data['paper'].num_nodes num_author_nodes = hetero_data['author'].num_nodes # Verify dimensions of embeddings assert embeddings_dict['paper'][0].shape == (num_paper_nodes, 32 ) # First layer assert embeddings_dict['author'][0].shape == (num_author_nodes, 32 ) # First layer ================================================ FILE: test/utils/test_functions.py ================================================ import torch from torch_geometric.utils import cumsum def test_cumsum(): x = torch.tensor([2, 4, 1]) assert cumsum(x).tolist() == [0, 2, 6, 7] x = torch.tensor([[2, 4], [3, 6]]) assert cumsum(x, dim=1).tolist() == [[0, 2, 6], [0, 3, 9]] ================================================ FILE: test/utils/test_geodesic.py ================================================ from math import sqrt import torch from torch_geometric.testing import withPackage from torch_geometric.utils import geodesic_distance @withPackage('gdist') def test_geodesic_distance(): pos = torch.tensor([ [0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [2.0, 2.0, 0.0], ]) face = torch.tensor([[0, 1, 3], [0, 2, 3]]).t() out = geodesic_distance(pos, face) expected = torch.tensor([ [0.0, 1.0, 1.0, sqrt(2)], [1.0, 0.0, sqrt(2), 1.0], [1.0, sqrt(2), 0.0, 1.0], [sqrt(2), 1.0, 1.0, 0.0], ]) assert torch.allclose(out, expected) assert torch.allclose(out, geodesic_distance(pos, face, num_workers=-1)) out = geodesic_distance(pos, face, norm=False) expected = torch.tensor([ [0, 2, 2, 2 * sqrt(2)], [2, 0, 2 * sqrt(2), 2], [2, 2 * sqrt(2), 0, 2], [2 * sqrt(2), 2, 2, 0], ]) assert torch.allclose(out, expected) src = torch.tensor([0, 0, 0, 0]) dst = torch.tensor([0, 1, 2, 3]) out = geodesic_distance(pos, face, src=src, dst=dst) expected = torch.tensor([0.0, 1.0, 1.0, sqrt(2)]) assert torch.allclose(out, expected) out = geodesic_distance(pos, face, dst=dst) expected = torch.tensor([0.0, 0.0, 0.0, 0.0]) assert torch.allclose(out, expected) ================================================ FILE: test/utils/test_grid.py ================================================ import torch from torch_geometric.testing import is_full_test from torch_geometric.utils import grid def test_grid(): (row, col), pos = grid(height=3, width=2) expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2] expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5] expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5] expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5] expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]] assert row.tolist() == expected_row assert col.tolist() == expected_col assert pos.tolist() == expected_pos if is_full_test(): jit = torch.jit.script(grid) (row, col), pos = jit(height=3, width=2) assert row.tolist() == expected_row assert col.tolist() == expected_col assert pos.tolist() == expected_pos ================================================ FILE: test/utils/test_hetero.py ================================================ import torch from torch_geometric.testing import get_random_edge_index from torch_geometric.utils.hetero import construct_bipartite_edge_index def test_construct_bipartite_edge_index(): edge_index = get_random_edge_index(4, 6, num_edges=20) edge_index_dict = { ('author', 'paper'): edge_index, ('paper', 'author'): edge_index.flip([0]), } edge_attr_dict = { ('author', 'paper'): torch.randn(edge_index.size(1), 16), ('paper', 'author'): torch.randn(edge_index.size(1), 16) } edge_index, edge_attr = construct_bipartite_edge_index( edge_index_dict, src_offset_dict={ ('author', 'paper'): 0, ('paper', 'author'): 4 }, dst_offset_dict={ 'author': 0, 'paper': 4 }, edge_attr_dict=edge_attr_dict, ) assert edge_index.size() == (2, 40) assert edge_index.min() >= 0 assert edge_index[0].max() > 4 and edge_index[1].max() > 6 assert edge_index.max() <= 10 assert edge_attr.size() == (40, 16) assert torch.equal(edge_attr[:20], edge_attr_dict['author', 'paper']) assert torch.equal(edge_attr[20:], edge_attr_dict['paper', 'author']) ================================================ FILE: test/utils/test_homophily.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.typing import SparseTensor from torch_geometric.utils import homophily def test_homophily(): edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 0, 4]]) y = torch.tensor([0, 0, 0, 0, 1]) batch = torch.tensor([0, 0, 0, 1, 1]) row, col = edge_index if torch_geometric.typing.WITH_TORCH_SPARSE: adj = SparseTensor(row=row, col=col, sparse_sizes=(5, 5)) method = 'edge' assert pytest.approx(homophily(edge_index, y, method=method)) == 0.75 if torch_geometric.typing.WITH_TORCH_SPARSE: assert pytest.approx(homophily(adj, y, method=method)) == 0.75 assert homophily(edge_index, y, batch, method).tolist() == [1., 0.] method = 'node' assert pytest.approx(homophily(edge_index, y, method=method)) == 0.6 if torch_geometric.typing.WITH_TORCH_SPARSE: assert pytest.approx(homophily(adj, y, method=method)) == 0.6 assert homophily(edge_index, y, batch, method).tolist() == [1., 0.] method = 'edge_insensitive' assert pytest.approx(homophily(edge_index, y, method=method)) == 0.1999999 if torch_geometric.typing.WITH_TORCH_SPARSE: assert pytest.approx(homophily(adj, y, method=method)) == 0.1999999 assert homophily(edge_index, y, batch, method).tolist() == [0., 0.] ================================================ FILE: test/utils/test_index_sort.py ================================================ import torch from torch_geometric.testing import withDevice from torch_geometric.utils import index_sort @withDevice def test_index_sort_stable(device): for _ in range(100): inputs = torch.randint(0, 4, size=(10, ), device=device) out = index_sort(inputs, stable=True) expected = torch.sort(inputs, stable=True) assert torch.equal(out[0], expected[0]) assert torch.equal(out[1], expected[1]) ================================================ FILE: test/utils/test_isolated.py ================================================ import torch from torch_geometric.testing import is_full_test from torch_geometric.utils import ( contains_isolated_nodes, remove_isolated_nodes, ) def test_contains_isolated_nodes(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) assert not contains_isolated_nodes(edge_index) assert contains_isolated_nodes(edge_index, num_nodes=3) if is_full_test(): jit = torch.jit.script(contains_isolated_nodes) assert not jit(edge_index) assert jit(edge_index, num_nodes=3) edge_index = torch.tensor([[0, 1, 2, 0], [1, 0, 2, 0]]) assert contains_isolated_nodes(edge_index) def test_remove_isolated_nodes(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) out, _, mask = remove_isolated_nodes(edge_index) assert out.tolist() == [[0, 1, 0], [1, 0, 0]] assert mask.tolist() == [1, 1] if is_full_test(): jit = torch.jit.script(remove_isolated_nodes) out, _, mask = jit(edge_index) assert out.tolist() == [[0, 1, 0], [1, 0, 0]] assert mask.tolist() == [1, 1] out, _, mask = remove_isolated_nodes(edge_index, num_nodes=3) assert out.tolist() == [[0, 1, 0], [1, 0, 0]] assert mask.tolist() == [1, 1, 0] edge_index = torch.tensor([[0, 2, 1, 0, 2], [2, 0, 1, 0, 2]]) edge_attr = torch.tensor([1, 2, 3, 4, 5]) out1, out2, mask = remove_isolated_nodes(edge_index, edge_attr) assert out1.tolist() == [[0, 1, 0, 1], [1, 0, 0, 1]] assert out2.tolist() == [1, 2, 4, 5] assert mask.tolist() == [1, 0, 1] ================================================ FILE: test/utils/test_laplacian.py ================================================ import torch from torch_geometric.testing import is_full_test from torch_geometric.utils import get_laplacian def test_get_laplacian(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) edge_weight = torch.tensor([1, 2, 2, 4], dtype=torch.float) lap = get_laplacian(edge_index, edge_weight) assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4] if is_full_test(): jit = torch.jit.script(get_laplacian) lap = jit(edge_index, edge_weight) assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4] lap_sym = get_laplacian(edge_index, edge_weight, normalization='sym') assert lap_sym[0].tolist() == lap[0].tolist() assert lap_sym[1].tolist() == [-0.5, -1, -0.5, -1, 1, 1, 1] lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw') assert lap_rw[0].tolist() == lap[0].tolist() assert lap_rw[1].tolist() == [-1, -0.5, -0.5, -1, 1, 1, 1] ================================================ FILE: test/utils/test_lexsort.py ================================================ import numpy as np import torch from torch_geometric.utils import lexsort def test_lexsort(): keys = [torch.randn(100) for _ in range(3)] expected = np.lexsort([key.numpy() for key in keys]) assert torch.equal(lexsort(keys), torch.from_numpy(expected)) ================================================ FILE: test/utils/test_loop.py ================================================ import torch from torch_geometric import EdgeIndex from torch_geometric.utils import ( add_remaining_self_loops, add_self_loops, contains_self_loops, get_self_loop_attr, remove_self_loops, segregate_self_loops, to_torch_coo_tensor, ) def test_contains_self_loops(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) assert contains_self_loops(edge_index) edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]]) assert not contains_self_loops(edge_index) def test_remove_self_loops(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6]]) expected = torch.tensor([[0, 1], [1, 0]]) out = remove_self_loops(edge_index) assert out[0].equal(expected) assert out[1] is None out = remove_self_loops(edge_index, edge_attr) assert out[0].equal(expected) assert out[1].equal(torch.tensor([[1, 2], [3, 4]])) adj = to_torch_coo_tensor(edge_index) adj, _ = remove_self_loops(adj) assert torch.diag(adj.to_dense()).tolist() == [0, 0] edge_index = EdgeIndex( edge_index, sparse_size=(2, 2), sort_order='row', is_undirected=True, ) out = remove_self_loops(edge_index) assert out[0].equal(expected) assert out[0].sparse_size() == (2, 2) assert out[0].sort_order == 'row' assert out[0].is_undirected assert out[1] is None out = remove_self_loops(edge_index, edge_attr) assert out[0].equal(expected) assert out[0].sparse_size() == (2, 2) assert out[0].sort_order == 'row' assert out[0].is_undirected assert out[1].equal(torch.tensor([[1, 2], [3, 4]])) def test_segregate_self_loops(): edge_index = torch.tensor([[0, 0, 1], [0, 1, 0]]) out = segregate_self_loops(edge_index) assert out[0].equal(torch.tensor([[0, 1], [1, 0]])) assert out[1] is None assert out[2].equal(torch.tensor([[0], [0]])) assert out[3] is None edge_attr = torch.tensor([1, 2, 3]) out = segregate_self_loops(edge_index, edge_attr) assert out[0].equal(torch.tensor([[0, 1], [1, 0]])) assert out[1].equal(torch.tensor([2, 3])) assert out[2].equal(torch.tensor([[0], [0]])) assert out[3].equal(torch.tensor([1])) edge_index = EdgeIndex( edge_index, sparse_size=(2, 2), sort_order='row', is_undirected=True, ) out = segregate_self_loops(edge_index) assert out[0].equal(torch.tensor([[0, 1], [1, 0]])) assert out[0].sparse_size() == (2, 2) assert out[0].sort_order == 'row' assert out[0].is_undirected assert out[1] is None assert out[2].equal(torch.tensor([[0], [0]])) assert out[2].sparse_size() == (2, 2) assert out[2].sort_order == 'row' assert out[2].is_undirected assert out[3] is None out = segregate_self_loops(edge_index, edge_attr) assert out[0].equal(torch.tensor([[0, 1], [1, 0]])) assert out[0].sparse_size() == (2, 2) assert out[0].sort_order == 'row' assert out[0].is_undirected assert out[1].equal(torch.tensor([2, 3])) assert out[2].equal(torch.tensor([[0], [0]])) assert out[2].sparse_size() == (2, 2) assert out[2].sort_order == 'row' assert out[2].is_undirected assert out[3].equal(torch.tensor([1])) def test_add_self_loops(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_weight = torch.tensor([0.5, 0.5, 0.5]) edge_attr = torch.eye(3) adj = to_torch_coo_tensor(edge_index, edge_weight) expected = torch.tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]) assert add_self_loops(edge_index)[0].equal(expected) out = add_self_loops(edge_index, edge_weight) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 1., 1.])) out = add_self_loops(adj)[0] assert out._indices().equal(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]])) assert out._values().equal(torch.tensor([1.5, 0.5, 0.5, 1.0])) out = add_self_loops(edge_index, edge_weight, fill_value=5) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 5.0, 5.0])) out = add_self_loops(adj, fill_value=5)[0] assert out._indices().equal(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]])) assert out._values().equal(torch.tensor([5.5, 0.5, 0.5, 5.0])) out = add_self_loops(edge_index, edge_weight, fill_value=torch.tensor(2.)) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 2., 2.])) out = add_self_loops(adj, fill_value=torch.tensor(2.))[0] assert out._indices().equal(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]])) assert out._values().equal(torch.tensor([2.5, 0.5, 0.5, 2.0])) out = add_self_loops(edge_index, edge_weight, fill_value='add') assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 1, 0.5])) # Tests with `edge_attr`: out = add_self_loops(edge_index, edge_attr) assert out[0].equal(expected) assert out[1].equal( torch.tensor([ [1., 0., 0.], [0., 1., 0.], [0., 0., 1.], [1., 1., 1.], [1., 1., 1.], ])) out = add_self_loops(edge_index, edge_attr, fill_value=torch.tensor([0., 1., 0.])) assert out[0].equal(expected) assert out[1].equal( torch.tensor([ [1., 0., 0.], [0., 1., 0.], [0., 0., 1.], [0., 1., 0.], [0., 1., 0.], ])) out = add_self_loops(edge_index, edge_attr, fill_value='add') assert out[0].equal(expected) assert out[1].equal( torch.tensor([ [1., 0., 0.], [0., 1., 0.], [0., 0., 1.], [0., 1., 1.], [1., 0., 0.], ])) edge_index = EdgeIndex( edge_index, sparse_size=(2, 2), sort_order='row', is_undirected=True, ) out, _ = add_self_loops(edge_index) assert out.equal(expected) assert out.sparse_size() == (2, 2) assert out.sort_order is None assert out.is_undirected # Test empty `edge_index` and `edge_weight`: edge_index = torch.empty(2, 0, dtype=torch.long) edge_weight = torch.empty(0) out = add_self_loops(edge_index, edge_weight, num_nodes=1) assert out[0].equal(torch.tensor([[0], [0]])) assert out[1].equal(torch.tensor([1.])) def test_add_self_loops_bipartite(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj = to_torch_coo_tensor(edge_index, size=(4, 2)) edge_index, _ = add_self_loops(edge_index, num_nodes=(4, 2)) assert edge_index.equal( torch.tensor([ [0, 1, 2, 3, 0, 1], [0, 0, 1, 1, 0, 1], ])) adj, _ = add_self_loops(adj) assert adj._indices().equal( torch.tensor([ [0, 1, 1, 2, 3], [0, 0, 1, 1, 1], ])) def test_add_remaining_self_loops(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_weight = torch.tensor([0.5, 0.5, 0.5]) edge_attr = torch.eye(3) expected = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]]) out = add_remaining_self_loops(edge_index, edge_weight) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 1])) out = add_remaining_self_loops(edge_index, edge_weight, fill_value=5) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 5.0])) out = add_remaining_self_loops(edge_index, edge_weight, fill_value=torch.tensor(2.)) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 2.0])) out = add_remaining_self_loops(edge_index, edge_weight, fill_value='add') assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 0.5])) # Test with `edge_attr`: out = add_remaining_self_loops(edge_index, edge_attr, fill_value=torch.tensor([0., 1., 0.])) assert out[0].equal(expected) assert out[1].equal( torch.tensor([ [1., 0., 0.], [0., 1., 0.], [0., 0., 1.], [0., 1., 0.], ])) edge_index = EdgeIndex( edge_index, sparse_size=(2, 2), sort_order='row', is_undirected=True, ) out, _ = add_remaining_self_loops(edge_index) assert out.equal(expected) assert out.sparse_size() == (2, 2) assert out.sort_order is None assert out.is_undirected def test_add_remaining_self_loops_without_initial_loops(): edge_index = torch.tensor([[0, 1], [1, 0]]) edge_weight = torch.tensor([0.5, 0.5]) expected = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]]) out = add_remaining_self_loops(edge_index, edge_weight) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 1, 1])) out = add_remaining_self_loops(edge_index, edge_weight, fill_value=5) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 5.0, 5.0])) out = add_remaining_self_loops(edge_index, edge_weight, fill_value=torch.tensor(2.0)) assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 2.0, 2.0])) # Test string `fill_value`: out = add_remaining_self_loops(edge_index, edge_weight, fill_value='add') assert out[0].equal(expected) assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 0.5])) def test_get_self_loop_attr(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) edge_weight = torch.tensor([0.2, 0.3, 0.5]) full_loop_weight = get_self_loop_attr(edge_index, edge_weight) assert full_loop_weight.equal(torch.tensor([0.5, 0.0])) full_loop_weight = get_self_loop_attr(edge_index, edge_weight, num_nodes=4) assert full_loop_weight.equal(torch.tensor([0.5, 0.0, 0.0, 0.0])) full_loop_weight = get_self_loop_attr(edge_index) assert full_loop_weight.equal(torch.tensor([1.0, 0.0])) edge_attr = torch.tensor([[1.0, 0.0], [0.0, 1.0], [0.5, 1.0]]) full_loop_attr = get_self_loop_attr(edge_index, edge_attr) assert full_loop_attr.equal(torch.tensor([[0.5, 1.0], [0.0, 0.0]])) ================================================ FILE: test/utils/test_map.py ================================================ import pytest import torch from torch_geometric.profile import benchmark from torch_geometric.testing import withDevice, withPackage from torch_geometric.utils.map import map_index @withDevice @withPackage('pandas') @pytest.mark.parametrize('max_index', [3, 100_000_000]) def test_map_index(device, max_index): src = torch.tensor([2, 0, 1, 0, max_index], device=device) index = torch.tensor([max_index, 2, 0, 1], device=device) out, mask = map_index(src, index, inclusive=True) assert out.device == device assert mask is None assert out.tolist() == [1, 2, 3, 2, 0] @withDevice @withPackage('pandas') @pytest.mark.parametrize('max_index', [3, 100_000_000]) def test_map_index_na(device, max_index): src = torch.tensor([2, 0, 1, 0, max_index], device=device) index = torch.tensor([max_index, 2, 0], device=device) out, mask = map_index(src, index, inclusive=False) assert out.device == device assert mask.device == device assert out.tolist() == [1, 2, 2, 0] assert mask.tolist() == [True, True, False, True, True] if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() src = torch.randint(0, 100_000_000, (100_000, ), device=args.device) index = src.unique() def trivial_map(src, index, max_index, inclusive): if max_index is None: max_index = max(src.max(), index.max()) if inclusive: assoc = src.new_empty(max_index + 1) else: assoc = src.new_full((max_index + 1, ), -1) assoc[index] = torch.arange(index.numel(), device=index.device) out = assoc[src] if inclusive: return out, None else: mask = out != -1 return out[mask], mask print('Inclusive:') benchmark( funcs=[trivial_map, map_index], func_names=['trivial', 'map_index'], args=(src, index, None, True), num_steps=100, num_warmups=50, ) print('Exclusive:') benchmark( funcs=[trivial_map, map_index], func_names=['trivial', 'map_index'], args=(src, index[:50_000], None, False), num_steps=100, num_warmups=50, ) ================================================ FILE: test/utils/test_mask.py ================================================ import torch from torch_geometric.utils import index_to_mask, mask_select, mask_to_index def test_mask_select(): src = torch.randn(6, 8) mask = torch.tensor([False, True, False, True, False, True]) out = mask_select(src, 0, mask) assert out.size() == (3, 8) assert torch.equal(src[torch.tensor([1, 3, 5])], out) jit = torch.jit.script(mask_select) assert torch.equal(jit(src, 0, mask), out) def test_index_to_mask(): index = torch.tensor([1, 3, 5]) mask = index_to_mask(index) assert mask.tolist() == [False, True, False, True, False, True] mask = index_to_mask(index, size=7) assert mask.tolist() == [False, True, False, True, False, True, False] def test_mask_to_index(): mask = torch.tensor([False, True, False, True, False, True]) index = mask_to_index(mask) assert index.tolist() == [1, 3, 5] ================================================ FILE: test/utils/test_mesh_laplacian.py ================================================ import torch from torch_geometric.utils import get_mesh_laplacian def test_get_mesh_laplacian_of_cube(): pos = torch.tensor([ [1.0, 1.0, 1.0], [1.0, -1.0, 1.0], [-1.0, -1.0, 1.0], [-1.0, 1.0, 1.0], [1.0, 1.0, -1.0], [1.0, -1.0, -1.0], [-1.0, -1.0, -1.0], [-1.0, 1.0, -1.0], ]) face = torch.tensor([ [0, 1, 2], [0, 3, 2], [4, 5, 1], [4, 0, 1], [7, 6, 5], [7, 4, 5], [3, 2, 6], [3, 7, 6], [4, 0, 3], [4, 7, 3], [1, 5, 6], [1, 2, 6], ]) edge_index, edge_weight = get_mesh_laplacian(pos, face.t(), normalization='rw') assert edge_index.tolist() == [ [ 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 0, 1, 2, 3, 4, 5, 6, 7 ], [ 1, 2, 3, 4, 0, 2, 4, 5, 6, 0, 1, 3, 6, 0, 2, 4, 6, 7, 0, 1, 3, 5, 7, 1, 4, 6, 7, 1, 2, 3, 5, 7, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7 ], ] assert torch.allclose( edge_weight, torch.tensor([ 0.375, 0.0, 0.375, 0.375, 0.3, 0.3, 0.0, 0.3, 0.0, 0.0, 0.375, 0.375, 0.375, 0.3, 0.3, 0.0, 0.0, 0.3, 0.3, 0.0, 0.0, 0.3, 0.3, 0.375, 0.375, 0.375, 0.0, 0.0, 0.3, 0.0, 0.3, 0.3, 0.375, 0.375, 0.0, 0.375, -1.125, -0.9, -1.125, -0.9, -0.9, -1.125, -0.9, -1.125 ])) def test_get_mesh_laplacian_of_irregular_triangular_prism(): pos = torch.tensor([ [0.0, 0.0, 0.0], [4.0, 0.0, 0.0], [0.0, 0.0, -3.0], [1.0, 5.0, -1.0], [3.0, 5.0, -1.0], [2.0, 5.0, -2.0], ]) face = torch.tensor([ [0, 1, 2], [3, 4, 5], [0, 1, 4], [0, 3, 4], [1, 2, 5], [1, 4, 5], [2, 0, 3], [2, 5, 3], ]) edge_index, edge_weight = get_mesh_laplacian(pos, face.t(), normalization='rw') assert edge_index.tolist() == [ [ 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5 ], [ 1, 2, 3, 4, 0, 2, 4, 5, 0, 1, 3, 5, 0, 2, 4, 5, 0, 1, 3, 5, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5 ], ] assert torch.allclose( edge_weight, torch.tensor([ 0.09730332, 0.15039921, 0.05081503, 0.00000000, 0.08726977, 0.03521059, 0.05363689, 0.00723919, 0.14497279, 0.03784235, 0.01629947, 0.03438699, 0.08362866, 0.02782887, 0.24252312, 0.40727590, 0.00000000, 0.08728313, 0.21507657, 0.38582093, 0.01117009, 0.04936920, 0.34247482, 0.36583540, -0.29851755, -0.18335645, -0.23350160, -0.76125660, -0.68818060, -0.76884955 ])) ================================================ FILE: test/utils/test_negative_sampling.py ================================================ import torch from torch_geometric.utils import ( batched_negative_sampling, contains_self_loops, is_undirected, negative_sampling, structured_negative_sampling, structured_negative_sampling_feasible, to_undirected, ) from torch_geometric.utils._negative_sampling import ( edge_index_to_vector, vector_to_edge_index, ) def is_negative(edge_index, neg_edge_index, size, bipartite): adj = torch.zeros(size, dtype=torch.bool) neg_adj = torch.zeros(size, dtype=torch.bool) adj[edge_index[0], edge_index[1]] = True neg_adj[neg_edge_index[0], neg_edge_index[1]] = True if not bipartite: arange = torch.arange(size[0]) assert neg_adj[arange, arange].sum() == 0 return (adj & neg_adj).sum() == 0 def test_edge_index_to_vector_and_vice_versa(): # Create a fully-connected graph: N = 10 row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1) col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1) edge_index = torch.stack([row, col], dim=0) idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True) assert population == N * N assert idx.tolist() == list(range(population)) edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True) assert is_undirected(edge_index2) assert edge_index.tolist() == edge_index2.tolist() idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False) assert population == N * N - N assert idx.tolist() == list(range(population)) mask = edge_index[0] != edge_index[1] # Remove self-loops. edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False) assert is_undirected(edge_index2) assert edge_index[:, mask].tolist() == edge_index2.tolist() idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False, force_undirected=True) assert population == (N * (N + 1)) / 2 - N assert idx.tolist() == list(range(population)) mask = edge_index[0] != edge_index[1] # Remove self-loops. edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False, force_undirected=True) assert is_undirected(edge_index2) assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist() def test_negative_sampling(): edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) neg_edge_index = negative_sampling(edge_index) assert neg_edge_index.size(1) == edge_index.size(1) assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) neg_edge_index = negative_sampling(edge_index, method='dense') assert neg_edge_index.size(1) == edge_index.size(1) assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) neg_edge_index = negative_sampling(edge_index, num_neg_samples=2) assert neg_edge_index.size(1) == 2 assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) # Test with float multiplier less than 1 neg_edge_index = negative_sampling(edge_index, num_neg_samples=0.5) assert neg_edge_index.size(1) == 2 # 50% of 4 edges = 2 edges assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) # Test with float multiplier greater than 1 neg_edge_index = negative_sampling(edge_index, num_neg_samples=1.5) assert neg_edge_index.size(1) == 6 # 150% of 4 edges = 6 edges assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) edge_index = to_undirected(edge_index) neg_edge_index = negative_sampling(edge_index, force_undirected=True) assert neg_edge_index.size(1) == edge_index.size(1) - 1 assert is_undirected(neg_edge_index) assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) def test_bipartite_negative_sampling(): edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) neg_edge_index = negative_sampling(edge_index, num_nodes=(3, 4)) assert neg_edge_index.size(1) == edge_index.size(1) assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True) neg_edge_index = negative_sampling(edge_index, num_nodes=(3, 4), num_neg_samples=2) assert neg_edge_index.size(1) == 2 assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True) def test_batched_negative_sampling(): edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) edge_index = torch.cat([edge_index, edge_index + 4], dim=1) batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) neg_edge_index = batched_negative_sampling(edge_index, batch) assert neg_edge_index.size(1) <= edge_index.size(1) # Test with float multiplier less than 1 neg_edge_index = batched_negative_sampling(edge_index, batch, num_neg_samples=0.5) assert neg_edge_index.size(1) <= 4 # 50% of 8 edges = 4 edges # Test with float multiplier greater than 1 neg_edge_index = batched_negative_sampling(edge_index, batch, num_neg_samples=1.5) assert neg_edge_index.size(1) <= 12 # 150% of 8 edges = 12 edges adj = torch.zeros(8, 8, dtype=torch.bool) adj[edge_index[0], edge_index[1]] = True neg_adj = torch.zeros(8, 8, dtype=torch.bool) neg_adj[neg_edge_index[0], neg_edge_index[1]] = True assert (adj & neg_adj).sum() == 0 assert (adj | neg_adj).sum() == edge_index.size(1) + neg_edge_index.size(1) assert neg_adj[:4, 4:].sum() == 0 assert neg_adj[4:, :4].sum() == 0 def test_bipartite_batched_negative_sampling(): edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]]) edge_index2 = edge_index1 + torch.tensor([[2], [4]]) edge_index3 = edge_index2 + torch.tensor([[2], [4]]) edge_index = torch.cat([edge_index1, edge_index2, edge_index3], dim=1) src_batch = torch.tensor([0, 0, 1, 1, 2, 2]) dst_batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) neg_edge_index = batched_negative_sampling(edge_index, (src_batch, dst_batch)) assert neg_edge_index.size(1) <= edge_index.size(1) adj = torch.zeros(6, 12, dtype=torch.bool) adj[edge_index[0], edge_index[1]] = True neg_adj = torch.zeros(6, 12, dtype=torch.bool) neg_adj[neg_edge_index[0], neg_edge_index[1]] = True assert (adj & neg_adj).sum() == 0 assert (adj | neg_adj).sum() == edge_index.size(1) + neg_edge_index.size(1) def test_structured_negative_sampling(): edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) i, j, k = structured_negative_sampling(edge_index) assert i.size(0) == edge_index.size(1) assert j.size(0) == edge_index.size(1) assert k.size(0) == edge_index.size(1) adj = torch.zeros(4, 4, dtype=torch.bool) adj[i, j] = 1 neg_adj = torch.zeros(4, 4, dtype=torch.bool) neg_adj[i, k] = 1 assert (adj & neg_adj).sum() == 0 # Test with no self-loops: edge_index = torch.LongTensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]]) i, j, k = structured_negative_sampling(edge_index, num_nodes=4, contains_neg_self_loops=False) neg_edge_index = torch.vstack([i, k]) assert not contains_self_loops(neg_edge_index) def test_structured_negative_sampling_feasible(): edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]) assert not structured_negative_sampling_feasible(edge_index, 3, False) assert structured_negative_sampling_feasible(edge_index, 3, True) assert structured_negative_sampling_feasible(edge_index, 4, False) ================================================ FILE: test/utils/test_nested.py ================================================ import pytest import torch from torch_geometric.utils import from_nested_tensor, to_nested_tensor def test_to_nested_tensor(): x = torch.randn(5, 4, 3) out = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1])) out = out.to_padded_tensor(padding=0) assert out.size() == (2, 3, 4, 3) assert torch.allclose(out[0, :2], x[0:2]) assert torch.allclose(out[1, :3], x[2:5]) out = to_nested_tensor(x, ptr=torch.tensor([0, 2, 5])) out = out.to_padded_tensor(padding=0) assert out.size() == (2, 3, 4, 3) assert torch.allclose(out[0, :2], x[0:2]) assert torch.allclose(out[1, :3], x[2:5]) out = to_nested_tensor(x) out = out.to_padded_tensor(padding=0) assert out.size() == (1, 5, 4, 3) assert torch.allclose(out[0], x) def test_from_nested_tensor(): x = torch.randn(5, 4, 3) nested = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1])) out, batch = from_nested_tensor(nested, return_batch=True) assert torch.equal(x, out) assert batch.tolist() == [0, 0, 1, 1, 1] nested = torch.nested.nested_tensor([torch.randn(4, 3), torch.randn(5, 4)]) with pytest.raises(ValueError, match="have the same size in dimension 1"): from_nested_tensor(nested) # Test zero-copy: nested = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1])) out = from_nested_tensor(nested) out += 1 # Increment in-place (which should increment `nested` as well). assert torch.equal(nested.to_padded_tensor(padding=0)[0, :2], out[0:2]) assert torch.equal(nested.to_padded_tensor(padding=0)[1, :3], out[2:5]) def test_to_and_from_nested_tensor_autograd(): x = torch.randn(5, 4, 3, requires_grad=True) grad = torch.randn_like(x) out = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1])) out = from_nested_tensor(out) out.backward(grad) assert torch.equal(x.grad, grad) ================================================ FILE: test/utils/test_noise_scheduler.py ================================================ import pytest import torch from torch_geometric.utils.noise_scheduler import ( get_diffusion_beta_schedule, get_smld_sigma_schedule, ) def test_get_smld_sigma_schedule(): expected = torch.tensor([ 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 0.04641589, 0.02782559, 0.01668101, 0.01 ]) out = get_smld_sigma_schedule( sigma_min=0.01, sigma_max=1.0, num_scales=10, ) assert torch.allclose(out, expected) @pytest.mark.parametrize( 'schedule_type', ['linear', 'quadratic', 'constant', 'sigmoid'], ) def test_get_diffusion_beta_schedule(schedule_type): out = get_diffusion_beta_schedule( schedule_type, beta_start=0.1, beta_end=0.2, num_diffusion_timesteps=10, ) assert out.size() == (10, ) ================================================ FILE: test/utils/test_normalize_edge_index.py ================================================ import pytest import torch from torch_geometric.utils import normalize_edge_index @pytest.mark.parametrize('add_self_loops', [False, True]) @pytest.mark.parametrize('symmetric', [False, True]) def test_normalize_edge_index(add_self_loops: bool, symmetric: bool): edge_index = torch.tensor([[0, 2, 2, 3], [2, 0, 3, 0]]) out = normalize_edge_index( edge_index, add_self_loops=add_self_loops, symmetric=symmetric, ) assert isinstance(out, tuple) and len(out) == 2 if not add_self_loops: assert out[0].equal(edge_index) else: assert out[0].tolist() == [ [0, 2, 2, 3, 0, 1, 2, 3], [2, 0, 3, 0, 0, 1, 2, 3], ] assert out[1].min() >= 0.0 assert out[1].min() <= 1.0 ================================================ FILE: test/utils/test_normalized_cut.py ================================================ import torch from torch_geometric.testing import is_full_test from torch_geometric.utils import normalized_cut def test_normalized_cut(): row = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4]) col = torch.tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3]) edge_attr = torch.tensor( [3.0, 3.0, 6.0, 3.0, 6.0, 1.0, 3.0, 2.0, 1.0, 2.0]) expected = torch.tensor([4.0, 4.0, 5.0, 2.5, 5.0, 1.0, 2.5, 2.0, 1.0, 2.0]) out = normalized_cut(torch.stack([row, col], dim=0), edge_attr) assert torch.allclose(out, expected) if is_full_test(): jit = torch.jit.script(normalized_cut) out = jit(torch.stack([row, col], dim=0), edge_attr) assert torch.allclose(out, expected) ================================================ FILE: test/utils/test_num_nodes.py ================================================ import torch from torch_geometric.utils import to_torch_coo_tensor from torch_geometric.utils.num_nodes import ( maybe_num_nodes, maybe_num_nodes_dict, ) def test_maybe_num_nodes(): edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]) assert maybe_num_nodes(edge_index, 4) == 4 assert maybe_num_nodes(edge_index) == 3 adj = to_torch_coo_tensor(edge_index) assert maybe_num_nodes(adj, 4) == 4 assert maybe_num_nodes(adj) == 3 def test_maybe_num_nodes_dict(): edge_index_dict = { '1': torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]), '2': torch.tensor([[0, 0, 1, 3], [1, 2, 0, 4]]) } num_nodes_dict = {'2': 6} assert maybe_num_nodes_dict(edge_index_dict) == {'1': 3, '2': 5} assert maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) == { '1': 3, '2': 6, } ================================================ FILE: test/utils/test_one_hot.py ================================================ import torch from torch_geometric.utils import one_hot def test_one_hot(): index = torch.tensor([0, 1, 2]) out = one_hot(index) assert out.size() == (3, 3) assert out.dtype == torch.float assert out.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]] out = one_hot(index, num_classes=4, dtype=torch.long) assert out.size() == (3, 4) assert out.dtype == torch.long assert out.tolist() == [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]] ================================================ FILE: test/utils/test_ppr.py ================================================ import pytest import torch from torch_geometric.datasets import KarateClub from torch_geometric.testing import withPackage from torch_geometric.utils import get_ppr @withPackage('numba') @pytest.mark.parametrize('target', [None, torch.tensor([0, 4, 5, 6])]) def test_get_ppr(target): data = KarateClub()[0] edge_index, edge_weight = get_ppr( data.edge_index, alpha=0.1, eps=1e-5, target=target, ) assert edge_index.size(0) == 2 assert edge_index.size(1) == edge_weight.numel() min_row = 0 if target is None else target.min() max_row = data.num_nodes - 1 if target is None else target.max() assert edge_index[0].min() == min_row and edge_index[0].max() == max_row assert edge_index[1].min() >= 0 and edge_index[1].max() < data.num_nodes assert edge_weight.min() >= 0.0 and edge_weight.max() <= 1.0 ================================================ FILE: test/utils/test_random.py ================================================ import numpy as np import torch from torch_geometric.utils import ( barabasi_albert_graph, erdos_renyi_graph, stochastic_blockmodel_graph, ) def test_erdos_renyi_graph(): torch.manual_seed(1234) edge_index = erdos_renyi_graph(5, 0.2, directed=False) assert edge_index.tolist() == [ [0, 1, 1, 1, 2, 4], [1, 0, 2, 4, 1, 1], ] edge_index = erdos_renyi_graph(5, 0.5, directed=True) assert edge_index.tolist() == [ [1, 1, 2, 2, 3, 4, 4, 4], [0, 3, 0, 4, 0, 0, 1, 3], ] def test_stochastic_blockmodel_graph(): torch.manual_seed(12345) block_sizes = [2, 2, 4] edge_probs = [ [0.25, 0.05, 0.02], [0.05, 0.35, 0.07], [0.02, 0.07, 0.40], ] edge_index = stochastic_blockmodel_graph(block_sizes, edge_probs, directed=False) assert edge_index.tolist() == [ [2, 3, 4, 4, 5, 5, 6, 7, 7, 7], [3, 2, 5, 7, 4, 7, 7, 4, 5, 6], ] edge_index = stochastic_blockmodel_graph(block_sizes, edge_probs, directed=True) assert edge_index.tolist() == [ [0, 1, 3, 5, 6, 6, 7, 7], [3, 3, 2, 4, 4, 7, 5, 6], ] def test_barabasi_albert_graph(): torch.manual_seed(12345) np.random.seed(12345) edge_index = barabasi_albert_graph(num_nodes=8, num_edges=3) assert edge_index.size() == (2, 26) ================================================ FILE: test/utils/test_repeat.py ================================================ from torch_geometric.utils.repeat import repeat def test_repeat(): assert repeat(None, length=4) is None assert repeat(4, length=4) == [4, 4, 4, 4] assert repeat([2, 3, 4], length=4) == [2, 3, 4, 4] assert repeat([1, 2, 3, 4], length=4) == [1, 2, 3, 4] assert repeat([1, 2, 3, 4, 5], length=4) == [1, 2, 3, 4] ================================================ FILE: test/utils/test_scatter.py ================================================ from itertools import product import pytest import torch import torch_geometric.typing from torch_geometric.profile import benchmark from torch_geometric.testing import withCUDA, withDevice, withPackage from torch_geometric.utils import group_argsort, group_cat, scatter from torch_geometric.utils._scatter import scatter_argmax def test_scatter_validate(): src = torch.randn(100, 32) index = torch.randint(0, 10, (100, ), dtype=torch.long) with pytest.raises(ValueError, match="must be one-dimensional"): scatter(src, index.view(-1, 1)) with pytest.raises(ValueError, match="must lay between 0 and 1"): scatter(src, index, dim=2) with pytest.raises(ValueError, match="invalid `reduce` argument 'std'"): scatter(src, index, reduce='std') @withDevice @withPackage('torch_scatter') @pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max']) def test_scatter(reduce, device): import torch_scatter src = torch.randn(100, 16, device=device) index = torch.randint(0, 8, (100, ), device=device) if device.type == 'mps' and reduce in ['min', 'max']: with pytest.raises(NotImplementedError, match="for the MPS device"): scatter(src, index, dim=0, reduce=reduce) return out1 = scatter(src, index, dim=0, reduce=reduce) out2 = torch_scatter.scatter(src, index, dim=0, reduce=reduce) assert out1.device == device assert torch.allclose(out1, out2, atol=1e-6) jit = torch.jit.script(scatter) out3 = jit(src, index, dim=0, reduce=reduce) assert torch.allclose(out1, out3, atol=1e-6) src = torch.randn(8, 100, 16, device=device) out1 = scatter(src, index, dim=1, reduce=reduce) out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce) assert out1.device == device assert torch.allclose(out1, out2, atol=1e-6) @withDevice @pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max']) def test_scatter_backward(reduce, device): src = torch.randn(8, 100, 16, device=device, requires_grad=True) index = torch.randint(0, 8, (100, ), device=device) if device.type == 'mps' and reduce in ['min', 'max']: with pytest.raises(NotImplementedError, match="for the MPS device"): scatter(src, index, dim=1, reduce=reduce) return out = scatter(src, index, dim=1, reduce=reduce) assert src.grad is None out.mean().backward() assert src.grad is not None @withDevice def test_scatter_any(device): src = torch.randn(6, 4, device=device) index = torch.tensor([0, 0, 1, 1, 2, 2], device=device) out = scatter(src, index, dim=0, reduce='any') for i in range(3): for j in range(4): assert float(out[i, j]) in src[2 * i:2 * i + 2, j].tolist() @withDevice @pytest.mark.parametrize('num_groups', [4]) @pytest.mark.parametrize('descending', [False, True]) def test_group_argsort(num_groups, descending, device): src = torch.randn(20, device=device) index = torch.randint(0, num_groups, (20, ), device=device) out = group_argsort(src, index, 0, num_groups, descending=descending) expected = torch.empty_like(index) for i in range(num_groups): mask = index == i tmp = src[mask].argsort(descending=descending) perm = torch.empty_like(tmp) perm[tmp] = torch.arange(tmp.numel(), device=device) expected[mask] = perm assert torch.equal(out, expected) empty_tensor = torch.tensor([], device=device) out = group_argsort(empty_tensor, empty_tensor) assert out.numel() == 0 @withCUDA def test_scatter_argmax(device): src = torch.arange(5, device=device) index = torch.tensor([2, 2, 0, 0, 3], device=device) old_state = torch_geometric.typing.WITH_TORCH_SCATTER torch_geometric.typing.WITH_TORCH_SCATTER = False argmax = scatter_argmax(src, index, dim_size=6) torch_geometric.typing.WITH_TORCH_SCATTER = old_state assert argmax.tolist() == [3, 5, 1, 4, 5, 5] @withDevice def test_group_cat(device): x1 = torch.randn(4, 4, device=device) x2 = torch.randn(2, 4, device=device) index1 = torch.tensor([0, 0, 1, 2], device=device) index2 = torch.tensor([0, 2], device=device) expected = torch.cat([x1[:2], x2[:1], x1[2:4], x2[1:]], dim=0) out, index = group_cat( [x1, x2], [index1, index2], dim=0, return_index=True, ) assert torch.equal(out, expected) assert index.tolist() == [0, 0, 0, 1, 2, 2] if __name__ == '__main__': # Insights on GPU: # ================ # * "sum": Prefer `scatter_add_` implementation # * "mean": Prefer manual implementation via `scatter_add_` + `count` # * "min"/"max": # * Prefer `scatter_reduce_` implementation without gradients # * Prefer `torch_sparse` implementation with gradients # * "mul": Prefer `torch_sparse` implementation # # Insights on CPU: # ================ # * "sum": Prefer `scatter_add_` implementation # * "mean": Prefer manual implementation via `scatter_add_` + `count` # * "min"/"max": Prefer `scatter_reduce_` implementation # * "mul" (probably not worth branching for this): # * Prefer `scatter_reduce_` implementation without gradients # * Prefer `torch_sparse` implementation with gradients import argparse from torch_geometric.typing import WITH_TORCH_SCATTER, torch_scatter parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') parser.add_argument('--aggr', type=str, default='all') args = parser.parse_args() num_nodes_list = [4_000, 8_000, 16_000, 32_000, 64_000] if args.aggr == 'all': aggrs = ['sum', 'mean', 'min', 'max', 'mul'] else: aggrs = args.aggr.split(',') def pytorch_scatter(x, index, dim_size, reduce): if reduce == 'min' or reduce == 'max': reduce = f'a{aggr}' # `amin` or `amax` elif reduce == 'mul': reduce = 'prod' out = x.new_zeros(dim_size, x.size(-1)) include_self = reduce in ['sum', 'mean'] index = index.view(-1, 1).expand(-1, x.size(-1)) out.scatter_reduce_(0, index, x, reduce, include_self=include_self) return out def pytorch_index_add(x, index, dim_size, reduce): if reduce != 'sum': raise NotImplementedError out = x.new_zeros(dim_size, x.size(-1)) out.index_add_(0, index, x) return out def own_scatter(x, index, dim_size, reduce): return torch_scatter.scatter(x, index, dim=0, dim_size=num_nodes, reduce=reduce) def optimized_scatter(x, index, dim_size, reduce): return scatter(x, index, dim=0, dim_size=dim_size, reduce=reduce) for aggr, num_nodes in product(aggrs, num_nodes_list): num_edges = num_nodes * 50 print(f'aggr: {aggr}, #nodes: {num_nodes}, #edges: {num_edges}') x = torch.randn(num_edges, 64, device=args.device) index = torch.randint(num_nodes, (num_edges, ), device=args.device) funcs = [pytorch_scatter] func_names = ['PyTorch scatter_reduce'] if aggr == 'sum': funcs.append(pytorch_index_add) func_names.append('PyTorch index_add') if WITH_TORCH_SCATTER: funcs.append(own_scatter) func_names.append('torch_scatter') funcs.append(optimized_scatter) func_names.append('Optimized PyG Scatter') benchmark( funcs=funcs, func_names=func_names, args=(x, index, num_nodes, aggr), num_steps=100 if args.device == 'cpu' else 1000, num_warmups=50 if args.device == 'cpu' else 500, backward=args.backward, ) ================================================ FILE: test/utils/test_segment.py ================================================ from itertools import product import pytest import torch import torch_geometric.typing from torch_geometric.index import index2ptr from torch_geometric.profile import benchmark from torch_geometric.testing import withCUDA, withoutExtensions from torch_geometric.utils import scatter, segment, segment_logsumexp @withCUDA @withoutExtensions @pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max']) def test_segment(device, without_extensions, reduce): src = torch.randn(20, 16, device=device) ptr = torch.tensor([0, 0, 5, 10, 15, 20], device=device) if (not torch_geometric.typing.WITH_TORCH_SCATTER and not torch_geometric.typing.WITH_PT20): with pytest.raises(ImportError, match="requires the 'torch-scatter'"): segment(src, ptr, reduce=reduce) else: out = segment(src, ptr, reduce=reduce) expected = getattr(torch, reduce)(src.view(4, 5, -1), dim=1) expected = expected[0] if isinstance(expected, tuple) else expected assert torch.allclose(out[:1], torch.zeros(1, 16, device=device)) assert torch.allclose(out[1:], expected) @withCUDA @withoutExtensions def test_segment_logsumexp(device, without_extensions) -> None: src = torch.randn(5, 4, device=device) expected = src.logsumexp(dim=0) ptr = torch.tensor([0, 0, 5, 5], device=device) out = segment_logsumexp(src, ptr, dim=0) assert out.size() == (3, 4) assert out[0].abs().sum() == 0.0 assert torch.allclose(expected, out[1]) assert out[2].abs().sum() == 0.0 expected = src.logsumexp(dim=1) ptr = torch.tensor([0, 0, 4, 4], device=device) out = segment_logsumexp(src, ptr, dim=1) assert out.size() == (5, 3) assert out[:, 0].abs().sum() == 0.0 assert torch.allclose(expected, out[:, 1]) assert out[:, 2].abs().sum() == 0.0 if __name__ == '__main__': # Insights on GPU: # ================ # * "mean": Prefer `torch._segment_reduce` implementation # * others: Prefer `torch_scatter` implementation # # Insights on CPU: # ================ # * "all": Prefer `torch_scatter` implementation (but `scatter(...)` # implementation is far superior due to multi-threading usage. import argparse from torch_geometric.typing import WITH_TORCH_SCATTER, torch_scatter parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') parser.add_argument('--aggr', type=str, default='all') args = parser.parse_args() num_nodes_list = [4_000, 8_000, 16_000, 32_000, 64_000] if args.aggr == 'all': aggrs = ['sum', 'mean', 'min', 'max'] else: aggrs = args.aggr.split(',') def pytorch_segment(x, ptr, reduce): if reduce == 'min' or reduce == 'max': reduce = f'a{aggr}' # `amin` or `amax` return torch._segment_reduce(x, reduce, offsets=ptr) def own_segment(x, ptr, reduce): return torch_scatter.segment_csr(x, ptr, reduce=reduce) def optimized_scatter(x, index, reduce, dim_size): return scatter(x, index, dim=0, dim_size=dim_size, reduce=reduce) def optimized_segment(x, index, reduce): return segment(x, ptr, reduce=reduce) for aggr, num_nodes in product(aggrs, num_nodes_list): num_edges = num_nodes * 50 print(f'aggr: {aggr}, #nodes: {num_nodes}, #edges: {num_edges}') x = torch.randn(num_edges, 64, device=args.device) index = torch.randint(num_nodes, (num_edges, ), device=args.device) index, _ = index.sort() ptr = index2ptr(index, size=num_nodes) funcs = [pytorch_segment] func_names = ['PyTorch segment_reduce'] arg_list = [(x, ptr, aggr)] if WITH_TORCH_SCATTER: funcs.append(own_segment) func_names.append('torch_scatter') arg_list.append((x, ptr, aggr)) funcs.append(optimized_scatter) func_names.append('Optimized PyG Scatter') arg_list.append((x, index, aggr, num_nodes)) funcs.append(optimized_segment) func_names.append('Optimized PyG Segment') arg_list.append((x, ptr, aggr)) benchmark( funcs=funcs, func_names=func_names, args=arg_list, num_steps=100 if args.device == 'cpu' else 1000, num_warmups=50 if args.device == 'cpu' else 500, backward=args.backward, ) ================================================ FILE: test/utils/test_select.py ================================================ import torch from torch_geometric.utils import narrow, select def test_select(): src = torch.randn(5, 3) index = torch.tensor([0, 2, 4]) mask = torch.tensor([True, False, True, False, True]) out = select(src, index, dim=0) assert torch.equal(out, src[index]) assert torch.equal(out, select(src, mask, dim=0)) assert torch.equal(out, torch.tensor(select(src.tolist(), index, dim=0))) assert torch.equal(out, torch.tensor(select(src.tolist(), mask, dim=0))) def test_narrow(): src = torch.randn(5, 3) out = narrow(src, dim=0, start=2, length=2) assert torch.equal(out, src[2:4]) assert torch.equal(out, torch.tensor(narrow(src.tolist(), 0, 2, 2))) ================================================ FILE: test/utils/test_smiles.py ================================================ import pytest from torch_geometric.testing import withPackage from torch_geometric.utils import from_rdmol, from_smiles, to_rdmol, to_smiles smiles = [ r'F/C=C/F', r'F/C=C\F', r'F/C=C\F', (r'COc1cccc([C@@H]2Oc3ccc(OC)cc3/C(=N/OC[C@@H](C)[C@H](OCc3ccccc3)' r'C(C)C)[C@@H]2O)c1'), r'C/C(=C\C(=O)c1ccc(C)o1)Nc1ccc2c(c1)OCO2', r'F[B-](F)(F)c1cnc2ccccc2c1', r'COC(=O)[C@@]1(Cc2ccccc2)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@H]2CN=C(SC)N21', (r'O=C(O)c1ccc(NS(=O)(=O)c2ccc3c(c2)C(=O)c2cc(S(=O)(=O)Nc4ccc(C(=O)O)' r'cc4)ccc2-3)cc1'), ] @withPackage('rdkit') @pytest.mark.parametrize('smiles', smiles) def test_from_to_smiles(smiles): data = from_smiles(smiles) assert to_smiles(data) == smiles @withPackage('rdkit') @pytest.mark.parametrize('smiles', smiles) def test_from_to_rdmol(smiles): from rdkit import Chem mol1 = Chem.MolFromSmiles(smiles) data = from_rdmol(mol1) mol2 = to_rdmol(data) assert Chem.MolToSmiles(mol1) == Chem.MolToSmiles(mol2) ================================================ FILE: test/utils/test_softmax.py ================================================ import pytest import torch import torch_geometric.typing from torch_geometric.profile import benchmark from torch_geometric.utils import softmax CALCULATION_VIA_PTR_AVAILABLE = (torch_geometric.typing.WITH_SOFTMAX or torch_geometric.typing.WITH_TORCH_SCATTER) def test_softmax(): src = torch.tensor([1., 1., 1., 1.]) index = torch.tensor([0, 0, 1, 2]) ptr = torch.tensor([0, 2, 3, 4]) out = softmax(src, index) assert out.tolist() == [0.5, 0.5, 1, 1] assert softmax(src, ptr=ptr).tolist() == out.tolist() src = src.view(-1, 1) out = softmax(src, index) assert out.tolist() == [[0.5], [0.5], [1], [1]] assert softmax(src, ptr=ptr).tolist() == out.tolist() jit = torch.jit.script(softmax) assert torch.allclose(jit(src, index), out) def test_softmax_backward(): src_sparse = torch.rand(4, 8) index = torch.tensor([0, 0, 1, 1]) src_dense = src_sparse.clone().view(2, 2, src_sparse.size(-1)) src_sparse.requires_grad_(True) src_dense.requires_grad_(True) out_sparse = softmax(src_sparse, index) out_sparse.mean().backward() out_dense = src_dense.softmax(dim=1) out_dense.mean().backward() assert torch.allclose(out_sparse, out_dense.view_as(out_sparse)) assert torch.allclose(src_sparse.grad, src_dense.grad.view_as(src_sparse)) def test_softmax_dim(): index = torch.tensor([0, 0, 0, 0]) ptr = torch.tensor([0, 4]) src = torch.randn(4) assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0)) assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0)) src = torch.randn(4, 16) assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0)) assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0)) src = torch.randn(4, 4) assert torch.allclose(softmax(src, index, dim=-1), src.softmax(dim=-1)) if CALCULATION_VIA_PTR_AVAILABLE: assert torch.allclose(softmax(src, ptr=ptr, dim=-1), src.softmax(-1)) else: with pytest.raises(ImportError, match="requires the 'torch-scatter'"): softmax(src, ptr=ptr, dim=-1) src = torch.randn(4, 4, 16) assert torch.allclose(softmax(src, index, dim=1), src.softmax(dim=1)) if CALCULATION_VIA_PTR_AVAILABLE: assert torch.allclose(softmax(src, ptr=ptr, dim=1), src.softmax(dim=1)) else: with pytest.raises(ImportError, match="requires the 'torch-scatter'"): softmax(src, ptr=ptr, dim=1) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() num_nodes, num_edges = 10_000, 200_000 x = torch.randn(num_edges, 64, device=args.device) index = torch.randint(num_nodes, (num_edges, ), device=args.device) compiled_softmax = torch.compile(softmax) def dense_softmax(x, index): x = x.view(num_nodes, -1, x.size(-1)) return x.softmax(dim=-1) benchmark( funcs=[dense_softmax, softmax, compiled_softmax], func_names=['Dense Softmax', 'Vanilla', 'Compiled'], args=(x, index), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) ================================================ FILE: test/utils/test_sort_edge_index.py ================================================ from typing import List, Optional, Tuple import torch from torch import Tensor import torch_geometric.typing from torch_geometric.utils import sort_edge_index def test_sort_edge_index(): edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) edge_attr = torch.tensor([[1], [2], [3], [4]]) out = sort_edge_index(edge_index) assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] torch_geometric.typing.MAX_INT64 = 1 out = sort_edge_index(edge_index) torch_geometric.typing.MAX_INT64 = torch.iinfo(torch.int64).max assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] out = sort_edge_index((edge_index[0], edge_index[1])) assert isinstance(out, tuple) assert out[0].tolist() == [0, 1, 1, 2] assert out[1].tolist() == [1, 0, 2, 1] out = sort_edge_index(edge_index, None) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1] is None out = sort_edge_index(edge_index, edge_attr) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1].tolist() == [[4], [3], [2], [1]] out = sort_edge_index(edge_index, [edge_attr, edge_attr.view(-1)]) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1][0].tolist() == [[4], [3], [2], [1]] assert out[1][1].tolist() == [4, 3, 2, 1] def test_sort_edge_index_jit(): @torch.jit.script def wrapper1(edge_index: Tensor) -> Tensor: return sort_edge_index(edge_index) @torch.jit.script def wrapper2( edge_index: Tensor, edge_attr: Optional[Tensor], ) -> Tuple[Tensor, Optional[Tensor]]: return sort_edge_index(edge_index, edge_attr) @torch.jit.script def wrapper3( edge_index: Tensor, edge_attr: List[Tensor], ) -> Tuple[Tensor, List[Tensor]]: return sort_edge_index(edge_index, edge_attr) edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) edge_attr = torch.tensor([[1], [2], [3], [4]]) out = wrapper1(edge_index) assert out.size() == edge_index.size() out = wrapper2(edge_index, None) assert out[0].size() == edge_index.size() assert out[1] is None out = wrapper2(edge_index, edge_attr) assert out[0].size() == edge_index.size() assert out[1].size() == edge_attr.size() out = wrapper3(edge_index, [edge_attr, edge_attr.view(-1)]) assert out[0].size() == edge_index.size() assert len(out[1]) == 2 assert out[1][0].size() == edge_attr.size() assert out[1][1].size() == edge_attr.view(-1).size() ================================================ FILE: test/utils/test_sparse.py ================================================ import os.path as osp import pytest import torch import torch_geometric.typing from torch_geometric.io import fs from torch_geometric.profile import benchmark from torch_geometric.testing import is_full_test, withCUDA, withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import ( dense_to_sparse, is_sparse, is_torch_sparse_tensor, to_edge_index, to_torch_coo_tensor, to_torch_csc_tensor, to_torch_csr_tensor, to_torch_sparse_tensor, ) from torch_geometric.utils.sparse import cat def test_dense_to_sparse(): adj = torch.tensor([ [3.0, 1.0], [2.0, 0.0], ]) edge_index, edge_attr = dense_to_sparse(adj) assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]] assert edge_attr.tolist() == [3, 1, 2] if is_full_test(): jit = torch.jit.script(dense_to_sparse) edge_index, edge_attr = jit(adj) assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]] assert edge_attr.tolist() == [3, 1, 2] adj = torch.tensor([[ [3.0, 1.0], [2.0, 0.0], ], [ [0.0, 1.0], [0.0, 2.0], ]]) edge_index, edge_attr = dense_to_sparse(adj) assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]] assert edge_attr.tolist() == [3, 1, 2, 1, 2] if is_full_test(): jit = torch.jit.script(dense_to_sparse) edge_index, edge_attr = jit(adj) assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]] assert edge_attr.tolist() == [3, 1, 2, 1, 2] adj = torch.tensor([ [ [3.0, 1.0, 0.0], [2.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], [ [0.0, 1.0, 0.0], [0.0, 2.0, 3.0], [0.0, 5.0, 0.0], ], ]) mask = torch.tensor([[True, True, False], [True, True, True]]) edge_index, edge_attr = dense_to_sparse(adj, mask) assert edge_index.tolist() == [[0, 0, 1, 2, 3, 3, 4], [0, 1, 0, 3, 3, 4, 3]] assert edge_attr.tolist() == [3, 1, 2, 1, 2, 3, 5] if is_full_test(): jit = torch.jit.script(dense_to_sparse) edge_index, edge_attr = jit(adj, mask) assert edge_index.tolist() == [[0, 0, 1, 2, 3, 3, 4], [0, 1, 0, 3, 3, 4, 3]] assert edge_attr.tolist() == [3, 1, 2, 1, 2, 3, 5] def test_dense_to_sparse_bipartite(): edge_index, edge_attr = dense_to_sparse(torch.rand(2, 10, 5)) assert edge_index[0].max() == 19 assert edge_index[1].max() == 9 def test_is_torch_sparse_tensor(): x = torch.randn(5, 5) assert not is_torch_sparse_tensor(x) assert is_torch_sparse_tensor(x.to_sparse()) if torch_geometric.typing.WITH_TORCH_SPARSE: assert not is_torch_sparse_tensor(SparseTensor.from_dense(x)) def test_is_sparse(): x = torch.randn(5, 5) assert not is_sparse(x) assert is_sparse(x.to_sparse()) if torch_geometric.typing.WITH_TORCH_SPARSE: assert is_sparse(SparseTensor.from_dense(x)) def test_to_torch_coo_tensor(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2], ]) edge_attr = torch.randn(edge_index.size(1), 8) adj = to_torch_coo_tensor(edge_index, is_coalesced=False) assert adj.is_coalesced() assert adj.size() == (4, 4) assert adj.layout == torch.sparse_coo assert torch.allclose(adj.indices(), edge_index) adj = to_torch_coo_tensor(edge_index, is_coalesced=True) assert adj.is_coalesced() assert adj.size() == (4, 4) assert adj.layout == torch.sparse_coo assert torch.allclose(adj.indices(), edge_index) adj = to_torch_coo_tensor(edge_index, size=6) assert adj.size() == (6, 6) assert adj.layout == torch.sparse_coo assert torch.allclose(adj.indices(), edge_index) adj = to_torch_coo_tensor(edge_index, edge_attr) assert adj.size() == (4, 4, 8) assert adj.layout == torch.sparse_coo assert torch.allclose(adj.indices(), edge_index) assert torch.allclose(adj.values(), edge_attr) if is_full_test(): jit = torch.jit.script(to_torch_coo_tensor) adj = jit(edge_index, edge_attr) assert adj.size() == (4, 4, 8) assert adj.layout == torch.sparse_coo assert torch.allclose(adj.indices(), edge_index) assert torch.allclose(adj.values(), edge_attr) def test_to_torch_csr_tensor(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2], ]) adj = to_torch_csr_tensor(edge_index) assert adj.size() == (4, 4) assert adj.layout == torch.sparse_csr assert torch.allclose(adj.to_sparse_coo().coalesce().indices(), edge_index) edge_weight = torch.randn(edge_index.size(1)) adj = to_torch_csr_tensor(edge_index, edge_weight) assert adj.size() == (4, 4) assert adj.layout == torch.sparse_csr coo = adj.to_sparse_coo().coalesce() assert torch.allclose(coo.indices(), edge_index) assert torch.allclose(coo.values(), edge_weight) if torch_geometric.typing.WITH_PT20: edge_attr = torch.randn(edge_index.size(1), 8) adj = to_torch_csr_tensor(edge_index, edge_attr) assert adj.size() == (4, 4, 8) assert adj.layout == torch.sparse_csr coo = adj.to_sparse_coo().coalesce() assert torch.allclose(coo.indices(), edge_index) assert torch.allclose(coo.values(), edge_attr) def test_to_torch_csc_tensor(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2], ]) adj = to_torch_csc_tensor(edge_index) assert adj.size() == (4, 4) assert adj.layout == torch.sparse_csc adj_coo = adj.to_sparse_coo().coalesce() if torch_geometric.typing.WITH_PT20: assert torch.allclose(adj_coo.indices(), edge_index) else: assert torch.allclose(adj_coo.indices().flip([0]), edge_index) edge_weight = torch.randn(edge_index.size(1)) adj = to_torch_csc_tensor(edge_index, edge_weight) assert adj.size() == (4, 4) assert adj.layout == torch.sparse_csc adj_coo = adj.to_sparse_coo().coalesce() if torch_geometric.typing.WITH_PT20: assert torch.allclose(adj_coo.indices(), edge_index) assert torch.allclose(adj_coo.values(), edge_weight) else: perm = adj_coo.indices()[0].argsort() assert torch.allclose(adj_coo.indices()[:, perm], edge_index) assert torch.allclose(adj_coo.values()[perm], edge_weight) if torch_geometric.typing.WITH_PT20: edge_attr = torch.randn(edge_index.size(1), 8) adj = to_torch_csc_tensor(edge_index, edge_attr) assert adj.size() == (4, 4, 8) assert adj.layout == torch.sparse_csc assert torch.allclose(adj.to_sparse_coo().coalesce().indices(), edge_index) assert torch.allclose(adj.to_sparse_coo().coalesce().values(), edge_attr) @withPackage('torch>=2.1.0') def test_to_torch_coo_tensor_save_load(tmp_path): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2], ]) adj = to_torch_coo_tensor(edge_index, is_coalesced=False) assert adj.is_coalesced() path = osp.join(tmp_path, 'adj.t') torch.save(adj, path) adj = fs.torch_load(path) assert adj.is_coalesced() def test_to_edge_index(): adj = torch.tensor([ [0., 1., 0., 0.], [1., 0., 1., 0.], [0., 1., 0., 1.], [0., 0., 1., 0.], ]).to_sparse() edge_index, edge_attr = to_edge_index(adj) assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.] if is_full_test(): jit = torch.jit.script(to_edge_index) edge_index, edge_attr = jit(adj) assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.] @withCUDA @pytest.mark.parametrize( 'layout', [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc], ) @pytest.mark.parametrize('dim', [0, 1, (0, 1)]) def test_cat(layout, dim, device): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device) if torch_geometric.typing.WITH_PT20: edge_weight = torch.rand(4, 2, device=device) else: edge_weight = torch.rand(4, device=device) adj = to_torch_sparse_tensor(edge_index, edge_weight, layout=layout) out = cat([adj, adj], dim=dim) edge_index, edge_weight = to_edge_index(out.to_sparse_csr()) if dim == 0: if torch_geometric.typing.WITH_PT20: assert out.size() == (6, 3, 2) else: assert out.size() == (6, 3) assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5] assert edge_index[1].tolist() == [1, 0, 2, 1, 1, 0, 2, 1] elif dim == 1: if torch_geometric.typing.WITH_PT20: assert out.size() == (3, 6, 2) else: assert out.size() == (3, 6) assert edge_index[0].tolist() == [0, 0, 1, 1, 1, 1, 2, 2] assert edge_index[1].tolist() == [1, 4, 0, 2, 3, 5, 1, 4] else: if torch_geometric.typing.WITH_PT20: assert out.size() == (6, 6, 2) else: assert out.size() == (6, 6) assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5] assert edge_index[1].tolist() == [1, 0, 2, 1, 4, 3, 5, 4] if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() num_nodes, num_edges = 10_000, 200_000 edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device) benchmark( funcs=[ SparseTensor.from_edge_index, to_torch_coo_tensor, to_torch_csr_tensor, to_torch_csc_tensor ], func_names=['SparseTensor', 'To COO', 'To CSR', 'To CSC'], args=(edge_index, None, (num_nodes, num_nodes)), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, ) ================================================ FILE: test/utils/test_spmm.py ================================================ import itertools import warnings import pytest import torch from torch import Tensor import torch_geometric.typing from torch_geometric import EdgeIndex from torch_geometric.profile import benchmark from torch_geometric.testing import withCUDA, withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import spmm, to_torch_coo_tensor @withCUDA @pytest.mark.parametrize('reduce', ['sum', 'mean']) def test_spmm_basic(device, reduce): src = torch.randn(5, 4, device=device) other = torch.randn(4, 8, device=device) out1 = (src @ other) / (src.size(1) if reduce == 'mean' else 1) out2 = spmm(src.to_sparse_csr(), other, reduce=reduce) assert out1.size() == (5, 8) assert torch.allclose(out1, out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce) assert torch.allclose(out2, out3, atol=1e-6) # Test `mean` reduction with isolated nodes: src[0] = 0. out1 = (src @ other) / (4. if reduce == 'mean' else 1.) out2 = spmm(src.to_sparse_csr(), other, reduce=reduce) assert out1.size() == (5, 8) assert torch.allclose(out1, out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce) assert torch.allclose(out2, out3, atol=1e-6) @withCUDA @withPackage('torch>=2.0.0') @pytest.mark.parametrize('reduce', ['min', 'max']) def test_spmm_reduce(device, reduce): src = torch.randn(5, 4, device=device) other = torch.randn(4, 8, device=device) if src.is_cuda: with pytest.raises(NotImplementedError, match="not yet supported"): spmm(src.to_sparse_csr(), other, reduce) else: out1 = spmm(src.to_sparse_csr(), other, reduce) assert out1.size() == (5, 8) if torch_geometric.typing.WITH_TORCH_SPARSE: out2 = spmm(SparseTensor.from_dense(src), other, reduce=reduce) assert torch.allclose(out1, out2) @withCUDA @withPackage('torch>=2.0.0') @pytest.mark.parametrize( 'layout', [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc]) @pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max']) def test_spmm_layout(device, layout, reduce): src = torch.randn(5, 4, device=device) if layout == torch.sparse_coo: src = src.to_sparse_coo() elif layout == torch.sparse_csr: src = src.to_sparse_csr() else: assert layout == torch.sparse_csc src = src.to_sparse_csc() other = torch.randn(4, 8, device=device) if src.is_cuda and reduce in {'min', 'max'}: with pytest.raises(NotImplementedError, match="not yet supported"): spmm(src, other, reduce=reduce) elif layout != torch.sparse_csr: with pytest.warns(UserWarning, match="Converting sparse tensor"): spmm(src, other, reduce=reduce) else: spmm(src, other, reduce=reduce) @pytest.mark.parametrize('reduce', ['sum', 'mean']) def test_spmm_jit(reduce): @torch.jit.script def jit_torch_sparse(src: SparseTensor, other: Tensor, reduce: str) -> Tensor: return spmm(src, other, reduce=reduce) @torch.jit.script def jit_torch(src: Tensor, other: Tensor, reduce: str) -> Tensor: return spmm(src, other, reduce=reduce) src = torch.randn(5, 4) other = torch.randn(4, 8) out1 = src @ other out2 = jit_torch(src.to_sparse_csr(), other, reduce) assert out1.size() == (5, 8) if reduce == 'sum': assert torch.allclose(out1, out2, atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: out3 = jit_torch_sparse(SparseTensor.from_dense(src), other, reduce) assert torch.allclose(out2, out3, atol=1e-6) @withCUDA @withPackage('torch>=2.0.0') @pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max']) def test_spmm_edge_index(device, reduce): src = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(4, 3), sort_order='row', device=device, ) other = torch.rand(3, 4, device=device) out = spmm(src, other, reduce=reduce) assert out.size() == (4, 4) if not other.is_cuda or reduce not in ['min', 'max']: out2 = spmm(src.to_sparse_csr(), other, reduce=reduce) assert torch.allclose(out, out2) if __name__ == '__main__': import argparse warnings.filterwarnings('ignore', ".*Sparse CSR tensor support.*") warnings.filterwarnings('ignore', ".*Converting sparse tensor to CSR.*") parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') args = parser.parse_args() num_nodes, num_edges = 10_000, 200_000 x = torch.randn(num_nodes, 64, device=args.device) edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device) reductions = ['sum', 'mean'] if not x.is_cuda: reductions.extend(['min', 'max']) layouts = [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc] for reduce, layout in itertools.product(reductions, layouts): print(f'Aggregator: {reduce}, Layout: {layout}') adj = to_torch_coo_tensor(edge_index, size=num_nodes) adj = adj.to_sparse(layout=layout) benchmark( funcs=[spmm], func_names=['spmm'], args=(adj, x, reduce), num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, ) ================================================ FILE: test/utils/test_subgraph.py ================================================ import torch from torch_geometric.nn import GCNConv, Linear from torch_geometric.testing import withDevice, withPackage from torch_geometric.utils import ( bipartite_subgraph, get_num_hops, index_to_mask, k_hop_subgraph, subgraph, ) def test_get_num_hops(): class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(3, 16, normalize=False) self.conv2 = GCNConv(16, 16, normalize=False) self.lin = Linear(16, 2) def forward(self, x, edge_index): x = torch.F.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return self.lin(x) assert get_num_hops(GNN()) == 2 def test_subgraph(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5], ]) edge_attr = torch.tensor( [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]) idx = torch.tensor([3, 4, 5]) mask = index_to_mask(idx, 7) indices = idx.tolist() for subset in [idx, mask, indices]: out = subgraph(subset, edge_index, edge_attr, return_edge_mask=True) assert out[0].tolist() == [[3, 4, 4, 5], [4, 3, 5, 4]] assert out[1].tolist() == [7.0, 8.0, 9.0, 10.0] assert out[2].tolist() == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0] out = subgraph(subset, edge_index, edge_attr, relabel_nodes=True) assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert out[1].tolist() == [7, 8, 9, 10] @withDevice @withPackage('pandas') def test_subgraph_large_index(device): subset = torch.tensor([50_000_000], device=device) edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device) edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True) assert edge_index.tolist() == [[0], [0]] def test_bipartite_subgraph(): edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6], [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]]) edge_attr = torch.tensor( [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]) idx = (torch.tensor([2, 3, 5]), torch.tensor([2, 3])) mask = (index_to_mask(idx[0], 7), index_to_mask(idx[1], 4)) indices = (idx[0].tolist(), idx[1].tolist()) mixed = (mask[0], idx[1]) for subset in [idx, mask, indices, mixed]: out = bipartite_subgraph(subset, edge_index, edge_attr, return_edge_mask=True) assert out[0].tolist() == [[2, 3, 5, 5], [3, 2, 2, 3]] assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0] assert out[2].tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0] out = bipartite_subgraph(subset, edge_index, edge_attr, relabel_nodes=True) assert out[0].tolist() == [[0, 1, 2, 2], [1, 0, 0, 1]] assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0] @withDevice @withPackage('pandas') def test_bipartite_subgraph_large_index(device): subset = torch.tensor([50_000_000], device=device) edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device) edge_index, _ = bipartite_subgraph( (subset, subset), edge_index, relabel_nodes=True, ) assert edge_index.tolist() == [[0], [0]] def test_k_hop_subgraph(): edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5], [2, 2, 4, 4, 6, 6], ]) subset, edge_index, mapping, edge_mask = k_hop_subgraph( node_idx=6, num_hops=2, edge_index=edge_index, relabel_nodes=True, ) assert subset.tolist() == [2, 3, 4, 5, 6] assert edge_index.tolist() == [[0, 1, 2, 3], [2, 2, 4, 4]] assert mapping.tolist() == [4] assert edge_mask.tolist() == [False, False, True, True, True, True] edge_index = torch.tensor([ [1, 2, 4, 5], [0, 1, 5, 6], ]) subset, edge_index, mapping, edge_mask = k_hop_subgraph( node_idx=[0, 6], num_hops=2, edge_index=edge_index, relabel_nodes=True, ) assert subset.tolist() == [0, 1, 2, 4, 5, 6] assert edge_index.tolist() == [[1, 2, 3, 4], [0, 1, 4, 5]] assert mapping.tolist() == [0, 5] assert edge_mask.tolist() == [True, True, True, True] edge_index = torch.tensor([ [0, 1, 2, 3, 4, 4, 5], [2, 2, 4, 4, 2, 6, 6], ]) subset, edge_index, mapping, edge_mask = k_hop_subgraph( node_idx=6, num_hops=2, edge_index=edge_index, relabel_nodes=False, directed=True, ) assert subset.tolist() == [2, 3, 4, 5, 6] assert edge_index.tolist() == [[2, 3, 4, 5], [4, 4, 6, 6]] assert mapping.tolist() == [4] assert edge_mask.tolist() == [False, False, True, True, False, True, True] ================================================ FILE: test/utils/test_to_dense_adj.py ================================================ import torch from torch_geometric.testing import is_full_test from torch_geometric.utils import to_dense_adj def test_to_dense_adj(): edge_index = torch.tensor([ [0, 0, 1, 2, 3, 4], [0, 1, 0, 3, 4, 2], ]) batch = torch.tensor([0, 0, 1, 1, 1]) adj = to_dense_adj(edge_index, batch) assert adj.size() == (2, 3, 3) assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]] assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]] if is_full_test(): jit = torch.jit.script(to_dense_adj) adj = jit(edge_index, batch) assert adj.size() == (2, 3, 3) assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]] assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]] adj = to_dense_adj(edge_index, batch, max_num_nodes=2) assert adj.size() == (2, 2, 2) assert adj[0].tolist() == [[1, 1], [1, 0]] assert adj[1].tolist() == [[0, 1], [0, 0]] adj = to_dense_adj(edge_index, batch, max_num_nodes=5) assert adj.size() == (2, 5, 5) assert adj[0][:3, :3].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]] assert adj[1][:3, :3].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]] edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) adj = to_dense_adj(edge_index, batch, edge_attr) assert adj.size() == (2, 3, 3) assert adj[0].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]] assert adj[1].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]] adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5) assert adj.size() == (2, 5, 5) assert adj[0][:3, :3].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]] assert adj[1][:3, :3].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]] edge_attr = edge_attr.view(-1, 1) adj = to_dense_adj(edge_index, batch, edge_attr) assert adj.size() == (2, 3, 3, 1) edge_attr = edge_attr.view(-1, 1) adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5) assert adj.size() == (2, 5, 5, 1) adj = to_dense_adj(edge_index) assert adj.size() == (1, 5, 5) assert adj[0].nonzero(as_tuple=False).t().tolist() == edge_index.tolist() adj = to_dense_adj(edge_index, max_num_nodes=10) assert adj.size() == (1, 10, 10) assert adj[0].nonzero(as_tuple=False).t().tolist() == edge_index.tolist() adj = to_dense_adj(edge_index, batch, batch_size=4) assert adj.size() == (4, 3, 3) def test_to_dense_adj_with_empty_edge_index(): edge_index = torch.tensor([[], []], dtype=torch.long) batch = torch.tensor([0, 0, 1, 1, 1]) adj = to_dense_adj(edge_index) assert adj.size() == (1, 0, 0) adj = to_dense_adj(edge_index, max_num_nodes=10) assert adj.size() == (1, 10, 10) and adj.sum() == 0 adj = to_dense_adj(edge_index, batch) assert adj.size() == (2, 3, 3) and adj.sum() == 0 adj = to_dense_adj(edge_index, batch, max_num_nodes=10) assert adj.size() == (2, 10, 10) and adj.sum() == 0 def test_to_dense_adj_with_duplicate_entries(): edge_index = torch.tensor([ [0, 0, 0, 1, 2, 3, 3, 4], [0, 0, 1, 0, 3, 4, 4, 2], ]) batch = torch.tensor([0, 0, 1, 1, 1]) adj = to_dense_adj(edge_index, batch) assert adj.size() == (2, 3, 3) assert adj[0].tolist() == [[2, 1, 0], [1, 0, 0], [0, 0, 0]] assert adj[1].tolist() == [[0, 1, 0], [0, 0, 2], [1, 0, 0]] edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) adj = to_dense_adj(edge_index, batch, edge_attr) assert adj.size() == (2, 3, 3) assert adj[0].tolist() == [ [3.0, 3.0, 0.0], [4.0, 0.0, 0.0], [0.0, 0.0, 0.0], ] assert adj[1].tolist() == [ [0.0, 5.0, 0.0], [0.0, 0.0, 13.0], [8.0, 0.0, 0.0], ] ================================================ FILE: test/utils/test_to_dense_batch.py ================================================ from typing import Tuple import pytest import torch from torch import Tensor from torch_geometric.experimental import set_experimental_mode from torch_geometric.testing import onlyFullTest from torch_geometric.utils import to_dense_batch @pytest.mark.parametrize('fill', [70.0, torch.tensor(49.0)]) def test_to_dense_batch(fill): x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) batch = torch.tensor([0, 0, 1, 2, 2, 2]) item = fill.item() if isinstance(fill, Tensor) else fill expected = torch.tensor([ [[1.0, 2.0], [3.0, 4.0], [item, item]], [[5.0, 6.0], [item, item], [item, item]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], ]) out, mask = to_dense_batch(x, batch, fill_value=fill) assert out.size() == (3, 3, 2) assert torch.equal(out, expected) assert mask.tolist() == [[1, 1, 0], [1, 0, 0], [1, 1, 1]] out, mask = to_dense_batch(x, batch, max_num_nodes=2, fill_value=fill) assert out.size() == (3, 2, 2) assert torch.equal(out, expected[:, :2]) assert mask.tolist() == [[1, 1], [1, 0], [1, 1]] out, mask = to_dense_batch(x, batch, max_num_nodes=5, fill_value=fill) assert out.size() == (3, 5, 2) assert torch.equal(out[:, :3], expected) assert mask.tolist() == [[1, 1, 0, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 0, 0]] out, mask = to_dense_batch(x, fill_value=fill) assert out.size() == (1, 6, 2) assert torch.equal(out[0], x) assert mask.tolist() == [[1, 1, 1, 1, 1, 1]] out, mask = to_dense_batch(x, max_num_nodes=2, fill_value=fill) assert out.size() == (1, 2, 2) assert torch.equal(out[0], x[:2]) assert mask.tolist() == [[1, 1]] out, mask = to_dense_batch(x, max_num_nodes=10, fill_value=fill) assert out.size() == (1, 10, 2) assert torch.equal(out[0, :6], x) assert mask.tolist() == [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]] out, mask = to_dense_batch(x, batch, batch_size=4, fill_value=fill) assert out.size() == (4, 3, 2) def test_to_dense_batch_disable_dynamic_shapes(): x = torch.tensor([ [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0], ]) batch = torch.tensor([0, 0, 1, 2, 2, 2]) with set_experimental_mode(True, 'disable_dynamic_shapes'): with pytest.raises(ValueError, match="'batch_size' needs to be set"): out, mask = to_dense_batch(x, batch, max_num_nodes=6) with pytest.raises(ValueError, match="'max_num_nodes' needs to be"): out, mask = to_dense_batch(x, batch, batch_size=4) with pytest.raises(ValueError, match="'batch_size' needs to be set"): out, mask = to_dense_batch(x) out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6) assert out.size() == (1, 6, 2) assert mask.size() == (1, 6) out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10) assert out.size() == (3, 10, 2) assert mask.size() == (3, 10) @onlyFullTest def test_to_dense_batch_jit(): @torch.jit.script def to_dense_batch_jit( x: Tensor, batch: Tensor, fill_value: Tensor, ) -> Tuple[Tensor, Tensor]: return to_dense_batch(x, batch, fill_value=fill_value) x = torch.randn(6, 2) batch = torch.tensor([0, 0, 1, 2, 2, 2]) out, mask = to_dense_batch_jit(x, batch, fill_value=torch.tensor(0.0)) assert out.size() == (3, 3, 2) assert mask.size() == (3, 3) ================================================ FILE: test/utils/test_total_influence.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.nn import GCNConv from torch_geometric.utils import total_influence class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(5, 6) self.conv2 = GCNConv(6, 7) def forward(self, x0, edge_index): x1 = self.conv1(x0, edge_index) x2 = self.conv2(x1, edge_index) return x2 def test_total_influence_smoke(): x = torch.randn(6, 5) edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]]) max_hops = 2 num_samples = 4 data = Data( x=x, edge_index=edge_index, ) model = GNN() I, R = total_influence( model, data, max_hops=max_hops, num_samples=num_samples, ) assert I.shape == (max_hops + 1, ) assert 0.0 <= R <= max_hops I, R = total_influence( model, data, max_hops=max_hops, num_samples=num_samples, average=False, ) assert I.shape == torch.Size([num_samples, max_hops + 1]) ================================================ FILE: test/utils/test_train_test_split_edges.py ================================================ import pytest import torch from torch_geometric.data import Data from torch_geometric.utils import train_test_split_edges def test_train_test_split_edges(): edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) edge_attr = torch.arange(edge_index.size(1)) data = Data(edge_index=edge_index, edge_attr=edge_attr) data.num_nodes = edge_index.max().item() + 1 with pytest.warns(UserWarning, match='deprecated'): data = train_test_split_edges(data, val_ratio=0.2, test_ratio=0.3) assert len(data) == 10 assert data.val_pos_edge_index.size() == (2, 2) assert data.val_neg_edge_index.size() == (2, 2) assert data.test_pos_edge_index.size() == (2, 3) assert data.test_neg_edge_index.size() == (2, 3) assert data.train_pos_edge_index.size() == (2, 10) assert data.train_neg_adj_mask.size() == (11, 11) assert data.train_neg_adj_mask.sum().item() == (11**2 - 11) / 2 - 4 - 6 - 5 assert data.train_pos_edge_attr.size() == (10, ) assert data.val_pos_edge_attr.size() == (2, ) assert data.test_pos_edge_attr.size() == (3, ) ================================================ FILE: test/utils/test_tree_decomposition.py ================================================ import pytest from torch_geometric.testing import withPackage from torch_geometric.utils import tree_decomposition @withPackage('rdkit') @pytest.mark.parametrize('smiles', [ r'F/C=C/F', r'C/C(=C\C(=O)c1ccc(C)o1)Nc1ccc2c(c1)OCO2', ]) def test_tree_decomposition(smiles): from rdkit import Chem mol = Chem.MolFromSmiles(smiles) tree_decomposition(mol) # TODO Test output ================================================ FILE: test/utils/test_trim_to_layer.py ================================================ from typing import List, Optional import torch from torch import Tensor import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.loader import NeighborLoader from torch_geometric.nn import GraphConv from torch_geometric.testing import withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import trim_to_layer from torch_geometric.utils._trim_to_layer import trim_sparse_tensor @withPackage('torch_sparse') def test_trim_sparse_tensor(): edge_index = torch.tensor([[0, 0, 1, 2], [1, 2, 3, 4]]) adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[5, 5]) adj = trim_sparse_tensor(adj, size=(3, 3), num_seed_nodes=1) row, col, _ = adj.coo() assert row.tolist() == [0, 0] assert col.tolist() == [1, 2] def test_trim_to_layer_basic(): x0 = torch.arange(4) edge_index0 = torch.tensor([[1, 2, 3], [0, 1, 2]]) edge_weight0 = torch.arange(3) num_sampled_nodes_per_hop = [1, 1, 1] num_sampled_edges_per_hop = [1, 1, 1] x1, edge_index1, edge_weight1 = trim_to_layer( layer=0, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, x=x0, edge_index=edge_index0, edge_attr=edge_weight0, ) assert torch.equal(x1, torch.arange(4)) assert edge_index1.tolist() == [[1, 2, 3], [0, 1, 2]] assert torch.equal(edge_weight1, torch.arange(3)) if torch_geometric.typing.WITH_TORCH_SPARSE: adj0 = SparseTensor.from_edge_index(edge_index0, edge_weight0, (4, 4)) x1, adj_t1, _ = trim_to_layer( layer=0, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, x=x0, edge_index=adj0.t(), edge_attr=edge_weight0, ) adj1 = adj_t1.t() assert adj1.sizes() == [4, 4] row, col, value = adj1.coo() assert torch.equal(x1, torch.arange(4)) assert row.tolist() == [1, 2, 3] assert col.tolist() == [0, 1, 2] assert torch.equal(value, torch.arange(3)) x2, edge_index2, edge_weight2 = trim_to_layer( layer=1, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, x=x1, edge_index=edge_index1, edge_attr=edge_weight1, ) assert torch.equal(x2, torch.arange(3)) assert edge_index2.tolist() == [[1, 2], [0, 1]] assert torch.equal(edge_weight2, torch.arange(2)) if torch_geometric.typing.WITH_TORCH_SPARSE: adj1 = SparseTensor.from_edge_index(edge_index1, edge_weight1, (4, 4)) x2, adj_t2, _ = trim_to_layer( layer=1, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, x=x1, edge_index=adj1.t(), ) adj2 = adj_t2.t() assert adj2.sizes() == [3, 3] row, col, value = adj2.coo() assert torch.equal(x2, torch.arange(3)) assert row.tolist() == [1, 2] assert col.tolist() == [0, 1] assert torch.equal(value, torch.arange(2)) x3, edge_index3, edge_weight3 = trim_to_layer( layer=2, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, x=x2, edge_index=edge_index2, edge_attr=edge_weight2, ) assert torch.equal(x3, torch.arange(2)) assert edge_index3.tolist() == [[1], [0]] assert torch.equal(edge_weight3, torch.arange(1)) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index2, edge_weight2, (3, 3)) x3, adj_t3, _ = trim_to_layer( layer=2, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, x=x2, edge_index=adj2.t(), ) adj3 = adj_t3.t() assert adj3.sizes() == [2, 2] row, col, value = adj3.coo() assert torch.equal(x3, torch.arange(2)) assert row.tolist() == [1] assert col.tolist() == [0] assert torch.equal(value, torch.arange(1)) def test_trim_to_layer_hetero(): x = {'v': torch.arange(4)} edge_index = {('v', 'to', 'v'): torch.tensor([[1, 2, 3], [0, 1, 2]])} edge_weight = {('v', 'to', 'v'): torch.arange(3)} num_sampled_nodes_per_hop = {'v': [1, 1, 1, 1]} num_sampled_edges_per_hop = {('v', 'to', 'v'): [1, 1, 1]} x, edge_index, edge_weight = trim_to_layer( layer=1, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, x=x, edge_index=edge_index, edge_attr=edge_weight, ) assert torch.equal(x['v'], torch.arange(3)) assert edge_index['v', 'to', 'v'].tolist() == [[1, 2], [0, 1]] assert torch.equal(edge_weight['v', 'to', 'v'], torch.arange(2)) class GNN(torch.nn.Module): def __init__(self, num_layers: int): super().__init__() self.convs = torch.nn.ModuleList( GraphConv(16, 16) for _ in range(num_layers)) def forward( self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, num_sampled_nodes: Optional[List[int]] = None, num_sampled_edges: Optional[List[int]] = None, ) -> Tensor: for i, conv in enumerate(self.convs): if num_sampled_nodes is not None: x, edge_index, edge_weight = trim_to_layer( i, num_sampled_nodes, num_sampled_edges, x, edge_index, edge_weight) x = conv(x, edge_index, edge_weight) return x @withPackage('pyg_lib') def test_trim_to_layer_with_neighbor_loader(): x = torch.randn(14, 16) edge_index = torch.tensor([ [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], ]) edge_weight = torch.rand(edge_index.size(1)) data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight) loader = NeighborLoader( data, num_neighbors=[1, 2, 4], batch_size=2, shuffle=False, ) batch = next(iter(loader)) model = GNN(num_layers=3) out1 = model(batch.x, batch.edge_index, batch.edge_weight)[:2] assert out1.size() == (2, 16) out2 = model(batch.x, batch.edge_index, batch.edge_weight, batch.num_sampled_nodes, batch.num_sampled_edges)[:2] assert out2.size() == (2, 16) assert torch.allclose(out1, out2, atol=1e-6) ================================================ FILE: test/utils/test_unbatch.py ================================================ import torch from torch_geometric.utils import unbatch, unbatch_edge_index def test_unbatch(): src = torch.arange(10) batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 3, 4, 4]) out = unbatch(src, batch) assert len(out) == 5 for i in range(len(out)): assert torch.equal(out[i], src[batch == i]) def test_unbatch_edge_index(): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 4, 5, 5, 6], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5], ]) batch = torch.tensor([0, 0, 0, 0, 1, 1, 1]) edge_indices = unbatch_edge_index(edge_index, batch) assert edge_indices[0].tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] assert edge_indices[1].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] ================================================ FILE: test/utils/test_undirected.py ================================================ import torch from torch import Tensor from torch_geometric.utils import is_undirected, to_undirected def test_is_undirected(): row = torch.tensor([0, 1, 0]) col = torch.tensor([1, 0, 0]) sym_weight = torch.tensor([0, 0, 1]) asym_weight = torch.tensor([0, 1, 1]) assert is_undirected(torch.stack([row, col], dim=0)) assert is_undirected(torch.stack([row, col], dim=0), sym_weight) assert not is_undirected(torch.stack([row, col], dim=0), asym_weight) row = torch.tensor([0, 1, 1]) col = torch.tensor([1, 0, 2]) assert not is_undirected(torch.stack([row, col], dim=0)) @torch.jit.script def jit(edge_index: Tensor) -> bool: return is_undirected(edge_index) assert not jit(torch.stack([row, col], dim=0)) def test_to_undirected(): row = torch.tensor([0, 1, 1]) col = torch.tensor([1, 0, 2]) edge_index = to_undirected(torch.stack([row, col], dim=0)) assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] @torch.jit.script def jit(edge_index: Tensor) -> Tensor: return to_undirected(edge_index) assert torch.equal(jit(torch.stack([row, col], dim=0)), edge_index) ================================================ FILE: test/visualization/test_graph_visualization.py ================================================ import os.path as osp import pytest import torch from torch_geometric.testing import onlyGraphviz, withPackage from torch_geometric.visualization import visualize_graph @onlyGraphviz @pytest.mark.parametrize('backend', [None, 'graphviz']) def test_visualize_graph_via_graphviz(tmp_path, backend): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3], ]) edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float() path = osp.join(tmp_path, 'graph.pdf') visualize_graph(edge_index, edge_weight, path, backend) assert osp.exists(path) @onlyGraphviz @pytest.mark.parametrize('backend', [None, 'graphviz']) def test_visualize_graph_via_graphviz_with_node_labels(tmp_path, backend): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3], ]) edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float() node_labels = ['A', 'B', 'C', 'D', 'E'] path = osp.join(tmp_path, 'graph.pdf') visualize_graph(edge_index, edge_weight, path, backend, node_labels) assert osp.exists(path) @withPackage('networkx', 'matplotlib') @pytest.mark.parametrize('backend', [None, 'networkx']) def test_visualize_graph_via_networkx(tmp_path, backend): edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 3, 4], [1, 0, 2, 1, 3, 2, 4, 3], ]) edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float() path = osp.join(tmp_path, 'graph.pdf') visualize_graph(edge_index, edge_weight, path, backend) assert osp.exists(path) ================================================ FILE: test/visualization/test_influence.py ================================================ import torch from torch_geometric.datasets import KarateClub from torch_geometric.nn import GCNConv from torch_geometric.visualization import influence class Net(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, out_channels) self.conv2 = GCNConv(out_channels, out_channels) def forward(self, x, edge_index): x = torch.nn.functional.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return x def test_influence(): data = KarateClub()[0] x = torch.randn(data.num_nodes, 8) out = influence(Net(x.size(1), 16), x, data.edge_index) assert out.size() == (data.num_nodes, data.num_nodes) assert torch.allclose(out.sum(dim=-1), torch.ones(data.num_nodes), atol=1e-04) ================================================ FILE: torch_geometric/__init__.py ================================================ from collections import defaultdict import torch import torch_geometric.typing from ._compile import compile, is_compiling from ._onnx import is_in_onnx_export, safe_onnx_export from .index import Index from .edge_index import EdgeIndex from .hash_tensor import HashTensor from .seed import seed_everything from .home import get_home_dir, set_home_dir from .device import is_mps_available, is_xpu_available, device from .isinstance import is_torch_instance from .debug import is_debug_enabled, debug, set_debug import torch_geometric.utils import torch_geometric.data import torch_geometric.sampler import torch_geometric.loader import torch_geometric.transforms import torch_geometric.datasets import torch_geometric.nn import torch_geometric.explain import torch_geometric.profile from .experimental import (is_experimental_mode_enabled, experimental_mode, set_experimental_mode) from .lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib') graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym') __version__ = '2.8.0' __all__ = [ 'Index', 'EdgeIndex', 'HashTensor', 'seed_everything', 'get_home_dir', 'set_home_dir', 'compile', 'is_compiling', 'is_in_onnx_export', 'safe_onnx_export', 'is_mps_available', 'is_xpu_available', 'device', 'is_torch_instance', 'is_debug_enabled', 'debug', 'set_debug', 'is_experimental_mode_enabled', 'experimental_mode', 'set_experimental_mode', 'torch_geometric', '__version__', ] if not torch_geometric.typing.WITH_PT113: import warnings as std_warnings std_warnings.warn( "PyG 2.7 removed support for PyTorch < 1.13. Consider " "Consider upgrading to PyTorch >= 1.13 or downgrading " "to PyG <= 2.6. ", stacklevel=2) # Serialization ############################################################### if torch_geometric.typing.WITH_PT24: torch.serialization.add_safe_globals([ dict, list, defaultdict, Index, torch_geometric.index.CatMetadata, EdgeIndex, torch_geometric.edge_index.SortOrder, torch_geometric.edge_index.CatMetadata, HashTensor, ]) ================================================ FILE: torch_geometric/_compile.py ================================================ import warnings from typing import Any, Callable, Optional, Union import torch import torch_geometric.typing def is_compiling() -> bool: r"""Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via :meth:`torch.compile`. """ if torch_geometric.typing.WITH_PT23: return torch.compiler.is_compiling() if torch_geometric.typing.WITH_PT21: return torch._dynamo.is_compiling() return False # pragma: no cover def compile( model: Optional[torch.nn.Module] = None, *args: Any, **kwargs: Any, ) -> Union[torch.nn.Module, Callable[[torch.nn.Module], torch.nn.Module]]: r"""Optimizes the given :pyg:`PyG` model/function via :meth:`torch.compile`. This function has the same signature as :meth:`torch.compile` (see `here `__). Args: model: The model to compile. *args: Additional arguments of :meth:`torch.compile`. **kwargs: Additional keyword arguments of :meth:`torch.compile`. .. note:: :meth:`torch_geometric.compile` is deprecated in favor of :meth:`torch.compile`. """ warnings.warn( "'torch_geometric.compile' is deprecated in favor of " "'torch.compile'", stacklevel=2) return torch.compile(model, *args, **kwargs) # type: ignore ================================================ FILE: torch_geometric/_onnx.py ================================================ import warnings from os import PathLike from typing import Any, Union import torch from torch_geometric import is_compiling def is_in_onnx_export() -> bool: r"""Returns :obj:`True` in case :pytorch:`PyTorch` is exporting to ONNX via :meth:`torch.onnx.export`. """ if is_compiling(): return False if torch.jit.is_scripting(): return False return torch.onnx.is_in_onnx_export() def safe_onnx_export( model: torch.nn.Module, args: Union[torch.Tensor, tuple[Any, ...]], f: Union[str, PathLike[Any], None], skip_on_error: bool = False, **kwargs: Any, ) -> bool: r"""A safe wrapper around :meth:`torch.onnx.export` that handles known ONNX serialization issues in PyTorch Geometric. This function provides workarounds for the ``onnx_ir.serde.SerdeError`` with boolean ``allowzero`` attributes that occurs in certain environments. Args: model (torch.nn.Module): The model to export. args (torch.Tensor or tuple): The input arguments for the model. f (str or PathLike): The file path to save the model. skip_on_error (bool): If True, return False instead of raising when workarounds fail. Useful for CI environments. **kwargs: Additional arguments passed to :meth:`torch.onnx.export`. Returns: bool: True if export succeeded, False if skipped due to known issues (only when skip_on_error=True). Example: >>> from torch_geometric.nn import SAGEConv >>> from torch_geometric import safe_onnx_export >>> >>> class MyModel(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.conv = SAGEConv(8, 16) ... def forward(self, x, edge_index): ... return self.conv(x, edge_index) >>> >>> model = MyModel() >>> x = torch.randn(3, 8) >>> edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]]) >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx') >>> >>> # For CI environments: >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx', ... skip_on_error=True) >>> if not success: ... print("ONNX export skipped due to known upstream issue") """ # Convert single tensor to tuple for torch.onnx.export compatibility if isinstance(args, torch.Tensor): args = (args, ) try: # First attempt: standard ONNX export torch.onnx.export(model, args, f, **kwargs) return True except Exception as e: error_str = str(e) error_type = type(e).__name__ # Check for the specific onnx_ir.serde.SerdeError patterns is_allowzero_error = (('onnx_ir.serde.SerdeError' in error_str and 'allowzero' in error_str) or 'ValueError: Value out of range: 1' in error_str or 'serialize_model_into' in error_str or 'serialize_attribute_into' in error_str) if is_allowzero_error: warnings.warn( f"Encountered known ONNX serialization issue ({error_type}). " "This is likely the allowzero boolean attribute bug. " "Attempting workaround...", UserWarning, stacklevel=2) # Apply workaround strategies return _apply_onnx_allowzero_workaround(model, args, f, skip_on_error, **kwargs) else: # Re-raise other errors raise def _apply_onnx_allowzero_workaround( model: torch.nn.Module, args: tuple[Any, ...], f: Union[str, PathLike[Any], None], skip_on_error: bool = False, **kwargs: Any, ) -> bool: r"""Apply workaround strategies for onnx_ir.serde.SerdeError with allowzero attributes. Returns: bool: True if export succeeded, False if skipped (when skip_on_error=True). """ # Strategy 1: Try without dynamo if it was enabled if kwargs.get('dynamo', False): try: kwargs_no_dynamo = kwargs.copy() kwargs_no_dynamo['dynamo'] = False warnings.warn( "Retrying ONNX export with dynamo=False as workaround", UserWarning, stacklevel=3) torch.onnx.export(model, args, f, **kwargs_no_dynamo) return True except Exception: pass # Strategy 2: Try with different opset versions original_opset = kwargs.get('opset_version', 18) for opset_version in [17, 16, 15, 14, 13, 11]: if opset_version != original_opset: try: kwargs_opset = kwargs.copy() kwargs_opset['opset_version'] = opset_version warnings.warn( f"Retrying ONNX export with opset_version={opset_version}", UserWarning, stacklevel=3) torch.onnx.export(model, args, f, **kwargs_opset) return True except Exception: continue # Strategy 3: Try legacy export (non-dynamo with older opset) try: kwargs_legacy = kwargs.copy() kwargs_legacy['dynamo'] = False kwargs_legacy['opset_version'] = 11 warnings.warn( "Retrying ONNX export with legacy settings " "(dynamo=False, opset_version=11)", UserWarning, stacklevel=3) torch.onnx.export(model, args, f, **kwargs_legacy) return True except Exception: pass # Strategy 4: Try with minimal settings try: minimal_kwargs: dict[str, Any] = { 'opset_version': 11, 'dynamo': False, } # Add optional parameters if they exist if kwargs.get('input_names') is not None: minimal_kwargs['input_names'] = kwargs.get('input_names') if kwargs.get('output_names') is not None: minimal_kwargs['output_names'] = kwargs.get('output_names') warnings.warn( "Retrying ONNX export with minimal settings as last resort", UserWarning, stacklevel=3) torch.onnx.export(model, args, f, **minimal_kwargs) return True except Exception: pass # If all strategies fail, handle based on skip_on_error flag import os pytest_detected = 'PYTEST_CURRENT_TEST' in os.environ or 'pytest' in str(f) if skip_on_error: # For CI environments: skip gracefully instead of failing warnings.warn( "ONNX export skipped due to known upstream issue " "(onnx_ir.serde.SerdeError). " "This is caused by a bug in the onnx_ir package where boolean " "allowzero attributes cannot be serialized. All workarounds " "failed. Consider updating packages: pip install --upgrade onnx " "onnxscript " "onnx_ir", UserWarning, stacklevel=3) return False # For regular usage: provide detailed error message error_msg = ( "Failed to export model to ONNX due to known serialization issue. " "This is caused by a bug in the onnx_ir package where boolean " "allowzero attributes cannot be serialized. " "Workarounds attempted: dynamo=False, multiple opset versions, " "and legacy export. ") if pytest_detected: error_msg += ( "\n\nThis error commonly occurs in pytest environments. " "Try one of these solutions:\n" "1. Run the export outside of pytest (in a regular Python " "script)\n" "2. Update packages: pip install --upgrade onnx onnxscript " "onnx_ir\n" "3. Use torch.jit.script() instead of ONNX export for testing\n" "4. Use safe_onnx_export(..., skip_on_error=True) to skip " "gracefully in CI") else: error_msg += ("\n\nTry updating packages: pip install --upgrade onnx " "onnxscript onnx_ir") raise RuntimeError(error_msg) ================================================ FILE: torch_geometric/backend.py ================================================ from typing import Optional import torch # If set to `True`, PyG is configured to use the `segment_matmul` and # `grouped_matmul` kernels from `pyg-lib` to parallelize matrix multiplication # across segments/groups of potentially varying size. # If set to `None`, will automatically decide whether to utilize # `segment_matmul` and `grouped_matmul` based on input sizes. # Requires `pyg-lib` to be installed. use_segment_matmul: Optional[bool] = None # Helper functions ############################################################ def use_segment_matmul_heuristic( num_segments: int, max_segment_size: int, in_channels: int, out_channels: int, ) -> bool: r"""A heuristic based on input sizes to determine whether the usage of :meth:`segment_matmul` can speed up computation. """ # NOTE This heuristic was learned on an A100 via sklearn using a simple # StandardScaler() -> LinearSVC() model. # For now, it is only used in combination with `RGCNConv`. x = torch.tensor([ num_segments, max_segment_size, in_channels, out_channels, ]) mean = torch.tensor([ 125.11603189, 12133.21523472, 163.81222321, 32.43755536, ]) std = torch.tensor([ 163.34480422, 27572.94543809, 177.6426489, 56.82103934, ]) weight = torch.tensor([ 2.43877659e+00, 1.67583047e+00, -5.20527282e-04, 3.43925501e-01, ]) bias = 1.20236999 x = (x - mean) / std return bool(x @ weight >= bias) ================================================ FILE: torch_geometric/config_mixin.py ================================================ import inspect from dataclasses import fields, is_dataclass from importlib import import_module from typing import Any, Dict from torch.nn import ModuleDict, ModuleList from torch_geometric.config_store import ( class_from_dataclass, dataclass_from_class, ) from torch_geometric.isinstance import is_torch_instance class ConfigMixin: r"""Enables a class to serialize/deserialize itself to a dataclass.""" def config(self) -> Any: r"""Creates a serializable configuration of the class.""" data_cls = dataclass_from_class(self.__class__) if data_cls is None: raise ValueError(f"Could not find the configuration class that " f"belongs to '{self.__class__.__name__}'. Please " f"register it in the configuration store.") kwargs: Dict[str, Any] = {} for field in fields(data_cls): if not hasattr(self, field.name): continue kwargs[field.name] = _recursive_config(getattr(self, field.name)) return data_cls(**kwargs) @classmethod def from_config(cls, cfg: Any, *args: Any, **kwargs: Any) -> Any: r"""Instantiates the class from a serializable configuration.""" if getattr(cfg, '_target_', None): cls = _locate_cls(cfg._target_) elif isinstance(cfg, dict) and '_target_' in cfg: cls = _locate_cls(cfg['_target_']) data_cls = cfg.__class__ if not is_dataclass(data_cls): data_cls = dataclass_from_class(cls) if data_cls is None: raise ValueError(f"Could not find the configuration class " f"that belongs to '{cls.__name__}'. Please " f"register it in the configuration store.") field_names = {field.name for field in fields(data_cls)} if isinstance(cfg, dict): _kwargs = {k: v for k, v in cfg.items() if k in field_names} cfg = data_cls(**_kwargs) assert is_dataclass(cfg) if len(args) > 0: # Convert `*args` to `**kwargs`: param_names = list(inspect.signature(cls).parameters.keys()) if 'args' in param_names: param_names.remove('args') if 'kwargs' in param_names: param_names.remove('kwargs') for name, arg in zip(param_names, args): kwargs[name] = arg for key in field_names: if key not in kwargs and key != '_target_': kwargs[key] = _recursive_from_config(getattr(cfg, key)) return cls(**kwargs) def _recursive_config(value: Any) -> Any: if isinstance(value, ConfigMixin): return value.config() if is_torch_instance(value, ConfigMixin): return value.config() if isinstance(value, (tuple, list, ModuleList)): return [_recursive_config(v) for v in value] if isinstance(value, (dict, ModuleDict)): return {k: _recursive_config(v) for k, v in value.items()} return value def _recursive_from_config(value: Any) -> Any: cls: Any = None if is_dataclass(value): if getattr(value, '_target_', None): try: cls = _locate_cls(value._target_) # type: ignore except ImportError: pass # Keep the dataclass as it is. else: cls = class_from_dataclass(value.__class__) elif isinstance(value, dict) and '_target_' in value: cls = _locate_cls(value['_target_']) if cls is not None and issubclass(cls, ConfigMixin): return cls.from_config(value) if isinstance(value, (tuple, list)): return [_recursive_from_config(v) for v in value] if isinstance(value, dict): return {k: _recursive_from_config(v) for k, v in value.items()} return value def _locate_cls(qualname: str) -> Any: parts = qualname.split('.') if len(parts) <= 1: raise ValueError(f"Qualified name is missing a dot (got '{qualname}')") if any([len(part) == 0 for part in parts]): raise ValueError(f"Relative imports not supported (got '{qualname}')") module_name, cls_name = '.'.join(parts[:-1]), parts[-1] return getattr(import_module(module_name), cls_name) ================================================ FILE: torch_geometric/config_store.py ================================================ import copy import inspect import typing from collections import defaultdict from dataclasses import dataclass, field, make_dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch EXCLUDE = {'self', 'args', 'kwargs'} MAPPING = { torch.nn.Module: Any, torch.Tensor: Any, } try: from omegaconf import MISSING except Exception: MISSING = '???' try: import hydra # noqa WITH_HYDRA = True except Exception: WITH_HYDRA = False if not typing.TYPE_CHECKING and WITH_HYDRA: from hydra.core.config_store import ConfigStore def get_node(cls: Union[str, Any]) -> Optional[Any]: if (not isinstance(cls, str) and cls.__module__ in {'builtins', 'typing'}): return None def _get_candidates(repo: Dict[str, Any]) -> List[Any]: outs: List[Any] = [] for key, value in repo.items(): if isinstance(value, dict): outs.extend(_get_candidates(value)) elif getattr(value.node._metadata, 'object_type', None) == cls: outs.append(value.node) elif getattr(value.node._metadata, 'orig_type', None) == cls: outs.append(value.node) elif isinstance(cls, str) and key == f'{cls}.yaml': outs.append(value.node) return outs candidates = _get_candidates(get_config_store().repo) if len(candidates) > 1: raise ValueError(f"Found multiple entries in the configuration " f"store for the same node '{candidates[0].name}'") return candidates[0] if len(candidates) == 1 else None def dataclass_from_class(cls: Union[str, Any]) -> Optional[Any]: r"""Returns the :obj:`dataclass` of a class registered in the global configuration store. """ node = get_node(cls) return node._metadata.object_type if node is not None else None def class_from_dataclass(cls: Union[str, Any]) -> Optional[Any]: r"""Returns the original class of a :obj:`dataclass` registered in the global configuration store. """ node = get_node(cls) return node._metadata.orig_type if node is not None else None else: class Singleton(type): _instances: Dict[type, Any] = {} def __call__(cls, *args: Any, **kwargs: Any) -> Any: if cls not in cls._instances: instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance return instance return cls._instances[cls] @dataclass class Metadata: orig_type: Optional[Any] = None @dataclass class ConfigNode: name: str node: Any group: Optional[str] = None _metadata: Metadata = field(default_factory=Metadata) class ConfigStore(metaclass=Singleton): def __init__(self) -> None: self.repo: Dict[str, Any] = defaultdict(dict) @classmethod def instance(cls, *args: Any, **kwargs: Any) -> 'ConfigStore': return cls(*args, **kwargs) def store( self, name: str, node: Any, group: Optional[str] = None, orig_type: Optional[Any] = None, ) -> None: cur = self.repo if group is not None: cur = cur[group] if name in cur: raise KeyError(f"Configuration '{name}' already registered. " f"Please store it under a different group.") metadata = Metadata(orig_type=orig_type) cur[name] = ConfigNode(name, node, group, metadata) def get_node(cls: Union[str, Any]) -> Optional[ConfigNode]: if (not isinstance(cls, str) and cls.__module__ in {'builtins', 'typing'}): return None def _get_candidates(repo: Dict[str, Any]) -> List[ConfigNode]: outs: List[ConfigNode] = [] for key, value in repo.items(): if isinstance(value, dict): outs.extend(_get_candidates(value)) elif value.node == cls: outs.append(value) elif value._metadata.orig_type == cls: outs.append(value) elif isinstance(cls, str) and key == cls: outs.append(value) return outs candidates = _get_candidates(get_config_store().repo) if len(candidates) > 1: raise ValueError(f"Found multiple entries in the configuration " f"store for the same node '{candidates[0].name}'") return candidates[0] if len(candidates) == 1 else None def dataclass_from_class(cls: Union[str, Any]) -> Optional[Any]: r"""Returns the :obj:`dataclass` of a class registered in the global configuration store. """ node = get_node(cls) return node.node if node is not None else None def class_from_dataclass(cls: Union[str, Any]) -> Optional[Any]: r"""Returns the original class of a :obj:`dataclass` registered in the global configuration store. """ node = get_node(cls) return node._metadata.orig_type if node is not None else None def map_annotation( annotation: Any, mapping: Optional[Dict[Any, Any]] = None, ) -> Any: origin = getattr(annotation, '__origin__', None) args: Tuple[Any, ...] = getattr(annotation, '__args__', tuple()) if origin in {Union, list, dict, tuple}: assert origin is not None args = tuple(map_annotation(a, mapping) for a in args) if type(annotation).__name__ == 'GenericAlias': # If annotated with `list[...]` or `dict[...]`: annotation = origin[args] else: # If annotated with `typing.List[...]` or `typing.Dict[...]`: annotation = copy.copy(annotation) annotation.__args__ = args return annotation if mapping is not None and annotation in mapping: return mapping[annotation] out = dataclass_from_class(annotation) if out is not None: return out return annotation def to_dataclass( cls: Any, base_cls: Optional[Any] = None, with_target: Optional[bool] = None, map_args: Optional[Dict[str, Tuple]] = None, exclude_args: Optional[List[str]] = None, strict: bool = False, ) -> Any: r"""Converts the input arguments of a given class :obj:`cls` to a :obj:`dataclass` schema. For example, .. code-block:: python from torch_geometric.transforms import NormalizeFeatures dataclass = to_dataclass(NormalizeFeatures) will generate .. code-block:: python @dataclass class NormalizeFeatures: _target_: str = "torch_geometric.transforms.NormalizeFeatures" attrs: List[str] = field(default_factory = lambda: ["x"]) Args: cls (Any): The class to generate a schema for. base_cls (Any, optional): The base class of the schema. (default: :obj:`None`) with_target (bool, optional): If set to :obj:`False`, will not add the :obj:`_target_` attribute to the schema. If set to :obj:`None`, will only add the :obj:`_target_` in case :obj:`base_cls` is given. (default: :obj:`None`) map_args (Dict[str, Tuple], optional): Arguments for which annotation and default values should be overridden. (default: :obj:`None`) exclude_args (List[str or int], optional): Arguments to exclude. (default: :obj:`None`) strict (bool, optional): If set to :obj:`True`, ensures that all arguments in both :obj:`map_args` and :obj:`exclude_args` are present in the input parameters. (default: :obj:`False`) """ fields = [] params = inspect.signature(cls.__init__).parameters if strict: # Check that keys in map_args or exclude_args are present. keys = set() if map_args is None else set(map_args.keys()) if exclude_args is not None: keys |= {arg for arg in exclude_args if isinstance(arg, str)} diff = keys - set(params.keys()) if len(diff) > 0: raise ValueError(f"Expected input argument(s) {diff} in " f"'{cls.__name__}'") for i, (name, arg) in enumerate(params.items()): if name in EXCLUDE: continue if exclude_args is not None: if name in exclude_args or i in exclude_args: continue if base_cls is not None: if name in base_cls.__dataclass_fields__: continue if map_args is not None and name in map_args: fields.append((name, ) + map_args[name]) continue annotation, default = arg.annotation, arg.default annotation = map_annotation(annotation, mapping=MAPPING) if annotation != inspect.Parameter.empty: # `Union` types are not supported (except for `Optional`). # As such, we replace them with either `Any` or `Optional[Any]`. origin = getattr(annotation, '__origin__', None) args = getattr(annotation, '__args__', []) if origin == Union and type(None) in args and len(args) > 2: annotation = Optional[Any] elif origin == Union and type(None) not in args: annotation = Any elif origin == list: if getattr(args[0], '__origin__', None) == Union: annotation = List[Any] elif origin == dict: if getattr(args[1], '__origin__', None) == Union: annotation = Dict[args[0], Any] # type: ignore else: annotation = Any if str(default) == "": # Fix `torch.optim.SGD.lr = _RequiredParameter()`: # https://github.com/pytorch/hydra-torch/blob/main/ # hydra-configs-torch/hydra_configs/torch/optim/sgd.py default = field(default=MISSING) elif default != inspect.Parameter.empty: if isinstance(default, (list, dict)): # Avoid late binding of default values inside a loop: # https://stackoverflow.com/questions/3431676/ # creating-functions-in-a-loop def wrapper(default: Any) -> Callable[[], Any]: return lambda: default default = field(default_factory=wrapper(default)) else: default = field(default=MISSING) fields.append((name, annotation, default)) with_target = base_cls is not None if with_target is None else with_target if with_target: full_cls_name = f'{cls.__module__}.{cls.__qualname__}' fields.append(('_target_', str, field(default=full_cls_name))) return make_dataclass(cls.__qualname__, fields=fields, bases=() if base_cls is None else (base_cls, )) def get_config_store() -> ConfigStore: r"""Returns the global configuration store.""" return ConfigStore.instance() def clear_config_store() -> ConfigStore: r"""Clears the global configuration store.""" config_store = get_config_store() for key in list(config_store.repo.keys()): if key != 'hydra' and not key.endswith('.yaml'): del config_store.repo[key] return config_store def register( cls: Optional[Any] = None, data_cls: Optional[Any] = None, group: Optional[str] = None, **kwargs: Any, ) -> Union[Any, Callable]: r"""Registers a class in the global configuration store. Args: cls (Any, optional): The class to register. If set to :obj:`None`, will return a decorator. (default: :obj:`None`) data_cls (Any, optional): The data class to register. If set to :obj:`None`, will dynamically create the data class according to :class:`~torch_geometric.config_store.to_dataclass`. (default: :obj:`None`) group (str, optional): The group in the global configuration store. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`~torch_geometric.config_store.to_dataclass`. """ if cls is not None: name = cls.__name__ if get_node(cls): raise ValueError(f"The class '{name}' is already registered in " "the global configuration store") if data_cls is None: data_cls = to_dataclass(cls, **kwargs) elif get_node(data_cls): raise ValueError( f"The data class '{data_cls.__name__}' is already registered " f"in the global configuration store") if not typing.TYPE_CHECKING and WITH_HYDRA: get_config_store().store(name, data_cls, group) get_node(name)._metadata.orig_type = cls else: get_config_store().store(name, data_cls, group, cls) return data_cls def bounded_register(cls: Any) -> Any: # Other-wise, return a decorator: register(cls=cls, data_cls=data_cls, group=group, **kwargs) return cls return bounded_register ############################################################################### @dataclass class Transform: pass @dataclass class Dataset: pass @dataclass class Model: pass @dataclass class Optimizer: pass @dataclass class LRScheduler: pass @dataclass class Config: dataset: Dataset = MISSING model: Model = MISSING optim: Optimizer = MISSING lr_scheduler: Optional[LRScheduler] = None def fill_config_store() -> None: import torch_geometric config_store = get_config_store() # Register `torch_geometric.transforms` ################################### transforms = torch_geometric.transforms for cls_name in set(transforms.__all__) - { 'BaseTransform', 'Compose', 'ComposeFilters', 'LinearTransformation', 'AddMetaPaths', # TODO }: cls = to_dataclass(getattr(transforms, cls_name), base_cls=Transform) # We use an explicit additional nesting level inside each config to # allow for applying multiple transformations. # See: hydra.cc/docs/patterns/select_multiple_configs_from_config_group config_store.store(cls_name, group='transform', node={cls_name: cls}) # Register `torch_geometric.datasets` ##################################### datasets = torch_geometric.datasets map_dataset_args: Dict[str, Any] = { 'transform': (Dict[str, Transform], field(default_factory=dict)), 'pre_transform': (Dict[str, Transform], field(default_factory=dict)), } for cls_name in set(datasets.__all__) - set(): cls = to_dataclass(getattr(datasets, cls_name), base_cls=Dataset, map_args=map_dataset_args, exclude_args=['pre_filter']) config_store.store(cls_name, group='dataset', node=cls) # Register `torch_geometric.models` ####################################### models = torch_geometric.nn.models.basic_gnn for cls_name in set(models.__all__) - set(): cls = to_dataclass(getattr(models, cls_name), base_cls=Model) config_store.store(cls_name, group='model', node=cls) # Register `torch.optim.Optimizer` ######################################## for cls_name in { key for key, cls in torch.optim.__dict__.items() if inspect.isclass(cls) and issubclass(cls, torch.optim.Optimizer) } - { 'Optimizer', }: cls = to_dataclass(getattr(torch.optim, cls_name), base_cls=Optimizer, exclude_args=['params']) config_store.store(cls_name, group='optimizer', node=cls) # Register `torch.optim.lr_scheduler` ##################################### for cls_name in { key for key, cls in torch.optim.lr_scheduler.__dict__.items() if inspect.isclass(cls) } - { 'Optimizer', '_LRScheduler', 'Counter', 'SequentialLR', 'ChainedScheduler', }: cls = to_dataclass(getattr(torch.optim.lr_scheduler, cls_name), base_cls=LRScheduler, exclude_args=['optimizer']) config_store.store(cls_name, group='lr_scheduler', node=cls) # Register global schema ################################################## config_store.store('config', node=Config) ================================================ FILE: torch_geometric/contrib/__init__.py ================================================ import warnings import torch_geometric.contrib.transforms # noqa import torch_geometric.contrib.datasets # noqa import torch_geometric.contrib.nn # noqa import torch_geometric.contrib.explain # noqa warnings.warn( "'torch_geometric.contrib' contains experimental code and is subject to " "change. Please use with caution.", stacklevel=2) __all__ = [] ================================================ FILE: torch_geometric/contrib/datasets/__init__.py ================================================ __all__ = classes = [] ================================================ FILE: torch_geometric/contrib/explain/__init__.py ================================================ from torch_geometric.deprecation import deprecated from .pgm_explainer import PGMExplainer from torch_geometric.explain.algorithm.graphmask_explainer import ( GraphMaskExplainer as NewGraphMaskExplainer) GraphMaskExplainer = deprecated( "use 'torch_geometric.explain.algorithm.GraphMaskExplainer' instead", )( NewGraphMaskExplainer) __all__ = classes = [ 'PGMExplainer', ] ================================================ FILE: torch_geometric/contrib/explain/pgm_explainer.py ================================================ import logging from typing import List, Optional, Tuple, Union import numpy as np import torch from torch import Tensor from torch_geometric.explain import ExplainerAlgorithm from torch_geometric.explain.config import ModelMode, ModelTaskLevel from torch_geometric.explain.explanation import Explanation from torch_geometric.utils import k_hop_subgraph from torch_geometric.utils._subgraph import get_num_hops class PGMExplainer(ExplainerAlgorithm): r"""The PGMExplainer model from the `"PGMExplainer: Probabilistic Graphical Model Explanations for Graph Neural Networks" `_ paper. The generated :class:`~torch_geometric.explain.Explanation` provides a :obj:`node_mask` and a :obj:`pgm_stats` tensor, which stores the :math:`p`-values of each node as calculated by the Chi-squared test. Args: feature_index (List): The indices of the perturbed features. If set to :obj:`None`, all features are perturbed. (default: :obj:`None`) perturb_mode (str, optional): The method to generate the variations in features. One of :obj:`"randint"`, :obj:`"mean"`, :obj:`"zero"`, :obj:`"max"` or :obj:`"uniform"`. (default: :obj:`"randint"`) perturbations_is_positive_only (bool, optional): If set to :obj:`True`, restrict perturbed values to be positive. (default: :obj:`False`) is_perturbation_scaled (bool, optional): If set to :obj:`True`, will normalize the range of the perturbed features. (default: :obj:`False`) num_samples (int, optional): The number of samples of perturbations used to test the significance of nodes to the prediction. (default: :obj:`100`) max_subgraph_size (int, optional): The maximum number of neighbors to consider for the explanation. (default: :obj:`None`) significance_threshold (float, optional): The statistical threshold (:math:`p`-value) for which a node is considered to have an effect on the prediction. (default: :obj:`0.05`) pred_threshold (float, optional): The buffer value (in range :obj:`[0, 1]`) to consider the output from a perturbed data to be different from the original. (default: :obj:`0.1`) """ def __init__( self, feature_index: Optional[List] = None, perturbation_mode: str = "randint", perturbations_is_positive_only: bool = False, is_perturbation_scaled: bool = False, num_samples: int = 100, max_subgraph_size: Optional[int] = None, significance_threshold: float = 0.05, pred_threshold: float = 0.1, ): super().__init__() self.feature_index = feature_index self.perturbation_mode = perturbation_mode self.perturbations_is_positive_only = perturbations_is_positive_only self.is_perturbation_scaled = is_perturbation_scaled self.num_samples = num_samples self.max_subgraph_size = max_subgraph_size self.significance_threshold = significance_threshold self.pred_threshold = pred_threshold def _perturb_features_on_nodes( self, x: Tensor, index: Tensor, ) -> Tensor: r"""Perturbs feature matrix :obj:`x`. Args: x (torch.Tensor): The feature matrix. index (torch.Tensor): The indices of nodes to perturb. """ x_perturb = x.detach().clone() perturb_array = x_perturb[index] epsilon = 0.05 * torch.max(x, dim=0).values if self.perturbation_mode == "randint": perturb_array = torch.randint(high=2, size=perturb_array.size(), device=x.device) elif self.perturbation_mode == "mean": perturb_array[:, self.feature_index] = torch.mean( x[:, self.feature_index]) elif self.perturbation_mode == "zero": perturb_array[:, self.feature_index] = 0 elif self.perturbation_mode == "max": perturb_array[:, self.feature_index] = torch.max( x[:, self.feature_index]) elif self.perturbation_mode == "uniform": random_perturbations = torch.rand( perturb_array.shape) * 2 * epsilon - epsilon perturb_array[:, self.feature_index] = perturb_array[ self.feature_index] + random_perturbations perturb_array.clamp(min=0, max=torch.max(x, dim=0)) if self.is_perturbation_scaled: perturb_array = torch.multiply( perturb_array, torch.rand(size=perturb_array.size())) * 2 x_perturb[index] = perturb_array.type(x_perturb.dtype) return x_perturb def _batch_perturb_features_on_node( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, indices_to_perturb: np.array, percentage: float = 50., # % time node gets perturbed **kwargs, ) -> Tensor: r"""Perturbs the node features of a batch of graphs for graph classification tasks. Args: model (torch.nn.Module): The GNN model. x (torch.Tensor): The node feature matrix edge_index (torch.Tensor): The edge indices. indices_to_perturb (np.array): The indices of nodes to perturb. percentage (float, optional): The percentage of times a node gets perturbed. (default: :obj:`50.`) **kwargs (optional): Additional arguments passed to :meth:`model.forward`. """ pred_torch = model(x, edge_index, **kwargs) soft_pred = torch.softmax(pred_torch, dim=1) pred_label = torch.argmax(soft_pred, dim=1) num_nodes = x.shape[0] samples = [] for _ in range(self.num_samples): x_perturb = x.detach().clone() seeds = np.random.randint(0, 100, size=len(indices_to_perturb)) perturbed_node_indexes = indices_to_perturb[(seeds < percentage)] x_perturb = self._perturb_features_on_nodes( x=x_perturb, index=perturbed_node_indexes, ) sample = np.zeros(num_nodes + 1) sample[perturbed_node_indexes] = 1 pred_perturb_torch = model(x_perturb, edge_index, **kwargs) soft_pred_perturb = torch.softmax(pred_perturb_torch, dim=1).squeeze() pred_change = torch.max(soft_pred) - soft_pred_perturb[pred_label] sample[num_nodes] = pred_change.detach() samples.append(sample) samples = torch.tensor(np.array(samples)) if self.perturbations_is_positive_only: samples = torch.abs(samples) top = int(self.num_samples / 8) top_idx = torch.argsort(samples[:, num_nodes])[-top:] for i in range(self.num_samples): if i in top_idx: samples[i, num_nodes] = 1 else: samples[i, num_nodes] = 0 return samples def _explain_graph( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, target=None, **kwargs, ) -> Tuple[Tensor, Tensor]: r"""Generates explanations for graph classification tasks. Args: model (torch.nn.Module): The model to explain. x (torch.Tensor): The node features. edge_index (torch.Tensor): The edge indices of the input graph. target (torch.Tensor, optional): The predicted label from the model. (default: :obj:`None`) **kwargs (optional): Additional arguments passed to :meth:`model.forward`. Returns: pgm_nodes (List): The neighbor nodes that are significant in the selected node's prediction. pgm_stats (torch.Tensor): The :math:`p`-values of all the nodes in the graph, ordered by node index. """ import pandas as pd from pgmpy.estimators.CITests import chi_square num_nodes = x.shape[0] if not self.max_subgraph_size: self.max_subgraph_size = int(num_nodes / 20) samples = self._batch_perturb_features_on_node( indices_to_perturb=np.array(range(num_nodes)), x=x, model=model, edge_index=edge_index, ) # note: the PC estimator is in the original code, ie. est= PC(data) # but as it does nothing it is not included here data = pd.DataFrame(np.array(samples.detach().cpu())) p_values = [] for node in range(num_nodes): chi2, p, _ = chi_square( node, int(target.detach().cpu()), [], data, boolean=False, significance_level=self.significance_threshold) p_values.append(p) # the original code uses number_candidates_nodes = int(top_nodes * 4) # if we consider 'top nodes' to equate to max number of nodes # it seems more correct to limit number_candidates_nodes to this candidate_nodes = np.argpartition( p_values, self.max_subgraph_size)[0:self.max_subgraph_size] # Round 2 samples = self._batch_perturb_features_on_node( indices_to_perturb=candidate_nodes, x=x, edge_index=edge_index, model=model, **kwargs) # note: the PC estimator is in the original code, ie. est= PC(data) # but as it does nothing it is not included here data = pd.DataFrame(np.array(samples.detach().cpu())) p_values = [] dependent_nodes = [] target = num_nodes for node in range(num_nodes): _, p, _ = chi_square( node, target, [], data, boolean=False, significance_level=self.significance_threshold) p_values.append(p) if p < self.significance_threshold: dependent_nodes.append(node) top_p = np.min((self.max_subgraph_size, num_nodes - 1)) ind_top_p = np.argpartition(p_values, top_p)[0:top_p] pgm_nodes = list(ind_top_p) node_mask = torch.zeros(x.size(), dtype=torch.int) node_mask[pgm_nodes] = 1 pgm_stats = torch.tensor(p_values) return node_mask, pgm_stats def _explain_node( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, target: Tensor, index: int, **kwargs, ) -> Tuple[Tensor, Tensor]: r"""Generates explanations for node classification tasks. Args: model (torch.nn.Module): The model to explain. x (torch.Tensor): The node features. edge_index (torch.Tensor): The edge indices of the input graph. target (torch.Tensor): The predicted label from the model. index (int): The index of the node for which the explanations is generated. **kwargs (optional): Additional arguments passed to :meth:`model.forward`. Returns: node_mask (torch.Tensor): A hard node mask corresponding to whether a node is significant in the selected node's prediction. pgm_stats (torch.Tensor): The :math:`p`-values of all the nodes in the graph, ordered by node index. """ import pandas as pd from pgmpy.estimators.CITests import chi_square neighbors, _, _, _ = k_hop_subgraph( node_idx=index, num_hops=get_num_hops(model), edge_index=edge_index, relabel_nodes=False, num_nodes=x.size(0), ) if index not in neighbors: neighbors = torch.cat([neighbors, index], dim=1) pred_model = model(x, edge_index, **kwargs) softmax_pred = torch.softmax(pred_model, dim=1) samples = [] pred_samples = [] for _ in range(self.num_samples): # A subset of neighbors are selected randomly for perturbing: seeds = np.random.choice([1, 0], size=(len(neighbors), )) x_perturb = self._perturb_features_on_nodes( x=x, index=neighbors[seeds == 1], ) # prediction after perturbation pred_perturb = model(x_perturb, edge_index, **kwargs) softmax_pred_perturb = torch.softmax(pred_perturb, dim=1) sample_bool = np.ones(shape=(len(neighbors), )) sample_bool[((softmax_pred_perturb[neighbors, target] + self.pred_threshold) >= softmax_pred[neighbors, target]).cpu()] = 0 samples.append(seeds) pred_samples.append(sample_bool) samples = np.asarray(samples) pred_samples = np.asarray(pred_samples) combine_samples = (samples * 10 + pred_samples) + 1 neighbors = np.array(neighbors.detach().cpu()) data_pgm = pd.DataFrame(combine_samples) data_pgm = data_pgm.rename(columns={ 0: "A", 1: "B" }) # Trick to use chi_square test on first two data columns index_original_to_subgraph = dict( zip(neighbors, list(data_pgm.columns))) index_subgraph_to_original = dict( zip(list(data_pgm.columns), neighbors)) p_values = [] dependent_neighbors = [] dependent_neighbors_p_values = [] for node in neighbors: if node == index: # null hypothesis is perturbing a particular # node has no effect on result p = 0 else: _, p, _ = chi_square( index_original_to_subgraph[node], index_original_to_subgraph[index], [], data_pgm, boolean=False, significance_level=self.significance_threshold) p_values.append(p) if p < self.significance_threshold: dependent_neighbors.append(node) dependent_neighbors_p_values.append(p) pgm_stats = torch.ones(x.size(0), dtype=torch.float) node_mask = torch.zeros(x.size(), dtype=torch.int) pgm_stats[neighbors] = torch.tensor(p_values, dtype=torch.float) if self.max_subgraph_size is None: pgm_nodes = dependent_neighbors else: top_p = np.min((self.max_subgraph_size, len(neighbors) - 1)) ind_top_p = np.argpartition(p_values, top_p)[0:top_p] pgm_nodes = [ index_subgraph_to_original[node] for node in ind_top_p ] node_mask[pgm_nodes] = 1 return node_mask, pgm_stats def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, # node index **kwargs, ) -> Explanation: if self.feature_index is None: self.feature_index = list(range(x.shape[-1])) if isinstance(index, Tensor): if index.numel() > 1: raise NotImplementedError( f"'{self.__class__.__name}' only supports a single " f"`index` for now") index = index.item() if self.model_config.task_level == ModelTaskLevel.node: node_mask, pgm_stats = self._explain_node( model=model, x=x, edge_index=edge_index, target=target[index], index=index, **kwargs, ) return Explanation( x=x, edge_index=edge_index, node_mask=node_mask, pgm_stats=pgm_stats, ) elif self.model_config.task_level == ModelTaskLevel.graph: node_mask, pgm_stats = self._explain_graph( model=model, x=x, target=target, edge_index=edge_index, **kwargs, ) return Explanation( node_mask=node_mask, pgm_stats=pgm_stats, ) def supports(self) -> bool: task_level = self.model_config.task_level if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]: logging.error(f"Task level '{task_level.value}' not supported") return False if self.explainer_config.edge_mask_type is not None: logging.error("Generation of edge masks is not supported") return False if self.model_config.mode == ModelMode.regression: logging.error("'PGMExplainer' only supports classification tasks") return False return True ================================================ FILE: torch_geometric/contrib/nn/__init__.py ================================================ from .conv import * # noqa from .models import * # noqa __all__ = [] ================================================ FILE: torch_geometric/contrib/nn/conv/__init__.py ================================================ __all__ = classes = [] ================================================ FILE: torch_geometric/contrib/nn/models/__init__.py ================================================ from .rbcd_attack import PRBCDAttack, GRBCDAttack __all__ = classes = [ 'PRBCDAttack', 'GRBCDAttack', ] ================================================ FILE: torch_geometric/contrib/nn/models/rbcd_attack.py ================================================ from collections import defaultdict from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import Tensor from tqdm import tqdm from torch_geometric.utils import coalesce, to_undirected # (predictions, labels, ids/mask) -> Tensor with one element LOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] class PRBCDAttack(torch.nn.Module): r"""The Projected Randomized Block Coordinate Descent (PRBCD) adversarial attack from the `Robustness of Graph Neural Networks at Scale `_ paper. This attack uses an efficient gradient based approach that (during the attack) relaxes the discrete entries in the adjacency matrix :math:`\{0, 1\}` to :math:`[0, 1]` and solely perturbs the adjacency matrix (no feature perturbations). Thus, this attack supports all models that can handle weighted graphs that are differentiable w.r.t. these edge weights, *e.g.*, :class:`~torch_geometric.nn.conv.GCNConv` or :class:`~torch_geometric.nn.conv.GraphConv`. For non-differentiable models you might need modifications, e.g., see example for :class:`~torch_geometric.nn.conv.GATConv`. The memory overhead is driven by the additional edges (at most :attr:`block_size`). For scalability reasons, the block is drawn with replacement and then the index is made unique. Thus, the actual block size is typically slightly smaller than specified. This attack can be used for both global and local attacks as well as test-time attacks (evasion) and training-time attacks (poisoning). Please see the provided examples. This attack is designed with a focus on node- or graph-classification, however, to adapt to other tasks you most likely only need to provide an appropriate loss and model. However, we currently do not support batching out of the box (sampling needs to be adapted). .. note:: For examples of using the PRBCD Attack, see `examples/contrib/rbcd_attack.py `_ for a test time attack (evasion) or `examples/contrib/rbcd_attack_poisoning.py `_ for a training time (poisoning) attack. Args: model (torch.nn.Module): The GNN module to assess. block_size (int): Number of randomly selected elements in the adjacency matrix to consider. epochs (int, optional): Number of epochs (aborts early if :obj:`mode='greedy'` and budget is satisfied) (default: :obj:`125`) epochs_resampling (int, optional): Number of epochs to resample the random block. (default: obj:`100`) loss (str or callable, optional): A loss to quantify the "strength" of an attack. Note that this function must match the output format of :attr:`model`. By default, it is assumed that the task is classification and that the model returns raw predictions (*i.e.*, no output activation) or uses :obj:`logsoftmax`. Moreover, and the number of predictions should match the number of labels passed to :attr:`attack`. Either pass a callable or one of: :obj:`'masked'`, :obj:`'margin'`, :obj:`'prob_margin'`, :obj:`'tanh_margin'`. (default: :obj:`'prob_margin'`) metric (callable, optional): Second (potentially non-differentiable) loss for monitoring or early stopping (if :obj:`mode='greedy'`). (default: same as :attr:`loss`) lr (float, optional): Learning rate for updating edge weights. Additionally, it is heuristically corrected for :attr:`block_size`, budget (see :attr:`attack`) and graph size. (default: :obj:`1_000`) is_undirected (bool, optional): If :obj:`True` the graph is assumed to be undirected. (default: :obj:`True`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) """ coeffs = { 'max_final_samples': 20, 'max_trials_sampling': 20, 'with_early_stopping': True, 'eps': 1e-7 } def __init__( self, model: torch.nn.Module, block_size: int, epochs: int = 125, epochs_resampling: int = 100, loss: Optional[Union[str, LOSS_TYPE]] = 'prob_margin', metric: Optional[Union[str, LOSS_TYPE]] = None, lr: float = 1_000, is_undirected: bool = True, log: bool = True, **kwargs, ): super().__init__() self.model = model self.block_size = block_size self.epochs = epochs if isinstance(loss, str): if loss == 'masked': self.loss = self._masked_cross_entropy elif loss == 'margin': self.loss = partial(self._margin_loss, reduce='mean') elif loss == 'prob_margin': self.loss = self._probability_margin_loss elif loss == 'tanh_margin': self.loss = self._tanh_margin_loss else: raise ValueError(f'Unknown loss `{loss}`') else: self.loss = loss self.is_undirected = is_undirected self.log = log self.metric = metric or self.loss self.epochs_resampling = epochs_resampling self.lr = lr self.coeffs.update(kwargs) def attack( self, x: Tensor, edge_index: Tensor, labels: Tensor, budget: int, idx_attack: Optional[Tensor] = None, **kwargs, ) -> Tuple[Tensor, Tensor]: """Attack the predictions for the provided model and graph. A subset of predictions may be specified with :attr:`idx_attack`. The attack is allowed to flip (i.e. add or delete) :attr:`budget` edges and will return the strongest perturbation it can find. It returns both the resulting perturbed :attr:`edge_index` as well as the perturbations. Args: x (torch.Tensor): The node feature matrix. edge_index (torch.Tensor): The edge indices. labels (torch.Tensor): The labels. budget (int): The number of allowed perturbations (i.e. number of edges that are flipped at most). idx_attack (torch.Tensor, optional): Filter for predictions/labels. Shape and type must match that it can index :attr:`labels` and the model's predictions. **kwargs (optional): Additional arguments passed to the GNN module. :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) """ self.model.eval() self.device = x.device assert kwargs.get('edge_weight') is None edge_weight = torch.ones(edge_index.size(1), device=self.device) self.edge_index = edge_index.cpu().clone() self.edge_weight = edge_weight.cpu().clone() self.num_nodes = x.size(0) # For collecting attack statistics self.attack_statistics = defaultdict(list) # Prepare attack and define `self.iterable` to iterate over step_sequence = self._prepare(budget) # Loop over the epochs (Algorithm 1, line 5) for step in tqdm(step_sequence, disable=not self.log, desc='Attack'): loss, gradient = self._forward_and_gradient( x, labels, idx_attack, **kwargs) scalars = self._update(step, gradient, x, labels, budget, idx_attack, **kwargs) scalars['loss'] = loss.item() self._append_statistics(scalars) perturbed_edge_index, flipped_edges = self._close( x, labels, budget, idx_attack, **kwargs) assert flipped_edges.size(1) <= budget, ( f'# perturbed edges {flipped_edges.size(1)} ' f'exceeds budget {budget}') return perturbed_edge_index, flipped_edges def _prepare(self, budget: int) -> Iterable[int]: """Prepare attack.""" if self.block_size <= budget: raise ValueError( f'The search space size ({self.block_size}) must be ' f'greater than the number of permutations ({budget})') # For early stopping (not explicitly covered by pseudo code) self.best_metric = float('-Inf') # Sample initial search space (Algorithm 1, line 3-4) self._sample_random_block(budget) steps = range(self.epochs) return steps @torch.no_grad() def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, budget: int, idx_attack: Optional[Tensor] = None, **kwargs) -> Dict[str, float]: """Update edge weights given gradient.""" # Gradient update step (Algorithm 1, line 7) self.block_edge_weight = self._update_edge_weights( budget, self.block_edge_weight, epoch, gradient) # For monitoring pmass_update = torch.clamp(self.block_edge_weight, 0, 1) # Projection to stay within relaxed `L_0` budget # (Algorithm 1, line 8) self.block_edge_weight = self._project(budget, self.block_edge_weight, self.coeffs['eps']) # For monitoring scalars = dict( prob_mass_after_update=pmass_update.sum().item(), prob_mass_after_update_max=pmass_update.max().item(), prob_mass_after_projection=self.block_edge_weight.sum().item(), prob_mass_after_projection_nonzero_weights=( self.block_edge_weight > self.coeffs['eps']).sum().item(), prob_mass_after_projection_max=self.block_edge_weight.max().item()) if not self.coeffs['with_early_stopping']: return scalars # Calculate metric after the current epoch (overhead # for monitoring and early stopping) topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) topk_block_edge_weight[torch.topk(self.block_edge_weight, budget).indices] = 1 edge_index, edge_weight = self._get_modified_adj( self.edge_index, self.edge_weight, self.block_edge_index, topk_block_edge_weight) prediction = self._forward(x, edge_index, edge_weight, **kwargs) metric = self.metric(prediction, labels, idx_attack) # Save best epoch for early stopping # (not explicitly covered by pseudo code) if metric > self.best_metric: self.best_metric = metric self.best_block = self.current_block.cpu().clone() self.best_edge_index = self.block_edge_index.cpu().clone() self.best_pert_edge_weight = self.block_edge_weight.cpu().clone() # Resampling of search space (Algorithm 1, line 9-14) if epoch < self.epochs_resampling - 1: self._resample_random_block(budget) elif epoch == self.epochs_resampling - 1: # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) self.current_block = self.best_block.to(self.device) self.block_edge_index = self.best_edge_index.to(self.device) block_edge_weight = self.best_pert_edge_weight.clone() self.block_edge_weight = block_edge_weight.to(self.device) scalars['metric'] = metric.item() return scalars @torch.no_grad() def _close(self, x: Tensor, labels: Tensor, budget: int, idx_attack: Optional[Tensor] = None, **kwargs) -> Tuple[Tensor, Tensor]: """Clean up and prepare return argument.""" # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) if self.coeffs['with_early_stopping']: self.current_block = self.best_block.to(self.device) self.block_edge_index = self.best_edge_index.to(self.device) self.block_edge_weight = self.best_pert_edge_weight.to(self.device) # Sample final discrete graph (Algorithm 1, line 16) edge_index, flipped_edges = self._sample_final_edges( x, labels, budget, idx_attack=idx_attack, **kwargs) return edge_index, flipped_edges def _forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, **kwargs) -> Tensor: """Forward model.""" return self.model(x, edge_index, edge_weight, **kwargs) def _forward_and_gradient(self, x: Tensor, labels: Tensor, idx_attack: Optional[Tensor] = None, **kwargs) -> Tuple[Tensor, Tensor]: """Forward and update edge weights.""" self.block_edge_weight.requires_grad = True # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}` # (Algorithm 1, line 6 / Algorithm 2, line 7) edge_index, edge_weight = self._get_modified_adj( self.edge_index, self.edge_weight, self.block_edge_index, self.block_edge_weight) # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) prediction = self._forward(x, edge_index, edge_weight, **kwargs) # Calculate loss combining all each node # (Algorithm 1, line 7 / Algorithm 2, line 8) loss = self.loss(prediction, labels, idx_attack) # Retrieve gradient towards the current block # (Algorithm 1, line 7 / Algorithm 2, line 8) gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] return loss, gradient def _get_modified_adj(self, edge_index: Tensor, edge_weight: Tensor, block_edge_index: Tensor, block_edge_weight: Tensor) -> Tuple[Tensor, Tensor]: """Merges adjacency matrix with current block (incl. weights).""" if self.is_undirected: block_edge_index, block_edge_weight = to_undirected( block_edge_index, block_edge_weight, num_nodes=self.num_nodes, reduce='mean') modified_edge_index = torch.cat( (edge_index.to(self.device), block_edge_index), dim=-1) modified_edge_weight = torch.cat( (edge_weight.to(self.device), block_edge_weight)) modified_edge_index, modified_edge_weight = coalesce( modified_edge_index, modified_edge_weight, num_nodes=self.num_nodes, reduce='sum') # Allow (soft) removal of edges is_edge_in_clean_adj = modified_edge_weight > 1 modified_edge_weight[is_edge_in_clean_adj] = ( 2 - modified_edge_weight[is_edge_in_clean_adj]) return modified_edge_index, modified_edge_weight def _filter_self_loops_in_block(self, with_weight: bool): is_not_sl = self.block_edge_index[0] != self.block_edge_index[1] self.current_block = self.current_block[is_not_sl] self.block_edge_index = self.block_edge_index[:, is_not_sl] if with_weight: self.block_edge_weight = self.block_edge_weight[is_not_sl] def _sample_random_block(self, budget: int = 0): for _ in range(self.coeffs['max_trials_sampling']): num_possible_edges = self._num_possible_edges( self.num_nodes, self.is_undirected) self.current_block = torch.randint(num_possible_edges, (self.block_size, ), device=self.device) self.current_block = torch.unique(self.current_block, sorted=True) if self.is_undirected: self.block_edge_index = self._linear_to_triu_idx( self.num_nodes, self.current_block) else: self.block_edge_index = self._linear_to_full_idx( self.num_nodes, self.current_block) self._filter_self_loops_in_block(with_weight=False) self.block_edge_weight = torch.full(self.current_block.shape, self.coeffs['eps'], device=self.device) if self.current_block.size(0) >= budget: return raise RuntimeError('Sampling random block was not successful. ' 'Please decrease `budget`.') def _resample_random_block(self, budget: int): # Keep at most half of the block (i.e. resample low weights) sorted_idx = torch.argsort(self.block_edge_weight) keep_above = (self.block_edge_weight <= self.coeffs['eps']).sum().long() if keep_above < sorted_idx.size(0) // 2: keep_above = sorted_idx.size(0) // 2 sorted_idx = sorted_idx[keep_above:] self.current_block = self.current_block[sorted_idx] # Sample until enough edges were drawn for _ in range(self.coeffs['max_trials_sampling']): n_edges_resample = self.block_size - self.current_block.size(0) num_possible_edges = self._num_possible_edges( self.num_nodes, self.is_undirected) lin_index = torch.randint(num_possible_edges, (n_edges_resample, ), device=self.device) current_block = torch.cat((self.current_block, lin_index)) self.current_block, unique_idx = torch.unique( current_block, sorted=True, return_inverse=True) if self.is_undirected: self.block_edge_index = self._linear_to_triu_idx( self.num_nodes, self.current_block) else: self.block_edge_index = self._linear_to_full_idx( self.num_nodes, self.current_block) # Merge existing weights with new edge weights block_edge_weight_prev = self.block_edge_weight[sorted_idx] self.block_edge_weight = torch.full(self.current_block.shape, self.coeffs['eps'], device=self.device) self.block_edge_weight[ unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev if not self.is_undirected: self._filter_self_loops_in_block(with_weight=True) if self.current_block.size(0) > budget: return raise RuntimeError('Sampling random block was not successful.' 'Please decrease `budget`.') def _sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, idx_attack: Optional[Tensor] = None, **kwargs) -> Tuple[Tensor, Tensor]: best_metric = float('-Inf') block_edge_weight = self.block_edge_weight block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 for i in range(self.coeffs['max_final_samples']): if i == 0: # In first iteration employ top k heuristic instead of sampling sampled_edges = torch.zeros_like(block_edge_weight) sampled_edges[torch.topk(block_edge_weight, budget).indices] = 1 else: sampled_edges = torch.bernoulli(block_edge_weight).float() if sampled_edges.sum() > budget: # Allowed budget is exceeded continue edge_index, edge_weight = self._get_modified_adj( self.edge_index, self.edge_weight, self.block_edge_index, sampled_edges) prediction = self._forward(x, edge_index, edge_weight, **kwargs) metric = self.metric(prediction, labels, idx_attack) # Save best sample if metric > best_metric: best_metric = metric self.block_edge_weight = sampled_edges.clone().cpu() # Recover best sample self.block_edge_weight = self.block_edge_weight.to(self.device) flipped_edges = self.block_edge_index[:, self.block_edge_weight > 0] edge_index, edge_weight = self._get_modified_adj( self.edge_index, self.edge_weight, self.block_edge_index, self.block_edge_weight) edge_mask = edge_weight == 1 edge_index = edge_index[:, edge_mask] return edge_index, flipped_edges def _update_edge_weights(self, budget: int, block_edge_weight: Tensor, epoch: int, gradient: Tensor) -> Tensor: # The learning rate is refined heuristically, s.t. (1) it is # independent of the number of perturbations (assuming an undirected # adjacency matrix) and (2) to decay learning rate during fine-tuning # (i.e. fixed search space). lr = (budget / self.num_nodes * self.lr / np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) return block_edge_weight + lr * gradient @staticmethod def _project(budget: int, values: Tensor, eps: float = 1e-7) -> Tensor: r"""Project :obj:`values`: :math:`budget \ge \sum \Pi_{[0, 1]}(\text{values})`. """ if torch.clamp(values, 0, 1).sum() > budget: left = (values - 1).min() right = values.max() miu = PRBCDAttack._bisection(values, left, right, budget) values = values - miu return torch.clamp(values, min=eps, max=1 - eps) @staticmethod def _bisection(edge_weights: Tensor, a: float, b: float, n_pert: int, eps=1e-5, max_iter=1e3) -> Tensor: """Bisection search for projection.""" def shift(offset: float): return (torch.clamp(edge_weights - offset, 0, 1).sum() - n_pert) miu = a for _ in range(int(max_iter)): miu = (a + b) / 2 # Check if middle point is root if (shift(miu) == 0.0): break # Decide the side to repeat the steps if (shift(miu) * shift(a) < 0): b = miu else: a = miu if ((b - a) <= eps): break return miu @staticmethod def _num_possible_edges(n: int, is_undirected: bool) -> int: """Determine number of possible edges for graph.""" if is_undirected: return n * (n - 1) // 2 else: return int(n**2) # We filter self-loops later @staticmethod def _linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor: """Linear index to upper triangular matrix without diagonal. This is similar to https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498 with number nodes decremented and col index incremented by one. """ nn = n * (n - 1) row_idx = n - 2 - torch.floor( torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long() col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div( (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor') return torch.stack((row_idx, col_idx)) @staticmethod def _linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor: """Linear index to dense matrix including diagonal.""" row_idx = torch.div(lin_idx, n, rounding_mode='floor') col_idx = lin_idx % n return torch.stack((row_idx, col_idx)) @staticmethod def _margin_loss(score: Tensor, labels: Tensor, idx_mask: Optional[Tensor] = None, reduce: Optional[str] = None) -> Tensor: r"""Margin loss between true score and highest non-target score. .. math:: m = - s_{y} + max_{y' \ne y} s_{y'} where :math:`m` is the margin :math:`s` the score and :math:`y` the labels. Args: score (Tensor): Some score (*e.g.*, logits) of shape :obj:`[n_elem, dim]`. labels (LongTensor): The labels of shape :obj:`[n_elem]`. idx_mask (Tensor, optional): To select subset of `score` and `labels` of shape :obj:`[n_select]`. Defaults to None. reduce (str, optional): if :obj:`mean` the result is aggregated. Otherwise, return element wise margin. :rtype: (Tensor) """ if idx_mask is not None: score = score[idx_mask] labels = labels[idx_mask] linear_idx = torch.arange(score.size(0), device=score.device) true_score = score[linear_idx, labels] score = score.clone() score[linear_idx, labels] = float('-Inf') best_non_target_score = score.amax(dim=-1) margin_ = best_non_target_score - true_score if reduce is None: return margin_ return margin_.mean() @staticmethod def _tanh_margin_loss(prediction: Tensor, labels: Tensor, idx_mask: Optional[Tensor] = None) -> Tensor: """Calculate tanh margin loss, a node-classification loss that focuses on nodes next to decision boundary. Args: prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`. labels (LongTensor): The labels of shape :obj:`[n_elem]`. idx_mask (Tensor, optional): To select subset of `score` and `labels` of shape :obj:`[n_select]`. Defaults to None. :rtype: (Tensor) """ log_prob = F.log_softmax(prediction, dim=-1) margin_ = GRBCDAttack._margin_loss(log_prob, labels, idx_mask) loss = torch.tanh(margin_).mean() return loss @staticmethod def _probability_margin_loss(prediction: Tensor, labels: Tensor, idx_mask: Optional[Tensor] = None) -> Tensor: """Calculate probability margin loss, a node-classification loss that focuses on nodes next to decision boundary. See `Are Defenses for Graph Neural Networks Robust? `_ for details. Args: prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`. labels (LongTensor): The labels of shape :obj:`[n_elem]`. idx_mask (Tensor, optional): To select subset of `score` and `labels` of shape :obj:`[n_select]`. Defaults to None. :rtype: (Tensor) """ prob = F.softmax(prediction, dim=-1) margin_ = GRBCDAttack._margin_loss(prob, labels, idx_mask) return margin_.mean() @staticmethod def _masked_cross_entropy(log_prob: Tensor, labels: Tensor, idx_mask: Optional[Tensor] = None) -> Tensor: """Calculate masked cross entropy loss, a node-classification loss that focuses on nodes next to decision boundary. Args: log_prob (Tensor): Log probabilities of shape :obj:`[n_elem, dim]`. labels (LongTensor): The labels of shape :obj:`[n_elem]`. idx_mask (Tensor, optional): To select subset of `score` and `labels` of shape :obj:`[n_select]`. Defaults to None. :rtype: (Tensor) """ if idx_mask is not None: log_prob = log_prob[idx_mask] labels = labels[idx_mask] is_correct = log_prob.argmax(-1) == labels if is_correct.any(): log_prob = log_prob[is_correct] labels = labels[is_correct] return F.nll_loss(log_prob, labels) def _append_statistics(self, mapping: Dict[str, Any]): for key, value in mapping.items(): self.attack_statistics[key].append(value) def __repr__(self) -> str: return f'{self.__class__.__name__}()' class GRBCDAttack(PRBCDAttack): r"""The Greedy Randomized Block Coordinate Descent (GRBCD) adversarial attack from the `Robustness of Graph Neural Networks at Scale `_ paper. GRBCD shares most of the properties and requirements with :class:`PRBCDAttack`. It also uses an efficient gradient based approach. However, it greedily flips edges based on the gradient towards the adjacency matrix. .. note:: For examples of using the GRBCD Attack, see `examples/contrib/rbcd_attack.py `_ for a test time attack (evasion). Args: model (torch.nn.Module): The GNN module to assess. block_size (int): Number of randomly selected elements in the adjacency matrix to consider. epochs (int, optional): Number of epochs (aborts early if :obj:`mode='greedy'` and budget is satisfied) (default: :obj:`125`) loss (str or callable, optional): A loss to quantify the "strength" of an attack. Note that this function must match the output format of :attr:`model`. By default, it is assumed that the task is classification and that the model returns raw predictions (*i.e.*, no output activation) or uses :obj:`logsoftmax`. Moreover, and the number of predictions should match the number of labels passed to :attr:`attack`. Either pass Callable or one of: :obj:`'masked'`, :obj:`'margin'`, :obj:`'prob_margin'`, :obj:`'tanh_margin'`. (default: :obj:`'masked'`) is_undirected (bool, optional): If :obj:`True` the graph is assumed to be undirected. (default: :obj:`True`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) """ coeffs = {'max_trials_sampling': 20, 'eps': 1e-7} def __init__( self, model: torch.nn.Module, block_size: int, epochs: int = 125, loss: Optional[Union[str, LOSS_TYPE]] = 'masked', is_undirected: bool = True, log: bool = True, **kwargs, ): super().__init__(model, block_size, epochs, loss=loss, is_undirected=is_undirected, log=log, **kwargs) @torch.no_grad() def _prepare(self, budget: int) -> List[int]: """Prepare attack.""" self.flipped_edges = self.edge_index.new_empty(2, 0).to(self.device) # Determine the number of edges to be flipped in each attach step/epoch step_size = budget // self.epochs if step_size > 0: steps = self.epochs * [step_size] for i in range(budget % self.epochs): steps[i] += 1 else: steps = [1] * budget # Sample initial search space (Algorithm 2, line 3-4) self._sample_random_block(step_size) return steps @torch.no_grad() def _update(self, step_size: int, gradient: Tensor, *args, **kwargs) -> Dict[str, Any]: """Update edge weights given gradient.""" _, topk_edge_index = torch.topk(gradient, step_size) flip_edge_index = self.block_edge_index[:, topk_edge_index] flip_edge_weight = torch.ones_like(flip_edge_index[0], dtype=torch.float32) self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), axis=-1) if self.is_undirected: flip_edge_index, flip_edge_weight = to_undirected( flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes, reduce='mean') edge_index = torch.cat( (self.edge_index.to(self.device), flip_edge_index.to(self.device)), dim=-1) edge_weight = torch.cat((self.edge_weight.to(self.device), flip_edge_weight.to(self.device))) edge_index, edge_weight = coalesce(edge_index, edge_weight, num_nodes=self.num_nodes, reduce='sum') is_one_mask = torch.isclose(edge_weight, torch.tensor(1.)) self.edge_index = edge_index[:, is_one_mask] self.edge_weight = edge_weight[is_one_mask] # self.edge_weight = torch.ones_like(self.edge_weight) assert self.edge_index.size(1) == self.edge_weight.size(0) # Sample initial search space (Algorithm 2, line 3-4) self._sample_random_block(step_size) # Return debug information scalars = { 'number_positive_entries_in_gradient': (gradient > 0).sum().item() } return scalars def _close(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: """Clean up and prepare return argument.""" return self.edge_index, self.flipped_edges ================================================ FILE: torch_geometric/contrib/transforms/__init__.py ================================================ __all__ = classes = [] ================================================ FILE: torch_geometric/data/__init__.py ================================================ # flake8: noqa import torch import torch_geometric.typing from .feature_store import FeatureStore, TensorAttr from .graph_store import GraphStore, EdgeAttr, EdgeLayout from .data import Data from .hetero_data import HeteroData from .batch import Batch from .temporal import TemporalData from .database import Database, SQLiteDatabase, RocksDatabase from .dataset import Dataset from .in_memory_dataset import InMemoryDataset from .on_disk_dataset import OnDiskDataset from .makedirs import makedirs from .download import download_url, download_google_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz from torch_geometric.lazy_loader import LazyLoader data_classes = [ 'Data', 'HeteroData', 'Batch', 'TemporalData', 'Dataset', 'InMemoryDataset', 'OnDiskDataset', ] remote_backend_classes = [ 'FeatureStore', 'GraphStore', 'TensorAttr', 'EdgeAttr', ] database_classes = [ 'Database', 'SQLiteDatabase', 'RocksDatabase', ] helper_functions = [ 'makedirs', 'download_url', 'download_google_url', 'extract_tar', 'extract_zip', 'extract_bz2', 'extract_gz', ] __all__ = data_classes + remote_backend_classes + helper_functions lightning = LazyLoader('lightning', globals(), 'torch_geometric.data.lightning') from torch_geometric.deprecation import deprecated from torch_geometric.loader import NeighborSampler from torch_geometric.loader import ClusterData from torch_geometric.loader import ClusterLoader from torch_geometric.loader import GraphSAINTSampler from torch_geometric.loader import GraphSAINTNodeSampler from torch_geometric.loader import GraphSAINTEdgeSampler from torch_geometric.loader import GraphSAINTRandomWalkSampler from torch_geometric.loader import ShaDowKHopSampler from torch_geometric.loader import RandomNodeLoader from torch_geometric.loader import DataLoader from torch_geometric.loader import DataListLoader from torch_geometric.loader import DenseDataLoader # Serialization ############################################################### if torch_geometric.typing.WITH_PT24: torch.serialization.add_safe_globals([ Data, HeteroData, TemporalData, ClusterData, TensorAttr, EdgeAttr, EdgeLayout, ]) # Deprecations ################################################################ NeighborSampler = deprecated( # type: ignore details="use 'loader.NeighborSampler' instead", func_name='data.NeighborSampler', )(NeighborSampler) ClusterData = deprecated( # type: ignore details="use 'loader.ClusterData' instead", func_name='data.ClusterData', )(ClusterData) ClusterLoader = deprecated( # type: ignore details="use 'loader.ClusterLoader' instead", func_name='data.ClusterLoader', )(ClusterLoader) GraphSAINTSampler = deprecated( # type: ignore details="use 'loader.GraphSAINTSampler' instead", func_name='data.GraphSAINTSampler', )(GraphSAINTSampler) GraphSAINTNodeSampler = deprecated( # type: ignore details="use 'loader.GraphSAINTNodeSampler' instead", func_name='data.GraphSAINTNodeSampler', )(GraphSAINTNodeSampler) GraphSAINTEdgeSampler = deprecated( # type: ignore details="use 'loader.GraphSAINTEdgeSampler' instead", func_name='data.GraphSAINTEdgeSampler', )(GraphSAINTEdgeSampler) GraphSAINTRandomWalkSampler = deprecated( # type: ignore details="use 'loader.GraphSAINTRandomWalkSampler' instead", func_name='data.GraphSAINTRandomWalkSampler', )(GraphSAINTRandomWalkSampler) ShaDowKHopSampler = deprecated( # type: ignore details="use 'loader.ShaDowKHopSampler' instead", func_name='data.ShaDowKHopSampler', )(ShaDowKHopSampler) RandomNodeSampler = deprecated( details="use 'loader.RandomNodeLoader' instead", func_name='data.RandomNodeSampler', )(RandomNodeLoader) DataLoader = deprecated( # type: ignore details="use 'loader.DataLoader' instead", func_name='data.DataLoader', )(DataLoader) DataListLoader = deprecated( # type: ignore details="use 'loader.DataListLoader' instead", func_name='data.DataListLoader', )(DataListLoader) DenseDataLoader = deprecated( # type: ignore details="use 'loader.DenseDataLoader' instead", func_name='data.DenseDataLoader', )(DenseDataLoader) ================================================ FILE: torch_geometric/data/batch.py ================================================ import inspect from collections.abc import Sequence from typing import Any, List, Optional, Type, Union import numpy as np import torch from torch import Tensor from typing_extensions import Self from torch_geometric.data.collate import collate from torch_geometric.data.data import BaseData, Data from torch_geometric.data.dataset import IndexType from torch_geometric.data.separate import separate class DynamicInheritance(type): # A meta class that sets the base class of a `Batch` object, e.g.: # * `Batch(Data)` in case `Data` objects are batched together # * `Batch(HeteroData)` in case `HeteroData` objects are batched together def __call__(cls, *args: Any, **kwargs: Any) -> Any: base_cls = kwargs.pop('_base_cls', Data) if issubclass(base_cls, Batch): new_cls = base_cls else: name = f'{base_cls.__name__}{cls.__name__}' # NOTE `MetaResolver` is necessary to resolve metaclass conflict # problems between `DynamicInheritance` and the metaclass of # `base_cls`. In particular, it creates a new common metaclass # from the defined metaclasses. class MetaResolver(type(cls), type(base_cls)): # type: ignore pass if name not in globals(): globals()[name] = MetaResolver(name, (cls, base_cls), {}) new_cls = globals()[name] params = list(inspect.signature(base_cls.__init__).parameters.items()) for i, (k, v) in enumerate(params[1:]): if k == 'args' or k == 'kwargs': continue if i < len(args) or k in kwargs: continue if v.default is not inspect.Parameter.empty: continue kwargs[k] = None return super(DynamicInheritance, new_cls).__call__(*args, **kwargs) class DynamicInheritanceGetter: def __call__(self, cls: Type, base_cls: Type) -> Self: return cls(_base_cls=base_cls) class Batch(metaclass=DynamicInheritance): r"""A data object describing a batch of graphs as one big (disconnected) graph. Inherits from :class:`torch_geometric.data.Data` or :class:`torch_geometric.data.HeteroData`. In addition, single graphs can be identified via the assignment vector :obj:`batch`, which maps each node to its respective graph identifier. :pyg:`PyG` allows modification to the underlying batching procedure by overwriting the :meth:`~Data.__inc__` and :meth:`~Data.__cat_dim__` functionalities. The :meth:`~Data.__inc__` method defines the incremental count between two consecutive graph attributes. By default, :pyg:`PyG` increments attributes by the number of nodes whenever their attribute names contain the substring :obj:`index` (for historical reasons), which comes in handy for attributes such as :obj:`edge_index` or :obj:`node_index`. However, note that this may lead to unexpected behavior for attributes whose names contain the substring :obj:`index` but should not be incremented. To make sure, it is best practice to always double-check the output of batching. Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together. """ @classmethod def from_data_list( cls, data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, ) -> Self: r"""Constructs a :class:`~torch_geometric.data.Batch` object from a list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects. The assignment vector :obj:`batch` is created on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`. """ batch, slice_dict, inc_dict = collate( cls, data_list=data_list, increment=True, add_batch=not isinstance(data_list[0], Batch), follow_batch=follow_batch, exclude_keys=exclude_keys, ) batch._num_graphs = len(data_list) # type: ignore batch._slice_dict = slice_dict # type: ignore batch._inc_dict = inc_dict # type: ignore return batch def get_example(self, idx: int) -> BaseData: r"""Gets the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial object. """ if not hasattr(self, '_slice_dict'): raise RuntimeError( "Cannot reconstruct 'Data' object from 'Batch' because " "'Batch' was not created via 'Batch.from_data_list()'") data = separate( cls=self.__class__.__bases__[-1], batch=self, idx=idx, slice_dict=self._slice_dict, inc_dict=self._inc_dict, decrement=True, ) return data def index_select(self, idx: IndexType) -> List[BaseData]: r"""Creates a subset of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects from specified indices :obj:`idx`. Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects. """ index: Sequence[int] if isinstance(idx, slice): index = list(range(self.num_graphs)[idx]) elif isinstance(idx, Tensor) and idx.dtype == torch.long: index = idx.flatten().tolist() elif isinstance(idx, Tensor) and idx.dtype == torch.bool: index = idx.flatten().nonzero(as_tuple=False).flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: index = idx.flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == bool: index = idx.flatten().nonzero()[0].flatten().tolist() elif isinstance(idx, Sequence) and not isinstance(idx, str): index = idx else: raise IndexError( f"Only slices (':'), list, tuples, torch.tensor and " f"np.ndarray of dtype long or bool are valid indices (got " f"'{type(idx).__name__}')") return [self.get_example(i) for i in index] def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): return self.get_example(idx) # type: ignore elif isinstance(idx, str) or (isinstance(idx, tuple) and isinstance(idx[0], str)): # Accessing attributes or node/edge types: return super().__getitem__(idx) # type: ignore else: return self.index_select(idx) def to_data_list(self) -> List[BaseData]: r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects from the :class:`~torch_geometric.data.Batch` object. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects. """ return [self.get_example(i) for i in range(self.num_graphs)] @property def num_graphs(self) -> int: """Returns the number of graphs in the batch.""" if hasattr(self, '_num_graphs'): return self._num_graphs elif hasattr(self, 'ptr'): return self.ptr.numel() - 1 elif hasattr(self, 'batch'): return int(self.batch.max()) + 1 else: raise ValueError("Can not infer the number of graphs") @property def batch_size(self) -> int: r"""Alias for :obj:`num_graphs`.""" return self.num_graphs def __len__(self) -> int: return self.num_graphs def __reduce__(self) -> Any: state = self.__dict__.copy() return DynamicInheritanceGetter(), self.__class__.__bases__, state ================================================ FILE: torch_geometric/data/collate.py ================================================ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import ( Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, ) import torch from torch import Tensor import torch_geometric.typing from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage, NodeStorage from torch_geometric.edge_index import SortOrder from torch_geometric.typing import ( SparseTensor, TensorFrame, torch_frame, torch_sparse, ) from torch_geometric.utils import cumsum, is_sparse, is_torch_sparse_tensor from torch_geometric.utils.sparse import cat T = TypeVar('T') SliceDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]] IncDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]] def collate( cls: Type[T], data_list: List[BaseData], increment: bool = True, add_batch: bool = True, follow_batch: Optional[Iterable[str]] = None, exclude_keys: Optional[Iterable[str]] = None, ) -> Tuple[T, SliceDictType, IncDictType]: # Collates a list of `data` objects into a single object of type `cls`. # `collate` can handle both homogeneous and heterogeneous data objects by # individually collating all their stores. # In addition, `collate` can handle nested data structures such as # dictionaries and lists. if not isinstance(data_list, (list, tuple)): # Materialize `data_list` to keep the `_parent` weakref alive. data_list = list(data_list) if cls != data_list[0].__class__: # Dynamic inheritance. out = cls(_base_cls=data_list[0].__class__) # type: ignore else: out = cls() # Create empty stores: out.stores_as(data_list[0]) # type: ignore follow_batch = set(follow_batch or []) exclude_keys = set(exclude_keys or []) # Group all storage objects of every data object in the `data_list` by key, # i.e. `key_to_stores = { key: [store_1, store_2, ...], ... }`: key_to_stores = defaultdict(list) for data in data_list: for store in data.stores: key_to_stores[store._key].append(store) # With this, we iterate over each list of storage objects and recursively # collate all its attributes into a unified representation: # We maintain two additional dictionaries: # * `slice_dict` stores a compressed index representation of each attribute # and is needed to re-construct individual elements from mini-batches. # * `inc_dict` stores how individual elements need to be incremented, e.g., # `edge_index` is incremented by the cumulated sum of previous elements. # We also need to make use of `inc_dict` when re-constructuing individual # elements as attributes that got incremented need to be decremented # while separating to obtain original values. device: Optional[torch.device] = None slice_dict: SliceDictType = {} inc_dict: IncDictType = {} for out_store in out.stores: # type: ignore key = out_store._key stores = key_to_stores[key] for attr in stores[0].keys(): if attr in exclude_keys: # Do not include top-level attribute. continue values = [store[attr] for store in stores] # The `num_nodes` attribute needs special treatment, as we need to # sum their values up instead of merging them to a list: if attr == 'num_nodes': out_store._num_nodes = values out_store.num_nodes = sum(values) continue # Skip batching of `ptr` vectors for now: if attr == 'ptr': continue # Collate attributes into a unified representation: value, slices, incs = _collate(attr, values, data_list, stores, increment) # If parts of the data are already on GPU, make sure that auxiliary # data like `batch` or `ptr` are also created on GPU: if isinstance(value, Tensor) and value.is_cuda: device = value.device out_store[attr] = value if key is not None: # Heterogeneous: store_slice_dict = slice_dict.get(key, {}) assert isinstance(store_slice_dict, dict) store_slice_dict[attr] = slices slice_dict[key] = store_slice_dict store_inc_dict = inc_dict.get(key, {}) assert isinstance(store_inc_dict, dict) store_inc_dict[attr] = incs inc_dict[key] = store_inc_dict else: # Homogeneous: slice_dict[attr] = slices inc_dict[attr] = incs # Add an additional batch vector for the given attribute: if attr in follow_batch: batch, ptr = _batch_and_ptr(slices, device) out_store[f'{attr}_batch'] = batch out_store[f'{attr}_ptr'] = ptr # In case of node-level storages, we add a top-level batch vector it: if (add_batch and isinstance(stores[0], NodeStorage) and stores[0].can_infer_num_nodes): repeats = [store.num_nodes or 0 for store in stores] out_store.batch = repeat_interleave(repeats, device=device) out_store.ptr = cumsum(torch.tensor(repeats, device=device)) return out, slice_dict, inc_dict def _collate( key: str, values: List[Any], data_list: List[BaseData], stores: List[BaseStorage], increment: bool, ) -> Tuple[Any, Any, Any]: elem = values[0] if isinstance(elem, Tensor) and not is_sparse(elem): # Concatenate a list of `torch.Tensor` along the `cat_dim`. # NOTE: We need to take care of incrementing elements appropriately. key = str(key) cat_dim = data_list[0].__cat_dim__(key, elem, stores[0]) if cat_dim is None or elem.dim() == 0: values = [value.unsqueeze(0) for value in values] sizes = torch.tensor([value.size(cat_dim or 0) for value in values]) slices = cumsum(sizes) if increment: incs = get_incs(key, values, data_list, stores) if incs.dim() > 1 or int(incs[-1]) != 0: values = [ value + inc.to(value.device) for value, inc in zip(values, incs) ] else: incs = None if getattr(elem, 'is_nested', False): tensors = [] for nested_tensor in values: tensors.extend(nested_tensor.unbind()) value = torch.nested.nested_tensor(tensors) return value, slices, incs out = None if (torch.utils.data.get_worker_info() is not None and not isinstance(elem, (Index, EdgeIndex))): # Write directly into shared memory to avoid an extra copy: numel = sum(value.numel() for value in values) if torch_geometric.typing.WITH_PT20: storage = elem.untyped_storage()._new_shared( numel * elem.element_size(), device=elem.device) else: storage = elem.storage()._new_shared(numel, device=elem.device) shape = list(elem.size()) if cat_dim is None or elem.dim() == 0: shape = [len(values)] + shape else: shape[cat_dim] = int(slices[-1]) out = elem.new(storage).resize_(*shape) value = torch.cat(values, dim=cat_dim or 0, out=out) if increment and isinstance(value, Index) and values[0].is_sorted: # Check whether the whole `Index` is sorted: if (value.diff() >= 0).all(): value._is_sorted = True if increment and isinstance(value, EdgeIndex) and values[0].is_sorted: # Check whether the whole `EdgeIndex` is sorted by row: if values[0].is_sorted_by_row and (value[0].diff() >= 0).all(): value._sort_order = SortOrder.ROW # Check whether the whole `EdgeIndex` is sorted by column: elif values[0].is_sorted_by_col and (value[1].diff() >= 0).all(): value._sort_order = SortOrder.COL return value, slices, incs elif isinstance(elem, TensorFrame): key = str(key) sizes = torch.tensor([value.num_rows for value in values]) slices = cumsum(sizes) value = torch_frame.cat(values, dim=0) return value, slices, None elif is_sparse(elem) and increment: # Concatenate a list of `SparseTensor` along the `cat_dim`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. key = str(key) cat_dim = data_list[0].__cat_dim__(key, elem, stores[0]) cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim repeats = [[value.size(dim) for dim in cat_dims] for value in values] slices = cumsum(torch.tensor(repeats)) if is_torch_sparse_tensor(elem): value = cat(values, dim=cat_dim) else: value = torch_sparse.cat(values, dim=cat_dim) return value, slices, None elif isinstance(elem, (int, float)): # Convert a list of numerical values to a `torch.Tensor`. value = torch.tensor(values) if increment: incs = get_incs(key, values, data_list, stores) if int(incs[-1]) != 0: value.add_(incs) else: incs = None slices = torch.arange(len(values) + 1) return value, slices, incs elif isinstance(elem, Mapping): # Recursively collate elements of dictionaries. value_dict, slice_dict, inc_dict = {}, {}, {} for key in elem.keys(): value_dict[key], slice_dict[key], inc_dict[key] = _collate( key, [v[key] for v in values], data_list, stores, increment) return value_dict, slice_dict, inc_dict elif (isinstance(elem, Sequence) and not isinstance(elem, str) and len(elem) > 0 and isinstance(elem[0], (Tensor, SparseTensor))): # Recursively collate elements of lists. value_list, slice_list, inc_list = [], [], [] for i in range(len(elem)): value, slices, incs = _collate(key, [v[i] for v in values], data_list, stores, increment) value_list.append(value) slice_list.append(slices) inc_list.append(incs) return value_list, slice_list, inc_list else: # Other-wise, just return the list of values as it is. slices = torch.arange(len(values) + 1) return values, slices, None def _batch_and_ptr( slices: Any, device: Optional[torch.device] = None, ) -> Tuple[Any, Any]: if (isinstance(slices, Tensor) and slices.dim() == 1): # Default case, turn slices tensor into batch. repeats = slices[1:] - slices[:-1] batch = repeat_interleave(repeats.tolist(), device=device) ptr = cumsum(repeats.to(device)) return batch, ptr elif isinstance(slices, Mapping): # Recursively batch elements of dictionaries. batch, ptr = {}, {} for k, v in slices.items(): batch[k], ptr[k] = _batch_and_ptr(v, device) return batch, ptr elif (isinstance(slices, Sequence) and not isinstance(slices, str) and isinstance(slices[0], Tensor)): # Recursively batch elements of lists. batch, ptr = [], [] for s in slices: sub_batch, sub_ptr = _batch_and_ptr(s, device) batch.append(sub_batch) ptr.append(sub_ptr) return batch, ptr else: # Failure of batching, usually due to slices.dim() != 1 return None, None ############################################################################### def repeat_interleave( repeats: List[int], device: Optional[torch.device] = None, ) -> Tensor: outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)] return torch.cat(outs, dim=0) def get_incs(key, values: List[Any], data_list: List[BaseData], stores: List[BaseStorage]) -> Tensor: repeats = [ data.__inc__(key, value, store) for value, data, store in zip(values, data_list, stores) ] if isinstance(repeats[0], Tensor): repeats = torch.stack(repeats, dim=0) else: repeats = torch.tensor(repeats) return cumsum(repeats[:-1]) ================================================ FILE: torch_geometric/data/data.py ================================================ import copy import warnings from collections import defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass from itertools import chain from typing import ( Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, overload, ) import numpy as np import torch from torch import Tensor from typing_extensions import Self from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr from torch_geometric.data.feature_store import _FieldStatus from torch_geometric.data.graph_store import EdgeLayout from torch_geometric.data.storage import ( BaseStorage, EdgeStorage, GlobalStorage, NodeStorage, ) from torch_geometric.deprecation import deprecated from torch_geometric.index import Index from torch_geometric.typing import ( EdgeTensorType, EdgeType, FeatureTensorType, NodeType, OptTensor, SparseTensor, TensorFrame, ) from torch_geometric.utils import is_sparse, select, subgraph class BaseData: def __getattr__(self, key: str) -> Any: raise NotImplementedError def __setattr__(self, key: str, value: Any): raise NotImplementedError def __delattr__(self, key: str): raise NotImplementedError def __getitem__(self, key: str) -> Any: raise NotImplementedError def __setitem__(self, key: str, value: Any): raise NotImplementedError def __delitem__(self, key: str): raise NotImplementedError def __copy__(self): raise NotImplementedError def __deepcopy__(self, memo): raise NotImplementedError def __repr__(self) -> str: raise NotImplementedError def stores_as(self, data: Self): raise NotImplementedError @property def stores(self) -> List[BaseStorage]: raise NotImplementedError @property def node_stores(self) -> List[NodeStorage]: raise NotImplementedError @property def edge_stores(self) -> List[EdgeStorage]: raise NotImplementedError def to_dict(self) -> Dict[str, Any]: r"""Returns a dictionary of stored key/value pairs.""" raise NotImplementedError def to_namedtuple(self) -> NamedTuple: r"""Returns a :obj:`NamedTuple` of stored key/value pairs.""" raise NotImplementedError def update(self, data: Self) -> Self: r"""Updates the data object with the elements from another data object. Added elements will override existing ones (in case of duplicates). """ raise NotImplementedError def concat(self, data: Self) -> Self: r"""Concatenates :obj:`self` with another :obj:`data` object. All values needs to have matching shapes at non-concat dimensions. """ out = copy.copy(self) for store, other_store in zip(out.stores, data.stores): store.concat(other_store) return out def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: r"""Returns the dimension for which the value :obj:`value` of the attribute :obj:`key` will get concatenated when creating mini-batches using :class:`torch_geometric.loader.DataLoader`. .. note:: This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute. """ raise NotImplementedError def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: r"""Returns the incremental count to cumulatively increase the value :obj:`value` of the attribute :obj:`key` when creating mini-batches using :class:`torch_geometric.loader.DataLoader`. .. note:: This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute. """ raise NotImplementedError def debug(self): raise NotImplementedError ########################################################################### def keys(self) -> List[str]: r"""Returns a list of all graph attribute names.""" out = [] for store in self.stores: out += list(store.keys()) return list(set(out)) def __len__(self) -> int: r"""Returns the number of graph attributes.""" return len(self.keys()) def __contains__(self, key: str) -> bool: r"""Returns :obj:`True` if the attribute :obj:`key` is present in the data. """ return key in self.keys() def __getstate__(self) -> Dict[str, Any]: return self.__dict__ def __setstate__(self, mapping: Dict[str, Any]): for key, value in mapping.items(): self.__dict__[key] = value @property def num_nodes(self) -> Optional[int]: r"""Returns the number of nodes in the graph. .. note:: The number of nodes in the data object is automatically inferred in case node-level attributes are present, *e.g.*, :obj:`data.x`. In some cases, however, a graph may only be given without any node-level attributes. :pyg:`PyG` then *guesses* the number of nodes according to :obj:`edge_index.max().item() + 1`. However, in case there exists isolated nodes, this number does not have to be correct which can result in unexpected behavior. Thus, we recommend to set the number of nodes in your data object explicitly via :obj:`data.num_nodes = ...`. You will be given a warning that requests you to do so. """ try: return sum([v.num_nodes for v in self.node_stores]) except TypeError: return None @overload def size(self) -> Tuple[Optional[int], Optional[int]]: pass @overload def size(self, dim: int) -> Optional[int]: pass def size( self, dim: Optional[int] = None ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: r"""Returns the size of the adjacency matrix induced by the graph.""" size = (self.num_nodes, self.num_nodes) return size if dim is None else size[dim] @property def num_edges(self) -> int: r"""Returns the number of edges in the graph. For undirected graphs, this will return the number of bi-directional edges, which is double the amount of unique edges. """ return sum([v.num_edges for v in self.edge_stores]) def node_attrs(self) -> List[str]: r"""Returns all node-level tensor attribute names.""" return list(set(chain(*[s.node_attrs() for s in self.node_stores]))) def edge_attrs(self) -> List[str]: r"""Returns all edge-level tensor attribute names.""" return list(set(chain(*[s.edge_attrs() for s in self.edge_stores]))) @property def node_offsets(self) -> Dict[NodeType, int]: out: Dict[NodeType, int] = {} offset: int = 0 for store in self.node_stores: out[store._key] = offset offset = offset + store.num_nodes return out def generate_ids(self): r"""Generates and sets :obj:`n_id` and :obj:`e_id` attributes to assign each node and edge to a continuously ascending and unique ID. """ for store in self.node_stores: store.n_id = torch.arange(store.num_nodes) for store in self.edge_stores: store.e_id = torch.arange(store.num_edges) def is_sorted(self, sort_by_row: bool = True) -> bool: r"""Returns :obj:`True` if edge indices :obj:`edge_index` are sorted. Args: sort_by_row (bool, optional): If set to :obj:`False`, will require column-wise order/by destination node order of :obj:`edge_index`. (default: :obj:`True`) """ return all( [store.is_sorted(sort_by_row) for store in self.edge_stores]) def sort(self, sort_by_row: bool = True) -> Self: r"""Sorts edge indices :obj:`edge_index` and their corresponding edge features. Args: sort_by_row (bool, optional): If set to :obj:`False`, will sort :obj:`edge_index` in column-wise order/by destination node. (default: :obj:`True`) """ out = copy.copy(self) for store in out.edge_stores: store.sort(sort_by_row) return out def is_coalesced(self) -> bool: r"""Returns :obj:`True` if edge indices :obj:`edge_index` are sorted and do not contain duplicate entries. """ return all([store.is_coalesced() for store in self.edge_stores]) def coalesce(self) -> Self: r"""Sorts and removes duplicated entries from edge indices :obj:`edge_index`. """ out = copy.copy(self) for store in out.edge_stores: store.coalesce() return out def is_sorted_by_time(self) -> bool: r"""Returns :obj:`True` if :obj:`time` is sorted.""" return all([store.is_sorted_by_time() for store in self.stores]) def sort_by_time(self) -> Self: r"""Sorts data associated with :obj:`time` according to :obj:`time`.""" out = copy.copy(self) for store in out.stores: store.sort_by_time() return out def snapshot( self, start_time: Union[float, int], end_time: Union[float, int], attr: str = 'time', ) -> Self: r"""Returns a snapshot of :obj:`data` to only hold events that occurred in period :obj:`[start_time, end_time]`. """ out = copy.copy(self) for store in out.stores: store.snapshot(start_time, end_time, attr) return out def up_to(self, end_time: Union[float, int]) -> Self: r"""Returns a snapshot of :obj:`data` to only hold events that occurred up to :obj:`end_time` (inclusive of :obj:`edge_time`). """ out = copy.copy(self) for store in out.stores: store.up_to(end_time) return out def has_isolated_nodes(self) -> bool: r"""Returns :obj:`True` if the graph contains isolated nodes.""" return any([store.has_isolated_nodes() for store in self.edge_stores]) def has_self_loops(self) -> bool: """Returns :obj:`True` if the graph contains self-loops.""" return any([store.has_self_loops() for store in self.edge_stores]) def is_undirected(self) -> bool: r"""Returns :obj:`True` if graph edges are undirected.""" return all([store.is_undirected() for store in self.edge_stores]) def is_directed(self) -> bool: r"""Returns :obj:`True` if graph edges are directed.""" return not self.is_undirected() def apply_(self, func: Callable, *args: str): r"""Applies the in-place function :obj:`func`, either to all attributes or only the ones given in :obj:`*args`. """ for store in self.stores: store.apply_(func, *args) return self def apply(self, func: Callable, *args: str): r"""Applies the function :obj:`func`, either to all attributes or only the ones given in :obj:`*args`. """ for store in self.stores: store.apply(func, *args) return self def clone(self, *args: str): r"""Performs cloning of tensors, either for all attributes or only the ones given in :obj:`*args`. """ return copy.copy(self).apply(lambda x: x.clone(), *args) def contiguous(self, *args: str): r"""Ensures a contiguous memory layout, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.contiguous(), *args) def to(self, device: Union[int, str, torch.device], *args: str, non_blocking: bool = False): r"""Performs tensor device conversion, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply( lambda x: x.to(device=device, non_blocking=non_blocking), *args) def cpu(self, *args: str): r"""Copies attributes to CPU memory, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.cpu(), *args) def cuda(self, device: Optional[Union[int, str]] = None, *args: str, non_blocking: bool = False): r"""Copies attributes to CUDA memory, either for all attributes or only the ones given in :obj:`*args`. """ # Some PyTorch tensor like objects require a default value for `cuda`: device = 'cuda' if device is None else device return self.apply(lambda x: x.cuda(device, non_blocking=non_blocking), *args) def pin_memory(self, *args: str): r"""Copies attributes to pinned memory, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.pin_memory(), *args) def share_memory_(self, *args: str): r"""Moves attributes to shared memory, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply_(lambda x: x.share_memory_(), *args) def detach_(self, *args: str): r"""Detaches attributes from the computation graph, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply_(lambda x: x.detach_(), *args) def detach(self, *args: str): r"""Detaches attributes from the computation graph by creating a new tensor, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.detach(), *args) def requires_grad_(self, *args: str, requires_grad: bool = True): r"""Tracks gradient computation, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply_( lambda x: x.requires_grad_(requires_grad=requires_grad), *args) def record_stream(self, stream: torch.cuda.Stream, *args: str): r"""Ensures that the tensor memory is not reused for another tensor until all current work queued on :obj:`stream` has been completed, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply_(lambda x: x.record_stream(stream), *args) @property def is_cuda(self) -> bool: r"""Returns :obj:`True` if any :class:`torch.Tensor` attribute is stored on the GPU, :obj:`False` otherwise. """ for store in self.stores: for value in store.values(): if isinstance(value, Tensor) and value.is_cuda: return True return False # Deprecated functions #################################################### @deprecated(details="use 'has_isolated_nodes' instead") def contains_isolated_nodes(self) -> bool: return self.has_isolated_nodes() @deprecated(details="use 'has_self_loops' instead") def contains_self_loops(self) -> bool: return self.has_self_loops() ############################################################################### @dataclass class DataTensorAttr(TensorAttr): r"""Tensor attribute for `Data` without group name.""" def __init__( self, attr_name=_FieldStatus.UNSET, index=None, ): super().__init__(None, attr_name, index) @dataclass class DataEdgeAttr(EdgeAttr): r"""Edge attribute class for `Data` without edge type.""" def __init__( self, layout: Optional[EdgeLayout] = None, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): super().__init__(None, layout, is_sorted, size) ############################################################################### class Data(BaseData, FeatureStore, GraphStore): r"""A data object describing a homogeneous graph. The data object can hold node-level, link-level and graph-level attributes. In general, :class:`~torch_geometric.data.Data` tries to mimic the behavior of a regular :python:`Python` dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities. See `here `__ for the accompanying tutorial. .. code-block:: python from torch_geometric.data import Data data = Data(x=x, edge_index=edge_index, ...) # Add additional arguments to `data`: data.train_idx = torch.tensor([...], dtype=torch.long) data.test_mask = torch.tensor([...], dtype=torch.bool) # Analyzing the graph structure: data.num_nodes >>> 23 data.is_directed() >>> False # PyTorch tensor functionality: data = data.pin_memory() data = data.to('cuda:0', non_blocking=True) Args: x (torch.Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) edge_index (LongTensor, optional): Graph connectivity in COO format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) edge_attr (torch.Tensor, optional): Edge feature matrix with shape :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) y (torch.Tensor, optional): Graph-level or node-level ground-truth labels with arbitrary shape. (default: :obj:`None`) pos (torch.Tensor, optional): Node position matrix with shape :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) time (torch.Tensor, optional): The timestamps for each event with shape :obj:`[num_edges]` or :obj:`[num_nodes]`. (default: :obj:`None`) **kwargs (optional): Additional attributes. """ def __init__( self, x: Optional[Tensor] = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: Optional[Union[Tensor, int, float]] = None, pos: OptTensor = None, time: OptTensor = None, **kwargs, ): # `Data` doesn't support group_name, so we need to adjust `TensorAttr` # accordingly here to avoid requiring `group_name` to be set: super().__init__(tensor_attr_cls=DataTensorAttr) # `Data` doesn't support edge_type, so we need to adjust `EdgeAttr` # accordingly here to avoid requiring `edge_type` to be set: GraphStore.__init__(self, edge_attr_cls=DataEdgeAttr) self.__dict__['_store'] = GlobalStorage(_parent=self) if x is not None: self.x = x if edge_index is not None: self.edge_index = edge_index if edge_attr is not None: self.edge_attr = edge_attr if y is not None: self.y = y if pos is not None: self.pos = pos if time is not None: self.time = time for key, value in kwargs.items(): setattr(self, key, value) def __getattr__(self, key: str) -> Any: if '_store' not in self.__dict__: raise RuntimeError( "The 'data' object was created by an older version of PyG. " "If this error occurred while loading an already existing " "dataset, remove the 'processed/' directory in the dataset's " "root folder and try again.") return getattr(self._store, key) def __setattr__(self, key: str, value: Any): propobj = getattr(self.__class__, key, None) if propobj is not None and getattr(propobj, 'fset', None) is not None: propobj.fset(self, value) else: setattr(self._store, key, value) def __delattr__(self, key: str): delattr(self._store, key) # TODO consider supporting the feature store interface for # __getitem__, __setitem__, and __delitem__ so, for example, we # can accept key: Union[str, TensorAttr] in __getitem__. def __getitem__(self, key: str) -> Any: return self._store[key] def __setitem__(self, key: str, value: Any): self._store[key] = value def __delitem__(self, key: str): if key in self._store: del self._store[key] def __copy__(self): out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = value out.__dict__['_store'] = copy.copy(self._store) out._store._parent = out return out def __deepcopy__(self, memo): out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = copy.deepcopy(value, memo) out._store._parent = out return out def __repr__(self) -> str: cls = self.__class__.__name__ has_dict = any([isinstance(v, Mapping) for v in self._store.values()]) if not has_dict: info = [size_repr(k, v) for k, v in self._store.items()] info = ', '.join(info) return f'{cls}({info})' else: info = [size_repr(k, v, indent=2) for k, v in self._store.items()] info = ',\n'.join(info) return f'{cls}(\n{info}\n)' @property def num_nodes(self) -> Optional[int]: return super().num_nodes @num_nodes.setter def num_nodes(self, num_nodes: Optional[int]): self._store.num_nodes = num_nodes def stores_as(self, data: Self): return self @property def stores(self) -> List[BaseStorage]: return [self._store] @property def node_stores(self) -> List[NodeStorage]: return [self._store] @property def edge_stores(self) -> List[EdgeStorage]: return [self._store] def to_dict(self) -> Dict[str, Any]: return self._store.to_dict() def to_namedtuple(self) -> NamedTuple: return self._store.to_namedtuple() def update(self, data: Union[Self, Dict[str, Any]]) -> Self: for key, value in data.items(): self[key] = value return self def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: if is_sparse(value) and ('adj' in key or 'edge_index' in key): return (0, 1) elif 'index' in key or key == 'face': return -1 else: return 0 def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: if 'batch' in key and isinstance(value, Tensor): if isinstance(value, Index): return value.get_dim_size() return int(value.max()) + 1 elif 'index' in key or key == 'face': num_nodes = self.num_nodes if num_nodes is None: raise RuntimeError(f"Unable to infer 'num_nodes' from the " f"attribute '{key}'. Please explicitly set " f"'num_nodes' as an attribute of 'data' to " f"prevent this error") return num_nodes else: return 0 def validate(self, raise_on_error: bool = True) -> bool: r"""Validates the correctness of the data.""" cls_name = self.__class__.__name__ status = True num_nodes = self.num_nodes if num_nodes is None: status = False warn_or_raise(f"'num_nodes' is undefined in '{cls_name}'", raise_on_error) if 'edge_index' in self: if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: status = False warn_or_raise( f"'edge_index' needs to be of shape [2, num_edges] in " f"'{cls_name}' (found {self.edge_index.size()})", raise_on_error) if 'edge_index' in self and self.edge_index.numel() > 0: if self.edge_index.min() < 0: status = False warn_or_raise( f"'edge_index' contains negative indices in " f"'{cls_name}' (found {int(self.edge_index.min())})", raise_on_error) if num_nodes is not None and self.edge_index.max() >= num_nodes: status = False warn_or_raise( f"'edge_index' contains larger indices than the number " f"of nodes ({num_nodes}) in '{cls_name}' " f"(found {int(self.edge_index.max())})", raise_on_error) return status def debug(self): pass # TODO def is_node_attr(self, key: str) -> bool: r"""Returns :obj:`True` if the object at key :obj:`key` denotes a node-level tensor attribute. """ return self._store.is_node_attr(key) def is_edge_attr(self, key: str) -> bool: r"""Returns :obj:`True` if the object at key :obj:`key` denotes an edge-level tensor attribute. """ return self._store.is_edge_attr(key) def subgraph(self, subset: Tensor) -> Self: r"""Returns the induced subgraph given by the node indices :obj:`subset`. Args: subset (LongTensor or BoolTensor): The nodes to keep. """ if 'edge_index' in self: edge_index, _, edge_mask = subgraph( subset, self.edge_index, relabel_nodes=True, num_nodes=self.num_nodes, return_edge_mask=True, ) else: edge_index = None edge_mask = torch.ones( self.num_edges, dtype=torch.bool, device=subset.device, ) data = copy.copy(self) for key, value in self: if key == 'edge_index': data.edge_index = edge_index elif key == 'num_nodes': if subset.dtype == torch.bool: data.num_nodes = int(subset.sum()) else: data.num_nodes = subset.size(0) elif self.is_node_attr(key): cat_dim = self.__cat_dim__(key, value) data[key] = select(value, subset, dim=cat_dim) elif self.is_edge_attr(key): cat_dim = self.__cat_dim__(key, value) data[key] = select(value, edge_mask, dim=cat_dim) return data def edge_subgraph(self, subset: Tensor) -> Self: r"""Returns the induced subgraph given by the edge indices :obj:`subset`. Will currently preserve all the nodes in the graph, even if they are isolated after subgraph computation. Args: subset (LongTensor or BoolTensor): The edges to keep. """ data = copy.copy(self) for key, value in self: if self.is_edge_attr(key): cat_dim = self.__cat_dim__(key, value) data[key] = select(value, subset, dim=cat_dim) return data def to_heterogeneous( self, node_type: Optional[Tensor] = None, edge_type: Optional[Tensor] = None, node_type_names: Optional[List[NodeType]] = None, edge_type_names: Optional[List[EdgeType]] = None, ): r"""Converts a :class:`~torch_geometric.data.Data` object to a heterogeneous :class:`~torch_geometric.data.HeteroData` object. For this, node and edge attributes are splitted according to the node-level and edge-level vectors :obj:`node_type` and :obj:`edge_type`, respectively. :obj:`node_type_names` and :obj:`edge_type_names` can be used to give meaningful node and edge type names, respectively. That is, the node_type :obj:`0` is given by :obj:`node_type_names[0]`. If the :class:`~torch_geometric.data.Data` object was constructed via :meth:`~torch_geometric.data.HeteroData.to_homogeneous`, the object can be reconstructed without any need to pass in additional arguments. Args: node_type (torch.Tensor, optional): A node-level vector denoting the type of each node. (default: :obj:`None`) edge_type (torch.Tensor, optional): An edge-level vector denoting the type of each edge. (default: :obj:`None`) node_type_names (List[str], optional): The names of node types. (default: :obj:`None`) edge_type_names (List[Tuple[str, str, str]], optional): The names of edge types. (default: :obj:`None`) """ from torch_geometric.data import HeteroData if node_type is None: node_type = self._store.get('node_type', None) if node_type is None: node_type = torch.zeros(self.num_nodes, dtype=torch.long) if node_type_names is None: store = self._store node_type_names = store.__dict__.get('_node_type_names', None) if node_type_names is None: node_type_names = [str(i) for i in node_type.unique().tolist()] if edge_type is None: edge_type = self._store.get('edge_type', None) if edge_type is None: edge_type = torch.zeros(self.num_edges, dtype=torch.long) if edge_type_names is None: store = self._store edge_type_names = store.__dict__.get('_edge_type_names', None) if edge_type_names is None: edge_type_names = [] edge_index = self.edge_index for i in edge_type.unique().tolist(): src, dst = edge_index[:, edge_type == i] src_types = node_type[src].unique().tolist() dst_types = node_type[dst].unique().tolist() if len(src_types) != 1 and len(dst_types) != 1: raise ValueError( "Could not construct a 'HeteroData' object from the " "'Data' object because single edge types span over " "multiple node types") edge_type_names.append((node_type_names[src_types[0]], str(i), node_type_names[dst_types[0]])) # We iterate over node types to find the local node indices belonging # to each node type. Furthermore, we create a global `index_map` vector # that maps global node indices to local ones in the final # heterogeneous graph: node_ids, index_map = {}, torch.empty_like(node_type) for i in range(len(node_type_names)): node_ids[i] = (node_type == i).nonzero(as_tuple=False).view(-1) index_map[node_ids[i]] = torch.arange(len(node_ids[i]), device=index_map.device) # We iterate over edge types to find the local edge indices: edge_ids = {} for i in range(len(edge_type_names)): edge_ids[i] = (edge_type == i).nonzero(as_tuple=False).view(-1) data = HeteroData() for i, key in enumerate(node_type_names): for attr, value in self.items(): if attr in {'node_type', 'edge_type', 'ptr'}: continue elif isinstance(value, Tensor) and self.is_node_attr(attr): cat_dim = self.__cat_dim__(attr, value) data[key][attr] = value.index_select(cat_dim, node_ids[i]) elif (isinstance(value, TensorFrame) and self.is_node_attr(attr)): data[key][attr] = value[node_ids[i]] if len(data[key]) == 0: data[key].num_nodes = node_ids[i].size(0) for i, key in enumerate(edge_type_names): src, _, dst = key for attr, value in self.items(): if attr in {'node_type', 'edge_type', 'ptr'}: continue elif attr == 'edge_index': edge_index = value[:, edge_ids[i]] edge_index[0] = index_map[edge_index[0]] edge_index[1] = index_map[edge_index[1]] data[key].edge_index = edge_index elif isinstance(value, Tensor) and self.is_edge_attr(attr): cat_dim = self.__cat_dim__(attr, value) data[key][attr] = value.index_select(cat_dim, edge_ids[i]) elif (isinstance(value, TensorFrame) and self.is_edge_attr(attr)): data[key][attr] = value[edge_ids[i]] # Add global attributes. exclude_keys = set(data.keys()) | { 'node_type', 'edge_type', 'edge_index', 'num_nodes', 'ptr' } for attr, value in self.items(): if attr in exclude_keys: continue data[attr] = value return data def connected_components(self) -> List[Self]: r"""Extracts connected components of the graph using a union-find algorithm. The components are returned as a list of :class:`~torch_geometric.data.Data` objects, where each object represents a connected component of the graph. .. code-block:: data = Data() data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]]) data.edge_index = torch.tensor( [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long ) components = data.connected_components() print(len(components)) >>> 2 print(components[0].x) >>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2]) Returns: List[Data]: A list of disconnected components. """ # Union-Find algorithm to find connected components self._parents: Dict[int, int] = {} self._ranks: Dict[int, int] = {} for edge in self.edge_index.t().tolist(): self._union(edge[0], edge[1]) # Rerun _find_parent to ensure all nodes are covered correctly for node in range(self.num_nodes): self._find_parent(node) # Group parents grouped_parents = defaultdict(list) for node, parent in self._parents.items(): grouped_parents[parent].append(node) del self._ranks del self._parents # Create components based on the found parents (roots) components: List[Self] = [] for nodes in grouped_parents.values(): # Convert the list of node IDs to a tensor subset = torch.tensor(nodes, dtype=torch.long) # Use the existing subgraph function component_data = self.subgraph(subset) components.append(component_data) return components ########################################################################### @classmethod def from_dict(cls, mapping: Dict[str, Any]) -> Self: r"""Creates a :class:`~torch_geometric.data.Data` object from a dictionary. """ return cls(**mapping) @property def num_node_features(self) -> int: r"""Returns the number of features per node in the graph.""" return self._store.num_node_features @property def num_features(self) -> int: r"""Returns the number of features per node in the graph. Alias for :py:attr:`~num_node_features`. """ return self.num_node_features @property def num_edge_features(self) -> int: r"""Returns the number of features per edge in the graph.""" return self._store.num_edge_features @property def num_node_types(self) -> int: r"""Returns the number of node types in the graph.""" return int(self.node_type.max()) + 1 if 'node_type' in self else 1 @property def num_edge_types(self) -> int: r"""Returns the number of edge types in the graph.""" return int(self.edge_type.max()) + 1 if 'edge_type' in self else 1 def __iter__(self) -> Iterable: r"""Iterates over all attributes in the data, yielding their attribute names and values. """ yield from self._store.items() def __call__(self, *args: str) -> Iterable: r"""Iterates over all attributes :obj:`*args` in the data, yielding their attribute names and values. If :obj:`*args` is not given, will iterate over all attributes. """ yield from self._store.items(*args) @property def x(self) -> Optional[Tensor]: return self['x'] if 'x' in self._store else None @x.setter def x(self, x: Optional[Tensor]): self._store.x = x @property def edge_index(self) -> Optional[Tensor]: return self['edge_index'] if 'edge_index' in self._store else None @edge_index.setter def edge_index(self, edge_index: Optional[Tensor]): self._store.edge_index = edge_index @property def edge_weight(self) -> Optional[Tensor]: return self['edge_weight'] if 'edge_weight' in self._store else None @edge_weight.setter def edge_weight(self, edge_weight: Optional[Tensor]): self._store.edge_weight = edge_weight @property def edge_attr(self) -> Optional[Tensor]: return self['edge_attr'] if 'edge_attr' in self._store else None @edge_attr.setter def edge_attr(self, edge_attr: Optional[Tensor]): self._store.edge_attr = edge_attr @property def y(self) -> Optional[Union[Tensor, int, float]]: return self['y'] if 'y' in self._store else None @y.setter def y(self, y: Optional[Tensor]): self._store.y = y @property def pos(self) -> Optional[Tensor]: return self['pos'] if 'pos' in self._store else None @pos.setter def pos(self, pos: Optional[Tensor]): self._store.pos = pos @property def batch(self) -> Optional[Tensor]: return self['batch'] if 'batch' in self._store else None @batch.setter def batch(self, batch: Optional[Tensor]): self._store.batch = batch @property def time(self) -> Optional[Tensor]: return self['time'] if 'time' in self._store else None @time.setter def time(self, time: Optional[Tensor]): self._store.time = time @property def face(self) -> Optional[Tensor]: return self['face'] if 'face' in self._store else None @face.setter def face(self, face: Optional[Tensor]): self._store.face = face # Deprecated functions #################################################### @property @deprecated(details="use 'data.face.size(-1)' instead") def num_faces(self) -> Optional[int]: r"""Returns the number of faces in the mesh.""" if 'face' in self._store and isinstance(self.face, Tensor): return self.face.size(self.__cat_dim__('face', self.face)) return None # FeatureStore interface ################################################## def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: out = self.get(attr.attr_name) if out is not None and attr.index is not None: out[attr.index] = tensor else: assert attr.index is None setattr(self, attr.attr_name, tensor) return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: tensor = getattr(self, attr.attr_name, None) if tensor is not None: # TODO this behavior is a bit odd, since TensorAttr requires that # we set `index`. So, we assume here that indexing by `None` is # equivalent to not indexing at all, which is not in line with # Python semantics. return tensor[attr.index] if attr.index is not None else tensor return None def _remove_tensor(self, attr: TensorAttr) -> bool: if hasattr(self, attr.attr_name): delattr(self, attr.attr_name) return True return False def _get_tensor_size(self, attr: TensorAttr) -> Tuple: return self._get_tensor(attr).size() def get_all_tensor_attrs(self) -> List[TensorAttr]: r"""Obtains all feature attributes stored in `Data`.""" return [ TensorAttr(attr_name=name) for name in self._store.keys() if self._store.is_node_attr(name) ] # GraphStore interface #################################################### def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: if not hasattr(self, '_edge_attrs'): self._edge_attrs = {} self._edge_attrs[edge_attr.layout] = edge_attr row, col = edge_index if edge_attr.layout == EdgeLayout.COO: self.edge_index = torch.stack([row, col], dim=0) elif edge_attr.layout == EdgeLayout.CSR: self.adj = SparseTensor( rowptr=row, col=col, sparse_sizes=edge_attr.size, is_sorted=True, trust_data=True, ) else: # edge_attr.layout == EdgeLayout.CSC: size = edge_attr.size[::-1] if edge_attr.size is not None else None self.adj_t = SparseTensor( rowptr=col, col=row, sparse_sizes=size, is_sorted=True, trust_data=True, ) return True def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: if edge_attr.size is None: edge_attr.size = self.size() # Modify in-place. if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self: row, col = self.edge_index return row, col elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self: rowptr, col, _ = self.adj.csr() return rowptr, col elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self: colptr, row, _ = self.adj_t.csr() return row, colptr return None def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self: del self.edge_index if hasattr(self, '_edge_attrs'): self._edge_attrs.pop(EdgeLayout.COO, None) return True elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self: del self.adj if hasattr(self, '_edge_attrs'): self._edge_attrs.pop(EdgeLayout.CSR, None) return True elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self: del self.adj_t if hasattr(self, '_edge_attrs'): self._edge_attrs.pop(EdgeLayout.CSC, None) return True return False def get_all_edge_attrs(self) -> List[EdgeAttr]: edge_attrs = getattr(self, '_edge_attrs', {}) if 'edge_index' in self and EdgeLayout.COO not in edge_attrs: edge_attrs[EdgeLayout.COO] = DataEdgeAttr('coo', is_sorted=False) if 'adj' in self and EdgeLayout.CSR not in edge_attrs: size = self.adj.sparse_sizes() edge_attrs[EdgeLayout.CSR] = DataEdgeAttr('csr', size=size) if 'adj_t' in self and EdgeLayout.CSC not in edge_attrs: size = self.adj_t.sparse_sizes()[::-1] edge_attrs[EdgeLayout.CSC] = DataEdgeAttr('csc', size=size) return list(edge_attrs.values()) # Connected Components Helper Functions ################################### def _find_parent(self, node: int) -> int: r"""Finds and returns the representative parent of the given node in a disjoint-set (union-find) data structure. Implements path compression to optimize future queries. Args: node (int): The node for which to find the representative parent. Returns: int: The representative parent of the node. """ if node not in self._parents: self._parents[node] = node self._ranks[node] = 0 if self._parents[node] != node: self._parents[node] = self._find_parent(self._parents[node]) return self._parents[node] def _union(self, node1: int, node2: int): r"""Merges the sets containing node1 and node2 in the disjoint-set data structure. Finds the root parents of node1 and node2 using the _find_parent method. If they belong to different sets, updates the parent of root2 to be root1, effectively merging the two sets. Args: node1 (int): The index of the first node to union. node2 (int): The index of the second node to union. """ root1 = self._find_parent(node1) root2 = self._find_parent(node2) if root1 != root2: if self._ranks[root1] < self._ranks[root2]: self._parents[root1] = root2 elif self._ranks[root1] > self._ranks[root2]: self._parents[root2] = root1 else: self._parents[root2] = root1 self._ranks[root1] += 1 ############################################################################### def size_repr(key: Any, value: Any, indent: int = 0) -> str: pad = ' ' * indent if isinstance(value, Tensor) and value.dim() == 0: out = value.item() elif isinstance(value, Tensor) and getattr(value, 'is_nested', False): out = str(list(value.to_padded_tensor(padding=0.0).size())) elif isinstance(value, Tensor): out = str(list(value.size())) elif isinstance(value, np.ndarray): out = str(list(value.shape)) elif isinstance(value, SparseTensor): out = str(value.sizes())[:-1] + f', nnz={value.nnz()}]' elif isinstance(value, TensorFrame): out = (f'{value.__class__.__name__}(' f'[{value.num_rows}, {value.num_cols}])') elif isinstance(value, str): out = f"'{value}'" elif isinstance(value, (Sequence, set)): out = str([len(value)]) elif isinstance(value, Mapping) and len(value) == 0: out = '{}' elif (isinstance(value, Mapping) and len(value) == 1 and not isinstance(list(value.values())[0], Mapping)): lines = [size_repr(k, v, 0) for k, v in value.items()] out = '{ ' + ', '.join(lines) + ' }' elif isinstance(value, Mapping): lines = [size_repr(k, v, indent + 2) for k, v in value.items()] out = '{\n' + ',\n'.join(lines) + ',\n' + pad + '}' else: out = str(value) key = str(key).replace("'", '') return f'{pad}{key}={out}' def warn_or_raise(msg: str, raise_on_error: bool = True): if raise_on_error: raise ValueError(msg) else: warnings.warn(msg, stacklevel=2) ================================================ FILE: torch_geometric/data/database.py ================================================ import io import warnings from abc import ABC, abstractmethod from dataclasses import dataclass from functools import cached_property from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor from tqdm import tqdm from torch_geometric import EdgeIndex, Index from torch_geometric.edge_index import SortOrder from torch_geometric.utils.mixin import CastMixin @dataclass class TensorInfo(CastMixin): dtype: torch.dtype size: Tuple[int, ...] = (-1, ) is_index: bool = False is_edge_index: bool = False def __post_init__(self) -> None: if self.is_index and self.is_edge_index: raise ValueError("Tensor cannot be a 'Index' and 'EdgeIndex' " "tensor at the same time") if self.is_index: self.size = (-1, ) if self.is_edge_index: self.size = (2, -1) def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]: if not isinstance(value, dict): return value if len(value) < 1 or len(value) > 3: return value if 'dtype' not in value: return value valid_keys = {'dtype', 'size', 'is_index', 'is_edge_index'} if len(set(value.keys()) | valid_keys) != len(valid_keys): return value return TensorInfo.cast(value) Schema = Union[Any, Dict[str, Any], Tuple[Any], List[Any]] SORT_ORDER_TO_INDEX: Dict[Optional[SortOrder], int] = { None: -1, SortOrder.ROW: 0, SortOrder.COL: 1, } INDEX_TO_SORT_ORDER = {v: k for k, v in SORT_ORDER_TO_INDEX.items()} class Database(ABC): r"""Base class for inserting and retrieving data from a database. A database acts as a persisted, out-of-memory and index-based key/value store for tensor and custom data: .. code-block:: python db = Database() db[0] = Data(x=torch.randn(5, 16), y=0, z='id_0') print(db[0]) >>> Data(x=[5, 16], y=0, z='id_0') To improve efficiency, it is recommended to specify the underlying :obj:`schema` of the data: .. code-block:: python db = Database(schema={ # Custom schema: # Tensor information can be specified through a dictionary: 'x': dict(dtype=torch.float, size=(-1, 16)), 'y': int, 'z': str, }) db[0] = dict(x=torch.randn(5, 16), y=0, z='id_0') print(db[0]) >>> {'x': torch.tensor(...), 'y': 0, 'z': 'id_0'} In addition, databases support batch-wise insert and get, and support syntactic sugar known from indexing :python:`Python` lists, *e.g.*: .. code-block:: python db = Database() db[2:5] = torch.randn(3, 16) print(db[torch.tensor([2, 3])]) >>> [torch.tensor(...), torch.tensor(...)] Args: schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of the input data. Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a dictionary with :obj:`dtype` and :obj:`size` keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. (default: :obj:`object`) """ def __init__(self, schema: Schema = object) -> None: schema_dict = self._to_dict(maybe_cast_to_tensor_info(schema)) self.schema: Dict[Union[str, int], Any] = { key: maybe_cast_to_tensor_info(value) for key, value in schema_dict.items() } @abstractmethod def connect(self) -> None: r"""Connects to the database. Databases will automatically connect on instantiation. """ raise NotImplementedError @abstractmethod def close(self) -> None: r"""Closes the connection to the database.""" raise NotImplementedError @abstractmethod def insert(self, index: int, data: Any) -> None: r"""Inserts data at the specified index. Args: index (int): The index at which to insert. data (Any): The object to insert. """ raise NotImplementedError def multi_insert( self, indices: Union[Sequence[int], Tensor, slice, range], data_list: Sequence[Any], batch_size: Optional[int] = None, log: bool = False, ) -> None: r"""Inserts a chunk of data at the specified indices. Args: indices (List[int] or torch.Tensor or range): The indices at which to insert. data_list (List[Any]): The objects to insert. batch_size (int, optional): If specified, will insert the data to the database in batches of size :obj:`batch_size`. (default: :obj:`None`) log (bool, optional): If set to :obj:`True`, will log progress to the console. (default: :obj:`False`) """ if isinstance(indices, slice): indices = self.slice_to_range(indices) length = min(len(indices), len(data_list)) batch_size = length if batch_size is None else batch_size if log and length > batch_size: desc = f'Insert {length} entries' offsets = tqdm(range(0, length, batch_size), desc=desc) else: offsets = range(0, length, batch_size) for start in offsets: self._multi_insert( indices[start:start + batch_size], data_list[start:start + batch_size], ) def _multi_insert( self, indices: Union[Sequence[int], Tensor, range], data_list: Sequence[Any], ) -> None: if isinstance(indices, Tensor): indices = indices.tolist() for index, data in zip(indices, data_list): self.insert(index, data) @abstractmethod def get(self, index: int) -> Any: r"""Gets data from the specified index. Args: index (int): The index to query. """ raise NotImplementedError def multi_get( self, indices: Union[Sequence[int], Tensor, slice, range], batch_size: Optional[int] = None, ) -> List[Any]: r"""Gets a chunk of data from the specified indices. Args: indices (List[int] or torch.Tensor or range): The indices to query. batch_size (int, optional): If specified, will request the data from the database in batches of size :obj:`batch_size`. (default: :obj:`None`) """ if isinstance(indices, slice): indices = self.slice_to_range(indices) length = len(indices) batch_size = length if batch_size is None else batch_size data_list: List[Any] = [] for start in range(0, length, batch_size): chunk_indices = indices[start:start + batch_size] data_list.extend(self._multi_get(chunk_indices)) return data_list def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]: if isinstance(indices, Tensor): indices = indices.tolist() return [self.get(index) for index in indices] # Helper functions ######################################################## @staticmethod def _to_dict( value: Union[Dict[Union[int, str], Any], Sequence[Any], Any], ) -> Dict[Union[str, int], Any]: if isinstance(value, dict): return value if isinstance(value, (tuple, list)): return {i: v for i, v in enumerate(value)} else: return {0: value} def slice_to_range(self, indices: slice) -> range: start = 0 if indices.start is None else indices.start stop = len(self) if indices.stop is None else indices.stop step = 1 if indices.step is None else indices.step return range(start, stop, step) # Python built-ins ######################################################## def __len__(self) -> int: raise NotImplementedError def __getitem__( self, key: Union[int, Sequence[int], Tensor, slice, range], ) -> Union[Any, List[Any]]: if isinstance(key, int): return self.get(key) else: return self.multi_get(key) def __setitem__( self, key: Union[int, Sequence[int], Tensor, slice, range], value: Union[Any, Sequence[Any]], ) -> None: if isinstance(key, int): self.insert(key, value) else: self.multi_insert(key, value) def __repr__(self) -> str: try: return f'{self.__class__.__name__}({len(self)})' except NotImplementedError: return f'{self.__class__.__name__}()' class SQLiteDatabase(Database): r"""An index-based key/value database based on :obj:`sqlite3`. .. note:: This database implementation requires the :obj:`sqlite3` package. Args: path (str): The path to where the database should be saved. name (str): The name of the table to save the data to. schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of the input data. Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a dictionary with :obj:`dtype` and :obj:`size` keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. (default: :obj:`object`) """ def __init__(self, path: str, name: str, schema: Schema = object) -> None: super().__init__(schema) warnings.filterwarnings('ignore', '.*given buffer is not writable.*') import sqlite3 self.path = path self.name = name self._connection: Optional[sqlite3.Connection] = None self._cursor: Optional[sqlite3.Cursor] = None self.connect() # Create the table (if it does not exist) by mapping the Python schema # to the corresponding SQL schema: sql_schema = ',\n'.join([ f' {col_name} {self._to_sql_type(type_info)}' for col_name, type_info in zip(self._col_names, self.schema.values()) ]) query = (f'CREATE TABLE IF NOT EXISTS {self.name} (\n' f' id INTEGER PRIMARY KEY,\n' f'{sql_schema}\n' f')') self.cursor.execute(query) def connect(self) -> None: import sqlite3 self._connection = sqlite3.connect(self.path) self._cursor = self._connection.cursor() def close(self) -> None: if self._connection is not None: self._connection.commit() self._connection.close() self._connection = None self._cursor = None @property def connection(self) -> Any: if self._connection is None: raise RuntimeError("No open database connection") return self._connection @property def cursor(self) -> Any: if self._cursor is None: raise RuntimeError("No open database connection") return self._cursor def insert(self, index: int, data: Any) -> None: query = (f'INSERT INTO {self.name} ' f'(id, {self._joined_col_names}) ' f'VALUES (?, {self._dummies})') self.cursor.execute(query, (index, *self._serialize(data))) self.connection.commit() def _multi_insert( self, indices: Union[Sequence[int], Tensor, range], data_list: Sequence[Any], ) -> None: if isinstance(indices, Tensor): indices = indices.tolist() data_list = [(index, *self._serialize(data)) for index, data in zip(indices, data_list)] query = (f'INSERT INTO {self.name} ' f'(id, {self._joined_col_names}) ' f'VALUES (?, {self._dummies})') self.cursor.executemany(query, data_list) self.connection.commit() def get(self, index: int) -> Any: query = (f'SELECT {self._joined_col_names} FROM {self.name} ' f'WHERE id = ?') self.cursor.execute(query, (index, )) return self._deserialize(self.cursor.fetchone()) def multi_get( self, indices: Union[Sequence[int], Tensor, slice, range], batch_size: Optional[int] = None, ) -> List[Any]: if isinstance(indices, slice): indices = self.slice_to_range(indices) elif isinstance(indices, Tensor): indices = indices.tolist() # We create a temporary ID table to then perform an INNER JOIN. # This avoids having a long IN clause and guarantees sorted outputs: join_table_name = f'{self.name}__join' # Temporary tables do not lock the database. query = (f'CREATE TEMP TABLE {join_table_name} (\n' f' id INTEGER,\n' f' row_id INTEGER\n' f')') self.cursor.execute(query) query = f'INSERT INTO {join_table_name} (id, row_id) VALUES (?, ?)' self.cursor.executemany(query, zip(indices, range(len(indices)))) self.connection.commit() query = f'SELECT * FROM {join_table_name}' self.cursor.execute(query) query = (f'SELECT {self._joined_col_names} ' f'FROM {self.name} INNER JOIN {join_table_name} ' f'ON {self.name}.id = {join_table_name}.id ' f'ORDER BY {join_table_name}.row_id') self.cursor.execute(query) if batch_size is None: data_list = self.cursor.fetchall() else: data_list = [] while True: chunk_list = self.cursor.fetchmany(size=batch_size) if len(chunk_list) == 0: break data_list.extend(chunk_list) query = f'DROP TABLE {join_table_name}' self.cursor.execute(query) return [self._deserialize(data) for data in data_list] def __len__(self) -> int: query = f'SELECT COUNT(*) FROM {self.name}' self.cursor.execute(query) return self.cursor.fetchone()[0] # Helper functions ######################################################## @cached_property def _col_names(self) -> List[str]: return [f'COL_{key}' for key in self.schema.keys()] @cached_property def _joined_col_names(self) -> str: return ', '.join(self._col_names) @cached_property def _dummies(self) -> str: return ', '.join(['?'] * len(self.schema.keys())) def _to_sql_type(self, type_info: Any) -> str: if type_info == int: return 'INTEGER NOT NULL' if type_info == float: return 'FLOAT' if type_info == str: return 'TEXT NOT NULL' else: return 'BLOB NOT NULL' def _serialize(self, row: Any) -> List[Any]: # Serializes the given input data according to `schema`: # * {int, float, str}: Use as they are. # * torch.Tensor: Convert into the raw byte string # * object: Dump via pickle # If we find a `torch.Tensor` that is not registered as such in # `schema`, we modify the schema in-place for improved efficiency. out: List[Any] = [] row_dict = self._to_dict(row) for key, schema in self.schema.items(): col = row_dict[key] if isinstance(col, Tensor) and not isinstance(schema, TensorInfo): self.schema[key] = schema = TensorInfo( col.dtype, is_index=isinstance(col, Index), is_edge_index=isinstance(col, EdgeIndex), ) if isinstance(schema, TensorInfo) and schema.is_index: assert isinstance(col, Index) meta = torch.tensor([ col.dim_size if col.dim_size is not None else -1, col.is_sorted, ], dtype=torch.long) out.append(meta.numpy().tobytes() + col.as_tensor().numpy().tobytes()) elif isinstance(schema, TensorInfo) and schema.is_edge_index: assert isinstance(col, EdgeIndex) num_rows, num_cols = col.sparse_size() meta = torch.tensor([ num_rows if num_rows is not None else -1, num_cols if num_cols is not None else -1, SORT_ORDER_TO_INDEX[col._sort_order], col.is_undirected, ], dtype=torch.long) out.append(meta.numpy().tobytes() + col.as_tensor().numpy().tobytes()) elif isinstance(schema, TensorInfo): assert isinstance(col, Tensor) out.append(col.numpy().tobytes()) elif schema in {int, float, str}: out.append(col) else: buffer = io.BytesIO() torch.save(col, buffer) out.append(buffer.getvalue()) return out def _deserialize(self, row: Tuple[Any]) -> Any: # Deserializes the DB data according to `schema`: # * {int, float, str}: Use as they are. # * torch.Tensor: Load raw byte string with `dtype` and `size` # information from `schema` # * object: Load via pickle out_dict = {} for i, (key, schema) in enumerate(self.schema.items()): value = row[i] if isinstance(schema, TensorInfo) and schema.is_index: meta = torch.frombuffer(value[:16], dtype=torch.long).tolist() dim_size = meta[0] if meta[0] >= 0 else None is_sorted = meta[1] > 0 if len(value) > 16: tensor = torch.frombuffer(value[16:], dtype=schema.dtype) else: tensor = torch.empty(0, dtype=schema.dtype) out_dict[key] = Index( tensor.view(*schema.size), dim_size=dim_size, is_sorted=is_sorted, ) elif isinstance(schema, TensorInfo) and schema.is_edge_index: meta = torch.frombuffer(value[:32], dtype=torch.long).tolist() num_rows = meta[0] if meta[0] >= 0 else None num_cols = meta[1] if meta[1] >= 0 else None sort_order = INDEX_TO_SORT_ORDER[meta[2]] is_undirected = meta[3] > 0 if len(value) > 32: tensor = torch.frombuffer(value[32:], dtype=schema.dtype) else: tensor = torch.empty(0, dtype=schema.dtype) out_dict[key] = EdgeIndex( tensor.view(*schema.size), sparse_size=(num_rows, num_cols), sort_order=sort_order, is_undirected=is_undirected, ) elif isinstance(schema, TensorInfo): if len(value) > 0: tensor = torch.frombuffer(value, dtype=schema.dtype) else: tensor = torch.empty(0, dtype=schema.dtype) out_dict[key] = tensor.view(*schema.size) elif schema == float: out_dict[key] = value if value is not None else float('NaN') elif schema in {int, str}: out_dict[key] = value else: out_dict[key] = torch.load( io.BytesIO(value), weights_only=False, ) # In case `0` exists as integer in the schema, this means that the # schema was passed as either a single entry or a tuple: if 0 in self.schema: if len(self.schema) == 1: return out_dict[0] else: return tuple(out_dict.values()) else: # Otherwise, return the dictionary as it is: return out_dict class RocksDatabase(Database): r"""An index-based key/value database based on :obj:`RocksDB`. .. note:: This database implementation requires the :obj:`rocksdict` package. .. warning:: :class:`RocksDatabase` is currently less optimized than :class:`SQLiteDatabase`. Args: path (str): The path to where the database should be saved. schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of the input data. Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a dictionary with :obj:`dtype` and :obj:`size` keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. (default: :obj:`object`) """ def __init__(self, path: str, schema: Schema = object) -> None: super().__init__(schema) import rocksdict self.path = path self._db: Optional[rocksdict.Rdict] = None self.connect() def connect(self) -> None: import rocksdict self._db = rocksdict.Rdict( self.path, options=rocksdict.Options(raw_mode=True), ) def close(self) -> None: if self._db is not None: self._db.close() self._db = None @property def db(self) -> Any: if self._db is None: raise RuntimeError("No open database connection") return self._db @staticmethod def to_key(index: int) -> bytes: return index.to_bytes(8, byteorder='big', signed=True) def insert(self, index: int, data: Any) -> None: self.db[self.to_key(index)] = self._serialize(data) def get(self, index: int) -> Any: return self._deserialize(self.db[self.to_key(index)]) def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]: if isinstance(indices, Tensor): indices = indices.tolist() data_list = self.db[[self.to_key(index) for index in indices]] return [self._deserialize(data) for data in data_list] # Helper functions ######################################################## def _serialize(self, row: Any) -> bytes: # Ensure that data is not a view of a larger tensor: if isinstance(row, Tensor): row = row.clone() buffer = io.BytesIO() torch.save(row, buffer) return buffer.getvalue() def _deserialize(self, row: bytes) -> Any: return torch.load( io.BytesIO(row), weights_only=False, ) ================================================ FILE: torch_geometric/data/datapipes.py ================================================ import copy from typing import Any, Callable, Iterator, Optional, Sequence import torch from torch_geometric.data import Batch from torch_geometric.utils import from_smiles try: from torch.utils.data import IterDataPipe, functional_datapipe from torch.utils.data.datapipes.iter import Batcher as IterBatcher except ImportError: IterDataPipe = IterBatcher = object # type: ignore def functional_datapipe(name: str) -> Callable: # type: ignore return lambda cls: cls @functional_datapipe('batch_graphs') class Batcher(IterBatcher): def __init__( self, dp: IterDataPipe, batch_size: int, drop_last: bool = False, ) -> None: super().__init__( dp, batch_size=batch_size, drop_last=drop_last, wrapper_class=Batch.from_data_list, ) @functional_datapipe('parse_smiles') class SMILESParser(IterDataPipe): def __init__( self, dp: IterDataPipe, smiles_key: str = 'smiles', target_key: Optional[str] = None, ) -> None: super().__init__() self.dp = dp self.smiles_key = smiles_key self.target_key = target_key def __iter__(self) -> Iterator: for d in self.dp: if isinstance(d, str): data = from_smiles(d) elif isinstance(d, dict): data = from_smiles(d[self.smiles_key]) if self.target_key is not None: y = d.get(self.target_key, None) if y is not None: y = float(y) if len(y) > 0 else float('NaN') data.y = torch.tensor([y], dtype=torch.float) else: raise ValueError( f"'{self.__class__.__name__}' expected either a string or " f"a dict as input (got '{type(d)}')") yield data class DatasetAdapter(IterDataPipe): def __init__(self, dataset: Sequence[Any]) -> None: super().__init__() self.dataset = dataset self.range = range(len(self)) def is_shardable(self) -> bool: return True def apply_sharding(self, num_shards: int, shard_idx: int) -> None: self.range = range(shard_idx, len(self), num_shards) def __iter__(self) -> Iterator: for i in self.range: yield self.dataset[i] def __len__(self) -> int: return len(self.dataset) def functional_transform(name: str) -> Callable: def wrapper(cls: Any) -> Any: @functional_datapipe(name) class DynamicMapper(IterDataPipe): def __init__( self, dp: IterDataPipe, *args: Any, **kwargs: Any, ) -> None: super().__init__() self.dp = dp self.fn = cls(*args, **kwargs) def __iter__(self) -> Iterator: for data in self.dp: yield self.fn(copy.copy(data)) return cls return wrapper ================================================ FILE: torch_geometric/data/dataset.py ================================================ import copy import os import os.path as osp import re import sys import warnings from collections.abc import Sequence from typing import ( Any, Callable, Iterable, Iterator, List, Optional, Tuple, Union, ) import numpy as np import torch.utils.data from torch import Tensor from torch_geometric.data.data import BaseData from torch_geometric.io import fs IndexType = Union[slice, Tensor, np.ndarray, Sequence] MISSING = '???' class Dataset(torch.utils.data.Dataset): r"""Dataset base class for creating graph datasets. See `here `__ for the accompanying tutorial. Args: root (str, optional): Root directory where the dataset should be saved. (optional: :obj:`None`) transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ @property def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: r"""The name of the files in the :obj:`self.raw_dir` folder that must be present in order to skip downloading. """ raise NotImplementedError @property def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: r"""The name of the files in the :obj:`self.processed_dir` folder that must be present in order to skip processing. """ raise NotImplementedError def download(self) -> None: r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" raise NotImplementedError def process(self) -> None: r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" raise NotImplementedError def len(self) -> int: r"""Returns the number of data objects stored in the dataset.""" raise NotImplementedError def get(self, idx: int) -> BaseData: r"""Gets the data object at index :obj:`idx`.""" raise NotImplementedError def __init__( self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True, force_reload: bool = False, ) -> None: super().__init__() if isinstance(root, str): root = osp.expanduser(fs.normpath(root)) self.root = root or MISSING self.transform = transform self.pre_transform = pre_transform self.pre_filter = pre_filter self.log = log self._indices: Optional[Sequence] = None self.force_reload = force_reload if self.has_download: self._download() if self.has_process: self._process() def indices(self) -> Sequence: return range(self.len()) if self._indices is None else self._indices @property def raw_dir(self) -> str: return osp.join(self.root, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, 'processed') @property def num_node_features(self) -> int: r"""Returns the number of features per node in the dataset.""" data = self[0] # Do not fill cache for `InMemoryDataset`: if hasattr(self, '_data_list') and self._data_list is not None: self._data_list[0] = None data = data[0] if isinstance(data, tuple) else data if hasattr(data, 'num_node_features'): return data.num_node_features raise AttributeError(f"'{data.__class__.__name__}' object has no " f"attribute 'num_node_features'") @property def num_features(self) -> int: r"""Returns the number of features per node in the dataset. Alias for :py:attr:`~num_node_features`. """ return self.num_node_features @property def num_edge_features(self) -> int: r"""Returns the number of features per edge in the dataset.""" data = self[0] # Do not fill cache for `InMemoryDataset`: if hasattr(self, '_data_list') and self._data_list is not None: self._data_list[0] = None data = data[0] if isinstance(data, tuple) else data if hasattr(data, 'num_edge_features'): return data.num_edge_features raise AttributeError(f"'{data.__class__.__name__}' object has no " f"attribute 'num_edge_features'") def _infer_num_classes(self, y: Optional[Tensor]) -> int: if y is None: return 0 elif y.numel() == y.size(0) and not torch.is_floating_point(y): return int(y.max()) + 1 elif y.numel() == y.size(0) and torch.is_floating_point(y): num_classes = torch.unique(y).numel() if num_classes > 2: warnings.warn( "Found floating-point labels while calling " "`dataset.num_classes`. Returning the number of " "unique elements. Please make sure that this " "is expected before proceeding.", stacklevel=2) return num_classes else: return y.size(-1) @property def num_classes(self) -> int: r"""Returns the number of classes in the dataset.""" # We iterate over the dataset and collect all labels to determine the # maximum number of classes. Importantly, in rare cases, `__getitem__` # may produce a tuple of data objects (e.g., when used in combination # with `RandomLinkSplit`, so we take care of this case here as well: data_list = _get_flattened_data_list([data for data in self]) if 'y' in data_list[0] and isinstance(data_list[0].y, Tensor): y = torch.cat([data.y for data in data_list if 'y' in data], dim=0) else: y = torch.as_tensor([data.y for data in data_list if 'y' in data]) # Do not fill cache for `InMemoryDataset`: if hasattr(self, '_data_list') and self._data_list is not None: self._data_list = self.len() * [None] return self._infer_num_classes(y) @property def raw_paths(self) -> List[str]: r"""The absolute filepaths that must be present in order to skip downloading. """ files = self.raw_file_names # Prevent a common source of error in which `file_names` are not # defined as a property. if isinstance(files, Callable): files = files() return [osp.join(self.raw_dir, f) for f in to_list(files)] @property def processed_paths(self) -> List[str]: r"""The absolute filepaths that must be present in order to skip processing. """ files = self.processed_file_names # Prevent a common source of error in which `file_names` are not # defined as a property. if isinstance(files, Callable): files = files() return [osp.join(self.processed_dir, f) for f in to_list(files)] @property def has_download(self) -> bool: r"""Checks whether the dataset defines a :meth:`download` method.""" return overrides_method(self.__class__, 'download') def _download(self): if files_exist(self.raw_paths): # pragma: no cover return fs.makedirs(self.raw_dir, exist_ok=True) self.download() @property def has_process(self) -> bool: r"""Checks whether the dataset defines a :meth:`process` method.""" return overrides_method(self.__class__, 'process') def _process(self): f = osp.join(self.processed_dir, 'pre_transform.pt') if not self.force_reload and osp.exists(f) and torch.load( f, weights_only=False) != _repr(self.pre_transform): warnings.warn( "The `pre_transform` argument differs from the one used in " "the pre-processed version of this dataset. If you want to " "make use of another pre-processing technique, pass " "`force_reload=True` explicitly to reload the dataset.", stacklevel=2) f = osp.join(self.processed_dir, 'pre_filter.pt') if not self.force_reload and osp.exists(f) and torch.load( f, weights_only=False) != _repr(self.pre_filter): warnings.warn( "The `pre_filter` argument differs from the one used in " "the pre-processed version of this dataset. If you want to " "make use of another pre-fitering technique, pass " "`force_reload=True` explicitly to reload the dataset.", stacklevel=2) if not self.force_reload and files_exist(self.processed_paths): return if self.log and 'PYTEST_CURRENT_TEST' not in os.environ: print('Processing...', file=sys.stderr) fs.makedirs(self.processed_dir, exist_ok=True) self.process() path = osp.join(self.processed_dir, 'pre_transform.pt') fs.torch_save(_repr(self.pre_transform), path) path = osp.join(self.processed_dir, 'pre_filter.pt') fs.torch_save(_repr(self.pre_filter), path) if self.log and 'PYTEST_CURRENT_TEST' not in os.environ: print('Done!', file=sys.stderr) def __len__(self) -> int: r"""The number of examples in the dataset.""" return len(self.indices()) def __getitem__( self, idx: Union[int, np.integer, IndexType], ) -> Union['Dataset', BaseData]: r"""In case :obj:`idx` is of type integer, will return the data object at index :obj:`idx` (and transforms it in case :obj:`transform` is present). In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool, will return a subset of the dataset at the specified indices. """ if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): data = self.get(self.indices()[idx]) data = data if self.transform is None else self.transform(data) return data else: return self.index_select(idx) def __iter__(self) -> Iterator[BaseData]: for i in range(len(self)): yield self[i] def index_select(self, idx: IndexType) -> 'Dataset': r"""Creates a subset of the dataset from specified indices :obj:`idx`. Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. """ indices = self.indices() if isinstance(idx, slice): start, stop, step = idx.start, idx.stop, idx.step # Allow floating-point slicing, e.g., dataset[:0.9] if isinstance(start, float): start = round(start * len(self)) if isinstance(stop, float): stop = round(stop * len(self)) idx = slice(start, stop, step) indices = indices[idx] elif isinstance(idx, Tensor) and idx.dtype == torch.long: return self.index_select(idx.flatten().tolist()) elif isinstance(idx, Tensor) and idx.dtype == torch.bool: idx = idx.flatten().nonzero(as_tuple=False) return self.index_select(idx.flatten().tolist()) elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: return self.index_select(idx.flatten().tolist()) elif isinstance(idx, np.ndarray) and idx.dtype == bool: idx = idx.flatten().nonzero()[0] return self.index_select(idx.flatten().tolist()) elif isinstance(idx, Sequence) and not isinstance(idx, str): indices = [indices[i] for i in idx] else: raise IndexError( f"Only slices (':'), list, tuples, torch.tensor and " f"np.ndarray of dtype long or bool are valid indices (got " f"'{type(idx).__name__}')") dataset = copy.copy(self) dataset._indices = indices return dataset def shuffle( self, return_perm: bool = False, ) -> Union['Dataset', Tuple['Dataset', Tensor]]: r"""Randomly shuffles the examples in the dataset. Args: return_perm (bool, optional): If set to :obj:`True`, will also return the random permutation used to shuffle the dataset. (default: :obj:`False`) """ perm = torch.randperm(len(self)) dataset = self.index_select(perm) return (dataset, perm) if return_perm is True else dataset def __repr__(self) -> str: arg_repr = str(len(self)) if len(self) > 1 else '' return f'{self.__class__.__name__}({arg_repr})' def get_summary(self) -> Any: r"""Collects summary statistics for the dataset.""" from torch_geometric.data.summary import Summary return Summary.from_dataset(self) def print_summary(self, fmt: str = "psql") -> None: r"""Prints summary statistics of the dataset to the console. Args: fmt (str, optional): Summary tables format. Available table formats can be found `here `__. (default: :obj:`"psql"`) """ print(self.get_summary().format(fmt=fmt)) def to_datapipe(self) -> Any: r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`. The returned instance can then be used with :pyg:`PyG's` built-in :class:`DataPipes` for batching graphs as follows: .. code-block:: python from torch_geometric.datasets import QM9 dp = QM9(root='./data/QM9/').to_datapipe() dp = dp.batch_graphs(batch_size=2, drop_last=True) for batch in dp: pass See the `PyTorch tutorial `_ for further background on DataPipes. """ from torch_geometric.data.datapipes import DatasetAdapter return DatasetAdapter(self) def overrides_method(cls, method_name: str) -> bool: from torch_geometric.data import InMemoryDataset if method_name in cls.__dict__: return True out = False for base in cls.__bases__: if base != Dataset and base != InMemoryDataset: out |= overrides_method(base, method_name) return out def to_list(value: Any) -> Sequence: if isinstance(value, Sequence) and not isinstance(value, str): return value else: return [value] def files_exist(files: List[str]) -> bool: # NOTE: We return `False` in case `files` is empty, leading to a # re-processing of files on every instantiation. return len(files) != 0 and all([fs.exists(f) for f in files]) def _repr(obj: Any) -> str: if obj is None: return 'None' return re.sub('(<.*?)\\s.*(>)', r'\1\2', str(obj)) def _get_flattened_data_list(data_list: Iterable[Any]) -> List[BaseData]: outs: List[BaseData] = [] for data in data_list: if isinstance(data, BaseData): outs.append(data) elif isinstance(data, (tuple, list)): outs.extend(_get_flattened_data_list(data)) elif isinstance(data, dict): outs.extend(_get_flattened_data_list(data.values())) return outs ================================================ FILE: torch_geometric/data/download.py ================================================ import os import os.path as osp import ssl import sys import urllib from typing import Optional import fsspec from torch_geometric.io import fs def download_url( url: str, folder: str, log: bool = True, filename: Optional[str] = None, ): r"""Downloads the content of an URL to a specific folder. Args: url (str): The URL. folder (str): The folder. log (bool, optional): If :obj:`False`, will not print anything to the console. (default: :obj:`True`) filename (str, optional): The filename of the downloaded file. If set to :obj:`None`, will correspond to the filename given by the URL. (default: :obj:`None`) """ if filename is None: filename = url.rpartition('/')[2] filename = filename if filename[0] == '?' else filename.split('?')[0] path = osp.join(folder, filename) if fs.exists(path): # pragma: no cover if log and 'PYTEST_CURRENT_TEST' not in os.environ: print(f'Using existing file {filename}', file=sys.stderr) return path if log and 'PYTEST_CURRENT_TEST' not in os.environ: print(f'Downloading {url}', file=sys.stderr) os.makedirs(folder, exist_ok=True) context = ssl._create_unverified_context() data = urllib.request.urlopen(url, context=context) with fsspec.open(path, 'wb') as f: # workaround for https://bugs.python.org/issue42853 while True: chunk = data.read(10 * 1024 * 1024) if not chunk: break f.write(chunk) return path def download_google_url( id: str, folder: str, filename: str, log: bool = True, ): r"""Downloads the content of a Google Drive ID to a specific folder.""" url = f'https://drive.usercontent.google.com/download?id={id}&confirm=t' return download_url(url, folder, log, filename) ================================================ FILE: torch_geometric/data/extract.py ================================================ import bz2 import gzip import os import os.path as osp import sys import tarfile import zipfile def maybe_log(path: str, log: bool = True) -> None: if log and 'PYTEST_CURRENT_TEST' not in os.environ: print(f'Extracting {path}', file=sys.stderr) def extract_tar( path: str, folder: str, mode: str = 'r:gz', log: bool = True, ) -> None: r"""Extracts a tar archive to a specific folder. Args: path (str): The path to the tar archive. folder (str): The folder. mode (str, optional): The compression mode. (default: :obj:`"r:gz"`) log (bool, optional): If :obj:`False`, will not print anything to the console. (default: :obj:`True`) """ maybe_log(path, log) with tarfile.open(path, mode) as f: f.extractall(folder, filter='data') def extract_zip(path: str, folder: str, log: bool = True) -> None: r"""Extracts a zip archive to a specific folder. Args: path (str): The path to the tar archive. folder (str): The folder. log (bool, optional): If :obj:`False`, will not print anything to the console. (default: :obj:`True`) """ maybe_log(path, log) with zipfile.ZipFile(path, 'r') as f: f.extractall(folder) def extract_bz2(path: str, folder: str, log: bool = True) -> None: r"""Extracts a bz2 archive to a specific folder. Args: path (str): The path to the tar archive. folder (str): The folder. log (bool, optional): If :obj:`False`, will not print anything to the console. (default: :obj:`True`) """ maybe_log(path, log) path = osp.abspath(path) with bz2.open(path, 'r') as r: with open(osp.join(folder, '.'.join(path.split('.')[:-1])), 'wb') as w: w.write(r.read()) def extract_gz(path: str, folder: str, log: bool = True) -> None: r"""Extracts a gz archive to a specific folder. Args: path (str): The path to the tar archive. folder (str): The folder. log (bool, optional): If :obj:`False`, will not print anything to the console. (default: :obj:`True`) """ maybe_log(path, log) path = osp.abspath(path) with gzip.open(path, 'r') as r: with open(osp.join(folder, '.'.join(path.split('.')[:-1])), 'wb') as w: w.write(r.read()) ================================================ FILE: torch_geometric/data/feature_store.py ================================================ r"""This class defines the abstraction for a backend-agnostic feature store. The goal of the feature store is to abstract away all node and edge feature memory management so that varying implementations can allow for independent scale-out. This particular feature store abstraction makes a few key assumptions: * The features we care about storing are node and edge features of a graph. To this end, the attributes that the feature store supports include a `group_name` (e.g. a heterogeneous node name or a heterogeneous edge type), an `attr_name` (e.g. `x` or `edge_attr`), and an index. * A feature can be uniquely identified from any associated attributes specified in `TensorAttr`. It is the job of a feature store implementer class to handle these assumptions properly. For example, a simple in-memory feature store implementation may concatenate all metadata values with a feature index and use this as a unique index in a KV store. More complicated implementations may choose to partition features in interesting manners based on the provided metadata. Major TODOs for future implementation: * Async `put` and `get` functionality """ import copy from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Any, List, Optional, Tuple, Union import numpy as np import torch from torch import Tensor from torch_geometric.typing import FeatureTensorType, NodeType from torch_geometric.utils.mixin import CastMixin # We allow indexing with a tensor, numpy array, Python slicing, or a single # integer index. IndexType = Union[torch.Tensor, np.ndarray, slice, int] class _FieldStatus(Enum): UNSET = None @dataclass class TensorAttr(CastMixin): r"""Defines the attributes of a :class:`FeatureStore` tensor. It holds all the parameters necessary to uniquely identify a tensor from the :class:`FeatureStore`. Note that the order of the attributes is important; this is the order in which attributes must be provided for indexing calls. :class:`FeatureStore` implementations can define a different ordering by overriding :meth:`TensorAttr.__init__`. """ # The group name that the tensor corresponds to. Defaults to UNSET. group_name: Optional[NodeType] = _FieldStatus.UNSET # The name of the tensor within its group. Defaults to UNSET. attr_name: Optional[str] = _FieldStatus.UNSET # The node indices the rows of the tensor correspond to. Defaults to UNSET. index: Optional[IndexType] = _FieldStatus.UNSET # Convenience methods ##################################################### def is_set(self, key: str) -> bool: r"""Whether an attribute is set in :obj:`TensorAttr`.""" assert key in self.__dataclass_fields__ return getattr(self, key) != _FieldStatus.UNSET def is_fully_specified(self) -> bool: r"""Whether the :obj:`TensorAttr` has no unset fields.""" return all([self.is_set(key) for key in self.__dataclass_fields__]) def update(self, attr: 'TensorAttr') -> 'TensorAttr': r"""Updates an :class:`TensorAttr` with set attributes from another :class:`TensorAttr`. """ for key in self.__dataclass_fields__: if attr.is_set(key): setattr(self, key, getattr(attr, key)) return self class AttrView(CastMixin): r"""Defines a view of a :class:`FeatureStore` that is obtained from a specification of attributes on the feature store. The view stores a reference to the backing feature store as well as a :class:`TensorAttr` object that represents the view's state. Users can create views either using the :class:`AttrView` constructor, :meth:`FeatureStore.view`, or by incompletely indexing a feature store. For example, the following calls all create views: .. code-block:: python store[group_name] store[group_name].feat store[group_name, feat] While the following calls all materialize those views and produce tensors by either calling the view or fully-specifying the view: .. code-block:: python store[group_name]() store[group_name].feat[index] store[group_name, feat][index] """ def __init__(self, store: 'FeatureStore', attr: TensorAttr): self.__dict__['_store'] = store self.__dict__['_attr'] = attr # Advanced indexing ####################################################### def __getattr__(self, key: Any) -> Union['AttrView', FeatureTensorType]: r"""Sets the first unset field of the backing :class:`TensorAttr` object to the attribute. This allows for :class:`AttrView` to be indexed by different values of attributes, in order. In particular, for a feature store that we want to index by :obj:`group_name` and :obj:`attr_name`, the following code will do so: .. code-block:: python store[group, attr] store[group].attr store.group.attr """ out = copy.copy(self) # Find the first attribute name that is UNSET: attr_name: Optional[str] = None for field in out._attr.__dataclass_fields__: if getattr(out._attr, field) == _FieldStatus.UNSET: attr_name = field break if attr_name is None: raise AttributeError(f"Cannot access attribute '{key}' on view " f"'{out}' as all attributes have already " f"been set in this view") setattr(out._attr, attr_name, key) if out._attr.is_fully_specified(): return out._store.get_tensor(out._attr) return out def __getitem__(self, key: Any) -> Union['AttrView', FeatureTensorType]: r"""Sets the first unset field of the backing :class:`TensorAttr` object to the attribute via indexing. This allows for :class:`AttrView` to be indexed by different values of attributes, in order. In particular, for a feature store that we want to index by :obj:`group_name` and :obj:`attr_name`, the following code will do so: .. code-block:: python store[group, attr] store[group][attr] """ return self.__getattr__(key) # Setting attributes ###################################################### def __setattr__(self, key: str, value: Any): r"""Supports attribute assignment to the backing :class:`TensorAttr` of an :class:`AttrView`. This allows for :class:`AttrView` objects to set their backing attribute values. In particular, the following operation sets the :obj:`index` of an :class:`AttrView`: .. code-block:: python view = store.view(group_name) view.index = torch.tensor([1, 2, 3]) """ if key not in self._attr.__dataclass_fields__: raise ValueError(f"Attempted to set nonexistent attribute '{key}' " f"(acceptable attributes are " f"{self._attr.__dataclass_fields__})") setattr(self._attr, key, value) def __setitem__(self, key: str, value: Any): r"""Supports attribute assignment to the backing :class:`TensorAttr` of an :class:`AttrView` via indexing. This allows for :class:`AttrView` objects to set their backing attribute values. In particular, the following operation sets the `index` of an :class:`AttrView`: .. code-block:: python view = store.view(TensorAttr(group_name)) view['index'] = torch.tensor([1, 2, 3]) """ self.__setattr__(key, value) # Miscellaneous built-ins ################################################# def __call__(self) -> FeatureTensorType: r"""Supports :class:`AttrView` as a callable to force retrieval from the currently specified attributes. In particular, this passes the current :class:`TensorAttr` object to a GET call, regardless of whether all attributes have been specified. It returns the result of this call. In particular, the following operation returns a tensor by performing a GET operation on the backing feature store: .. code-block:: python store[group_name, attr_name]() """ attr = copy.copy(self._attr) for key in attr.__dataclass_fields__: # Set all UNSET values to None. if not attr.is_set(key): setattr(attr, key, None) return self._store.get_tensor(attr) def __copy__(self) -> 'AttrView': out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = value out.__dict__['_attr'] = copy.copy(out.__dict__['_attr']) return out def __eq__(self, obj: Any) -> bool: r"""Compares two :class:`AttrView` objects by checking equality of their :class:`FeatureStore` references and :class:`TensorAttr` attributes. """ if not isinstance(obj, AttrView): return False return self._store == obj._store and self._attr == obj._attr def __repr__(self) -> str: return (f'{self.__class__.__name__}(store={self._store}, ' f'attr={self._attr})') # TODO (manan, matthias) Ideally, we want to let `FeatureStore` inherit from # `MutableMapping` to clearly indicate its behavior and usage to the user. # However, having `MutableMapping` as a base class leads to strange behavior # in combination with PyTorch and PyTorch Lightning, in particular since these # libraries use customized logic during mini-batch for `Mapping` base classes. class FeatureStore(ABC): r"""An abstract base class to access features from a remote feature store. Args: tensor_attr_cls (TensorAttr, optional): A user-defined :class:`TensorAttr` class to customize the required attributes and their ordering to unique identify tensor values. (default: :obj:`None`) """ _tensor_attr_cls: TensorAttr def __init__(self, tensor_attr_cls: Optional[Any] = None): super().__init__() self.__dict__['_tensor_attr_cls'] = tensor_attr_cls or TensorAttr # Core (CRUD) ############################################################# @abstractmethod def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: r"""To be implemented by :class:`FeatureStore` subclasses.""" def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: r"""Synchronously adds a :obj:`tensor` to the :class:`FeatureStore`. Returns whether insertion was successful. Args: tensor (torch.Tensor or np.ndarray): The feature tensor to be added. *args: Arguments passed to :class:`TensorAttr`. **kwargs: Keyword arguments passed to :class:`TensorAttr`. Raises: ValueError: If the input :class:`TensorAttr` is not fully specified. """ attr = self._tensor_attr_cls.cast(*args, **kwargs) if not attr.is_fully_specified(): raise ValueError(f"The input TensorAttr '{attr}' is not fully " f"specified. Please fully-specify the input by " f"specifying all 'UNSET' fields") return self._put_tensor(tensor, attr) @abstractmethod def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: r"""To be implemented by :class:`FeatureStore` subclasses.""" def get_tensor( self, *args, convert_type: bool = False, **kwargs, ) -> FeatureTensorType: r"""Synchronously obtains a :class:`tensor` from the :class:`FeatureStore`. Args: *args: Arguments passed to :class:`TensorAttr`. convert_type (bool, optional): Whether to convert the type of the output tensor to the type of the attribute index. (default: :obj:`False`) **kwargs: Keyword arguments passed to :class:`TensorAttr`. Raises: ValueError: If the input :class:`TensorAttr` is not fully specified. """ attr = self._tensor_attr_cls.cast(*args, **kwargs) if not attr.is_fully_specified(): raise ValueError(f"The input TensorAttr '{attr}' is not fully " f"specified. Please fully-specify the input by " f"specifying all 'UNSET' fields.") tensor = self._get_tensor(attr) if convert_type: tensor = self._to_type(attr, tensor) return tensor def _multi_get_tensor( self, attrs: List[TensorAttr], ) -> List[Optional[FeatureTensorType]]: r"""To be implemented by :class:`FeatureStore` subclasses.""" return [self._get_tensor(attr) for attr in attrs] def multi_get_tensor( self, attrs: List[TensorAttr], convert_type: bool = False, ) -> List[FeatureTensorType]: r"""Synchronously obtains a list of tensors from the :class:`FeatureStore` for each tensor associated with the attributes in :obj:`attrs`. .. note:: The default implementation simply iterates over all calls to :meth:`get_tensor`. Implementer classes that can provide additional, more performant functionality are recommended to to override this method. Args: attrs (List[TensorAttr]): A list of input :class:`TensorAttr` objects that identify the tensors to obtain. convert_type (bool, optional): Whether to convert the type of the output tensor to the type of the attribute index. (default: :obj:`False`) Raises: ValueError: If any input :class:`TensorAttr` is not fully specified. """ attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs] bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()] if len(bad_attrs) > 0: raise ValueError( f"The input TensorAttr(s) '{bad_attrs}' are not fully " f"specified. Please fully-specify them by specifying all " f"'UNSET' fields") tensors = self._multi_get_tensor(attrs) if convert_type: tensors = [ self._to_type(attr, tensor) for attr, tensor in zip(attrs, tensors) ] return tensors @abstractmethod def _remove_tensor(self, attr: TensorAttr) -> bool: r"""To be implemented by :obj:`FeatureStore` subclasses.""" def remove_tensor(self, *args, **kwargs) -> bool: r"""Removes a tensor from the :class:`FeatureStore`. Returns whether deletion was successful. Args: *args: Arguments passed to :class:`TensorAttr`. **kwargs: Keyword arguments passed to :class:`TensorAttr`. Raises: ValueError: If the input :class:`TensorAttr` is not fully specified. """ attr = self._tensor_attr_cls.cast(*args, **kwargs) if not attr.is_fully_specified(): raise ValueError(f"The input TensorAttr '{attr}' is not fully " f"specified. Please fully-specify the input by " f"specifying all 'UNSET' fields.") return self._remove_tensor(attr) def update_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: r"""Updates a :obj:`tensor` in the :class:`FeatureStore` with a new value. Returns whether the update was successful. .. note:: Implementer classes can choose to define more efficient update methods; the default performs a removal and insertion. Args: tensor (torch.Tensor or np.ndarray): The feature tensor to be updated. *args: Arguments passed to :class:`TensorAttr`. **kwargs: Keyword arguments passed to :class:`TensorAttr`. """ attr = self._tensor_attr_cls.cast(*args, **kwargs) self.remove_tensor(attr) return self.put_tensor(tensor, attr) # Additional methods ###################################################### @abstractmethod def _get_tensor_size(self, attr: TensorAttr) -> Optional[Tuple[int, ...]]: pass def get_tensor_size(self, *args, **kwargs) -> Optional[Tuple[int, ...]]: r"""Obtains the size of a tensor given its :class:`TensorAttr`, or :obj:`None` if the tensor does not exist. """ attr = self._tensor_attr_cls.cast(*args, **kwargs) if not attr.is_set('index'): attr.index = None return self._get_tensor_size(attr) @abstractmethod def get_all_tensor_attrs(self) -> List[TensorAttr]: r"""Returns all registered tensor attributes.""" # `AttrView` methods ###################################################### def view(self, *args, **kwargs) -> AttrView: r"""Returns a view of the :class:`FeatureStore` given a not yet fully-specified :class:`TensorAttr`. """ attr = self._tensor_attr_cls.cast(*args, **kwargs) return AttrView(self, attr) # Helper functions ######################################################## @staticmethod def _to_type( attr: TensorAttr, tensor: FeatureTensorType, ) -> FeatureTensorType: if isinstance(attr.index, Tensor) and isinstance(tensor, np.ndarray): return torch.from_numpy(tensor) if isinstance(attr.index, np.ndarray) and isinstance(tensor, Tensor): return tensor.detach().cpu().numpy() return tensor # Python built-ins ######################################################## def __setitem__(self, key: TensorAttr, value: FeatureTensorType): r"""Supports :obj:`store[tensor_attr] = tensor`.""" # CastMixin will handle the case of key being a tuple or TensorAttr # object: key = self._tensor_attr_cls.cast(key) assert key.is_fully_specified() self.put_tensor(value, key) def __getitem__(self, key: TensorAttr) -> Any: r"""Supports pythonic indexing into the :class:`FeatureStore`. In particular, the following rules are followed for indexing: * A fully-specified :obj:`key` will produce a tensor output. * A partially-specified :obj:`key` will produce an :class:`AttrView` output, which is a view on the :class:`FeatureStore`. If a view is called, it will produce a tensor output from the corresponding (partially specified) attributes. """ # CastMixin will handle the case of key being a tuple or TensorAttr: attr = self._tensor_attr_cls.cast(key) if attr.is_fully_specified(): return self.get_tensor(attr) # If the view is not fully-specified, return a :class:`AttrView`: return self.view(attr) def __delitem__(self, attr: TensorAttr): r"""Supports :obj:`del store[tensor_attr]`.""" # CastMixin will handle the case of key being a tuple or TensorAttr # object: attr = self._tensor_attr_cls.cast(attr) attr = copy.copy(attr) for key in attr.__dataclass_fields__: # Set all UNSET values to None. if not attr.is_set(key): setattr(attr, key, None) self.remove_tensor(attr) def __iter__(self): raise NotImplementedError def __eq__(self, obj: object) -> bool: return id(self) == id(obj) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/data/graph_store.py ================================================ r"""This class defines the abstraction for a backend-agnostic graph store. The goal of the graph store is to abstract away all graph edge index memory management so that varying implementations can allow for independent scale-out. This particular graph store abstraction makes a few key assumptions: * The edge indices we care about storing are represented either in COO, CSC, or CSR format. They can be uniquely identified by an edge type (in PyG, this is a tuple of the source node, relation type, and destination node). * Edge indices are static once they are stored in the graph. That is, we do not support dynamic modification of edge indices once they have been inserted into the graph store. It is the job of a graph store implementer class to handle these assumptions properly. For example, a simple in-memory graph store implementation may concatenate all metadata values with an edge index and use this as a unique index in a KV store. More complicated implementations may choose to partition the graph in interesting manners based on the provided metadata. """ import copy from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple from torch import Tensor from torch_geometric.index import index2ptr, ptr2index from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor from torch_geometric.utils import index_sort from torch_geometric.utils.mixin import CastMixin # The output of converting between two types in the GraphStore is a Tuple of # dictionaries: row, col, and perm. The dictionaries are keyed by the edge # type of the input edge attribute. # * The row dictionary contains the row tensor for COO, the row pointer for # CSR, or the row tensor for CSC # * The col dictionary contains the col tensor for COO, the col tensor for # CSR, or the col pointer for CSC # * The perm dictionary contains the permutation of edges that was applied # in converting between formats, if applicable. ConversionOutputType = Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[EdgeType, OptTensor]] class EdgeLayout(Enum): COO = 'coo' CSC = 'csc' CSR = 'csr' @dataclass class EdgeAttr(CastMixin): r"""Defines the attributes of a :obj:`GraphStore` edge. It holds all the parameters necessary to uniquely identify an edge from the :class:`GraphStore`. Note that the order of the attributes is important; this is the order in which attributes must be provided for indexing calls. :class:`GraphStore` implementations can define a different ordering by overriding :meth:`EdgeAttr.__init__`. """ # The type of the edge: edge_type: EdgeType # The layout of the edge representation: layout: EdgeLayout # Whether the edge index is sorted by destination node. Useful for # avoiding sorting costs when performing neighbor sampling, and only # meaningful for COO (CSC is sorted and CSR is not sorted by definition): is_sorted: bool = False # The number of source and destination nodes in this edge type: size: Optional[Tuple[int, int]] = None # NOTE we define __init__ to force-cast layout def __init__( self, edge_type: EdgeType, layout: EdgeLayout, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): layout = EdgeLayout(layout) if layout == EdgeLayout.CSR and is_sorted: raise ValueError("Cannot create a 'CSR' edge attribute with " "option 'is_sorted=True'") if layout == EdgeLayout.CSC: is_sorted = True self.edge_type = edge_type self.layout = layout self.is_sorted = is_sorted self.size = size class GraphStore(ABC): r"""An abstract base class to access edges from a remote graph store. Args: edge_attr_cls (EdgeAttr, optional): A user-defined :class:`EdgeAttr` class to customize the required attributes and their ordering to uniquely identify edges. (default: :obj:`None`) """ def __init__(self, edge_attr_cls: Optional[Any] = None): super().__init__() self.__dict__['_edge_attr_cls'] = edge_attr_cls or EdgeAttr # Core (CRUD) ############################################################# @abstractmethod def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: r"""To be implemented by :class:`GraphStore` subclasses.""" def put_edge_index(self, edge_index: EdgeTensorType, *args, **kwargs) -> bool: r"""Synchronously adds an :obj:`edge_index` tuple to the :class:`GraphStore`. Returns whether insertion was successful. Args: edge_index (Tuple[torch.Tensor, torch.Tensor]): The :obj:`edge_index` tuple in a format specified in :class:`EdgeAttr`. *args: Arguments passed to :class:`EdgeAttr`. **kwargs: Keyword arguments passed to :class:`EdgeAttr`. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._put_edge_index(edge_index, edge_attr) @abstractmethod def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: r"""To be implemented by :class:`GraphStore` subclasses.""" def get_edge_index(self, *args, **kwargs) -> EdgeTensorType: r"""Synchronously obtains an :obj:`edge_index` tuple from the :class:`GraphStore`. Args: *args: Arguments passed to :class:`EdgeAttr`. **kwargs: Keyword arguments passed to :class:`EdgeAttr`. Raises: KeyError: If the :obj:`edge_index` corresponding to the input :class:`EdgeAttr` was not found. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) edge_index = self._get_edge_index(edge_attr) if edge_index is None: raise KeyError(f"'edge_index' for '{edge_attr}' not found") return edge_index @abstractmethod def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: r"""To be implemented by :class:`GraphStore` subclasses.""" def remove_edge_index(self, *args, **kwargs) -> bool: r"""Synchronously deletes an :obj:`edge_index` tuple from the :class:`GraphStore`. Returns whether deletion was successful. Args: *args: Arguments passed to :class:`EdgeAttr`. **kwargs: Keyword arguments passed to :class:`EdgeAttr`. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._remove_edge_index(edge_attr) @abstractmethod def get_all_edge_attrs(self) -> List[EdgeAttr]: r"""Returns all registered edge attributes.""" # Layout Conversion ####################################################### def coo( self, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: r"""Returns the edge indices in the :class:`GraphStore` in COO format. Args: edge_types (List[Any], optional): The edge types of edge indices to obtain. If set to :obj:`None`, will return the edge indices of all existing edge types. (default: :obj:`None`) store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ return self._edges_to_layout(EdgeLayout.COO, edge_types, store) def csr( self, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: r"""Returns the edge indices in the :class:`GraphStore` in CSR format. Args: edge_types (List[Any], optional): The edge types of edge indices to obtain. If set to :obj:`None`, will return the edge indices of all existing edge types. (default: :obj:`None`) store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ return self._edges_to_layout(EdgeLayout.CSR, edge_types, store) def csc( self, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: r"""Returns the edge indices in the :class:`GraphStore` in CSC format. Args: edge_types (List[Any], optional): The edge types of edge indices to obtain. If set to :obj:`None`, will return the edge indices of all existing edge types. (default: :obj:`None`) store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ return self._edges_to_layout(EdgeLayout.CSC, edge_types, store) # Python built-ins ######################################################## def __setitem__(self, key: EdgeAttr, value: EdgeTensorType): self.put_edge_index(value, key) def __getitem__(self, key: EdgeAttr) -> Optional[EdgeTensorType]: return self.get_edge_index(key) def __delitem__(self, key: EdgeAttr): return self.remove_edge_index(key) def __repr__(self) -> str: return f'{self.__class__.__name__}()' # Helper methods ########################################################## def _edge_to_layout( self, attr: EdgeAttr, layout: EdgeLayout, store: bool = False, ) -> Tuple[Tensor, Tensor, OptTensor]: (row, col), perm = self.get_edge_index(attr), None if layout == EdgeLayout.COO: # COO output requested: if attr.layout == EdgeLayout.CSR: # CSR->COO row = ptr2index(row) elif attr.layout == EdgeLayout.CSC: # CSC->COO col = ptr2index(col) elif layout == EdgeLayout.CSR: # CSR output requested: if attr.layout == EdgeLayout.CSC: # CSC->COO col = ptr2index(col) if attr.layout != EdgeLayout.CSR: # COO->CSR num_rows = attr.size[0] if attr.size is not None else int( row.max()) + 1 row, perm = index_sort(row, max_value=num_rows) col = col[perm] row = index2ptr(row, num_rows) else: # CSC output requested: if attr.layout == EdgeLayout.CSR: # CSR->COO row = ptr2index(row) if attr.layout != EdgeLayout.CSC: # COO->CSC if hasattr(self, 'meta') and self.meta.get('is_hetero', False): # Hotfix for `LocalGraphStore`, where in heterogeneous # graphs, edge indices for different edge types have # continuous indices not starting at 0. num_cols = int(col.max()) + 1 elif attr.size is not None: num_cols = attr.size[1] else: num_cols = int(col.max()) + 1 if not attr.is_sorted: # Not sorted by destination. col, perm = index_sort(col, max_value=num_cols) row = row[perm] col = index2ptr(col, num_cols) if attr.layout != layout and store: attr = copy.copy(attr) attr.layout = layout if perm is not None: attr.is_sorted = False self.put_edge_index((row, col), attr) return row, col, perm def _edges_to_layout( self, layout: EdgeLayout, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: edge_attrs: List[EdgeAttr] = self.get_all_edge_attrs() if hasattr(self, 'meta'): # `LocalGraphStore` hack. is_hetero = self.meta.get('is_hetero', False) else: is_hetero = all(attr.edge_type is not None for attr in edge_attrs) if not is_hetero: return self._edge_to_layout(edge_attrs[0], layout, store) # Obtain all edge attributes, grouped by type: edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list) for attr in self.get_all_edge_attrs(): edge_type_attrs[attr.edge_type].append(attr) # Check that requested edge types exist and filter: if edge_types is not None: for edge_type in edge_types: if edge_type not in edge_type_attrs: raise ValueError(f"The 'edge_index' of type '{edge_type}' " f"was not found in the graph store.") edge_type_attrs = { key: attr for key, attr in edge_type_attrs.items() if key in edge_types } # Convert layout from its most favorable original layout: row_dict, col_dict, perm_dict = {}, {}, {} for edge_type, attrs in edge_type_attrs.items(): layouts = [attr.layout for attr in attrs] if layout in layouts: # No conversion needed. attr = attrs[layouts.index(layout)] elif EdgeLayout.COO in layouts: # Prefer COO for conversion. attr = attrs[layouts.index(EdgeLayout.COO)] elif EdgeLayout.CSC in layouts: attr = attrs[layouts.index(EdgeLayout.CSC)] elif EdgeLayout.CSR in layouts: attr = attrs[layouts.index(EdgeLayout.CSR)] row_dict[edge_type], col_dict[edge_type], perm_dict[edge_type] = ( self._edge_to_layout(attr, layout, store)) return row_dict, col_dict, perm_dict ================================================ FILE: torch_geometric/data/hetero_data.py ================================================ import copy import re import warnings from collections import defaultdict, namedtuple from collections.abc import Mapping from itertools import chain from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import torch from torch import Tensor from typing_extensions import Self from torch_geometric import Index from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr from torch_geometric.data.data import BaseData, Data, size_repr, warn_or_raise from torch_geometric.data.graph_store import EdgeLayout from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage from torch_geometric.typing import ( DEFAULT_REL, EdgeTensorType, EdgeType, FeatureTensorType, NodeOrEdgeType, NodeType, QueryType, SparseTensor, TensorFrame, torch_frame, ) from torch_geometric.utils import ( bipartite_subgraph, contains_isolated_nodes, is_sparse, is_undirected, mask_select, ) NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage] _DISPLAYED_TYPE_NAME_WARNING: bool = False class HeteroData(BaseData, FeatureStore, GraphStore): r"""A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects. Storage objects can hold either node-level, link-level or graph-level attributes. In general, :class:`~torch_geometric.data.HeteroData` tries to mimic the behavior of a regular **nested** :python:`Python` dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities. .. code-block:: from torch_geometric.data import HeteroData data = HeteroData() # Create two node types "paper" and "author" holding a feature matrix: data['paper'].x = torch.randn(num_papers, num_paper_features) data['author'].x = torch.randn(num_authors, num_authors_features) # Create an edge type "(author, writes, paper)" and building the # graph connectivity: data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges] data['paper'].num_nodes >>> 23 data['author', 'writes', 'paper'].num_edges >>> 52 # PyTorch tensor functionality: data = data.pin_memory() data = data.to('cuda:0', non_blocking=True) Note that there exists multiple ways to create a heterogeneous graph data, *e.g.*: * To initialize a node of type :obj:`"paper"` holding a node feature matrix :obj:`x_paper` named :obj:`x`: .. code-block:: python from torch_geometric.data import HeteroData # (1) Assign attributes after initialization, data = HeteroData() data['paper'].x = x_paper # or (2) pass them as keyword arguments during initialization, data = HeteroData(paper={ 'x': x_paper }) # or (3) pass them as dictionaries during initialization, data = HeteroData({'paper': { 'x': x_paper }}) * To initialize an edge from source node type :obj:`"author"` to destination node type :obj:`"paper"` with relation type :obj:`"writes"` holding a graph connectivity matrix :obj:`edge_index_author_paper` named :obj:`edge_index`: .. code-block:: python # (1) Assign attributes after initialization, data = HeteroData() data['author', 'writes', 'paper'].edge_index = edge_index_author_paper # or (2) pass them as keyword arguments during initialization, data = HeteroData(author__writes__paper={ 'edge_index': edge_index_author_paper }) # or (3) pass them as dictionaries during initialization, data = HeteroData({ ('author', 'writes', 'paper'): { 'edge_index': edge_index_author_paper } }) """ def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): super().__init__() self.__dict__['_global_store'] = BaseStorage(_parent=self) self.__dict__['_node_store_dict'] = {} self.__dict__['_edge_store_dict'] = {} for key, value in chain((_mapping or {}).items(), kwargs.items()): if '__' in key and isinstance(value, Mapping): key = tuple(key.split('__')) if isinstance(value, Mapping): self[key].update(value) else: setattr(self, key, value) @classmethod def from_dict(cls, mapping: Dict[str, Any]) -> Self: r"""Creates a :class:`~torch_geometric.data.HeteroData` object from a dictionary. """ out = cls() for key, value in mapping.items(): if key == '_global_store': out.__dict__['_global_store'] = BaseStorage( _parent=out, **value) elif isinstance(key, str): out._node_store_dict[key] = NodeStorage( _parent=out, _key=key, **value) else: out._edge_store_dict[key] = EdgeStorage( _parent=out, _key=key, **value) return out def __getattr__(self, key: str) -> Any: # `data.*_dict` => Link to node and edge stores. # `data.*` => Link to the `_global_store`. # Using `data.*_dict` is the same as using `collect()` for collecting # nodes and edges features. if hasattr(self._global_store, key): return getattr(self._global_store, key) elif bool(re.search('_dict$', key)): return self.collect(key[:-5]) raise AttributeError(f"'{self.__class__.__name__}' has no " f"attribute '{key}'") def __setattr__(self, key: str, value: Any): # NOTE: We aim to prevent duplicates in node or edge types. if key in self.node_types: raise AttributeError(f"'{key}' is already present as a node type") elif key in self.edge_types: raise AttributeError(f"'{key}' is already present as an edge type") setattr(self._global_store, key, value) def __delattr__(self, key: str): delattr(self._global_store, key) def __getitem__(self, *args: QueryType) -> Any: # `data[*]` => Link to either `_global_store`, _node_store_dict` or # `_edge_store_dict`. # If neither is present, we create a new `Storage` object for the given # node/edge-type. key = self._to_canonical(*args) out = self._global_store.get(key, None) if out is not None: return out if isinstance(key, tuple): return self.get_edge_store(*key) else: return self.get_node_store(key) def __setitem__(self, key: str, value: Any): if key in self.node_types: raise AttributeError(f"'{key}' is already present as a node type") elif key in self.edge_types: raise AttributeError(f"'{key}' is already present as an edge type") self._global_store[key] = value def __delitem__(self, *args: QueryType): # `del data[*]` => Link to `_node_store_dict` or `_edge_store_dict`. key = self._to_canonical(*args) if key in self.edge_types: del self._edge_store_dict[key] elif key in self.node_types: del self._node_store_dict[key] else: del self._global_store[key] def __copy__(self): out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = value out.__dict__['_global_store'] = copy.copy(self._global_store) out._global_store._parent = out out.__dict__['_node_store_dict'] = {} for key, store in self._node_store_dict.items(): out._node_store_dict[key] = copy.copy(store) out._node_store_dict[key]._parent = out out.__dict__['_edge_store_dict'] = {} for key, store in self._edge_store_dict.items(): out._edge_store_dict[key] = copy.copy(store) out._edge_store_dict[key]._parent = out return out def __deepcopy__(self, memo): out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = copy.deepcopy(value, memo) out._global_store._parent = out for key in self._node_store_dict.keys(): out._node_store_dict[key]._parent = out for key in out._edge_store_dict.keys(): out._edge_store_dict[key]._parent = out return out def __repr__(self) -> str: info1 = [size_repr(k, v, 2) for k, v in self._global_store.items()] info2 = [size_repr(k, v, 2) for k, v in self._node_store_dict.items()] info3 = [size_repr(k, v, 2) for k, v in self._edge_store_dict.items()] info = ',\n'.join(info1 + info2 + info3) info = f'\n{info}\n' if len(info) > 0 else info return f'{self.__class__.__name__}({info})' def stores_as(self, data: Self): for node_type in data.node_types: self.get_node_store(node_type) for edge_type in data.edge_types: self.get_edge_store(*edge_type) return self @property def stores(self) -> List[BaseStorage]: r"""Returns a list of all storages of the graph.""" return ([self._global_store] + list(self.node_stores) + list(self.edge_stores)) @property def node_types(self) -> List[NodeType]: r"""Returns a list of all node types of the graph.""" return list(self._node_store_dict.keys()) @property def node_stores(self) -> List[NodeStorage]: r"""Returns a list of all node storages of the graph.""" return list(self._node_store_dict.values()) @property def edge_types(self) -> List[EdgeType]: r"""Returns a list of all edge types of the graph.""" return list(self._edge_store_dict.keys()) @property def edge_stores(self) -> List[EdgeStorage]: r"""Returns a list of all edge storages of the graph.""" return list(self._edge_store_dict.values()) def node_items(self) -> List[Tuple[NodeType, NodeStorage]]: r"""Returns a list of node type and node storage pairs.""" return list(self._node_store_dict.items()) def edge_items(self) -> List[Tuple[EdgeType, EdgeStorage]]: r"""Returns a list of edge type and edge storage pairs.""" return list(self._edge_store_dict.items()) @property def input_type(self) -> Optional[Union[NodeType, EdgeType]]: r"""Returns the seed/input node/edge type of the graph in case it refers to a sampled subgraph, *e.g.*, obtained via :class:`~torch_geometric.loader.NeighborLoader` or :class:`~torch_geometric.loader.LinkNeighborLoader`. """ for node_type, store in self.node_items(): if hasattr(store, 'input_id'): return node_type for edge_type, store in self.edge_items(): if hasattr(store, 'input_id'): return edge_type return None def to_dict(self) -> Dict[str, Any]: out_dict: Dict[str, Any] = {} out_dict['_global_store'] = self._global_store.to_dict() for key, store in chain(self._node_store_dict.items(), self._edge_store_dict.items()): out_dict[key] = store.to_dict() return out_dict def to_namedtuple(self) -> NamedTuple: field_names = list(self._global_store.keys()) field_values = list(self._global_store.values()) field_names += [ '__'.join(key) if isinstance(key, tuple) else key for key in self.node_types + self.edge_types ] field_values += [ store.to_namedtuple() for store in self.node_stores + self.edge_stores ] DataTuple = namedtuple('DataTuple', field_names) return DataTuple(*field_values) def set_value_dict( self, key: str, value_dict: Dict[str, Any], ) -> Self: r"""Sets the values in the dictionary :obj:`value_dict` to the attribute with name :obj:`key` to all node/edge types present in the dictionary. .. code-block:: python data = HeteroData() data.set_value_dict('x', { 'paper': torch.randn(4, 16), 'author': torch.randn(8, 32), }) print(data['paper'].x) """ for k, v in (value_dict or {}).items(): self[k][key] = v return self def update(self, data: Self) -> Self: for store in data.stores: for key, value in store.items(): self[store._key][key] = value return self def __cat_dim__(self, key: str, value: Any, store: Optional[NodeOrEdgeStorage] = None, *args, **kwargs) -> Any: if is_sparse(value) and ('adj' in key or 'edge_index' in key): return (0, 1) elif isinstance(store, EdgeStorage) and 'index' in key: return -1 return 0 def __inc__(self, key: str, value: Any, store: Optional[NodeOrEdgeStorage] = None, *args, **kwargs) -> Any: if 'batch' in key and isinstance(value, Tensor): if isinstance(value, Index): return value.get_dim_size() return int(value.max()) + 1 elif isinstance(store, EdgeStorage) and 'index' in key: return torch.tensor(store.size()).view(2, 1) else: return 0 @property def num_nodes(self) -> Optional[int]: r"""Returns the number of nodes in the graph.""" return super().num_nodes @property def num_node_features(self) -> Dict[NodeType, int]: r"""Returns the number of features per node type in the graph.""" return { key: store.num_node_features for key, store in self._node_store_dict.items() } @property def num_features(self) -> Dict[NodeType, int]: r"""Returns the number of features per node type in the graph. Alias for :py:attr:`~num_node_features`. """ return self.num_node_features @property def num_edge_features(self) -> Dict[EdgeType, int]: r"""Returns the number of features per edge type in the graph.""" return { key: store.num_edge_features for key, store in self._edge_store_dict.items() } def has_isolated_nodes(self) -> bool: r"""Returns :obj:`True` if the graph contains isolated nodes.""" edge_index, _, _ = to_homogeneous_edge_index(self) return contains_isolated_nodes(edge_index, num_nodes=self.num_nodes) def is_undirected(self) -> bool: r"""Returns :obj:`True` if graph edges are undirected.""" edge_index, _, _ = to_homogeneous_edge_index(self) return is_undirected(edge_index, num_nodes=self.num_nodes) def validate(self, raise_on_error: bool = True) -> bool: r"""Validates the correctness of the data.""" cls_name = self.__class__.__name__ status = True node_types = set(self.node_types) num_src_node_types = {src for src, _, _ in self.edge_types} num_dst_node_types = {dst for _, _, dst in self.edge_types} dangling_types = (num_src_node_types | num_dst_node_types) - node_types if len(dangling_types) > 0: status = False warn_or_raise( f"The node types {dangling_types} are referenced in edge " f"types but do not exist as node types", raise_on_error) dangling_types = node_types - (num_src_node_types | num_dst_node_types) if len(dangling_types) > 0: warn_or_raise( # May be intended. f"The node types {dangling_types} are isolated and are not " f"referenced by any edge type ", raise_on_error=False) for edge_type, store in self._edge_store_dict.items(): src, _, dst = edge_type num_src_nodes = self[src].num_nodes num_dst_nodes = self[dst].num_nodes if num_src_nodes is None: status = False warn_or_raise( f"'num_nodes' is undefined in node type '{src}' of " f"'{cls_name}'", raise_on_error) if num_dst_nodes is None: status = False warn_or_raise( f"'num_nodes' is undefined in node type '{dst}' of " f"'{cls_name}'", raise_on_error) if 'edge_index' in store: if (store.edge_index.dim() != 2 or store.edge_index.size(0) != 2): status = False warn_or_raise( f"'edge_index' of edge type {edge_type} needs to be " f"of shape [2, num_edges] in '{cls_name}' (found " f"{store.edge_index.size()})", raise_on_error) if 'edge_index' in store and store.edge_index.numel() > 0: if store.edge_index.min() < 0: status = False warn_or_raise( f"'edge_index' of edge type {edge_type} contains " f"negative indices in '{cls_name}' " f"(found {int(store.edge_index.min())})", raise_on_error) if (num_src_nodes is not None and store.edge_index[0].max() >= num_src_nodes): status = False warn_or_raise( f"'edge_index' of edge type {edge_type} contains " f"larger source indices than the number of nodes " f"({num_src_nodes}) of this node type in '{cls_name}' " f"(found {int(store.edge_index[0].max())})", raise_on_error) if (num_dst_nodes is not None and store.edge_index[1].max() >= num_dst_nodes): status = False warn_or_raise( f"'edge_index' of edge type {edge_type} contains " f"larger destination indices than the number of nodes " f"({num_dst_nodes}) of this node type in '{cls_name}' " f"(found {int(store.edge_index[1].max())})", raise_on_error) return status def connected_components(self) -> List[Self]: r"""Extracts connected components of the heterogeneous graph using a union-find algorithm. The components are returned as a list of :class:`~torch_geometric.data.HeteroData` objects. .. code-block:: data = HeteroData() data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) data["blue"].x = torch.tensor([[5.0], [6.0]]) data["red", "to", "red"].edge_index = torch.tensor( [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long ) components = data.connected_components() print(len(components)) >>> 4 print(components[0]) >>> HeteroData( red={x: tensor([[1.], [2.]])}, blue={x: tensor([[]])}, red, to, red={edge_index: tensor([[0, 1], [1, 0]])} ) Returns: List[HeteroData]: A list of connected components. """ # Initialize union-find structures self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {} self._ranks: Dict[Tuple[str, int], int] = {} # Union-Find algorithm to find connected components for edge_type in self.edge_types: src, _, dst = edge_type edge_index = self[edge_type].edge_index for src_node, dst_node in edge_index.t().tolist(): self._union((src, src_node), (dst, dst_node)) # Rerun _find_parent to ensure all nodes are covered correctly for node_type in self.node_types: for node_index in range(self[node_type].num_nodes): self._find_parent((node_type, node_index)) # Group nodes by their representative parent components_map = defaultdict(list) for node, parent in self._parents.items(): components_map[parent].append(node) del self._parents del self._ranks components: List[Self] = [] for nodes in components_map.values(): # Prefill subset_dict with all node types to ensure all are present subset_dict = {node_type: [] for node_type in self.node_types} # Convert the list of (node_type, node_id) tuples to a subset_dict for node_type, node_id in nodes: subset_dict[node_type].append(node_id) # Convert lists to tensors for node_type, node_ids in subset_dict.items(): subset_dict[node_type] = torch.tensor(node_ids, dtype=torch.long) # Use the existing subgraph function to do all the heavy lifting component_data = self.subgraph(subset_dict) components.append(component_data) return components def debug(self): pass # TODO ########################################################################### def _to_canonical(self, *args: QueryType) -> NodeOrEdgeType: # Converts a given `QueryType` to its "canonical type": # 1. `relation_type` will get mapped to the unique # `(src_node_type, relation_type, dst_node_type)` tuple. # 2. `(src_node_type, dst_node_type)` will get mapped to the unique # `(src_node_type, *, dst_node_type)` tuple, and # `(src_node_type, 'to', dst_node_type)` otherwise. if len(args) == 1: args = args[0] if isinstance(args, str): node_types = [key for key in self.node_types if key == args] if len(node_types) == 1: args = node_types[0] return args # Try to map to edge type based on unique relation type: edge_types = [key for key in self.edge_types if key[1] == args] if len(edge_types) == 1: args = edge_types[0] return args elif len(args) == 2: # Try to find the unique source/destination node tuple: edge_types = [ key for key in self.edge_types if key[0] == args[0] and key[-1] == args[-1] ] if len(edge_types) == 1: args = edge_types[0] return args elif len(edge_types) == 0: args = (args[0], DEFAULT_REL, args[1]) return args return args def metadata(self) -> Tuple[List[NodeType], List[EdgeType]]: r"""Returns the heterogeneous meta-data, *i.e.* its node and edge types. .. code-block:: python data = HeteroData() data['paper'].x = ... data['author'].x = ... data['author', 'writes', 'paper'].edge_index = ... print(data.metadata()) >>> (['paper', 'author'], [('author', 'writes', 'paper')]) """ return self.node_types, self.edge_types def collect( self, key: str, allow_empty: bool = False, ) -> Dict[NodeOrEdgeType, Any]: r"""Collects the attribute :attr:`key` from all node and edge types. .. code-block:: python data = HeteroData() data['paper'].x = ... data['author'].x = ... print(data.collect('x')) >>> { 'paper': ..., 'author': ...} .. note:: This is equivalent to writing :obj:`data.x_dict`. Args: key (str): The attribute to collect from all node and edge types. allow_empty (bool, optional): If set to :obj:`True`, will not raise an error in case the attribute does not exit in any node or edge type. (default: :obj:`False`) """ mapping = {} for subtype, store in chain(self._node_store_dict.items(), self._edge_store_dict.items()): if hasattr(store, key): mapping[subtype] = getattr(store, key) if not allow_empty and len(mapping) == 0: raise KeyError(f"Tried to collect '{key}' but did not find any " f"occurrences of it in any node and/or edge type") return mapping def _check_type_name(self, name: str): global _DISPLAYED_TYPE_NAME_WARNING if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name: _DISPLAYED_TYPE_NAME_WARNING = True warnings.warn( f"There exist type names in the " f"'{self.__class__.__name__}' object that contain " f"double underscores '__' (e.g., '{name}'). This " f"may lead to unexpected behavior. To avoid any " f"issues, ensure that your type names only contain " f"single underscores.", stacklevel=2) def get_node_store(self, key: NodeType) -> NodeStorage: r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object of a particular node type :attr:`key`. If the storage is not present yet, will create a new :class:`torch_geometric.data.storage.NodeStorage` object for the given node type. .. code-block:: python data = HeteroData() node_storage = data.get_node_store('paper') """ out = self._node_store_dict.get(key, None) if out is None: self._check_type_name(key) out = NodeStorage(_parent=self, _key=key) self._node_store_dict[key] = out return out def get_edge_store(self, src: str, rel: str, dst: str) -> EdgeStorage: r"""Gets the :class:`~torch_geometric.data.storage.EdgeStorage` object of a particular edge type given by the tuple :obj:`(src, rel, dst)`. If the storage is not present yet, will create a new :class:`torch_geometric.data.storage.EdgeStorage` object for the given edge type. .. code-block:: python data = HeteroData() edge_storage = data.get_edge_store('author', 'writes', 'paper') """ key = (src, rel, dst) out = self._edge_store_dict.get(key, None) if out is None: self._check_type_name(rel) out = EdgeStorage(_parent=self, _key=key) self._edge_store_dict[key] = out return out def rename(self, name: NodeType, new_name: NodeType) -> Self: r"""Renames the node type :obj:`name` to :obj:`new_name` in-place.""" node_store = self._node_store_dict.pop(name) node_store._key = new_name self._node_store_dict[new_name] = node_store for edge_type in self.edge_types: src, rel, dst = edge_type if src == name or dst == name: edge_store = self._edge_store_dict.pop(edge_type) src = new_name if src == name else src dst = new_name if dst == name else dst edge_type = (src, rel, dst) edge_store._key = edge_type self._edge_store_dict[edge_type] = edge_store return self def subgraph(self, subset_dict: Dict[NodeType, Tensor]) -> Self: r"""Returns the induced subgraph containing the node types and corresponding nodes in :obj:`subset_dict`. If a node type is not a key in :obj:`subset_dict` then all nodes of that type remain in the graph. .. code-block:: python data = HeteroData() data['paper'].x = ... data['author'].x = ... data['conference'].x = ... data['paper', 'cites', 'paper'].edge_index = ... data['author', 'paper'].edge_index = ... data['paper', 'conference'].edge_index = ... print(data) >>> HeteroData( paper={ x=[10, 16] }, author={ x=[5, 32] }, conference={ x=[5, 8] }, (paper, cites, paper)={ edge_index=[2, 50] }, (author, to, paper)={ edge_index=[2, 30] }, (paper, to, conference)={ edge_index=[2, 25] } ) subset_dict = { 'paper': torch.tensor([3, 4, 5, 6]), 'author': torch.tensor([0, 2]), } print(data.subgraph(subset_dict)) >>> HeteroData( paper={ x=[4, 16] }, author={ x=[2, 32] }, conference={ x=[5, 8] }, (paper, cites, paper)={ edge_index=[2, 24] }, (author, to, paper)={ edge_index=[2, 5] }, (paper, to, conference)={ edge_index=[2, 10] } ) Args: subset_dict (Dict[str, LongTensor or BoolTensor]): A dictionary holding the nodes to keep for each node type. """ data = copy.copy(self) subset_dict = copy.copy(subset_dict) for node_type, subset in subset_dict.items(): for key, value in self[node_type].items(): if key == 'num_nodes': if subset.dtype == torch.bool: data[node_type].num_nodes = int(subset.sum()) else: data[node_type].num_nodes = subset.size(0) elif self[node_type].is_node_attr(key): data[node_type][key] = value[subset] else: data[node_type][key] = value for edge_type in self.edge_types: if 'edge_index' not in self[edge_type]: continue src, _, dst = edge_type src_subset = subset_dict.get(src) if src_subset is None: src_subset = torch.arange(data[src].num_nodes) dst_subset = subset_dict.get(dst) if dst_subset is None: dst_subset = torch.arange(data[dst].num_nodes) edge_index, _, edge_mask = bipartite_subgraph( (src_subset, dst_subset), self[edge_type].edge_index, relabel_nodes=True, size=(self[src].num_nodes, self[dst].num_nodes), return_edge_mask=True, ) for key, value in self[edge_type].items(): if key == 'edge_index': data[edge_type].edge_index = edge_index elif self[edge_type].is_edge_attr(key): data[edge_type][key] = value[edge_mask] else: data[edge_type][key] = value return data def edge_subgraph( self, subset_dict: Dict[EdgeType, Tensor], ) -> Self: r"""Returns the induced subgraph given by the edge indices in :obj:`subset_dict` for certain edge types. Will currently preserve all the nodes in the graph, even if they are isolated after subgraph computation. Args: subset_dict (Dict[Tuple[str, str, str], LongTensor or BoolTensor]): A dictionary holding the edges to keep for each edge type. """ data = copy.copy(self) for edge_type, subset in subset_dict.items(): edge_store, new_edge_store = self[edge_type], data[edge_type] for key, value in edge_store.items(): if edge_store.is_edge_attr(key): dim = self.__cat_dim__(key, value, edge_store) if subset.dtype == torch.bool: new_edge_store[key] = mask_select(value, dim, subset) else: new_edge_store[key] = value.index_select(dim, subset) return data def node_type_subgraph(self, node_types: List[NodeType]) -> Self: r"""Returns the subgraph induced by the given :obj:`node_types`, *i.e.* the returned :class:`HeteroData` object only contains the node types which are included in :obj:`node_types`, and only contains the edge types where both end points are included in :obj:`node_types`. """ data = copy.copy(self) for edge_type in self.edge_types: src, _, dst = edge_type if src not in node_types or dst not in node_types: del data[edge_type] for node_type in self.node_types: if node_type not in node_types: del data[node_type] return data def edge_type_subgraph(self, edge_types: List[EdgeType]) -> Self: r"""Returns the subgraph induced by the given :obj:`edge_types`, *i.e.* the returned :class:`HeteroData` object only contains the edge types which are included in :obj:`edge_types`, and only contains the node types of the end points which are included in :obj:`node_types`. """ edge_types = [self._to_canonical(e) for e in edge_types] data = copy.copy(self) for edge_type in self.edge_types: if edge_type not in edge_types: del data[edge_type] node_types = {e[0] for e in edge_types} node_types |= {e[-1] for e in edge_types} for node_type in self.node_types: if node_type not in node_types: del data[node_type] return data def to_homogeneous( self, node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, add_node_type: bool = True, add_edge_type: bool = True, dummy_values: bool = True, ) -> Data: """Converts a :class:`~torch_geometric.data.HeteroData` object to a homogeneous :class:`~torch_geometric.data.Data` object. By default, all features with same feature dimensionality across different types will be merged into a single representation, unless otherwise specified via the :obj:`node_attrs` and :obj:`edge_attrs` arguments. Furthermore, attributes named :obj:`node_type` and :obj:`edge_type` will be added to the returned :class:`~torch_geometric.data.Data` object, denoting node-level and edge-level vectors holding the node and edge type as integers, respectively. Args: node_attrs (List[str], optional): The node features to combine across all node types. These node features need to be of the same feature dimensionality. If set to :obj:`None`, will automatically determine which node features to combine. (default: :obj:`None`) edge_attrs (List[str], optional): The edge features to combine across all edge types. These edge features need to be of the same feature dimensionality. If set to :obj:`None`, will automatically determine which edge features to combine. (default: :obj:`None`) add_node_type (bool, optional): If set to :obj:`False`, will not add the node-level vector :obj:`node_type` to the returned :class:`~torch_geometric.data.Data` object. (default: :obj:`True`) add_edge_type (bool, optional): If set to :obj:`False`, will not add the edge-level vector :obj:`edge_type` to the returned :class:`~torch_geometric.data.Data` object. (default: :obj:`True`) dummy_values (bool, optional): If set to :obj:`True`, will fill attributes of remaining types with dummy values. Dummy values are :obj:`NaN` for floating point attributes, :obj:`False` for booleans, and :obj:`-1` for integers. (default: :obj:`True`) """ def get_sizes(stores: List[BaseStorage]) -> Dict[str, List[Tuple]]: sizes_dict = defaultdict(list) for store in stores: for key, value in store.items(): if key in [ 'edge_index', 'edge_label_index', 'adj', 'adj_t' ]: continue if isinstance(value, Tensor): dim = self.__cat_dim__(key, value, store) size = value.size()[:dim] + value.size()[dim + 1:] sizes_dict[key].append(tuple(size)) return sizes_dict def fill_dummy_(stores: List[BaseStorage], keys: Optional[List[str]] = None): sizes_dict = get_sizes(stores) if keys is not None: sizes_dict = { key: sizes for key, sizes in sizes_dict.items() if key in keys } sizes_dict = { key: sizes for key, sizes in sizes_dict.items() if len(set(sizes)) == 1 } for store in stores: # Fill stores with dummy features: for key, sizes in sizes_dict.items(): if key not in store: ref = list(self.collect(key).values())[0] dim = self.__cat_dim__(key, ref, store) if ref.is_floating_point(): dummy = float('NaN') elif ref.dtype == torch.bool: dummy = False else: dummy = -1 if isinstance(store, NodeStorage): dim_size = store.num_nodes else: dim_size = store.num_edges shape = sizes[0][:dim] + (dim_size, ) + sizes[0][dim:] store[key] = torch.full(shape, dummy, dtype=ref.dtype, device=ref.device) def _consistent_size(stores: List[BaseStorage]) -> List[str]: sizes_dict = get_sizes(stores) keys = [] for key, sizes in sizes_dict.items(): # The attribute needs to exist in all types: if len(sizes) != len(stores): continue # The attributes needs to have the same number of dimensions: lengths = {len(size) for size in sizes} if len(lengths) != 1: continue # The attributes needs to have the same size in all dimensions: if len(sizes[0]) != 1 and len(set(sizes)) != 1: continue keys.append(key) # Check for consistent column names in `TensorFrame`: tf_cols = defaultdict(list) for store in stores: for key, value in store.items(): if isinstance(value, TensorFrame): cols = tuple(chain(*value.col_names_dict.values())) tf_cols[key].append(cols) for key, cols in tf_cols.items(): # The attribute needs to exist in all types: if len(cols) != len(stores): continue # The attributes needs to have the same column names: lengths = set(cols) if len(lengths) != 1: continue keys.append(key) return keys if dummy_values: self = copy.copy(self) fill_dummy_(self.node_stores, node_attrs) fill_dummy_(self.edge_stores, edge_attrs) edge_index, node_slices, edge_slices = to_homogeneous_edge_index(self) device = edge_index.device if edge_index is not None else None data = Data(**self._global_store.to_dict()) if edge_index is not None: data.edge_index = edge_index data._node_type_names = list(node_slices.keys()) data._edge_type_names = list(edge_slices.keys()) # Combine node attributes into a single tensor: if node_attrs is None: node_attrs = _consistent_size(self.node_stores) for key in node_attrs: if key in {'ptr'}: continue values = [store[key] for store in self.node_stores] if isinstance(values[0], TensorFrame): value = torch_frame.cat(values, dim=0) else: dim = self.__cat_dim__(key, values[0], self.node_stores[0]) dim = values[0].dim() + dim if dim < 0 else dim # For two-dimensional features, we allow arbitrary shapes and # pad them with zeros if necessary in case their size doesn't # match: if values[0].dim() == 2 and dim == 0: _max = max([value.size(-1) for value in values]) for i, v in enumerate(values): if v.size(-1) < _max: pad = v.new_zeros(v.size(0), _max - v.size(-1)) values[i] = torch.cat([v, pad], dim=-1) value = torch.cat(values, dim) data[key] = value if not data.can_infer_num_nodes: data.num_nodes = list(node_slices.values())[-1][1] # Combine edge attributes into a single tensor: if edge_attrs is None: edge_attrs = _consistent_size(self.edge_stores) for key in edge_attrs: values = [store[key] for store in self.edge_stores] dim = self.__cat_dim__(key, values[0], self.edge_stores[0]) value = torch.cat(values, dim) if len(values) > 1 else values[0] data[key] = value if 'edge_label_index' in self: edge_label_index_dict = self.edge_label_index_dict for edge_type, edge_label_index in edge_label_index_dict.items(): edge_label_index = edge_label_index.clone() edge_label_index[0] += node_slices[edge_type[0]][0] edge_label_index[1] += node_slices[edge_type[-1]][0] edge_label_index_dict[edge_type] = edge_label_index data.edge_label_index = torch.cat( list(edge_label_index_dict.values()), dim=-1) if add_node_type: sizes = [offset[1] - offset[0] for offset in node_slices.values()] sizes = torch.tensor(sizes, dtype=torch.long, device=device) node_type = torch.arange(len(sizes), device=device) data.node_type = node_type.repeat_interleave(sizes) if add_edge_type and edge_index is not None: sizes = [offset[1] - offset[0] for offset in edge_slices.values()] sizes = torch.tensor(sizes, dtype=torch.long, device=device) edge_type = torch.arange(len(sizes), device=device) data.edge_type = edge_type.repeat_interleave(sizes) return data # FeatureStore interface ################################################## def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: if not attr.is_set('index'): attr.index = None out = self._node_store_dict.get(attr.group_name, None) if out: # Group name exists, handle index or create new attribute name: val = getattr(out, attr.attr_name, None) if val is not None: val[attr.index] = tensor else: assert attr.index is None setattr(self[attr.group_name], attr.attr_name, tensor) else: # No node storage found, just store tensor in new one: setattr(self[attr.group_name], attr.attr_name, tensor) return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: # Retrieve tensor and index accordingly: tensor = getattr(self[attr.group_name], attr.attr_name, None) if tensor is not None: # TODO this behavior is a bit odd, since TensorAttr requires that # we set `index`. So, we assume here that indexing by `None` is # equivalent to not indexing at all, which is not in line with # Python semantics. return tensor[attr.index] if attr.index is not None else tensor return None def _remove_tensor(self, attr: TensorAttr) -> bool: # Remove tensor entirely: if hasattr(self[attr.group_name], attr.attr_name): delattr(self[attr.group_name], attr.attr_name) return True return False def _get_tensor_size(self, attr: TensorAttr) -> Tuple: return self._get_tensor(attr).size() def get_all_tensor_attrs(self) -> List[TensorAttr]: out = [] for group_name, group in self.node_items(): for attr_name in group: if group.is_node_attr(attr_name): out.append(TensorAttr(group_name, attr_name)) return out # GraphStore interface #################################################### def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: if not hasattr(self, '_edge_attrs'): self._edge_attrs = {} self._edge_attrs[(edge_attr.edge_type, edge_attr.layout)] = edge_attr row, col = edge_index store = self[edge_attr.edge_type] if edge_attr.layout == EdgeLayout.COO: store.edge_index = torch.stack([row, col], dim=0) elif edge_attr.layout == EdgeLayout.CSR: store.adj = SparseTensor( rowptr=row, col=col, sparse_sizes=edge_attr.size, is_sorted=True, trust_data=True, ) else: # edge_attr.layout == EdgeLayout.CSC: size = edge_attr.size[::-1] if edge_attr.size is not None else None store.adj_t = SparseTensor( rowptr=col, col=row, sparse_sizes=size, is_sorted=True, trust_data=True, ) return True def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: r"""Gets an edge index from edge storage, in the specified layout.""" store = self[edge_attr.edge_type] edge_attrs = getattr(self, '_edge_attrs', {}) if (edge_attr.edge_type, edge_attr.layout) in edge_attrs: edge_attr = edge_attrs[(edge_attr.edge_type, edge_attr.layout)] if edge_attr.size is None: edge_attr.size = store.size() # Modify in-place. if edge_attr.layout == EdgeLayout.COO and 'edge_index' in store: row, col = store.edge_index return row, col elif edge_attr.layout == EdgeLayout.CSR and 'adj' in store: rowptr, col, _ = store.adj.csr() return rowptr, col elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in store: colptr, row, _ = store.adj_t.csr() return row, colptr return None def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: edge_type = edge_attr.edge_type store = self[edge_type] if edge_attr.layout == EdgeLayout.COO and 'edge_index' in store: del store.edge_index if hasattr(self, '_edge_attrs'): self._edge_attrs.pop((edge_type, EdgeLayout.COO), None) return True elif edge_attr.layout == EdgeLayout.CSR and 'adj' in store: del store.adj if hasattr(self, '_edge_attrs'): self._edge_attrs.pop((edge_type, EdgeLayout.CSR), None) return True elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in store: del store.adj_t if hasattr(self, '_edge_attrs'): self._edge_attrs.pop((edge_type, EdgeLayout.CSC), None) return True return False def get_all_edge_attrs(self) -> List[EdgeAttr]: edge_attrs = getattr(self, '_edge_attrs', {}) for store in self.edge_stores: if ('edge_index' in store and (store._key, EdgeLayout.COO) not in edge_attrs): edge_attrs[(store._key, EdgeLayout.COO)] = EdgeAttr( store._key, 'coo', is_sorted=False) if ('adj' in store and (store._key, EdgeLayout.CSR) not in edge_attrs): size = store.adj.sparse_sizes() edge_attrs[(store._key, EdgeLayout.CSR)] = EdgeAttr( store._key, 'csr', size=size) if ('adj_t' in store and (store._key, EdgeLayout.CSC) not in edge_attrs): size = store.adj_t.sparse_sizes()[::-1] edge_attrs[(store._key, EdgeLayout.CSC)] = EdgeAttr( store._key, 'csc', size=size) return list(edge_attrs.values()) # Connected Components Helper Functions ################################### def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]: r"""Finds and returns the representative parent of the given node in a disjoint-set (union-find) data structure. Implements path compression to optimize future queries. Args: node (tuple[str, int]): The node for which to find the parent. First element is the node type, second is the node index. Returns: tuple[str, int]: The representative parent of the node. """ if node not in self._parents: self._parents[node] = node self._ranks[node] = 0 if self._parents[node] != node: self._parents[node] = self._find_parent(self._parents[node]) return self._parents[node] def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]): r"""Merges the node1 and node2 in the disjoint-set data structure. Finds the root parents of node1 and node2 using the _find_parent method. If they belong to different sets, updates the parent of root2 to be root1, effectively merging the two sets. Args: node1 (Tuple[str, int]): The first node to union. First element is the node type, second is the node index. node2 (Tuple[str, int]): The second node to union. First element is the node type, second is the node index. """ root1 = self._find_parent(node1) root2 = self._find_parent(node2) if root1 != root2: if self._ranks[root1] < self._ranks[root2]: self._parents[root1] = root2 elif self._ranks[root1] > self._ranks[root2]: self._parents[root2] = root1 else: self._parents[root2] = root1 self._ranks[root1] += 1 # Helper functions ############################################################ def get_node_slices(num_nodes: Dict[str, int]) -> Dict[str, Tuple[int, int]]: r"""Returns the boundaries of each node type in a graph.""" node_slices: Dict[NodeType, Tuple[int, int]] = {} cumsum = 0 for node_type, N in num_nodes.items(): node_slices[node_type] = (cumsum, cumsum + N) cumsum += N return node_slices def offset_edge_index( node_slices: Dict[NodeType, Tuple[int, int]], edge_type: EdgeType, edge_index: Tensor, ) -> Tensor: r"""Increases the edge indices by the offsets of source and destination node types. """ src, _, dst = edge_type offset = [[node_slices[src][0]], [node_slices[dst][0]]] offset = torch.tensor(offset, device=edge_index.device) return edge_index + offset def to_homogeneous_edge_index( data: HeteroData, ) -> Tuple[Optional[Tensor], Dict[NodeType, Any], Dict[EdgeType, Any]]: r"""Converts a heterogeneous graph into a homogeneous typed graph.""" # Record slice information per node type: node_slices = get_node_slices(data.num_nodes_dict) # Record edge indices and slice information per edge type: cumsum = 0 edge_indices: List[Tensor] = [] edge_slices: Dict[EdgeType, Tuple[int, int]] = {} for edge_type, edge_index in data.collect('edge_index', True).items(): edge_index = offset_edge_index(node_slices, edge_type, edge_index) edge_indices.append(edge_index) edge_slices[edge_type] = (cumsum, cumsum + edge_index.size(1)) cumsum += edge_index.size(1) edge_index: Optional[Tensor] = None if len(edge_indices) == 1: # Memory-efficient `torch.cat`: edge_index = edge_indices[0] elif len(edge_indices) > 1: edge_index = torch.cat(edge_indices, dim=-1) return edge_index, node_slices, edge_slices ================================================ FILE: torch_geometric/data/hypergraph_data.py ================================================ import copy import warnings from typing import Any, List, Optional import torch from torch import Tensor from typing_extensions import Self from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType, NodeType, OptTensor from torch_geometric.utils import select from torch_geometric.utils._subgraph import hyper_subgraph class HyperGraphData(Data): r"""A data object describing a hypergraph. The data object can hold node-level, link-level and graph-level attributes. This object differs from a standard :obj:`~torch_geometric.data.Data` object by having hyperedges, i.e. edges that connect more than two nodes. For example, in the hypergraph scenario :math:`\mathcal{G} = (\mathcal{V}, \mathcal{E})` with :math:`\mathcal{V} = \{ 0, 1, 2, 3, 4 \}` and :math:`\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3, 4 \} \}`, the hyperedge index :obj:`edge_index` is represented as: .. code-block:: python # hyper graph with two hyperedges # connecting 3 and 4 nodes, respectively edge_index = torch.tensor([ [0, 1, 2, 1, 2, 3, 4], [0, 0, 0, 1, 1, 1, 1], ]) Args: x (torch.Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) edge_index (LongTensor, optional): Hyperedge tensor with shape :obj:`[2, num_edges*num_nodes_per_edge]`. Where `edge_index[1]` denotes the hyperedge index and `edge_index[0]` denotes the node indices that are connected by the hyperedge. (default: :obj:`None`) (default: :obj:`None`) edge_attr (torch.Tensor, optional): Edge feature matrix with shape :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) y (torch.Tensor, optional): Graph-level or node-level ground-truth labels with arbitrary shape. (default: :obj:`None`) pos (torch.Tensor, optional): Node position matrix with shape :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) **kwargs (optional): Additional attributes. """ def __init__( self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs: Any, ) -> None: super().__init__( x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos, **kwargs, ) @property def num_edges(self) -> int: r"""Returns the number of hyperedges in the hypergraph.""" if self.edge_index is None: return 0 return max(self.edge_index[1]) + 1 @property def num_nodes(self) -> Optional[int]: num_nodes = super().num_nodes # For hypergraphs, `edge_index[1]` does not contain node indices. # Therefore, the below code is used to prevent `num_nodes` being # estimated as the number of hyperedges. if (self.edge_index is not None and num_nodes == self.num_edges): return max(self.edge_index[0]) + 1 return num_nodes @num_nodes.setter def num_nodes(self, num_nodes: Optional[int]) -> None: self._store.num_nodes = num_nodes def is_edge_attr(self, key: str) -> bool: val = super().is_edge_attr(key) if not val and self.edge_index is not None: return key in self and self[key].size(0) == self.num_edges return val def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any: if key == 'edge_index': return torch.tensor([[self.num_nodes], [self.num_edges]]) else: return super().__inc__(key, value, *args, **kwargs) def subgraph(self, subset: Tensor) -> 'HyperGraphData': r"""Returns the induced subgraph given by the node indices :obj:`subset`. .. note:: If only a subset of a hyperedge's nodes are to be selected in the subgraph, the hyperedge will remain in the subgraph, but only the selected nodes will be connected by the hyperedge. Hyperedges that only connects one node in the subgraph will be removed. Examples: >>> x = torch.randn(4, 16) >>> edge_index = torch.tensor([ ... [0, 1, 0, 2, 1, 1, 2, 4], ... [0, 0, 1, 1, 1, 2, 2, 2] >>> ]) >>> data = HyperGraphData(x = x, edge_index = edge_index) >>> subset = torch.tensor([1, 2, 4]) >>> subgraph = data.subgraph(subset) >>> subgraph.edge_index tensor([[2, 1, 1, 2, 4], [0, 0, 1, 1, 1]]) Args: subset (LongTensor or BoolTensor): The nodes to keep. """ assert self.edge_index is not None out = hyper_subgraph(subset, self.edge_index, relabel_nodes=True, num_nodes=self.num_nodes, return_edge_mask=True) edge_index, _, edge_mask = out data = copy.copy(self) for key, value in self.items(): if key == 'edge_index': data.edge_index = edge_index elif key == 'num_nodes': if subset.dtype == torch.bool: data.num_nodes = int(subset.sum()) else: data.num_nodes = subset.size(0) elif self.is_node_attr(key): cat_dim = self.__cat_dim__(key, value) data[key] = select(value, subset, dim=cat_dim) elif self.is_edge_attr(key): cat_dim = self.__cat_dim__(key, value) data[key] = select(value, edge_mask, dim=cat_dim) return data def edge_subgraph(self, subset: Tensor) -> Self: raise NotImplementedError def to_heterogeneous( self, node_type: Optional[Tensor] = None, edge_type: Optional[Tensor] = None, node_type_names: Optional[List[NodeType]] = None, edge_type_names: Optional[List[EdgeType]] = None, ) -> HeteroData: raise NotImplementedError def has_isolated_nodes(self) -> bool: if self.edge_index is None: return False return torch.unique(self.edge_index[0]).size(0) < self.num_nodes def is_directed(self) -> bool: raise NotImplementedError def is_undirected(self) -> bool: raise NotImplementedError def has_self_loops(self) -> bool: raise NotImplementedError def validate(self, raise_on_error: bool = True) -> bool: r"""Validates the correctness of the data.""" cls_name = self.__class__.__name__ status = True num_nodes = self.num_nodes if num_nodes is None: status = False warn_or_raise(f"'num_nodes' is undefined in '{cls_name}'", raise_on_error) if self.edge_index is not None: if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: status = False warn_or_raise( f"'edge_index' needs to be of shape [2, num_edges] in " f"'{cls_name}' (found {self.edge_index.size()})", raise_on_error) if self.edge_index is not None and self.edge_index.numel() > 0: if self.edge_index.min() < 0: status = False warn_or_raise( f"'edge_index' contains negative indices in " f"'{cls_name}' (found {int(self.edge_index.min())})", raise_on_error) if num_nodes is not None and self.edge_index[0].max() >= num_nodes: status = False warn_or_raise( f"'edge_index' contains larger indices than the number " f"of nodes ({num_nodes}) in '{cls_name}' " f"(found {int(self.edge_index.max())})", raise_on_error) return status def warn_or_raise(msg: str, raise_on_error: bool = True) -> None: if raise_on_error: raise ValueError(msg) else: warnings.warn(msg, stacklevel=2) ================================================ FILE: torch_geometric/data/in_memory_dataset.py ================================================ import copy import os.path as osp import warnings from typing import ( Any, Callable, Dict, Iterable, List, Mapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, ) import torch from torch import Tensor from tqdm import tqdm import torch_geometric from torch_geometric.data import Batch, Data from torch_geometric.data.collate import collate from torch_geometric.data.data import BaseData from torch_geometric.data.dataset import Dataset, IndexType from torch_geometric.data.separate import separate from torch_geometric.io import fs class InMemoryDataset(Dataset): r"""Dataset base class for creating graph datasets which easily fit into CPU memory. See `here `__ for the accompanying tutorial. Args: root (str, optional): Root directory where the dataset should be saved. (optional: :obj:`None`) transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ @property def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: raise NotImplementedError @property def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: raise NotImplementedError def __init__( self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, log, force_reload) self._data: Optional[BaseData] = None self.slices: Optional[Dict[str, Tensor]] = None self._data_list: Optional[MutableSequence[Optional[BaseData]]] = None @property def num_classes(self) -> int: if self.transform is None: return self._infer_num_classes(self._data.y) return super().num_classes def len(self) -> int: if self.slices is None: return 1 for _, value in nested_iter(self.slices): return len(value) - 1 return 0 def get(self, idx: int) -> BaseData: # TODO (matthias) Avoid unnecessary copy here. if self.len() == 1: return copy.copy(self._data) if not hasattr(self, '_data_list') or self._data_list is None: self._data_list = self.len() * [None] elif self._data_list[idx] is not None: return copy.copy(self._data_list[idx]) data = separate( cls=self._data.__class__, batch=self._data, idx=idx, slice_dict=self.slices, decrement=False, ) self._data_list[idx] = copy.copy(data) return data @classmethod def save(cls, data_list: Sequence[BaseData], path: str) -> None: r"""Saves a list of data objects to the file path :obj:`path`.""" data, slices = cls.collate(data_list) fs.torch_save((data.to_dict(), slices, data.__class__), path) def load(self, path: str, data_cls: Type[BaseData] = Data) -> None: r"""Loads the dataset from the file path :obj:`path`.""" out = fs.torch_load(path) assert isinstance(out, tuple) assert len(out) == 2 or len(out) == 3 if len(out) == 2: # Backward compatibility. data, self.slices = out else: data, self.slices, data_cls = out if not isinstance(data, dict): # Backward compatibility. self.data = data else: self.data = data_cls.from_dict(data) @staticmethod def collate( data_list: Sequence[BaseData], ) -> Tuple[BaseData, Optional[Dict[str, Tensor]]]: r"""Collates a list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects to the internal storage format of :class:`~torch_geometric.data.InMemoryDataset`. """ if len(data_list) == 1: return data_list[0], None data, slices, _ = collate( data_list[0].__class__, data_list=data_list, increment=False, add_batch=False, ) return data, slices def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset': r"""Performs a deep-copy of the dataset. If :obj:`idx` is not given, will clone the full dataset. Otherwise, will only clone a subset of the dataset from indices :obj:`idx`. Indices can be slices, lists, tuples, and a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. """ if idx is None: data_list = [self.get(i) for i in self.indices()] else: data_list = [self.get(i) for i in self.index_select(idx).indices()] dataset = copy.copy(self) dataset._indices = None dataset._data_list = None dataset.data, dataset.slices = self.collate(data_list) return dataset def to_on_disk_dataset( self, root: Optional[str] = None, backend: str = 'sqlite', log: bool = True, ) -> 'torch_geometric.data.OnDiskDataset': r"""Converts the :class:`InMemoryDataset` to a :class:`OnDiskDataset` variant. Useful for distributed training and hardware instances with limited amount of shared memory. root (str, optional): Root directory where the dataset should be saved. If set to :obj:`None`, will save the dataset in :obj:`root/on_disk`. Note that it is important to specify :obj:`root` to account for different dataset splits. (optional: :obj:`None`) backend (str): The :class:`Database` backend to use. (default: :obj:`"sqlite"`) log (bool, optional): Whether to print any console output while processing the dataset. (default: :obj:`True`) """ if root is None and (self.root is None or not osp.exists(self.root)): raise ValueError(f"The root directory of " f"'{self.__class__.__name__}' is not specified. " f"Please pass in 'root' when creating on-disk " f"datasets from it.") root = root or osp.join(self.root, 'on_disk') in_memory_dataset = self ref_data = in_memory_dataset.get(0) if not isinstance(ref_data, Data): raise NotImplementedError( f"`{self.__class__.__name__}.to_on_disk_dataset()` is " f"currently only supported on homogeneous graphs") # Parse the schema ==================================================== schema: Dict[str, Any] = {} for key, value in ref_data.to_dict().items(): if isinstance(value, (int, float, str)): schema[key] = value.__class__ elif isinstance(value, Tensor) and value.dim() == 0: schema[key] = dict(dtype=value.dtype, size=(-1, )) elif isinstance(value, Tensor): size = list(value.size()) size[ref_data.__cat_dim__(key, value)] = -1 schema[key] = dict(dtype=value.dtype, size=tuple(size)) else: schema[key] = object # Create the on-disk dataset ========================================== class OnDiskDataset(torch_geometric.data.OnDiskDataset): def __init__( self, root: str, transform: Optional[Callable] = None, ): super().__init__( root=root, transform=transform, backend=backend, schema=schema, ) def process(self): _iter = [ in_memory_dataset.get(i) for i in in_memory_dataset.indices() ] if log: # pragma: no cover _iter = tqdm(_iter, desc='Converting to OnDiskDataset') data_list: List[Data] = [] for i, data in enumerate(_iter): data_list.append(data) if i + 1 == len(in_memory_dataset) or (i + 1) % 1000 == 0: self.extend(data_list) data_list = [] def serialize(self, data: Data) -> Dict[str, Any]: return data.to_dict() def deserialize(self, data: Dict[str, Any]) -> Data: return Data.from_dict(data) def __repr__(self) -> str: arg_repr = str(len(self)) if len(self) > 1 else '' return (f'OnDisk{in_memory_dataset.__class__.__name__}(' f'{arg_repr})') return OnDiskDataset(root, transform=in_memory_dataset.transform) @property def data(self) -> Any: msg1 = ("It is not recommended to directly access the internal " "storage format `data` of an 'InMemoryDataset'.") msg2 = ("The given 'InMemoryDataset' only references a subset of " "examples of the full dataset, but 'data' will contain " "information of the full dataset.") msg3 = ("The data of the dataset is already cached, so any " "modifications to `data` will not be reflected when accessing " "its elements. Clearing the cache now by removing all " "elements in `dataset._data_list`.") msg4 = ("If you are absolutely certain what you are doing, access the " "internal storage via `InMemoryDataset._data` instead to " "suppress this warning. Alternatively, you can access stacked " "individual attributes of every graph via " "`dataset.{attr_name}`.") msg = msg1 if self._indices is not None: msg += f' {msg2}' if self._data_list is not None: msg += f' {msg3}' self._data_list = None msg += f' {msg4}' warnings.warn(msg, stacklevel=2) return self._data @data.setter def data(self, value: Any): self._data = value self._data_list = None def __getattr__(self, key: str) -> Any: data = self.__dict__.get('_data') if isinstance(data, Data) and key in data: if self._indices is None and data.__inc__(key, data[key]) == 0: return data[key] else: data_list = [self.get(i) for i in self.indices()] return Batch.from_data_list(data_list)[key] raise AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{key}'") def to(self, device: Union[int, str]) -> 'InMemoryDataset': r"""Performs device conversion of the whole dataset.""" if self._indices is not None: raise ValueError("The given 'InMemoryDataset' only references a " "subset of examples of the full dataset") if self._data_list is not None: raise ValueError("The data of the dataset is already cached") self._data.to(device) return self def cpu(self, *args: str) -> 'InMemoryDataset': r"""Moves the dataset to CPU memory.""" return self.to(torch.device('cpu')) def cuda( self, device: Optional[Union[int, str]] = None, ) -> 'InMemoryDataset': r"""Moves the dataset toto CUDA memory.""" if isinstance(device, int): device = f'cuda:{int}' elif device is None: device = 'cuda' return self.to(device) def nested_iter(node: Union[Mapping, Sequence]) -> Iterable: if isinstance(node, Mapping): for value in node.values(): yield from nested_iter(value) elif isinstance(node, Sequence): yield from enumerate(node) else: yield None, node ================================================ FILE: torch_geometric/data/lightning/__init__.py ================================================ from .datamodule import LightningDataset, LightningNodeData, LightningLinkData __all__ = classes = [ 'LightningDataset', 'LightningNodeData', 'LightningLinkData', ] ================================================ FILE: torch_geometric/data/lightning/datamodule.py ================================================ import copy import inspect import warnings from typing import Any, Dict, Optional, Tuple, Type, Union import torch from torch_geometric.data import Data, Dataset, HeteroData from torch_geometric.loader import DataLoader, LinkLoader, NodeLoader from torch_geometric.sampler import BaseSampler, NeighborSampler from torch_geometric.typing import InputEdges, InputNodes, OptTensor try: from lightning.pytorch import LightningDataModule as _LightningDataModule _pl_is_available = True except ImportError: try: from pytorch_lightning import \ LightningDataModule as _LightningDataModule _pl_is_available = True except ImportError: _pl_is_available = False _LightningDataModule = object class LightningDataModule(_LightningDataModule): def __init__(self, has_val: bool, has_test: bool, **kwargs: Any) -> None: super().__init__() if not _pl_is_available: raise ModuleNotFoundError( "No module named 'pytorch_lightning' (or 'lightning') found " "in your Python environment. Run 'pip install " "pytorch_lightning' or 'pip install lightning'") if not has_val: self.val_dataloader = None # type: ignore if not has_test: self.test_dataloader = None # type: ignore kwargs.setdefault('batch_size', 1) kwargs.setdefault('num_workers', 0) kwargs.setdefault('pin_memory', True) kwargs.setdefault('persistent_workers', kwargs.get('num_workers', 0) > 0) if 'shuffle' in kwargs: warnings.warn( f"The 'shuffle={kwargs['shuffle']}' option is " f"ignored in '{self.__class__.__name__}'. Remove it " f"from the argument list to disable this warning", stacklevel=2) del kwargs['shuffle'] self.kwargs = kwargs def __repr__(self) -> str: return f'{self.__class__.__name__}({kwargs_repr(**self.kwargs)})' class LightningData(LightningDataModule): def __init__( self, data: Union[Data, HeteroData], has_val: bool, has_test: bool, loader: str = 'neighbor', graph_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: kwargs.setdefault('batch_size', 1) kwargs.setdefault('num_workers', 0) if graph_sampler is not None: loader = 'custom' # For full-batch training, we use reasonable defaults for a lot of # data-loading options: if loader not in ['full', 'neighbor', 'link_neighbor', 'custom']: raise ValueError(f"Undefined 'loader' option (got '{loader}')") if loader == 'full' and kwargs['batch_size'] != 1: warnings.warn( f"Re-setting 'batch_size' to 1 in " f"'{self.__class__.__name__}' for loader='full' " f"(got '{kwargs['batch_size']}')", stacklevel=2) kwargs['batch_size'] = 1 if loader == 'full' and kwargs['num_workers'] != 0: warnings.warn( f"Re-setting 'num_workers' to 0 in " f"'{self.__class__.__name__}' for loader='full' " f"(got '{kwargs['num_workers']}')", stacklevel=2) kwargs['num_workers'] = 0 if loader == 'full' and kwargs.get('sampler') is not None: warnings.warn( "'sampler' option is not supported for " "loader='full'", stacklevel=2) kwargs.pop('sampler', None) if loader == 'full' and kwargs.get('batch_sampler') is not None: warnings.warn( "'batch_sampler' option is not supported for " "loader='full'", stacklevel=2) kwargs.pop('batch_sampler', None) super().__init__(has_val, has_test, **kwargs) if loader == 'full': if kwargs.get('pin_memory', False): warnings.warn( f"Re-setting 'pin_memory' to 'False' in " f"'{self.__class__.__name__}' for loader='full' " f"(got 'True')", stacklevel=2) self.kwargs['pin_memory'] = False self.data = data self.loader = loader # Determine sampler and loader arguments ############################## if loader in ['neighbor', 'link_neighbor']: # Define a new `NeighborSampler` to be re-used across data loaders: sampler_kwargs, self.loader_kwargs = split_kwargs( self.kwargs, NeighborSampler, ) sampler_kwargs.setdefault('share_memory', self.kwargs['num_workers'] > 0) self.graph_sampler: BaseSampler = NeighborSampler( data, **sampler_kwargs) elif graph_sampler is not None: sampler_kwargs, self.loader_kwargs = split_kwargs( self.kwargs, graph_sampler.__class__, ) if len(sampler_kwargs) > 0: warnings.warn( f"Ignoring the arguments " f"{list(sampler_kwargs.keys())} in " f"'{self.__class__.__name__}' since a custom " f"'graph_sampler' was passed", stacklevel=2) self.graph_sampler = graph_sampler else: assert loader == 'full' self.loader_kwargs = self.kwargs # Determine validation sampler and loader arguments ################### self.eval_loader_kwargs = copy.copy(self.loader_kwargs) if eval_loader_kwargs is not None: # If the user wants to override certain values during evaluation, # we shallow-copy the graph sampler and update its attributes. if hasattr(self, 'graph_sampler'): self.eval_graph_sampler = copy.copy(self.graph_sampler) eval_sampler_kwargs, eval_loader_kwargs = split_kwargs( eval_loader_kwargs, self.graph_sampler.__class__, ) for key, value in eval_sampler_kwargs.items(): setattr(self.eval_graph_sampler, key, value) self.eval_loader_kwargs.update(eval_loader_kwargs) elif hasattr(self, 'graph_sampler'): self.eval_graph_sampler = self.graph_sampler self.eval_loader_kwargs.pop('sampler', None) self.eval_loader_kwargs.pop('batch_sampler', None) if 'batch_sampler' in self.loader_kwargs: self.loader_kwargs.pop('batch_size', None) @property def train_shuffle(self) -> bool: shuffle = self.loader_kwargs.get('sampler', None) is None shuffle &= self.loader_kwargs.get('batch_sampler', None) is None return shuffle def prepare_data(self) -> None: if self.loader == 'full': assert self.trainer is not None try: num_devices = self.trainer.num_devices except AttributeError: # PyTorch Lightning < 1.6 backward compatibility: num_devices = self.trainer.num_processes # type: ignore num_gpus = self.trainer.num_gpus # type: ignore num_devices = max(num_devices, num_gpus) if num_devices > 1: raise ValueError( f"'{self.__class__.__name__}' with loader='full' requires " f"training on a single device") super().prepare_data() def full_dataloader(self, **kwargs: Any) -> torch.utils.data.DataLoader: warnings.filterwarnings('ignore', '.*does not have many workers.*') warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') return torch.utils.data.DataLoader( [self.data], # type: ignore collate_fn=lambda xs: xs[0], **kwargs, ) def __repr__(self) -> str: kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) return f'{self.__class__.__name__}({kwargs})' class LightningDataset(LightningDataModule): r"""Converts a set of :class:`~torch_geometric.data.Dataset` objects into a :class:`pytorch_lightning.LightningDataModule` variant. It can then be automatically used as a :obj:`datamodule` for multi-GPU graph-level training via :lightning:`null` `PyTorch Lightning `__. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.DataLoader`. .. note:: Currently only the :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and :class:`pytorch_lightning.strategies.DDPStrategy` training strategies of :lightning:`null` `PyTorch Lightning `__ are supported in order to correctly share data across all devices/processes: .. code-block:: python import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule) Args: train_dataset (Dataset): The training dataset. val_dataset (Dataset, optional): The validation dataset. (default: :obj:`None`) test_dataset (Dataset, optional): The test dataset. (default: :obj:`None`) pred_dataset (Dataset, optional): The prediction dataset. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.DataLoader`. """ def __init__( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, pred_dataset: Optional[Dataset] = None, **kwargs: Any, ) -> None: super().__init__( has_val=val_dataset is not None, has_test=test_dataset is not None, **kwargs, ) self.train_dataset = train_dataset self.val_dataset = val_dataset self.test_dataset = test_dataset self.pred_dataset = pred_dataset def dataloader(self, dataset: Dataset, **kwargs: Any) -> DataLoader: return DataLoader(dataset, **kwargs) def train_dataloader(self) -> DataLoader: from torch.utils.data import IterableDataset shuffle = not isinstance(self.train_dataset, IterableDataset) shuffle &= self.kwargs.get('sampler', None) is None shuffle &= self.kwargs.get('batch_sampler', None) is None return self.dataloader( self.train_dataset, shuffle=shuffle, **self.kwargs, ) def val_dataloader(self) -> DataLoader: assert self.val_dataset is not None kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) return self.dataloader(self.val_dataset, shuffle=False, **kwargs) def test_dataloader(self) -> DataLoader: assert self.test_dataset is not None kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) return self.dataloader(self.test_dataset, shuffle=False, **kwargs) def predict_dataloader(self) -> DataLoader: assert self.pred_dataset is not None kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) return self.dataloader(self.pred_dataset, shuffle=False, **kwargs) def __repr__(self) -> str: kwargs = kwargs_repr( train_dataset=self.train_dataset, val_dataset=self.val_dataset, test_dataset=self.test_dataset, pred_dataset=self.pred_dataset, **self.kwargs, ) return f'{self.__class__.__name__}({kwargs})' class LightningNodeData(LightningData): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant. It can then be automatically used as a :obj:`datamodule` for multi-GPU node-level training via :lightning:`null` `PyTorch Lightning `__. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.NeighborLoader`. .. note:: Currently only the :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and :class:`pytorch_lightning.strategies.DDPStrategy` training strategies of :lightning:`null` `PyTorch Lightning `__ are supported in order to correctly share data across all devices/processes: .. code-block:: python import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule) Args: data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of training nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`train_mask`, :obj:`train_idx`, or :obj:`train_index` attributes. (default: :obj:`None`) input_train_time (torch.Tensor, optional): The timestamp of training nodes. (default: :obj:`None`) input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of validation nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`val_mask`, :obj:`valid_mask`, :obj:`val_idx`, :obj:`valid_idx`, :obj:`val_index`, or :obj:`valid_index` attributes. (default: :obj:`None`) input_val_time (torch.Tensor, optional): The timestamp of validation edges. (default: :obj:`None`) input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of test nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`test_mask`, :obj:`test_idx`, or :obj:`test_index` attributes. (default: :obj:`None`) input_test_time (torch.Tensor, optional): The timestamp of test nodes. (default: :obj:`None`) input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of prediction nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`pred_mask`, :obj:`pred_idx`, or :obj:`pred_index` attributes. (default: :obj:`None`) input_pred_time (torch.Tensor, optional): The timestamp of prediction nodes. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"neighbor"`). (default: :obj:`"neighbor"`) node_sampler (BaseSampler, optional): A custom sampler object to generate mini-batches. If set, will ignore the :obj:`loader` option. (default: :obj:`None`) eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments that override the :class:`torch_geometric.loader.NeighborLoader` configuration during evaluation. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.NeighborLoader`. """ def __init__( self, data: Union[Data, HeteroData], input_train_nodes: InputNodes = None, input_train_time: OptTensor = None, input_val_nodes: InputNodes = None, input_val_time: OptTensor = None, input_test_nodes: InputNodes = None, input_test_time: OptTensor = None, input_pred_nodes: InputNodes = None, input_pred_time: OptTensor = None, loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: if input_train_nodes is None: input_train_nodes = infer_input_nodes(data, split='train') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='val') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='valid') if input_test_nodes is None: input_test_nodes = infer_input_nodes(data, split='test') if input_pred_nodes is None: input_pred_nodes = infer_input_nodes(data, split='pred') super().__init__( data=data, has_val=input_val_nodes is not None, has_test=input_test_nodes is not None, loader=loader, graph_sampler=node_sampler, eval_loader_kwargs=eval_loader_kwargs, **kwargs, ) self.input_train_nodes = input_train_nodes self.input_train_time = input_train_time self.input_train_id: OptTensor = None self.input_val_nodes = input_val_nodes self.input_val_time = input_val_time self.input_val_id: OptTensor = None self.input_test_nodes = input_test_nodes self.input_test_time = input_test_time self.input_test_id: OptTensor = None self.input_pred_nodes = input_pred_nodes self.input_pred_time = input_pred_time self.input_pred_id: OptTensor = None def dataloader( self, input_nodes: InputNodes, input_time: OptTensor = None, input_id: OptTensor = None, node_sampler: Optional[BaseSampler] = None, **kwargs: Any, ) -> torch.utils.data.DataLoader: if self.loader == 'full': return self.full_dataloader(**kwargs) assert node_sampler is not None return NodeLoader( self.data, node_sampler=node_sampler, input_nodes=input_nodes, input_time=input_time, input_id=input_id, **kwargs, ) def train_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_train_nodes, self.input_train_time, self.input_train_id, node_sampler=getattr(self, 'graph_sampler', None), shuffle=self.train_shuffle, **self.loader_kwargs, ) def val_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_val_nodes, self.input_val_time, self.input_val_id, node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def test_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_test_nodes, self.input_test_time, self.input_test_id, node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def predict_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_pred_nodes, self.input_pred_time, self.input_pred_id, node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) class LightningLinkData(LightningData): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant. It can then be automatically used as a :obj:`datamodule` for multi-GPU link-level training via :lightning:`null` `PyTorch Lightning `__. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.LinkNeighborLoader`. .. note:: Currently only the :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and :class:`pytorch_lightning.strategies.DDPStrategy` training strategies of :lightning:`null` `PyTorch Lightning `__ are supported in order to correctly share data across all devices/processes: .. code-block:: python import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule) Args: data (Data or HeteroData or Tuple[FeatureStore, GraphStore]): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object, or a tuple of a :class:`~torch_geometric.data.FeatureStore` and :class:`~torch_geometric.data.GraphStore` objects. input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The training edges. (default: :obj:`None`) input_train_labels (torch.Tensor, optional): The labels of training edges. (default: :obj:`None`) input_train_time (torch.Tensor, optional): The timestamp of training edges. (default: :obj:`None`) input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The validation edges. (default: :obj:`None`) input_val_labels (torch.Tensor, optional): The labels of validation edges. (default: :obj:`None`) input_val_time (torch.Tensor, optional): The timestamp of validation edges. (default: :obj:`None`) input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The test edges. (default: :obj:`None`) input_test_labels (torch.Tensor, optional): The labels of test edges. (default: :obj:`None`) input_test_time (torch.Tensor, optional): The timestamp of test edges. (default: :obj:`None`) input_pred_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The prediction edges. (default: :obj:`None`) input_pred_labels (torch.Tensor, optional): The labels of prediction edges. (default: :obj:`None`) input_pred_time (torch.Tensor, optional): The timestamp of prediction edges. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"neighbor"`). (default: :obj:`"neighbor"`) link_sampler (BaseSampler, optional): A custom sampler object to generate mini-batches. If set, will ignore the :obj:`loader` option. (default: :obj:`None`) eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments that override the :class:`torch_geometric.loader.LinkNeighborLoader` configuration during evaluation. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.LinkNeighborLoader`. """ def __init__( self, data: Union[Data, HeteroData], input_train_edges: InputEdges = None, input_train_labels: OptTensor = None, input_train_time: OptTensor = None, input_val_edges: InputEdges = None, input_val_labels: OptTensor = None, input_val_time: OptTensor = None, input_test_edges: InputEdges = None, input_test_labels: OptTensor = None, input_test_time: OptTensor = None, input_pred_edges: InputEdges = None, input_pred_labels: OptTensor = None, input_pred_time: OptTensor = None, loader: str = 'neighbor', link_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: super().__init__( data=data, has_val=input_val_edges is not None, has_test=input_test_edges is not None, loader=loader, graph_sampler=link_sampler, eval_loader_kwargs=eval_loader_kwargs, **kwargs, ) self.input_train_edges = input_train_edges self.input_train_labels = input_train_labels self.input_train_time = input_train_time self.input_train_id: OptTensor = None self.input_val_edges = input_val_edges self.input_val_labels = input_val_labels self.input_val_time = input_val_time self.input_val_id: OptTensor = None self.input_test_edges = input_test_edges self.input_test_labels = input_test_labels self.input_test_time = input_test_time self.input_test_id: OptTensor = None self.input_pred_edges = input_pred_edges self.input_pred_labels = input_pred_labels self.input_pred_time = input_pred_time self.input_pred_id: OptTensor = None def dataloader( self, input_edges: InputEdges, input_labels: OptTensor = None, input_time: OptTensor = None, input_id: OptTensor = None, link_sampler: Optional[BaseSampler] = None, **kwargs: Any, ) -> torch.utils.data.DataLoader: if self.loader == 'full': return self.full_dataloader(**kwargs) assert link_sampler is not None return LinkLoader( self.data, link_sampler=link_sampler, edge_label_index=input_edges, edge_label=input_labels, edge_label_time=input_time, input_id=input_id, **kwargs, ) def train_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_train_edges, self.input_train_labels, self.input_train_time, self.input_train_id, link_sampler=getattr(self, 'graph_sampler', None), shuffle=self.train_shuffle, **self.loader_kwargs, ) def val_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_val_edges, self.input_val_labels, self.input_val_time, self.input_val_id, link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def test_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_test_edges, self.input_test_labels, self.input_test_time, self.input_test_id, link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def predict_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_pred_edges, self.input_pred_labels, self.input_pred_time, self.input_pred_id, link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) ############################################################################### # TODO Support Tuple[FeatureStore, GraphStore] def infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes: attr_name: Optional[str] = None if f'{split}_mask' in data: attr_name = f'{split}_mask' elif f'{split}_idx' in data: attr_name = f'{split}_idx' elif f'{split}_index' in data: attr_name = f'{split}_index' if attr_name is None: return None if isinstance(data, Data): return data[attr_name] if isinstance(data, HeteroData): input_nodes_dict = { node_type: store[attr_name] for node_type, store in data.node_items() if attr_name in store } if len(input_nodes_dict) != 1: raise ValueError(f"Could not automatically determine the input " f"nodes of {data} since there exists multiple " f"types with attribute '{attr_name}'") return list(input_nodes_dict.items())[0] return None def kwargs_repr(**kwargs: Any) -> str: return ', '.join([f'{k}={v}' for k, v in kwargs.items() if v is not None]) def split_kwargs( kwargs: Dict[str, Any], sampler_cls: Type, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: r"""Splits keyword arguments into sampler and loader arguments.""" sampler_args = inspect.signature(sampler_cls).parameters sampler_kwargs: Dict[str, Any] = {} loader_kwargs: Dict[str, Any] = {} for key, value in kwargs.items(): if key in sampler_args: sampler_kwargs[key] = value else: loader_kwargs[key] = value return sampler_kwargs, loader_kwargs ================================================ FILE: torch_geometric/data/makedirs.py ================================================ from torch_geometric.deprecation import deprecated from torch_geometric.io import fs @deprecated("use 'os.makedirs(path, exist_ok=True)' instead") def makedirs(path: str): r"""Recursively creates a directory. .. warning:: :meth:`makedirs` is deprecated and will be removed soon. Please use :obj:`os.makedirs(path, exist_ok=True)` instead. Args: path (str): The path to create. """ fs.makedirs(path, exist_ok=True) ================================================ FILE: torch_geometric/data/on_disk_dataset.py ================================================ import os from typing import Any, Callable, Iterable, List, Optional, Sequence, Union from torch import Tensor from torch_geometric.data import Database, RocksDatabase, SQLiteDatabase from torch_geometric.data.data import BaseData from torch_geometric.data.database import Schema from torch_geometric.data.dataset import Dataset class OnDiskDataset(Dataset): r"""Dataset base class for creating large graph datasets which do not easily fit into CPU memory at once by leveraging a :class:`Database` backend for on-disk storage and access of data objects. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) backend (str): The :class:`Database` backend to use (one of :obj:`"sqlite"` or :obj:`"rocksdb"`). (default: :obj:`"sqlite"`) schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of the input data. Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a dictionary with :obj:`dtype` and :obj:`size` keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. If specified to anything different than :obj:`object`, implementations of :class:`OnDiskDataset` need to override :meth:`serialize` and :meth:`deserialize` methods. (default: :obj:`object`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) """ BACKENDS = { 'sqlite': SQLiteDatabase, 'rocksdb': RocksDatabase, } def __init__( self, root: str, transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, backend: str = 'sqlite', schema: Schema = object, log: bool = True, ) -> None: if backend not in self.BACKENDS: raise ValueError(f"Database backend must be one of " f"{set(self.BACKENDS.keys())} " f"(got '{backend}')") self.backend = backend self.schema = schema self._db: Optional[Database] = None self._numel: Optional[int] = None super().__init__(root, transform, pre_filter=pre_filter, log=log) @property def processed_file_names(self) -> str: return f'{self.backend}.db' @property def db(self) -> Database: r"""Returns the underlying :class:`Database`.""" if self._db is not None: return self._db kwargs = {} cls = self.BACKENDS[self.backend] if issubclass(cls, SQLiteDatabase): kwargs['name'] = self.__class__.__name__ os.makedirs(self.processed_dir, exist_ok=True) path = self.processed_paths[0] self._db = cls(path=path, schema=self.schema, **kwargs) self._numel = len(self._db) return self._db def close(self) -> None: r"""Closes the connection to the underlying database.""" if self._db is not None: self._db.close() def serialize(self, data: BaseData) -> Any: r"""Serializes the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into the expected DB schema. """ if self.schema == object: return data raise NotImplementedError(f"`{self.__class__.__name__}.serialize()` " f"needs to be overridden in case a " f"non-default schema was passed") def deserialize(self, data: Any) -> BaseData: r"""Deserializes the DB entry into a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object. """ if self.schema == object: return data raise NotImplementedError(f"`{self.__class__.__name__}.deserialize()` " f"needs to be overridden in case a " f"non-default schema was passed") def append(self, data: BaseData) -> None: r"""Appends the data object to the dataset.""" index = len(self) self.db.insert(index, self.serialize(data)) self._numel += 1 def extend( self, data_list: Sequence[BaseData], batch_size: Optional[int] = None, ) -> None: r"""Extends the dataset by a list of data objects.""" start = len(self) end = start + len(data_list) data_list = [self.serialize(data) for data in data_list] self.db.multi_insert(range(start, end), data_list, batch_size) self._numel += (end - start) def get(self, idx: int) -> BaseData: r"""Gets the data object at index :obj:`idx`.""" return self.deserialize(self.db.get(idx)) def multi_get( self, indices: Union[Iterable[int], Tensor, slice, range], batch_size: Optional[int] = None, ) -> List[BaseData]: r"""Gets a list of data objects from the specified indices.""" if len(indices) == 1: data_list = [self.db.get(indices[0])] else: data_list = self.db.multi_get(indices, batch_size) data_list = [self.deserialize(data) for data in data_list] if self.transform is not None: data_list = [self.transform(data) for data in data_list] return data_list def __getitems__(self, indices: List[int]) -> List[BaseData]: return self.multi_get(indices) def len(self) -> int: if self._numel is None: self._numel = len(self.db) return self._numel def __repr__(self) -> str: return f'{self.__class__.__name__}({len(self)})' ================================================ FILE: torch_geometric/data/remote_backend_utils.py ================================================ # This file defines a set of utilities for remote backends (backends that are # characterize as Tuple[FeatureStore, GraphStore]). TODO support for # non-heterogeneous graphs (feature stores with a group_name=None). from typing import Optional, Tuple, Union, overload from torch_geometric.data import FeatureStore, GraphStore from torch_geometric.typing import EdgeType, NodeType @overload def _internal_num_nodes( feature_store: FeatureStore, graph_store: GraphStore, query: NodeType, ) -> int: pass @overload def _internal_num_nodes( feature_store: FeatureStore, graph_store: GraphStore, query: EdgeType, ) -> Tuple[int, int]: pass # NOTE PyG also supports querying by a relation type `rel` in an edge type # (src, rel, dst). It may be worth supporting this in remote backends as well. def _internal_num_nodes( feature_store: FeatureStore, graph_store: GraphStore, query: Union[NodeType, EdgeType], ) -> Union[int, Tuple[int, int]]: r"""Returns the number of nodes in the node type or the number of source and destination nodes in an edge type by sequentially accessing attributes in the feature and graph stores that reveal this number. """ def _matches_node_type( query: Union[NodeType, EdgeType], node_type: Optional[NodeType], ) -> bool: if isinstance(query, (list, tuple)): # EdgeType: return query[0] == node_type or query[-1] == node_type else: return query == node_type node_query = isinstance(query, NodeType) # TODO: In general, a feature store and graph store should be able to # expose methods that allow for easy access to individual attributes, # instead of requiring iteration to identify a particular attribute. # Implementing this should reduce the iteration below. # 1. Check the edges in the GraphStore, for each node type in each edge: num_rows = num_cols = None for edge_attr in graph_store.get_all_edge_attrs(): if edge_attr.size is None: continue if _matches_node_type(query, edge_attr.edge_type[0]): num_rows = num_rows or edge_attr.size[0] if _matches_node_type(query, edge_attr.edge_type[-1]): num_cols = num_cols or edge_attr.size[-1] if node_query and num_rows is not None: return num_rows if node_query and num_cols is not None: return num_cols if not node_query and num_rows is not None and num_cols is not None: return num_rows, num_cols # 2. Check the node types stored in the FeatureStore: tensor_attrs = feature_store.get_all_tensor_attrs() matching_attrs = [ attr for attr in tensor_attrs if _matches_node_type(query, attr.group_name) ] if node_query: if len(matching_attrs) > 0: size = feature_store.get_tensor_size(matching_attrs[0]) if size is not None: return size[0] else: matching_src_attrs = [ attr for attr in matching_attrs if attr.group_name == query[0] ] matching_dst_attrs = [ attr for attr in matching_attrs if attr.group_name == query[-1] ] if len(matching_src_attrs) > 0 and len(matching_dst_attrs) > 0: src_size = feature_store.get_tensor_size(matching_src_attrs[0]) dst_size = feature_store.get_tensor_size(matching_dst_attrs[0]) if src_size is not None and dst_size is not None: return src_size[0], dst_size[0] raise ValueError( f"Unable to accurately infer the number of nodes corresponding to " f"query {query} from feature store {feature_store} and graph store " f"{graph_store}. Please consider either adding an edge containing " f"the nodes in this query or feature tensors for the nodes in this " f"query.") def num_nodes( feature_store: FeatureStore, graph_store: GraphStore, query: NodeType, ) -> int: r"""Returns the number of nodes in a given node type stored in a remote backend. """ return _internal_num_nodes(feature_store, graph_store, query) def size( feature_store: FeatureStore, graph_store: GraphStore, query: EdgeType, ) -> Tuple[int, int]: r"""Returns the size of an edge (number of source nodes, number of destination nodes) in an edge stored in a remote backend. """ return _internal_num_nodes(feature_store, graph_store, query) ================================================ FILE: torch_geometric/data/separate.py ================================================ from collections.abc import Mapping, Sequence from typing import Any, Type, TypeVar from torch import Tensor from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage from torch_geometric.typing import SparseTensor, TensorFrame from torch_geometric.utils import narrow T = TypeVar('T') def separate( cls: Type[T], batch: Any, idx: int, slice_dict: Any, inc_dict: Any = None, decrement: bool = True, ) -> T: # Separates the individual element from a `batch` at index `idx`. # `separate` can handle both homogeneous and heterogeneous data objects by # individually separating all their stores. # In addition, `separate` can handle nested data structures such as # dictionaries and lists. data = cls().stores_as(batch) # Iterate over each storage object and recursively separate its attributes: for batch_store, data_store in zip(batch.stores, data.stores): key = batch_store._key if key is not None: # Heterogeneous: attrs = slice_dict[key].keys() else: # Homogeneous: attrs = set(batch_store.keys()) attrs = [attr for attr in slice_dict.keys() if attr in attrs] for attr in attrs: if key is not None: slices = slice_dict[key][attr] incs = inc_dict[key][attr] if decrement else None else: slices = slice_dict[attr] incs = inc_dict[attr] if decrement else None data_store[attr] = _separate(attr, batch_store[attr], idx, slices, incs, batch, batch_store, decrement) # The `num_nodes` attribute needs special treatment, as we cannot infer # the real number of nodes from the total number of nodes alone: if hasattr(batch_store, '_num_nodes'): data_store.num_nodes = batch_store._num_nodes[idx] return data def _separate( key: str, values: Any, idx: int, slices: Any, incs: Any, batch: BaseData, store: BaseStorage, decrement: bool, ) -> Any: if isinstance(values, Tensor): # Narrow a `torch.Tensor` based on `slices`. # NOTE: We need to take care of decrementing elements appropriately. key = str(key) cat_dim = batch.__cat_dim__(key, values, store) start, end = int(slices[idx]), int(slices[idx + 1]) value = narrow(values, cat_dim or 0, start, end - start) value = value.squeeze(0) if cat_dim is None else value if isinstance(values, Index) and values._cat_metadata is not None: # Reconstruct original `Index` metadata: value._dim_size = values._cat_metadata.dim_size[idx] value._is_sorted = values._cat_metadata.is_sorted[idx] if isinstance(values, EdgeIndex) and values._cat_metadata is not None: # Reconstruct original `EdgeIndex` metadata: value._sparse_size = values._cat_metadata.sparse_size[idx] value._sort_order = values._cat_metadata.sort_order[idx] value._is_undirected = values._cat_metadata.is_undirected[idx] if (decrement and incs is not None and (incs.dim() > 1 or int(incs[idx]) != 0)): value = value - incs[idx].to(value.device) return value elif isinstance(values, SparseTensor) and decrement: # Narrow a `SparseTensor` based on `slices`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. key = str(key) cat_dim = batch.__cat_dim__(key, values, store) cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim for i, dim in enumerate(cat_dims): start, end = int(slices[idx][i]), int(slices[idx + 1][i]) values = values.narrow(dim, start, end - start) return values elif isinstance(values, TensorFrame): key = str(key) start, end = int(slices[idx]), int(slices[idx + 1]) value = values[start:end] return value elif isinstance(values, Mapping): # Recursively separate elements of dictionaries. return { key: _separate( key, value, idx, slices=slices[key], incs=incs[key] if decrement else None, batch=batch, store=store, decrement=decrement, ) for key, value in values.items() } elif (isinstance(values, Sequence) and isinstance(values[0], Sequence) and not isinstance(values[0], str) and len(values[0]) > 0 and isinstance(values[0][0], (Tensor, SparseTensor)) and isinstance(slices, Sequence)): # Recursively separate elements of lists of lists. return [value[idx] for value in values] elif (isinstance(values, Sequence) and not isinstance(values, str) and isinstance(values[0], (Tensor, SparseTensor)) and isinstance(slices, Sequence)): # Recursively separate elements of lists of Tensors/SparseTensors. return [ _separate( key, value, idx, slices=slices[i], incs=incs[i] if decrement else None, batch=batch, store=store, decrement=decrement, ) for i, value in enumerate(values) ] else: return values[idx] ================================================ FILE: torch_geometric/data/storage.py ================================================ import copy import warnings import weakref from collections import defaultdict, namedtuple from collections.abc import Mapping, MutableMapping, Sequence from enum import Enum from typing import ( Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Set, Tuple, Union, overload, ) import numpy as np import torch from torch import Tensor from typing_extensions import Self from torch_geometric import EdgeIndex from torch_geometric.data.view import ItemsView, KeysView, ValuesView from torch_geometric.typing import ( EdgeType, NodeType, SparseTensor, TensorFrame, ) from torch_geometric.utils import ( coalesce, contains_isolated_nodes, is_torch_sparse_tensor, is_undirected, select, sort_edge_index, ) N_KEYS = {'x', 'feat', 'pos', 'batch', 'node_type', 'n_id', 'tf'} E_KEYS = {'edge_index', 'edge_weight', 'edge_attr', 'edge_type', 'e_id'} class AttrType(Enum): NODE = 'NODE' EDGE = 'EDGE' OTHER = 'OTHER' class BaseStorage(MutableMapping): # This class wraps a Python dictionary and extends it as follows: # 1. It allows attribute assignments, e.g.: # `storage.x = ...` in addition to `storage['x'] = ...` # 2. It allows private attributes that are not exposed to the user, e.g.: # `storage._{key} = ...` and accessible via `storage._{key}` # 3. It holds an (optional) weak reference to its parent object, e.g.: # `storage._parent = weakref.ref(parent)` # 4. It allows iterating over only a subset of keys, e.g.: # `storage.values('x', 'y')` or `storage.items('x', 'y') # 5. It adds additional PyTorch Tensor functionality, e.g.: # `storage.cpu()`, `storage.cuda()` or `storage.share_memory_()`. def __init__( self, _mapping: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: super().__init__() self._mapping: Dict[str, Any] = {} for key, value in (_mapping or {}).items(): setattr(self, key, value) for key, value in kwargs.items(): setattr(self, key, value) @property def _key(self) -> Any: return None def _pop_cache(self, key: str) -> None: for cache in getattr(self, '_cached_attr', {}).values(): cache.discard(key) def __len__(self) -> int: return len(self._mapping) def __getattr__(self, key: str) -> Any: if key == '_mapping': self._mapping = {} return self._mapping try: return self[key] except KeyError: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{key}'" ) from None def __setattr__(self, key: str, value: Any) -> None: propobj = getattr(self.__class__, key, None) if propobj is not None and getattr(propobj, 'fset', None) is not None: propobj.fset(self, value) elif key == '_parent': self.__dict__[key] = weakref.ref(value) elif key[:1] == '_': self.__dict__[key] = value else: self[key] = value def __delattr__(self, key: str) -> None: if key[:1] == '_': del self.__dict__[key] else: del self[key] def __getitem__(self, key: str) -> Any: return self._mapping[key] def __setitem__(self, key: str, value: Any) -> None: self._pop_cache(key) if value is None and key in self._mapping: del self._mapping[key] elif value is not None: self._mapping[key] = value def __delitem__(self, key: str) -> None: if key in self._mapping: self._pop_cache(key) del self._mapping[key] def __iter__(self) -> Iterator[Any]: return iter(self._mapping) def __copy__(self) -> Self: out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): if key != '_cached_attr': out.__dict__[key] = value out._mapping = copy.copy(out._mapping) return out def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> Self: out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = value out._mapping = copy.deepcopy(out._mapping, memo) return out def __getstate__(self) -> Dict[str, Any]: out = self.__dict__.copy() _parent = out.get('_parent', None) if _parent is not None: out['_parent'] = _parent() return out def __setstate__(self, mapping: Dict[str, Any]) -> None: for key, value in mapping.items(): self.__dict__[key] = value _parent = self.__dict__.get('_parent', None) if _parent is not None: self.__dict__['_parent'] = weakref.ref(_parent) def __repr__(self) -> str: return repr(self._mapping) # Allow iterating over subsets ############################################ # In contrast to standard `keys()`, `values()` and `items()` functions of # Python dictionaries, we allow to only iterate over a subset of items # denoted by a list of keys `args`. # This is especially useful for adding PyTorch Tensor functionality to the # storage object, e.g., in case we only want to transfer a subset of keys # to the GPU (i.e. the ones that are relevant to the deep learning model). def keys(self, *args: str) -> KeysView: # type: ignore return KeysView(self._mapping, *args) def values(self, *args: str) -> ValuesView: # type: ignore return ValuesView(self._mapping, *args) def items(self, *args: str) -> ItemsView: # type: ignore return ItemsView(self._mapping, *args) def apply_(self, func: Callable, *args: str) -> Self: r"""Applies the in-place function :obj:`func`, either to all attributes or only the ones given in :obj:`*args`. """ for value in self.values(*args): recursive_apply_(value, func) return self def apply(self, func: Callable, *args: str) -> Self: r"""Applies the function :obj:`func`, either to all attributes or only the ones given in :obj:`*args`. """ for key, value in self.items(*args): self[key] = recursive_apply(value, func) return self # Additional functionality ################################################ def get(self, key: str, value: Optional[Any] = None) -> Any: return self._mapping.get(key, value) def to_dict(self) -> Dict[str, Any]: r"""Returns a dictionary of stored key/value pairs.""" out_dict = copy.copy(self._mapping) # Needed to preserve individual `num_nodes` attributes when calling # `BaseData.collate`. # TODO (matthias) Try to make this more generic. if '_num_nodes' in self.__dict__: out_dict['_num_nodes'] = self.__dict__['_num_nodes'] return out_dict def to_namedtuple(self) -> NamedTuple: r"""Returns a :obj:`NamedTuple` of stored key/value pairs.""" field_names = list(self.keys()) typename = f'{self.__class__.__name__}Tuple' StorageTuple = namedtuple(typename, field_names) # type: ignore return StorageTuple(*[self[key] for key in field_names]) def clone(self, *args: str) -> Self: r"""Performs a deep-copy of the object.""" return copy.deepcopy(self) def contiguous(self, *args: str) -> Self: r"""Ensures a contiguous memory layout, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.contiguous(), *args) def to( self, device: Union[int, str], *args: str, non_blocking: bool = False, ) -> Self: r"""Performs tensor dtype and/or device conversion, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply( lambda x: x.to(device=device, non_blocking=non_blocking), *args) def cpu(self, *args: str) -> Self: r"""Copies attributes to CPU memory, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.cpu(), *args) def cuda( self, device: Optional[Union[int, str]] = None, *args: str, non_blocking: bool = False, ) -> Self: # pragma: no cover r"""Copies attributes to CUDA memory, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.cuda(device, non_blocking=non_blocking), *args) def pin_memory(self, *args: str) -> Self: r"""Copies attributes to pinned memory, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.pin_memory(), *args) def share_memory_(self, *args: str) -> Self: r"""Moves attributes to shared memory, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.share_memory_(), *args) def detach_(self, *args: str) -> Self: r"""Detaches attributes from the computation graph, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.detach_(), *args) def detach(self, *args: str) -> Self: r"""Detaches attributes from the computation graph by creating a new tensor, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply(lambda x: x.detach(), *args) def requires_grad_(self, *args: str, requires_grad: bool = True) -> Self: r"""Tracks gradient computation, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply( lambda x: x.requires_grad_(requires_grad=requires_grad), *args) def record_stream(self, stream: torch.cuda.Stream, *args: str) -> Self: r"""Ensures that the tensor memory is not reused for another tensor until all current work queued on :obj:`stream` has been completed, either for all attributes or only the ones given in :obj:`*args`. """ return self.apply_(lambda x: x.record_stream(stream), *args) # Time Handling ########################################################### def _cat_dims(self, keys: Iterable[str]) -> Dict[str, int]: return { key: self._parent().__cat_dim__(key, self[key], self) for key in keys } def _select( self, keys: Iterable[str], index_or_mask: Tensor, ) -> Self: for key, dim in self._cat_dims(keys).items(): self[key] = select(self[key], index_or_mask, dim) return self def concat(self, other: Self) -> Self: if not (set(self.keys()) == set(other.keys())): raise AttributeError('Given storage is not compatible') for key, dim in self._cat_dims(self.keys()).items(): value1 = self[key] value2 = other[key] if key in {'num_nodes', 'num_edges'}: self[key] = value1 + value2 elif isinstance(value1, list): self[key] = value1 + value2 elif isinstance(value1, Tensor): self[key] = torch.cat([value1, value2], dim=dim) else: raise NotImplementedError( f"'{self.__class__.__name__}.concat' not yet implemented " f"for '{type(value1)}'") return self def is_sorted_by_time(self) -> bool: if 'time' in self: return bool(torch.all(self.time[:-1] <= self.time[1:])) return True def sort_by_time(self) -> Self: if self.is_sorted_by_time(): return self if 'time' in self: _, perm = torch.sort(self.time, stable=True) if self.is_node_attr('time'): keys = self.node_attrs() elif self.is_edge_attr('time'): keys = self.edge_attrs() self._select(keys, perm) return self def snapshot( self, start_time: Union[float, int], end_time: Union[float, int], attr: str = 'time', ) -> Self: if attr in self: time = self[attr] mask = (time >= start_time) & (time <= end_time) if self.is_node_attr(attr): keys = self.node_attrs() elif self.is_edge_attr(attr): keys = self.edge_attrs() self._select(keys, mask) if self.is_node_attr(attr) and 'num_nodes' in self: self.num_nodes: Optional[int] = int(mask.sum()) return self def up_to(self, time: Union[float, int]) -> Self: if 'time' in self: return self.snapshot(self.time.min().item(), time) return self class NodeStorage(BaseStorage): r"""A storage for node-level information.""" @property def _key(self) -> NodeType: key = self.__dict__.get('_key', None) if key is None or not isinstance(key, str): raise ValueError("'_key' does not denote a valid node type") return key @property def can_infer_num_nodes(self) -> bool: keys = set(self.keys()) num_node_keys = { 'num_nodes', 'x', 'pos', 'batch', 'adj', 'adj_t', 'edge_index', 'face' } if len(keys & num_node_keys) > 0: return True elif len([key for key in keys if 'node' in key]) > 0: return True else: return False @property def num_nodes(self) -> Optional[int]: # We sequentially access attributes that reveal the number of nodes. if 'num_nodes' in self: return self['num_nodes'] for key, value in self.items(): if isinstance(value, Tensor) and key in N_KEYS: cat_dim = self._parent().__cat_dim__(key, value, self) return value.size(cat_dim) if isinstance(value, np.ndarray) and key in N_KEYS: cat_dim = self._parent().__cat_dim__(key, value, self) return value.shape[cat_dim] if isinstance(value, TensorFrame) and key in N_KEYS: return value.num_rows for key, value in self.items(): if isinstance(value, Tensor) and 'node' in key: cat_dim = self._parent().__cat_dim__(key, value, self) return value.size(cat_dim) if isinstance(value, np.ndarray) and 'node' in key: cat_dim = self._parent().__cat_dim__(key, value, self) return value.shape[cat_dim] if isinstance(value, TensorFrame) and 'node' in key: return value.num_rows if 'edge_index' in self and isinstance(self.edge_index, EdgeIndex): if self.edge_index.sparse_size(0) is not None: return self.edge_index.sparse_size(0) if self.edge_index.sparse_size(1) is not None: return self.edge_index.sparse_size(1) if 'adj' in self and isinstance(self.adj, (Tensor, SparseTensor)): return self.adj.size(0) if 'adj_t' in self and isinstance(self.adj_t, (Tensor, SparseTensor)): return self.adj_t.size(1) warnings.warn( f"Unable to accurately infer 'num_nodes' from the attribute set " f"'{set(self.keys())}'. Please explicitly set 'num_nodes' as an " f"attribute of " + ("'data'" if self._key is None else f"'data[{self._key}]'") + " to suppress this warning", stacklevel=2) if 'edge_index' in self and isinstance(self.edge_index, Tensor): if self.edge_index.numel() > 0: return int(self.edge_index.max()) + 1 return 0 if 'face' in self and isinstance(self.face, Tensor): if self.face.numel() > 0: return int(self.face.max()) + 1 return 0 return None @num_nodes.setter def num_nodes(self, num_nodes: Optional[int]) -> None: self['num_nodes'] = num_nodes @property def num_node_features(self) -> int: x: Optional[Any] = self.get('x') if isinstance(x, Tensor): return 1 if x.dim() == 1 else x.size(-1) if isinstance(x, np.ndarray): return 1 if x.ndim == 1 else x.shape[-1] if isinstance(x, SparseTensor): return 1 if x.dim() == 1 else x.size(-1) if isinstance(x, TensorFrame): return x.num_cols tf: Optional[Any] = self.get('tf') if isinstance(tf, TensorFrame): return tf.num_cols return 0 @property def num_features(self) -> int: return self.num_node_features def is_node_attr(self, key: str) -> bool: if '_cached_attr' not in self.__dict__: self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set) if key in self._cached_attr[AttrType.NODE]: return True if key in self._cached_attr[AttrType.OTHER]: return False value = self[key] if (isinstance(value, (list, tuple, TensorFrame)) and len(value) == self.num_nodes): self._cached_attr[AttrType.NODE].add(key) return True if not isinstance(value, (Tensor, np.ndarray)): self._cached_attr[AttrType.OTHER].add(key) return False if value.ndim == 0: self._cached_attr[AttrType.OTHER].add(key) return False cat_dim = self._parent().__cat_dim__(key, value, self) if value.shape[cat_dim] != self.num_nodes: self._cached_attr[AttrType.OTHER].add(key) return False self._cached_attr[AttrType.NODE].add(key) return True def is_edge_attr(self, key: str) -> bool: return False def node_attrs(self) -> List[str]: return [key for key in self.keys() if self.is_node_attr(key)] class EdgeStorage(BaseStorage): r"""A storage for edge-level information. We support multiple ways to store edge connectivity in a :class:`EdgeStorage` object: * :obj:`edge_index`: A :class:`torch.LongTensor` holding edge indices in COO format with shape :obj:`[2, num_edges]` (the default format) * :obj:`adj`: A :class:`torch_sparse.SparseTensor` holding edge indices in a sparse format, supporting both COO and CSR format. * :obj:`adj_t`: A **transposed** :class:`torch_sparse.SparseTensor` holding edge indices in a sparse format, supporting both COO and CSR format. This is the most efficient one for graph-based deep learning models as indices are sorted based on target nodes. """ @property def _key(self) -> EdgeType: key = self.__dict__.get('_key', None) if key is None or not isinstance(key, tuple) or not len(key) == 3: raise ValueError("'_key' does not denote a valid edge type") return key @property def edge_index(self) -> Tensor: if 'edge_index' in self: return self['edge_index'] if 'adj' in self and isinstance(self.adj, SparseTensor): return torch.stack(self.adj.coo()[:2], dim=0) if 'adj_t' in self and isinstance(self.adj_t, SparseTensor): return torch.stack(self.adj_t.coo()[:2][::-1], dim=0) raise AttributeError( f"'{self.__class__.__name__}' object has no attribute " f"'edge_index', 'adj' or 'adj_t'") @edge_index.setter def edge_index(self, edge_index: Optional[Tensor]) -> None: self['edge_index'] = edge_index @property def num_edges(self) -> int: # We sequentially access attributes that reveal the number of edges. if 'num_edges' in self: return self['num_edges'] for key, value in self.items(): if isinstance(value, Tensor) and key in E_KEYS: cat_dim = self._parent().__cat_dim__(key, value, self) return value.size(cat_dim) if isinstance(value, np.ndarray) and key in E_KEYS: cat_dim = self._parent().__cat_dim__(key, value, self) return value.shape[cat_dim] if isinstance(value, TensorFrame) and key in E_KEYS: return value.num_rows for key, value in self.items(): if isinstance(value, Tensor) and 'edge' in key: cat_dim = self._parent().__cat_dim__(key, value, self) return value.size(cat_dim) if isinstance(value, np.ndarray) and 'edge' in key: cat_dim = self._parent().__cat_dim__(key, value, self) return value.shape[cat_dim] if isinstance(value, TensorFrame) and 'edge' in key: return value.num_rows for value in self.values('adj', 'adj_t'): if isinstance(value, SparseTensor): return value.nnz() elif is_torch_sparse_tensor(value): return value._nnz() return 0 @property def num_edge_features(self) -> int: edge_attr: Optional[Any] = self.get('edge_attr') if isinstance(edge_attr, Tensor): return 1 if edge_attr.dim() == 1 else edge_attr.size(-1) if isinstance(edge_attr, np.ndarray): return 1 if edge_attr.ndim == 1 else edge_attr.shape[-1] if isinstance(edge_attr, TensorFrame): return edge_attr.num_cols return 0 @property def num_features(self) -> int: return self.num_edge_features @overload def size(self) -> Tuple[Optional[int], Optional[int]]: pass @overload def size(self, dim: int) -> Optional[int]: pass def size( self, dim: Optional[int] = None ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: if self._key is None: raise NameError("Unable to infer 'size' without explicit " "'_key' assignment") size = (self._parent()[self._key[0]].num_nodes, self._parent()[self._key[-1]].num_nodes) return size if dim is None else size[dim] def is_node_attr(self, key: str) -> bool: return False def is_edge_attr(self, key: str) -> bool: if '_cached_attr' not in self.__dict__: self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set) if key in self._cached_attr[AttrType.EDGE]: return True if key in self._cached_attr[AttrType.OTHER]: return False value = self[key] if (isinstance(value, (list, tuple, TensorFrame)) and len(value) == self.num_edges): self._cached_attr[AttrType.EDGE].add(key) return True if not isinstance(value, (Tensor, np.ndarray)): self._cached_attr[AttrType.OTHER].add(key) return False if value.ndim == 0: self._cached_attr[AttrType.OTHER].add(key) return False cat_dim = self._parent().__cat_dim__(key, value, self) if value.shape[cat_dim] != self.num_edges: self._cached_attr[AttrType.OTHER].add(key) return False self._cached_attr[AttrType.EDGE].add(key) return True def edge_attrs(self) -> List[str]: return [key for key in self.keys() if self.is_edge_attr(key)] def is_sorted(self, sort_by_row: bool = True) -> bool: if 'edge_index' in self: index = self.edge_index[0] if sort_by_row else self.edge_index[1] return bool(torch.all(index[:-1] <= index[1:])) return True def sort(self, sort_by_row: bool = True) -> Self: if 'edge_index' in self: edge_attrs = self.edge_attrs() edge_attrs.remove('edge_index') edge_feats = [self[edge_attr] for edge_attr in edge_attrs] self.edge_index, edge_feats = sort_edge_index( self.edge_index, edge_feats, sort_by_row=sort_by_row) for key, edge_feat in zip(edge_attrs, edge_feats): self[key] = edge_feat return self def is_coalesced(self) -> bool: for value in self.values('adj', 'adj_t'): return value.is_coalesced() if 'edge_index' in self: size = [s for s in self.size() if s is not None] num_nodes = max(size) if len(size) > 0 else None new_edge_index = coalesce(self.edge_index, num_nodes=num_nodes) return (self.edge_index.numel() == new_edge_index.numel() and torch.equal(self.edge_index, new_edge_index)) return True def coalesce(self, reduce: str = 'sum') -> Self: for key, value in self.items('adj', 'adj_t'): self[key] = value.coalesce(reduce) if 'edge_index' in self: size = [s for s in self.size() if s is not None] num_nodes = max(size) if len(size) > 0 else None self.edge_index, self.edge_attr = coalesce( self.edge_index, edge_attr=self.get('edge_attr'), num_nodes=num_nodes, ) return self def has_isolated_nodes(self) -> bool: edge_index, num_nodes = self.edge_index, self.size(1) if num_nodes is None: raise NameError("Unable to infer 'num_nodes'") if self.is_bipartite(): return torch.unique(edge_index[1]).numel() < num_nodes else: return contains_isolated_nodes(edge_index, num_nodes) def has_self_loops(self) -> bool: if self.is_bipartite(): return False edge_index = self.edge_index return int((edge_index[0] == edge_index[1]).sum()) > 0 def is_undirected(self) -> bool: if self.is_bipartite(): return False for value in self.values('adj', 'adj_t'): return value.is_symmetric() edge_index = self.edge_index edge_attr = self.edge_attr if 'edge_attr' in self else None return is_undirected(edge_index, edge_attr, num_nodes=self.size(0)) def is_directed(self) -> bool: return not self.is_undirected() def is_bipartite(self) -> bool: return self._key is not None and self._key[0] != self._key[-1] class GlobalStorage(NodeStorage, EdgeStorage): r"""A storage for both node-level and edge-level information.""" @property def _key(self) -> Any: return None @property def num_features(self) -> int: return self.num_node_features @overload def size(self) -> Tuple[Optional[int], Optional[int]]: pass @overload def size(self, dim: int) -> Optional[int]: pass def size( self, dim: Optional[int] = None ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: size = (self.num_nodes, self.num_nodes) return size if dim is None else size[dim] def is_node_attr(self, key: str) -> bool: if '_cached_attr' not in self.__dict__: self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set) if key in self._cached_attr[AttrType.NODE]: return True if key in self._cached_attr[AttrType.EDGE]: return False if key in self._cached_attr[AttrType.OTHER]: return False value = self[key] if (isinstance(value, (list, tuple, TensorFrame)) and len(value) == self.num_nodes): self._cached_attr[AttrType.NODE].add(key) return True if not isinstance(value, (Tensor, np.ndarray)): return False if value.ndim == 0: self._cached_attr[AttrType.OTHER].add(key) return False cat_dim = self._parent().__cat_dim__(key, value, self) if not isinstance(cat_dim, int): return False num_nodes, num_edges = self.num_nodes, self.num_edges if value.shape[cat_dim] != num_nodes: if value.shape[cat_dim] == num_edges: self._cached_attr[AttrType.EDGE].add(key) else: self._cached_attr[AttrType.OTHER].add(key) return False if num_nodes != num_edges: self._cached_attr[AttrType.NODE].add(key) return True if 'edge' not in key: self._cached_attr[AttrType.NODE].add(key) return True else: self._cached_attr[AttrType.EDGE].add(key) return False def is_edge_attr(self, key: str) -> bool: if '_cached_attr' not in self.__dict__: self._cached_attr = defaultdict(set) if key in self._cached_attr[AttrType.EDGE]: return True if key in self._cached_attr[AttrType.NODE]: return False if key in self._cached_attr[AttrType.OTHER]: return False value = self[key] if (isinstance(value, (list, tuple, TensorFrame)) and len(value) == self.num_edges): self._cached_attr[AttrType.EDGE].add(key) return True if not isinstance(value, (Tensor, np.ndarray)): return False if value.ndim == 0: self._cached_attr[AttrType.OTHER].add(key) return False cat_dim = self._parent().__cat_dim__(key, value, self) if not isinstance(cat_dim, int): return False num_nodes, num_edges = self.num_nodes, self.num_edges if value.shape[cat_dim] != num_edges: if value.shape[cat_dim] == num_nodes: self._cached_attr[AttrType.NODE].add(key) else: self._cached_attr[AttrType.OTHER].add(key) return False if num_edges != num_nodes: self._cached_attr[AttrType.EDGE].add(key) return True if 'edge' in key: self._cached_attr[AttrType.EDGE].add(key) return True else: self._cached_attr[AttrType.NODE].add(key) return False def recursive_apply_(data: Any, func: Callable) -> Any: if isinstance(data, Tensor): func(data) elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple for value in data: recursive_apply_(value, func) elif isinstance(data, Sequence) and not isinstance(data, str): for value in data: recursive_apply_(value, func) elif isinstance(data, Mapping): for value in data.values(): recursive_apply_(value, func) else: try: func(data) except Exception: pass def recursive_apply(data: Any, func: Callable) -> Any: if isinstance(data, Tensor): return func(data) elif isinstance(data, torch.nn.utils.rnn.PackedSequence): return func(data) elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple return type(data)(*(recursive_apply(d, func) for d in data)) elif isinstance(data, Sequence) and not isinstance(data, str): return [recursive_apply(d, func) for d in data] elif isinstance(data, Mapping): return {key: recursive_apply(data[key], func) for key in data} else: try: return func(data) except Exception: return data ================================================ FILE: torch_geometric/data/summary.py ================================================ from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Union import torch from tqdm import tqdm from typing_extensions import Self from torch_geometric.data import Dataset, HeteroData from torch_geometric.typing import EdgeType, NodeType @dataclass class Stats: mean: float std: float min: float quantile25: float median: float quantile75: float max: float @classmethod def from_data( cls, data: Union[List[int], List[float], torch.Tensor], ) -> Self: if not isinstance(data, torch.Tensor): data = torch.tensor(data) data = data.to(torch.float) return cls( mean=data.mean().item(), std=data.std().item(), min=data.min().item(), quantile25=data.quantile(0.25).item(), median=data.median().item(), quantile75=data.quantile(0.75).item(), max=data.max().item(), ) @dataclass(repr=False) class Summary: name: str num_graphs: int num_nodes: Stats num_edges: Stats num_nodes_per_type: Optional[Dict[NodeType, Stats]] = None num_edges_per_type: Optional[Dict[EdgeType, Stats]] = None @classmethod def from_dataset( cls, dataset: Dataset, progress_bar: Optional[bool] = None, per_type: bool = True, ) -> Self: r"""Creates a summary of a :class:`~torch_geometric.data.Dataset` object. Args: dataset (Dataset): The dataset. progress_bar (bool, optional): If set to :obj:`True`, will show a progress bar during stats computation. If set to :obj:`None`, will automatically decide whether to show a progress bar based on dataset size. (default: :obj:`None`) per_type (bool, optional): If set to :obj:`True`, will separate statistics per node and edge type (only applicable in heterogeneous graph datasets). (default: :obj:`True`) """ name = dataset.__class__.__name__ if progress_bar is None: progress_bar = len(dataset) >= 10000 if progress_bar: dataset = tqdm(dataset) num_nodes, num_edges = [], [] _num_nodes_per_type = defaultdict(list) _num_edges_per_type = defaultdict(list) for data in dataset: assert data.num_nodes is not None num_nodes.append(data.num_nodes) num_edges.append(data.num_edges) if per_type and isinstance(data, HeteroData): for node_type in data.node_types: _num_nodes_per_type[node_type].append( data[node_type].num_nodes) for edge_type in data.edge_types: _num_edges_per_type[edge_type].append( data[edge_type].num_edges) num_nodes_per_type = None if len(_num_nodes_per_type) > 0: num_nodes_per_type = { node_type: Stats.from_data(num_nodes_list) for node_type, num_nodes_list in _num_nodes_per_type.items() } num_edges_per_type = None if len(_num_edges_per_type) > 0: num_edges_per_type = { edge_type: Stats.from_data(num_edges_list) for edge_type, num_edges_list in _num_edges_per_type.items() } return cls( name=name, num_graphs=len(dataset), num_nodes=Stats.from_data(num_nodes), num_edges=Stats.from_data(num_edges), num_nodes_per_type=num_nodes_per_type, num_edges_per_type=num_edges_per_type, ) def format(self, fmt: str = "psql") -> str: r"""Formats summary statistics of the dataset. Args: fmt (str, optional): Summary tables format. Available table formats can be found `here `__. (default: :obj:`"psql"`) """ from tabulate import tabulate body = f'{self.name} (#graphs={self.num_graphs}):\n' content = [['', '#nodes', '#edges']] stats = [self.num_nodes, self.num_edges] for field in Stats.__dataclass_fields__: row = [field] + [f'{getattr(s, field):.1f}' for s in stats] content.append(row) body += tabulate(content, headers='firstrow', tablefmt=fmt) if self.num_nodes_per_type is not None: content = [['']] content[0] += list(self.num_nodes_per_type.keys()) for field in Stats.__dataclass_fields__: row = [field] + [ f'{getattr(s, field):.1f}' for s in self.num_nodes_per_type.values() ] content.append(row) body += "\nNumber of nodes per node type:\n" body += tabulate(content, headers='firstrow', tablefmt=fmt) if self.num_edges_per_type is not None: content = [['']] content[0] += [ f"({', '.join(edge_type)})" for edge_type in self.num_edges_per_type.keys() ] for field in Stats.__dataclass_fields__: row = [field] + [ f'{getattr(s, field):.1f}' for s in self.num_edges_per_type.values() ] content.append(row) body += "\nNumber of edges per edge type:\n" body += tabulate(content, headers='firstrow', tablefmt=fmt) return body def __repr__(self) -> str: return self.format() ================================================ FILE: torch_geometric/data/temporal.py ================================================ import copy from typing import ( Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, ) import numpy as np import torch from torch import Tensor from torch_geometric.data.data import BaseData, size_repr from torch_geometric.data.storage import ( BaseStorage, EdgeStorage, GlobalStorage, NodeStorage, ) class TemporalData(BaseData): r"""A data object composed by a stream of events describing a temporal graph. The :class:`~torch_geometric.data.TemporalData` object can hold a list of events (that can be understood as temporal edges in a graph) with structured messages. An event is composed by a source node, a destination node, a timestamp and a message. Any *Continuous-Time Dynamic Graph* (CTDG) can be represented with these four values. In general, :class:`~torch_geometric.data.TemporalData` tries to mimic the behavior of a regular :python:`Python` dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities. .. code-block:: python from torch import Tensor from torch_geometric.data import TemporalData events = TemporalData( src=Tensor([1,2,3,4]), dst=Tensor([2,3,4,5]), t=Tensor([1000,1010,1100,2000]), msg=Tensor([1,1,0,0]) ) # Add additional arguments to `events`: events.y = Tensor([1,1,0,0]) # It is also possible to set additional arguments in the constructor events = TemporalData( ..., y=Tensor([1,1,0,0]) ) # Get the number of events: events.num_events >>> 4 # Analyzing the graph structure: events.num_nodes >>> 5 # PyTorch tensor functionality: events = events.pin_memory() events = events.to('cuda:0', non_blocking=True) Args: src (torch.Tensor, optional): A list of source nodes for the events with shape :obj:`[num_events]`. (default: :obj:`None`) dst (torch.Tensor, optional): A list of destination nodes for the events with shape :obj:`[num_events]`. (default: :obj:`None`) t (torch.Tensor, optional): The timestamps for each event with shape :obj:`[num_events]`. (default: :obj:`None`) msg (torch.Tensor, optional): Messages feature matrix with shape :obj:`[num_events, num_msg_features]`. (default: :obj:`None`) **kwargs (optional): Additional attributes. .. note:: The shape of :obj:`src`, :obj:`dst`, :obj:`t` and the first dimension of :obj`msg` should be the same (:obj:`num_events`). """ def __init__( self, src: Optional[Tensor] = None, dst: Optional[Tensor] = None, t: Optional[Tensor] = None, msg: Optional[Tensor] = None, **kwargs, ): super().__init__() self.__dict__['_store'] = GlobalStorage(_parent=self) self.src = src self.dst = dst self.t = t self.msg = msg for key, value in kwargs.items(): setattr(self, key, value) @classmethod def from_dict(cls, mapping: Dict[str, Any]) -> 'TemporalData': r"""Creates a :class:`~torch_geometric.data.TemporalData` object from a Python dictionary. """ return cls(**mapping) def index_select(self, idx: Any) -> 'TemporalData': idx = prepare_idx(idx) data = copy.copy(self) for key, value in data._store.items(): if value.size(0) == self.num_events: data[key] = value[idx] return data def __getitem__(self, idx: Any) -> Any: if isinstance(idx, str): return self._store[idx] return self.index_select(idx) def __setitem__(self, key: str, value: Any): """Sets the attribute :obj:`key` to :obj:`value`.""" self._store[key] = value def __delitem__(self, key: str): if key in self._store: del self._store[key] def __getattr__(self, key: str) -> Any: if '_store' not in self.__dict__: raise RuntimeError( "The 'data' object was created by an older version of PyG. " "If this error occurred while loading an already existing " "dataset, remove the 'processed/' directory in the dataset's " "root folder and try again.") return getattr(self._store, key) def __setattr__(self, key: str, value: Any): setattr(self._store, key, value) def __delattr__(self, key: str): delattr(self._store, key) def __iter__(self) -> Iterable: for i in range(self.num_events): yield self[i] def __len__(self) -> int: return self.num_events def __call__(self, *args: List[str]) -> Iterable: yield from self._store.items(*args) def __copy__(self): out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = value out.__dict__['_store'] = copy.copy(self._store) out._store._parent = out return out def __deepcopy__(self, memo): out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = copy.deepcopy(value, memo) out._store._parent = out return out def stores_as(self, data: 'TemporalData'): return self @property def stores(self) -> List[BaseStorage]: return [self._store] @property def node_stores(self) -> List[NodeStorage]: return [self._store] @property def edge_stores(self) -> List[EdgeStorage]: return [self._store] def to_dict(self) -> Dict[str, Any]: return self._store.to_dict() def to_namedtuple(self) -> NamedTuple: return self._store.to_namedtuple() def debug(self): pass # TODO @property def num_nodes(self) -> int: r"""Returns the number of nodes in the graph.""" return max(int(self.src.max()), int(self.dst.max())) + 1 @property def num_events(self) -> int: r"""Returns the number of events loaded. .. note:: In a :class:`~torch_geometric.data.TemporalData`, each row denotes an event. Thus, they can be also understood as edges. """ return self.src.size(0) @property def num_edges(self) -> int: r"""Alias for :meth:`~torch_geometric.data.TemporalData.num_events`.""" return self.num_events @property def edge_index(self) -> Tensor: r"""Returns the edge indices of the graph.""" if 'edge_index' in self: return self._store['edge_index'] if self.src is not None and self.dst is not None: return torch.stack([self.src, self.dst], dim=0) raise ValueError(f"{self.__class__.__name__} does not contain " f"'edge_index' information") def size( self, dim: Optional[int] = None ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: r"""Returns the size of the adjacency matrix induced by the graph.""" size = (int(self.src.max()), int(self.dst.max())) return size if dim is None else size[dim] def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: return 0 def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: if 'batch' in key and isinstance(value, Tensor): return int(value.max()) + 1 elif key in ['src', 'dst']: return self.num_nodes else: return 0 def __repr__(self) -> str: cls = self.__class__.__name__ info = ', '.join([size_repr(k, v) for k, v in self._store.items()]) return f'{cls}({info})' ########################################################################### def train_val_test_split(self, val_ratio: float = 0.15, test_ratio: float = 0.15): r"""Splits the data in training, validation and test sets according to time. Args: val_ratio (float, optional): The proportion (in percents) of the dataset to include in the validation split. (default: :obj:`0.15`) test_ratio (float, optional): The proportion (in percents) of the dataset to include in the test split. (default: :obj:`0.15`) """ val_time, test_time = np.quantile( self.t.cpu().numpy(), [1. - val_ratio - test_ratio, 1. - test_ratio]) val_idx = int((self.t <= val_time).sum()) test_idx = int((self.t <= test_time).sum()) return self[:val_idx], self[val_idx:test_idx], self[test_idx:] ########################################################################### def coalesce(self): raise NotImplementedError def has_isolated_nodes(self) -> bool: raise NotImplementedError def has_self_loops(self) -> bool: raise NotImplementedError def is_undirected(self) -> bool: raise NotImplementedError def is_directed(self) -> bool: raise NotImplementedError ############################################################################### def prepare_idx(idx): if isinstance(idx, int): return slice(idx, idx + 1) if isinstance(idx, (list, tuple)): return torch.tensor(idx) elif isinstance(idx, slice): return idx elif isinstance(idx, torch.Tensor) and idx.dtype == torch.long: return idx elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool: return idx raise IndexError( f"Only strings, integers, slices (`:`), list, tuples, and long or " f"bool tensors are valid indices (got '{type(idx).__name__}')") ================================================ FILE: torch_geometric/data/view.py ================================================ from typing import Any, Iterator, List, Mapping, Tuple class MappingView: def __init__(self, mapping: Mapping[str, Any], *args: str): self._mapping = mapping self._args = args def _keys(self) -> List[str]: if len(self._args) == 0: return list(self._mapping.keys()) else: return [arg for arg in self._args if arg in self._mapping] def __len__(self) -> int: return len(self._keys()) def __repr__(self) -> str: mapping = {key: self._mapping[key] for key in self._keys()} return f'{self.__class__.__name__}({mapping})' __class_getitem__ = classmethod(type([])) # type: ignore class KeysView(MappingView): def __iter__(self) -> Iterator[str]: yield from self._keys() class ValuesView(MappingView): def __iter__(self) -> Iterator[Any]: for key in self._keys(): yield self._mapping[key] class ItemsView(MappingView): def __iter__(self) -> Iterator[Tuple[str, Any]]: for key in self._keys(): yield (key, self._mapping[key]) ================================================ FILE: torch_geometric/datasets/__init__.py ================================================ # flake8: noqa from .karate import KarateClub from .tu_dataset import TUDataset from .gnn_benchmark_dataset import GNNBenchmarkDataset from .planetoid import Planetoid from .nell import NELL from .citation_full import CitationFull, CoraFull from .coauthor import Coauthor from .amazon import Amazon from .ppi import PPI from .reddit import Reddit from .reddit2 import Reddit2 from .flickr import Flickr from .yelp import Yelp from .amazon_products import AmazonProducts from .qm7 import QM7b from .qm9 import QM9 from .md17 import MD17 from .zinc import ZINC from .aqsol import AQSOL from .molecule_net import MoleculeNet from .pcqm4m import PCQM4Mv2 from .entities import Entities from .rel_link_pred_dataset import RelLinkPredDataset from .ged_dataset import GEDDataset from .attributed_graph_dataset import AttributedGraphDataset from .mnist_superpixels import MNISTSuperpixels from .faust import FAUST from .dynamic_faust import DynamicFAUST from .shapenet import ShapeNet from .modelnet import ModelNet from .medshapenet import MedShapeNet from .coma import CoMA from .shrec2016 import SHREC2016 from .tosca import TOSCA from .pcpnet_dataset import PCPNetDataset from .s3dis import S3DIS from .geometry import GeometricShapes from .bitcoin_otc import BitcoinOTC from .gdelt_lite import GDELTLite from .icews import ICEWS18 from .gdelt import GDELT from .willow_object_class import WILLOWObjectClass from .pascal import PascalVOCKeypoints from .pascal_pf import PascalPF from .snap_dataset import SNAPDataset from .suite_sparse import SuiteSparseMatrixCollection from .word_net import WordNet18, WordNet18RR from .freebase import FB15k_237 from .wikics import WikiCS from .webkb import WebKB from .wikipedia_network import WikipediaNetwork from .heterophilous_graph_dataset import HeterophilousGraphDataset from .actor import Actor from .upfd import UPFD from .github import GitHub from .facebook import FacebookPagePage from .lastfm_asia import LastFMAsia from .deezer_europe import DeezerEurope from .gemsec import GemsecDeezer from .twitch import Twitch from .airports import Airports from .lrgb import LRGBDataset from .malnet_tiny import MalNetTiny from .omdb import OMDB from .polblogs import PolBlogs from .email_eu_core import EmailEUCore from .linkx_dataset import LINKXDataset from .elliptic import EllipticBitcoinDataset from .elliptic_temporal import EllipticBitcoinTemporalDataset from .dgraph import DGraphFin from .hydro_net import HydroNet from .airfrans import AirfRANS from .jodie import JODIEDataset from .wikidata import Wikidata5M from .myket import MyketDataset from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset, CWQDataset from .git_mol_dataset import GitMolDataset from .molecule_gpt_dataset import MoleculeGPTDataset from .instruct_mol_dataset import InstructMolDataset from .protein_mpnn_dataset import ProteinMPNNDataset from .tag_dataset import TAGDataset from .city import CityNetwork from .teeth3ds import Teeth3DS from .dbp15k import DBP15K from .aminer import AMiner from .ogb_mag import OGB_MAG from .dblp import DBLP from .movie_lens import MovieLens from .movie_lens_100k import MovieLens100K from .movie_lens_1m import MovieLens1M from .imdb import IMDB from .last_fm import LastFM from .hgb_dataset import HGBDataset from .taobao import Taobao from .igmc_dataset import IGMCDataset from .amazon_book import AmazonBook from .hm import HM from .ose_gvcs import OSE_GVCS from .rcdd import RCDD from .opf import OPFDataset from .cornell import CornellTemporalHyperGraphDataset from .fake import FakeDataset, FakeHeteroDataset from .sbm_dataset import StochasticBlockModelDataset from .sbm_dataset import RandomPartitionGraphDataset from .mixhop_synthetic_dataset import MixHopSyntheticDataset from .explainer_dataset import ExplainerDataset from .infection_dataset import InfectionDataset from .ba2motif_dataset import BA2MotifDataset from .ba_multi_shapes import BAMultiShapesDataset from .ba_shapes import BAShapes import torch_geometric.datasets.utils homo_datasets = [ 'KarateClub', 'TUDataset', 'GNNBenchmarkDataset', 'Planetoid', 'NELL', 'CitationFull', 'CoraFull', 'Coauthor', 'Amazon', 'PPI', 'Reddit', 'Reddit2', 'Flickr', 'Yelp', 'AmazonProducts', 'QM7b', 'QM9', 'MD17', 'ZINC', 'AQSOL', 'MoleculeNet', 'PCQM4Mv2', 'Entities', 'RelLinkPredDataset', 'GEDDataset', 'AttributedGraphDataset', 'MNISTSuperpixels', 'FAUST', 'DynamicFAUST', 'ShapeNet', 'ModelNet', 'MedShapeNet', 'CoMA', 'SHREC2016', 'TOSCA', 'PCPNetDataset', 'S3DIS', 'GeometricShapes', 'BitcoinOTC', 'GDELTLite', 'ICEWS18', 'GDELT', 'WILLOWObjectClass', 'PascalVOCKeypoints', 'PascalPF', 'SNAPDataset', 'SuiteSparseMatrixCollection', 'WordNet18', 'WordNet18RR', 'FB15k_237', 'WikiCS', 'WebKB', 'WikipediaNetwork', 'HeterophilousGraphDataset', 'Actor', 'UPFD', 'GitHub', 'FacebookPagePage', 'LastFMAsia', 'DeezerEurope', 'GemsecDeezer', 'Twitch', 'Airports', 'LRGBDataset', 'MalNetTiny', 'OMDB', 'PolBlogs', 'EmailEUCore', 'LINKXDataset', 'EllipticBitcoinDataset', 'EllipticBitcoinTemporalDataset', 'DGraphFin', 'HydroNet', 'AirfRANS', 'JODIEDataset', 'Wikidata5M', 'MyketDataset', 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', 'CWQDataset', 'GitMolDataset', 'MoleculeGPTDataset', 'InstructMolDataset', 'ProteinMPNNDataset', 'TAGDataset', 'CityNetwork', 'Teeth3DS', ] hetero_datasets = [ 'DBP15K', 'AMiner', 'OGB_MAG', 'DBLP', 'MovieLens', 'MovieLens100K', 'MovieLens1M', 'IMDB', 'LastFM', 'HGBDataset', 'Taobao', 'IGMCDataset', 'AmazonBook', 'HM', 'OSE_GVCS', 'RCDD', 'OPFDataset', ] hyper_datasets = [ 'CornellTemporalHyperGraphDataset', ] synthetic_datasets = [ 'FakeDataset', 'FakeHeteroDataset', 'StochasticBlockModelDataset', 'RandomPartitionGraphDataset', 'MixHopSyntheticDataset', 'ExplainerDataset', 'InfectionDataset', 'BA2MotifDataset', 'BAMultiShapesDataset', 'BAShapes', ] __all__ = homo_datasets + hetero_datasets + hyper_datasets + synthetic_datasets ================================================ FILE: torch_geometric/datasets/actor.py ================================================ from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.utils import coalesce class Actor(InMemoryDataset): r"""The actor-only induced subgraph of the film-director-actor-writer network used in the `"Geom-GCN: Geometric Graph Convolutional Networks" `_ paper. Each node corresponds to an actor, and the edge between two nodes denotes co-occurrence on the same Wikipedia page. Node features correspond to some keywords in the Wikipedia pages. The task is to classify the nodes into five categories in term of words of actor's Wikipedia. Args: root: Root directory where the dataset should be saved. transform: A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. force_reload: Whether to re-process the dataset. **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 7,600 - 30,019 - 932 - 5 """ url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['out1_node_feature_label.txt', 'out1_graph_edges.txt' ] + [f'film_split_0.6_0.2_{i}.npz' for i in range(10)] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for f in self.raw_file_names[:2]: download_url(f'{self.url}/new_data/film/{f}', self.raw_dir) for f in self.raw_file_names[2:]: download_url(f'{self.url}/splits/{f}', self.raw_dir) def process(self) -> None: with open(self.raw_paths[0]) as f: node_data = [x.split('\t') for x in f.read().split('\n')[1:-1]] rows, cols = [], [] for n_id, line, _ in node_data: indices = [int(x) for x in line.split(',')] rows += [int(n_id)] * len(indices) cols += indices row, col = torch.tensor(rows), torch.tensor(cols) x = torch.zeros(int(row.max()) + 1, int(col.max()) + 1) x[row, col] = 1. y = torch.empty(len(node_data), dtype=torch.long) for n_id, _, label in node_data: y[int(n_id)] = int(label) with open(self.raw_paths[1]) as f: edge_data = f.read().split('\n')[1:-1] edge_indices = [[int(v) for v in r.split('\t')] for r in edge_data] edge_index = torch.tensor(edge_indices).t().contiguous() edge_index = coalesce(edge_index, num_nodes=x.size(0)) train_masks, val_masks, test_masks = [], [], [] for path in self.raw_paths[2:]: tmp = np.load(path) train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)] val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)] test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)] train_mask = torch.stack(train_masks, dim=1) val_mask = torch.stack(val_masks, dim=1) test_mask = torch.stack(test_masks, dim=1) data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/airfrans.py ================================================ import json import os from typing import Callable, List, Optional from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class AirfRANS(InMemoryDataset): r"""The AirfRANS dataset from the `"AirfRANS: High Fidelity Computational Fluid Dynamics Dataset for Approximating Reynolds-Averaged Navier-Stokes Solutions" `_ paper, consisting of 1,000 simulations of steady-state aerodynamics over 2D airfoils in a subsonic flight regime. The different tasks (:obj:`"full"`, :obj:`"scarce"`, :obj:`"reynolds"`, :obj:`"aoa"`) define the utilized training and test splits. Each simulation is given as a point cloud defined as the nodes of the simulation mesh. Each point of a point cloud is described via 5 features: the inlet velocity (two components in meter per second), the distance to the airfoil (one component in meter), and the normals (two components in meter, set to :obj:`0` if the point is not on the airfoil). Each point is given a target of 4 components for the underlying regression task: the velocity (two components in meter per second), the pressure divided by the specific mass (one component in meter squared per second squared), the turbulent kinematic viscosity (one component in meter squared per second). Finally, a boolean is attached to each point to inform if this point lies on the airfoil or not. A library for manipulating simulations of the dataset is available `here `_. The dataset is released under the `ODbL v1.0 License `_. .. note:: Data objects contain no edge indices to be agnostic to the simulation mesh. You are free to build a graph via the :obj:`torch_geometric.transforms.RadiusGraph` transform. Args: root: Root directory where the dataset should be saved. task: The task to study (:obj:`"full"`, :obj:`"scarce"`, :obj:`"reynolds"`, :obj:`"aoa"`) that defines the utilized training and test splits. train: If :obj:`True`, loads the training dataset, otherwise the test dataset. transform: A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. pre_filter: A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. force_reload: Whether to re-process the dataset. **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #tasks * - 1,000 - ~180,000 - 0 - 5 - 4 """ url = 'https://data.isir.upmc.fr/extrality/pytorch_geometric/AirfRANS.zip' tasks = ['full', 'scarce', 'reynolds', 'aoa'] def __init__( self, root: str, task: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: if task not in self.tasks: raise ValueError(f"Expected 'task' to be in {self.tasks} " f"got '{task}'") self.task = 'full' if task == 'scarce' and not train else task self.split = 'train' if train else 'test' super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['AirfRANS.pt', 'manifest.json'] @property def processed_file_names(self) -> str: return f'{self.task}_{self.split}.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: with open(self.raw_paths[1]) as f: manifest = json.load(f) total = manifest['full_train'] + manifest['full_test'] partial = set(manifest[f'{self.task}_{self.split}']) data_list = [] raw_data = fs.torch_load(self.raw_paths[0]) for k, s in enumerate(total): if s in partial: data = Data(**raw_data[k]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'task={self.task}, split={self.split})') ================================================ FILE: torch_geometric/datasets/airports.py ================================================ import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.utils import coalesce class Airports(InMemoryDataset): r"""The Airports dataset from the `"struc2vec: Learning Node Representations from Structural Identity" `_ paper, where nodes denote airports and labels correspond to activity levels. Features are given by one-hot encoded node identifiers, as described in the `"GraLSP: Graph Neural Networks with Local Structural Patterns" `_ paper. Args: root: Root directory where the dataset should be saved. name: The name of the dataset (:obj:`"USA"`, :obj:`"Brazil"`, :obj:`"Europe"`). transform: A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform (callable, optional): A function/transform that takes in :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. force_reload: Whether to re-process the dataset. """ edge_url = ('https://github.com/leoribeiro/struc2vec/' 'raw/master/graph/{}-airports.edgelist') label_url = ('https://github.com/leoribeiro/struc2vec/' 'raw/master/graph/labels-{}-airports.txt') def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in ['usa', 'brazil', 'europe'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: return [ f'{self.name}-airports.edgelist', f'labels-{self.name}-airports.txt', ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.edge_url.format(self.name), self.raw_dir) download_url(self.label_url.format(self.name), self.raw_dir) def process(self) -> None: index_map, ys = {}, [] with open(self.raw_paths[1]) as f: rows = f.read().split('\n')[1:-1] for i, row in enumerate(rows): idx, label = row.split() index_map[int(idx)] = i ys.append(int(label)) y = torch.tensor(ys, dtype=torch.long) x = torch.eye(y.size(0)) edge_indices = [] with open(self.raw_paths[0]) as f: rows = f.read().split('\n')[:-1] for row in rows: src, dst = row.split() edge_indices.append([index_map[int(src)], index_map[int(dst)]]) edge_index = torch.tensor(edge_indices).t().contiguous() edge_index = coalesce(edge_index, num_nodes=y.size(0)) data = Data(x=x, edge_index=edge_index, y=y) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name.capitalize()}Airports()' ================================================ FILE: torch_geometric/datasets/amazon.py ================================================ import os.path as osp from typing import Callable, Optional from torch_geometric.data import InMemoryDataset, download_url from torch_geometric.io import read_npz class Amazon(InMemoryDataset): r"""The Amazon Computers and Amazon Photo networks from the `"Pitfalls of Graph Neural Network Evaluation" `_ paper. Nodes represent goods and edges represent that two goods are frequently bought together. Given product reviews as bag-of-words node features, the task is to map goods to their respective product category. Args: root: Root directory where the dataset should be saved. name: The name of the dataset (:obj:`"Computers"`, :obj:`"Photo"`). transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. force_reload: Whether to re-process the dataset. **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Computers - 13,752 - 491,722 - 767 - 10 * - Photo - 7,650 - 238,162 - 745 - 8 """ url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/' def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in ['computers', 'photo'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name.capitalize(), 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name.capitalize(), 'processed') @property def raw_file_names(self) -> str: return f'amazon_electronics_{self.name.lower()}.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url + self.raw_file_names, self.raw_dir) def process(self) -> None: data = read_npz(self.raw_paths[0], to_undirected=True) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}{self.name.capitalize()}()' ================================================ FILE: torch_geometric/datasets/amazon_book.py ================================================ from typing import Callable, List, Optional import torch from torch_geometric.data import HeteroData, InMemoryDataset, download_url class AmazonBook(InMemoryDataset): r"""A subset of the AmazonBook rating dataset from the `"LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation" `_ paper. This is a heterogeneous dataset consisting of 52,643 users and 91,599 books with approximately 2.9 million ratings between them. No labels or features are provided. Args: root: Root directory where the dataset should be saved. transform: A function/transform that takes in an :class:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in an :class:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. force_reload: Whether to re-process the dataset. """ url = ('https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/' 'master/data/amazon-book') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return ['user_list.txt', 'item_list.txt', 'train.txt', 'test.txt'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for name in self.raw_file_names: download_url(f'{self.url}/{name}', self.raw_dir) def process(self) -> None: import pandas as pd data = HeteroData() # Process number of nodes for each node type: node_types = ['user', 'book'] for path, node_type in zip(self.raw_paths, node_types): df = pd.read_csv(path, sep=' ', header=0) data[node_type].num_nodes = len(df) # Process edge information for training and testing: attr_names = ['edge_index', 'edge_label_index'] for path, attr_name in zip(self.raw_paths[2:], attr_names): rows, cols = [], [] with open(path) as f: lines = f.readlines() for line in lines: indices = line.strip().split(' ') for dst in indices[1:]: rows.append(int(indices[0])) cols.append(int(dst)) index = torch.tensor([rows, cols]) data['user', 'rates', 'book'][attr_name] = index if attr_name == 'edge_index': data['book', 'rated_by', 'user'][attr_name] = index.flip([0]) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/amazon_products.py ================================================ import json import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_google_url class AmazonProducts(InMemoryDataset): r"""The Amazon dataset from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" `_ paper, containing products and its categories. Args: root: Root directory where the dataset should be saved. transform: A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. force_reload: Whether to re-process the dataset. **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 1,569,960 - 264,339,468 - 200 - 107 """ adj_full_id = '17qhNA8H1IpbkkR-T2BmPQm8QNW5do-aa' feats_id = '10SW8lCvAj-kb6ckkfTOC5y0l8XXdtMxj' class_map_id = '1LIl4kimLfftj4-7NmValuWyCQE8AaE7P' role_id = '1npK9xlmbnjNkV80hK2Q68wTEVOFjnt4K' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz') download_google_url(self.feats_id, self.raw_dir, 'feats.npy') download_google_url(self.class_map_id, self.raw_dir, 'class_map.json') download_google_url(self.role_id, self.raw_dir, 'role.json') def process(self) -> None: import scipy.sparse as sp f = np.load(osp.join(self.raw_dir, 'adj_full.npz')) adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape']) adj = adj.tocoo() row = torch.from_numpy(adj.row).to(torch.long) col = torch.from_numpy(adj.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) x = np.load(osp.join(self.raw_dir, 'feats.npy')) x = torch.from_numpy(x).to(torch.float) ys = [-1] * x.size(0) with open(osp.join(self.raw_dir, 'class_map.json')) as f: class_map = json.load(f) for key, item in class_map.items(): ys[int(key)] = item y = torch.tensor(ys) with open(osp.join(self.raw_dir, 'role.json')) as f: role = json.load(f) train_mask = torch.zeros(x.size(0), dtype=torch.bool) train_mask[torch.tensor(role['tr'])] = True val_mask = torch.zeros(x.size(0), dtype=torch.bool) val_mask[torch.tensor(role['va'])] = True test_mask = torch.zeros(x.size(0), dtype=torch.bool) test_mask[torch.tensor(role['te'])] = True data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/aminer.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs from torch_geometric.utils import coalesce class AMiner(InMemoryDataset): r"""The heterogeneous AMiner dataset from the `"metapath2vec: Scalable Representation Learning for Heterogeneous Networks" `_ paper, consisting of nodes from type :obj:`"paper"`, :obj:`"author"` and :obj:`"venue"`. Venue categories and author research interests are available as ground truth labels for a subset of nodes. Args: root: Root directory where the dataset should be saved. transform: A function/transform that takes in a :class:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in a :class:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. force_reload: Whether to re-process the dataset. """ url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1' y_url = 'https://www.dropbox.com/s/nkocx16rpl4ydde/label.zip?dl=1' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'id_author.txt', 'id_conf.txt', 'paper.txt', 'paper_author.txt', 'paper_conf.txt', 'label' ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: fs.rm(self.raw_dir) path = download_url(self.url, self.root) extract_zip(path, self.root) os.rename(osp.join(self.root, 'net_aminer'), self.raw_dir) os.unlink(path) path = download_url(self.y_url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: import pandas as pd data = HeteroData() # Get author labels. path = osp.join(self.raw_dir, 'id_author.txt') author = pd.read_csv(path, sep='\t', names=['idx', 'name'], index_col=1) path = osp.join(self.raw_dir, 'label', 'googlescholar.8area.author.label.txt') df = pd.read_csv(path, sep=' ', names=['name', 'y']) df = df.join(author, on='name') data['author'].y = torch.from_numpy(df['y'].values) - 1 data['author'].y_index = torch.from_numpy(df['idx'].values) # Get venue labels. path = osp.join(self.raw_dir, 'id_conf.txt') venue = pd.read_csv(path, sep='\t', names=['idx', 'name'], index_col=1) path = osp.join(self.raw_dir, 'label', 'googlescholar.8area.venue.label.txt') df = pd.read_csv(path, sep=' ', names=['name', 'y']) df = df.join(venue, on='name') data['venue'].y = torch.from_numpy(df['y'].values) - 1 data['venue'].y_index = torch.from_numpy(df['idx'].values) # Get paper<->author connectivity. path = osp.join(self.raw_dir, 'paper_author.txt') paper_author = pd.read_csv(path, sep='\t', header=None) paper_author = torch.from_numpy(paper_author.values) paper_author = paper_author.t().contiguous() M, N = int(paper_author[0].max() + 1), int(paper_author[1].max() + 1) paper_author = coalesce(paper_author, num_nodes=max(M, N)) data['paper'].num_nodes = M data['author'].num_nodes = N data['paper', 'written_by', 'author'].edge_index = paper_author data['author', 'writes', 'paper'].edge_index = paper_author.flip([0]) # Get paper<->venue connectivity. path = osp.join(self.raw_dir, 'paper_conf.txt') paper_venue = pd.read_csv(path, sep='\t', header=None) paper_venue = torch.from_numpy(paper_venue.values) paper_venue = paper_venue.t().contiguous() M, N = int(paper_venue[0].max() + 1), int(paper_venue[1].max() + 1) paper_venue = coalesce(paper_venue, num_nodes=max(M, N)) data['venue'].num_nodes = N data['paper', 'published_in', 'venue'].edge_index = paper_venue data['venue', 'publishes', 'paper'].edge_index = paper_venue.flip([0]) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/aqsol.py ================================================ import os import os.path as osp import pickle from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class AQSOL(InMemoryDataset): r"""The AQSOL dataset from the `Benchmarking Graph Neural Networks `_ paper based on `AqSolDB `_, a standardized database of 9,982 molecular graphs with their aqueous solubility values, collected from 9 different data sources. The aqueous solubility targets are collected from experimental measurements and standardized to LogS units in AqSolDB. These final values denote the property to regress in the :class:`AQSOL` dataset. After filtering out few graphs with no bonds/edges, the total number of molecular graphs is 9,833. For each molecular graph, the node features are the types of heavy atoms and the edge features are the types of bonds between them, similar as in the :class:`~torch_geometric.datasets.ZINC` dataset. Args: root: Root directory where the dataset should be saved. split: If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. pre_filter (callable, optional): A function that takes in an :class:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. force_reload: Whether to re-process the dataset. **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 9,833 - ~17.6 - ~35.8 - 1 - 1 """ url = 'https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1' def __init__( self, root: str, split: str = 'train', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ): assert split in ['train', 'val', 'test'] super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = osp.join(self.processed_dir, f'{split}.pt') self.load(path) @property def raw_file_names(self) -> List[str]: return [ 'train.pickle', 'val.pickle', 'test.pickle', 'atom_dict.pickle', 'bond_dict.pickle' ] @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: fs.rm(self.raw_dir) path = download_url(self.url, self.root) extract_zip(path, self.root) os.rename(osp.join(self.root, 'asqol_graph_raw'), self.raw_dir) os.unlink(path) def process(self) -> None: for raw_path, path in zip(self.raw_paths, self.processed_paths): with open(raw_path, 'rb') as f: graphs = pickle.load(f) data_list: List[Data] = [] for graph in graphs: x, edge_attr, edge_index, y = graph x = torch.from_numpy(x) edge_attr = torch.from_numpy(edge_attr) edge_index = torch.from_numpy(edge_index) y = torch.tensor([y]).float() if edge_index.numel() == 0: continue # Skipping for graphs with no bonds/edges. data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, path) def atoms(self) -> List[str]: return [ 'Br', 'C', 'N', 'O', 'Cl', 'Zn', 'F', 'P', 'S', 'Na', 'Al', 'Si', 'Mo', 'Ca', 'W', 'Pb', 'B', 'V', 'Co', 'Mg', 'Bi', 'Fe', 'Ba', 'K', 'Ti', 'Sn', 'Cd', 'I', 'Re', 'Sr', 'H', 'Cu', 'Ni', 'Lu', 'Pr', 'Te', 'Ce', 'Nd', 'Gd', 'Zr', 'Mn', 'As', 'Hg', 'Sb', 'Cr', 'Se', 'La', 'Dy', 'Y', 'Pd', 'Ag', 'In', 'Li', 'Rh', 'Nb', 'Hf', 'Cs', 'Ru', 'Au', 'Sm', 'Ta', 'Pt', 'Ir', 'Be', 'Ge' ] def bonds(self) -> List[str]: return ['NONE', 'SINGLE', 'DOUBLE', 'AROMATIC', 'TRIPLE'] ================================================ FILE: torch_geometric/datasets/attributed_graph_dataset.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_google_url, extract_zip, ) from torch_geometric.io import fs class AttributedGraphDataset(InMemoryDataset): r"""A variety of attributed graph datasets from the `"Scaling Attributed Network Embedding to Massive Graphs" `_ paper. Args: root: Root directory where the dataset should be saved. name: The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`, :obj:`"CiteSeer"`, :obj:`"PubMed"`, :obj:`"BlogCatalog"`, :obj:`"PPI"`, :obj:`"Flickr"`, :obj:`"Facebook"`, :obj:`"Twitter"`, :obj:`"TWeibo"`, :obj:`"MAG"`). transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. force_reload: Whether to re-process the dataset. **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Wiki - 2,405 - 17,981 - 4,973 - 17 * - Cora - 2,708 - 5,429 - 1,433 - 7 * - CiteSeer - 3,312 - 4,715 - 3,703 - 6 * - PubMed - 19,717 - 44,338 - 500 - 3 * - BlogCatalog - 5,196 - 343,486 - 8,189 - 6 * - PPI - 56,944 - 1,612,348 - 50 - 121 * - Flickr - 7,575 - 479,476 - 12,047 - 9 * - Facebook - 4,039 - 88,234 - 1,283 - 193 * - TWeibo - 2,320,895 - 9,840,066 - 1,657 - 8 * - MAG - 59,249,719 - 978,147,253 - 2,000 - 100 """ datasets = { 'wiki': '1EPhlbziZTQv19OsTrKrAJwsElbVPEbiV', 'cora': '1FyVnpdsTT-lhkVPotUW8OVeuCi1vi3Ey', 'citeseer': '1d3uQIpHiemWJPgLgTafi70RFYye7hoCp', 'pubmed': '1DOK3FfslyJoGXUSCSrK5lzdyLfIwOz6k', 'blogcatalog': '178PqGqh67RUYMMP6-SoRHDoIBh8ku5FS', 'ppi': '1dvwRpPT4gGtOcNP_Q-G1TKl9NezYhtez', 'flickr': '1tZp3EB20fAC27SYWwa-x66_8uGsuU62X', 'facebook': '12aJWAGCM4IvdGI2fiydDNyWzViEOLZH8', 'twitter': '1fUYggzZlDrt9JsLsSdRUHiEzQRW1kSA4', 'tweibo': '1-2xHDPFCsuBuFdQN_7GLleWa8R_t50qU', 'mag': '1ggraUMrQgdUyA3DjSRzzqMv0jFkU65V5', } def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in self.datasets.keys() super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: return ['attrs.npz', 'edgelist.txt', 'labels.txt'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: id = self.datasets[self.name] path = download_google_url(id, self.raw_dir, 'data.zip') extract_zip(path, self.raw_dir) os.unlink(path) path = osp.join(self.raw_dir, f'{self.name}.attr') if self.name == 'mag': path = osp.join(self.raw_dir, self.name) for name in self.raw_file_names: os.rename(osp.join(path, name), osp.join(self.raw_dir, name)) fs.rm(path) def process(self) -> None: import pandas as pd import scipy.sparse as sp x = sp.load_npz(self.raw_paths[0]).tocsr() if x.shape[-1] > 10000 or self.name == 'mag': x = torch.sparse_csr_tensor( crow_indices=x.indptr, col_indices=x.indices, values=x.data, size=x.shape, ) else: x = torch.from_numpy(x.todense()).to(torch.float) df = pd.read_csv(self.raw_paths[1], header=None, sep=None, engine='python') edge_index = torch.from_numpy(df.values).t().contiguous() with open(self.raw_paths[2]) as f: rows = f.read().split('\n')[:-1] ys = [[int(y) - 1 for y in row.split()[1:]] for row in rows] multilabel = max([len(y) for y in ys]) > 1 if not multilabel: y = torch.tensor(ys).view(-1) else: num_classes = max([y for row in ys for y in row]) + 1 y = torch.zeros((len(ys), num_classes), dtype=torch.float) for i, row in enumerate(ys): for j in row: y[i, j] = 1. data = Data(x=x, edge_index=edge_index, y=y) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name.capitalize()}()' ================================================ FILE: torch_geometric/datasets/ba2motif_dataset.py ================================================ import pickle from typing import Callable, List, Optional import torch from torch_geometric.data import Data, InMemoryDataset, download_url class BA2MotifDataset(InMemoryDataset): r"""The synthetic BA-2motifs graph classification dataset for evaluating explainabilty algorithms, as described in the `"Parameterized Explainer for Graph Neural Network" `_ paper. :class:`~torch_geometric.datasets.BA2MotifDataset` contains 1000 random Barabasi-Albert (BA) graphs. Half of the graphs are attached with a :class:`~torch_geometric.datasets.motif_generator.HouseMotif`, and the rest are attached with a five-node :class:`~torch_geometric.datasets.motif_generator.CycleMotif`. The graphs are assigned to one of the two classes according to the type of attached motifs. This dataset is pre-computed from the official implementation. If you want to create own variations of it, you can make use of the :class:`~torch_geometric.datasets.ExplainerDataset`: .. code-block:: python import torch from torch_geometric.datasets import ExplainerDataset from torch_geometric.datasets.graph_generator import BAGraph from torch_geometric.datasets.motif_generator import HouseMotif from torch_geometric.datasets.motif_generator import CycleMotif dataset1 = ExplainerDataset( graph_generator=BAGraph(num_nodes=25, num_edges=1), motif_generator=HouseMotif(), num_motifs=1, num_graphs=500, ) dataset2 = ExplainerDataset( graph_generator=BAGraph(num_nodes=25, num_edges=1), motif_generator=CycleMotif(5), num_motifs=1, num_graphs=500, ) dataset = torch.utils.data.ConcatDataset([dataset1, dataset2]) Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 1000 - 25 - ~51.0 - 10 - 2 """ url = 'https://github.com/flyingdoog/PGExplainer/raw/master/dataset' filename = 'BA-2motif.pkl' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return self.filename @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(f'{self.url}/{self.filename}', self.raw_dir) def process(self) -> None: with open(self.raw_paths[0], 'rb') as f: adj, x, y = pickle.load(f) adjs = torch.from_numpy(adj) xs = torch.from_numpy(x).to(torch.float) ys = torch.from_numpy(y) data_list: List[Data] = [] for i in range(xs.size(0)): edge_index = adjs[i].nonzero().t() x = xs[i] y = int(ys[i].nonzero()) data = Data(x=x, edge_index=edge_index, y=y) if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/ba_multi_shapes.py ================================================ import pickle from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class BAMultiShapesDataset(InMemoryDataset): r"""The synthetic BA-Multi-Shapes graph classification dataset for evaluating explainabilty algorithms, as described in the `"Global Explainability of GNNs via Logic Combination of Learned Concepts" `_ paper. Given three atomic motifs, namely House (H), Wheel (W), and Grid (G), :class:`~torch_geometric.datasets.BAMultiShapesDataset` contains 1,000 graphs where each graph is obtained by attaching the motifs to a random Barabasi-Albert (BA) as follows: * class 0: :math:`\emptyset \lor H \lor W \lor G \lor \{ H, W, G \}` * class 1: :math:`(H \land W) \lor (H \land G) \lor (W \land G)` This dataset is pre-computed from the official implementation. Args: root: Root directory where the dataset should be saved. transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. pre_transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. pre_filter: A function that takes in a :class:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. force_reload: Whether to re-process the dataset. **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 1000 - 40 - ~87.0 - 10 - 2 """ url = ('https://github.com/steveazzolin/gnn_logic_global_expl/raw/master/' 'datasets/BAMultiShapes/BAMultiShapes.pkl') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'BAMultiShapes.pkl' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url, self.raw_dir) def process(self) -> None: with open(self.raw_paths[0], 'rb') as f: adjs, xs, ys = pickle.load(f) data_list: List[Data] = [] for adj, x, y in zip(adjs, xs, ys): edge_index = torch.from_numpy(adj).nonzero().t() x = torch.from_numpy(np.array(x)).to(torch.float) data = Data(x=x, edge_index=edge_index, y=y) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/ba_shapes.py ================================================ from typing import Callable, Optional, Tuple import torch from torch import Tensor from torch_geometric.data import Data, InMemoryDataset from torch_geometric.deprecation import deprecated from torch_geometric.utils import barabasi_albert_graph def house() -> Tuple[Tensor, Tensor]: edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4], [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1]]) label = torch.tensor([1, 1, 2, 2, 3]) return edge_index, label @deprecated("use 'datasets.ExplainerDataset' in combination with " "'datasets.graph_generator.BAGraph' instead") class BAShapes(InMemoryDataset): r"""The BA-Shapes dataset from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" `__ paper, containing a Barabasi-Albert (BA) graph with 300 nodes and a set of 80 "house"-structured graphs connected to it. .. warning:: :class:`BAShapes` is deprecated and will be removed in a future release. Use :class:`ExplainerDataset` in combination with :class:`torch_geometric.datasets.graph_generator.BAGraph` instead. Args: connection_distribution: Specifies how the houses and the BA graph get connected. Valid inputs are :obj:`"random"` (random BA graph nodes are selected for connection to the houses), and :obj:`"uniform"` (uniformly distributed BA graph nodes are selected for connection to the houses). transform: A function/transform that takes in a :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. """ def __init__( self, connection_distribution: str = "random", transform: Optional[Callable] = None, ) -> None: super().__init__(None, transform) assert connection_distribution in ['random', 'uniform'] # Build the Barabasi-Albert graph: num_nodes = 300 edge_index = barabasi_albert_graph(num_nodes, num_edges=5) edge_label = torch.zeros(edge_index.size(1), dtype=torch.int64) node_label = torch.zeros(num_nodes, dtype=torch.int64) # Select nodes to connect shapes: num_houses = 80 if connection_distribution == 'random': connecting_nodes = torch.randperm(num_nodes)[:num_houses] else: step = num_nodes // num_houses connecting_nodes = torch.arange(0, num_nodes, step) # Connect houses to Barabasi-Albert graph: edge_indices = [edge_index] edge_labels = [edge_label] node_labels = [node_label] for i in range(num_houses): house_edge_index, house_label = house() edge_indices.append(house_edge_index + num_nodes) edge_indices.append( torch.tensor([[int(connecting_nodes[i]), num_nodes], [num_nodes, int(connecting_nodes[i])]])) edge_labels.append( torch.ones(house_edge_index.size(1), dtype=torch.long)) edge_labels.append(torch.zeros(2, dtype=torch.long)) node_labels.append(house_label) num_nodes += 5 edge_index = torch.cat(edge_indices, dim=1) edge_label = torch.cat(edge_labels, dim=0) node_label = torch.cat(node_labels, dim=0) x = torch.ones((num_nodes, 10), dtype=torch.float) expl_mask = torch.zeros(num_nodes, dtype=torch.bool) expl_mask[torch.arange(400, num_nodes, 5)] = True data = Data(x=x, edge_index=edge_index, y=node_label, expl_mask=expl_mask, edge_label=edge_label) self.data, self.slices = self.collate([data]) ================================================ FILE: torch_geometric/datasets/bitcoin_otc.py ================================================ import datetime import os from typing import Callable, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_gz, ) class BitcoinOTC(InMemoryDataset): r"""The Bitcoin-OTC dataset from the `"EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs" `_ paper, consisting of 138 who-trusts-whom networks of sequential time steps. Args: root (str): Root directory where the dataset should be saved. edge_window_size (int, optional): The window size for the existence of an edge in the graph sequence since its initial creation. (default: :obj:`10`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 138 - 6,005 - ~2,573.2 - 0 - 0 """ url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz' def __init__( self, root: str, edge_window_size: int = 10, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.edge_window_size = edge_window_size super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'soc-sign-bitcoinotc.csv' @property def processed_file_names(self) -> str: return 'data.pt' @property def num_nodes(self) -> int: assert isinstance(self._data, Data) assert self._data.edge_index is not None return int(self._data.edge_index.max()) + 1 def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_gz(path, self.raw_dir) os.unlink(path) def process(self) -> None: with open(self.raw_paths[0]) as f: lines = [[x for x in line.split(',')] for line in f.read().split('\n')[:-1]] edge_indices = [[int(line[0]), int(line[1])] for line in lines] edge_index = torch.tensor(edge_indices, dtype=torch.long) edge_index = edge_index - edge_index.min() edge_index = edge_index.t().contiguous() num_nodes = int(edge_index.max()) + 1 edge_attrs = [int(line[2]) for line in lines] edge_attr = torch.tensor(edge_attrs, dtype=torch.long) stamps = [ datetime.datetime.fromtimestamp(int(float(line[3]))) for line in lines ] offset = datetime.timedelta(days=13.8) # Results in 138 time steps. graph_indices, factor = [], 1 for t in stamps: factor = factor if t < stamps[0] + factor * offset else factor + 1 graph_indices.append(factor - 1) graph_idx = torch.tensor(graph_indices, dtype=torch.long) data_list = [] for i in range(int(graph_idx.max()) + 1): mask = (graph_idx > (i - self.edge_window_size)) & (graph_idx <= i) data = Data() data.edge_index = edge_index[:, mask] data.edge_attr = edge_attr[mask] data.num_nodes = num_nodes data_list.append(data) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/brca_tgca.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class BrcaTcga(InMemoryDataset): r"""The breast cancer (BRCA TCGA Pan-Cancer Atlas) dataset consisting of patients with survival information and gene expression data from `cBioPortal `_ and a network of biological interactions between those nodes from `Pathway Commons `_. The dataset contains the gene features of 1,082 patients, and the overall survival time (in months) of each patient as label. Pre-processing and example model codes on how to use this dataset can be found `here `_. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features * - 1,082 - 9,288 - 271,771 - 1,082 """ url = 'https://zenodo.org/record/8251328/files/brca_tcga.zip?download=1' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['graph_idx.csv', 'graph_labels.csv', 'edge_index.pt'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) fs.rm(self.raw_dir) os.rename(osp.join(self.root, 'brca_tcga'), self.raw_dir) def process(self) -> None: import pandas as pd graph_feat = pd.read_csv(self.raw_paths[0], index_col=0).values graph_feat = torch.from_numpy(graph_feat).to(torch.float) graph_labels = np.loadtxt(self.raw_paths[1], delimiter=',') graph_label = torch.from_numpy(graph_labels).to(torch.float) edge_index = fs.torch_load(self.raw_paths[2]) data_list = [] for x, y in zip(graph_feat, graph_label): data = Data(x=x.view(-1, 1), edge_index=edge_index, y=y) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/citation_full.py ================================================ import os.path as osp from typing import Callable, Optional from torch_geometric.data import InMemoryDataset, download_url from torch_geometric.io import read_npz class CitationFull(InMemoryDataset): r"""The full citation network datasets from the `"Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking" `_ paper. Nodes represent documents and edges represent citation links. Datasets include :obj:`"Cora"`, :obj:`"Cora_ML"`, :obj:`"CiteSeer"`, :obj:`"DBLP"`, :obj:`"PubMed"`. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"Cora"`, :obj:`"Cora_ML"` :obj:`"CiteSeer"`, :obj:`"DBLP"`, :obj:`"PubMed"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) to_undirected (bool, optional): Whether the original graph is converted to an undirected one. (default: :obj:`True`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Cora - 19,793 - 126,842 - 8,710 - 70 * - Cora_ML - 2,995 - 16,316 - 2,879 - 7 * - CiteSeer - 4,230 - 10,674 - 602 - 6 * - DBLP - 17,716 - 105,734 - 1,639 - 4 * - PubMed - 19,717 - 88,648 - 500 - 3 """ url = 'https://github.com/abojchevski/graph2gauss/raw/master/data/{}.npz' def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, to_undirected: bool = True, force_reload: bool = False, ) -> None: self.name = name.lower() self.to_undirected = to_undirected assert self.name in ['cora', 'cora_ml', 'citeseer', 'dblp', 'pubmed'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.name}.npz' @property def processed_file_names(self) -> str: suffix = 'undirected' if self.to_undirected else 'directed' return f'data_{suffix}.pt' def download(self) -> None: download_url(self.url.format(self.name), self.raw_dir) def process(self) -> None: data = read_npz(self.raw_paths[0], to_undirected=self.to_undirected) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name.capitalize()}Full()' class CoraFull(CitationFull): r"""Alias for :class:`~torch_geometric.datasets.CitationFull` with :obj:`name="Cora"`. **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 19,793 - 126,842 - 8,710 - 70 """ def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, ) -> None: super().__init__(root, 'cora', transform, pre_transform) def download(self) -> None: super().download() def process(self) -> None: super().process() ================================================ FILE: torch_geometric/datasets/city.py ================================================ import os.path as osp from typing import Callable, Optional from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_tar, ) from torch_geometric.io import fs class CityNetwork(InMemoryDataset): r"""The City-Networks are introduced in `"Towards Quantifying Long-Range Interactions in Graph Machine Learning: a Large Graph Dataset and a Measurement" `_ paper. The dataset contains four city networks: `paris`, `shanghai`, `la`, and `london`, where nodes represent junctions and edges represent undirected road segments. The task is to predict each node's eccentricity score, which is approximated based on its 16-hop neighborhood and naturally requires long-range information. The score indicates how accessible one node is in the network, and is mapped to 10 quantiles for transductive classification. See the original `source code `_ for more details on the individual networks. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (``"paris"``, ``"shanghai"``, ``"la"``, ``"london"``). augmented (bool, optional): Whether to use the augmented node features from edge features.(default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :class:`~torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :class:`~torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - paris - 114,127 - 182,511 - 37 - 10 * - shanghai - 183,917 - 262,092 - 37 - 10 * - la - 240,587 - 341,523 - 37 - 10 * - london - 568,795 - 756,502 - 37 - 10 """ url = "https://github.com/LeonResearch/City-Networks/raw/refs/heads/main/data/" # noqa: E501 def __init__( self, root: str, name: str, augmented: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, delete_raw: bool = False, ) -> None: self.name = name.lower() assert self.name in ["paris", "shanghai", "la", "london"] self.augmented = augmented self.delete_raw = delete_raw super().__init__( root, transform, pre_transform, force_reload=force_reload, ) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, "raw") @property def processed_dir(self) -> str: return osp.join(self.root, self.name, "processed") @property def raw_file_names(self) -> str: return f"{self.name}.json" @property def processed_file_names(self) -> str: return "data.pt" def download(self) -> None: self.download_path = download_url( self.url + f"{self.name}.tar.gz", self.raw_dir, ) def process(self) -> None: extract_tar(self.download_path, self.raw_dir) data_path = osp.join(self.raw_dir, self.name) node_feat = fs.torch_load( osp.join( data_path, f"node_features{'_augmented' if self.augmented else ''}.pt", )) edge_index = fs.torch_load(osp.join(data_path, "edge_indices.pt")) label = fs.torch_load( osp.join(data_path, "10-chunk_16-hop_node_labels.pt")) train_mask = fs.torch_load(osp.join(data_path, "train_mask.pt")) val_mask = fs.torch_load(osp.join(data_path, "valid_mask.pt")) test_mask = fs.torch_load(osp.join(data_path, "test_mask.pt")) data = Data( x=node_feat, edge_index=edge_index, y=label, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, ) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) if self.delete_raw: fs.rm(data_path) def __repr__(self) -> str: return (f"{self.__class__.__name__}(" f"root='{self.root}', " f"name='{self.name}', " f"augmented={self.augmented})") ================================================ FILE: torch_geometric/datasets/coauthor.py ================================================ import os.path as osp from typing import Callable, Optional from torch_geometric.data import InMemoryDataset, download_url from torch_geometric.io import read_npz class Coauthor(InMemoryDataset): r"""The Coauthor CS and Coauthor Physics networks from the `"Pitfalls of Graph Neural Network Evaluation" `_ paper. Nodes represent authors that are connected by an edge if they co-authored a paper. Given paper keywords for each author's papers, the task is to map authors to their respective field of study. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"CS"`, :obj:`"Physics"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - CS - 18,333 - 163,788 - 6,805 - 15 * - Physics - 34,493 - 495,924 - 8,415 - 5 """ url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/' def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert name.lower() in ['cs', 'physics'] self.name = 'CS' if name.lower() == 'cs' else 'Physics' super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'ms_academic_{self.name[:3].lower()}.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url + self.raw_file_names, self.raw_dir) def process(self) -> None: data = read_npz(self.raw_paths[0], to_undirected=True) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}{self.name}()' ================================================ FILE: torch_geometric/datasets/coma.py ================================================ import os.path as osp from glob import glob from typing import Callable, List, Optional import torch from torch_geometric.data import InMemoryDataset, extract_zip from torch_geometric.io import read_ply class CoMA(InMemoryDataset): r"""The CoMA 3D faces dataset from the `"Generating 3D faces using Convolutional Mesh Autoencoders" `_ paper, containing 20,466 meshes of extreme expressions captured over 12 different subjects. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 20,465 - 5,023 - 29,990 - 3 - 12 """ url = 'https://coma.is.tue.mpg.de/' categories = [ 'bareteeth', 'cheeks_in', 'eyebrow', 'high_smile', 'lips_back', 'lips_up', 'mouth_down', 'mouth_extreme', 'mouth_middle', 'mouth_open', 'mouth_side', 'mouth_up', ] def __init__( self, root: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> str: return 'COMA_data.zip' @property def processed_file_names(self) -> List[str]: return ['training.pt', 'test.pt'] def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download 'COMA_data.zip' from " f"'{self.url}' and move it to '{self.raw_dir}'") def process(self) -> None: folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*'))) if len(folders) == 0: extract_zip(self.raw_paths[0], self.raw_dir, log=False) folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*'))) train_data_list, test_data_list = [], [] for folder in folders: for i, category in enumerate(self.categories): files = sorted(glob(osp.join(folder, category, '*.ply'))) for j, f in enumerate(files): data = read_ply(f) data.y = torch.tensor([i], dtype=torch.long) if self.pre_filter is not None and\ not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) if (j % 100) < 90: train_data_list.append(data) else: test_data_list.append(data) self.save(train_data_list, self.processed_paths[0]) self.save(test_data_list, self.processed_paths[1]) ================================================ FILE: torch_geometric/datasets/cornell.py ================================================ import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import InMemoryDataset, download_url from torch_geometric.data.hypergraph_data import HyperGraphData class CornellTemporalHyperGraphDataset(InMemoryDataset): r"""A collection of temporal higher-order network datasets from the `"Simplicial Closure and higher-order link prediction" `_ paper. Each of the datasets is a timestamped sequence of simplices, where a simplex is a set of :math:`k` nodes. See the original `datasets page `_ for more details about individual datasets. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) setting (str, optional): If :obj:`"transductive"`, loads the dataset for transductive training. If :obj:`"inductive"`, loads the dataset for inductive training. (default: :obj:`"transductive"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ names = [ 'email-Eu', 'email-Enron', 'NDC-classes', 'tags-math-sx', 'email-Eu-25', 'NDC-substances', 'congress-bills', 'tags-ask-ubuntu', 'email-Enron-25', 'NDC-classes-25', 'threads-ask-ubuntu', 'contact-high-school', 'NDC-substances-25', 'congress-bills-25', 'contact-primary-school', ] url = ('https://huggingface.co/datasets/SauravMaheshkar/{}/raw/main/' 'processed/{}/{}') def __init__( self, root: str, name: str, split: str = 'train', setting: str = 'transductive', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert name in self.names assert setting in ['transductive', 'inductive'] self.name = name self.setting = setting super().__init__(root, transform, pre_transform, pre_filter, force_reload) if split == 'train': path = self.processed_paths[0] elif split == 'val': path = self.processed_paths[1] elif split == 'test': path = self.processed_paths[2] else: raise ValueError(f"Split '{split}' not found") self.load(path) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, self.setting, 'raw') @property def raw_file_names(self) -> List[str]: return ['train_df.csv', 'val_df.csv', 'test_df.csv'] @property def processed_dir(self) -> str: return osp.join(self.root, self.name, self.setting, 'processed') @property def processed_file_names(self) -> List[str]: return ['train_data.pt', 'val_data.pt', 'test_data.pt'] def download(self) -> None: for filename in self.raw_file_names: url = self.url.format(self.name, self.setting, filename) download_url(url, self.raw_dir) def process(self) -> None: import pandas as pd for raw_path, path in zip(self.raw_paths, self.processed_paths): df = pd.read_csv(raw_path) data_list = [] for i, row in df.iterrows(): edge_indices: List[List[int]] = [[], []] for node in eval(row['nodes']): # str(list) -> list: edge_indices[0].append(node) edge_indices[1].append(i) # Use `i` as hyper-edge index. x = torch.tensor([[row['timestamp']]], dtype=torch.float) edge_index = torch.tensor(edge_indices) data = HyperGraphData(x=x, edge_index=edge_index) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, path) ================================================ FILE: torch_geometric/datasets/dblp.py ================================================ import os import os.path as osp from itertools import product from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) class DBLP(InMemoryDataset): r"""A subset of the DBLP computer science bibliography website, as collected in the `"MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph Embedding" `_ paper. DBLP is a heterogeneous graph containing four types of entities - authors (4,057 nodes), papers (14,328 nodes), terms (7,723 nodes), and conferences (20 nodes). The authors are divided into four research areas (database, data mining, artificial intelligence, information retrieval). Each author is described by a bag-of-words representation of their paper keywords. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 :header-rows: 1 * - Node/Edge Type - #nodes/#edges - #features - #classes * - Author - 4,057 - 334 - 4 * - Paper - 14,328 - 4,231 - * - Term - 7,723 - 50 - * - Conference - 20 - 0 - * - Author-Paper - 196,425 - - * - Paper-Term - 85,810 - - * - Conference-Paper - 14,328 - - """ url = 'https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npy', 'labels.npy', 'node_types.npy', 'train_val_test_idx.npz' ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import scipy.sparse as sp data = HeteroData() node_types = ['author', 'paper', 'term', 'conference'] for i, node_type in enumerate(node_types[:2]): x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz')) data[node_type].x = torch.from_numpy(x.todense()).to(torch.float) x = np.load(osp.join(self.raw_dir, 'features_2.npy')) data['term'].x = torch.from_numpy(x).to(torch.float) node_type_idx = np.load(osp.join(self.raw_dir, 'node_types.npy')) node_type_idx = torch.from_numpy(node_type_idx).to(torch.long) data['conference'].num_nodes = int((node_type_idx == 3).sum()) y = np.load(osp.join(self.raw_dir, 'labels.npy')) data['author'].y = torch.from_numpy(y).to(torch.long) split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz')) for name in ['train', 'val', 'test']: idx = split[f'{name}_idx'] idx = torch.from_numpy(idx).to(torch.long) mask = torch.zeros(data['author'].num_nodes, dtype=torch.bool) mask[idx] = True data['author'][f'{name}_mask'] = mask s = {} N_a = data['author'].num_nodes N_p = data['paper'].num_nodes N_t = data['term'].num_nodes N_c = data['conference'].num_nodes s['author'] = (0, N_a) s['paper'] = (N_a, N_a + N_p) s['term'] = (N_a + N_p, N_a + N_p + N_t) s['conference'] = (N_a + N_p + N_t, N_a + N_p + N_t + N_c) A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')) for src, dst in product(node_types, node_types): A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() if A_sub.nnz > 0: row = torch.from_numpy(A_sub.row).to(torch.long) col = torch.from_numpy(A_sub.col).to(torch.long) data[src, dst].edge_index = torch.stack([row, col], dim=0) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/datasets/dbp15k.py ================================================ import os import os.path as osp from typing import Callable, Dict, List, Optional, Tuple import torch from torch import Tensor from torch_geometric.data import ( Data, InMemoryDataset, download_google_url, extract_zip, ) from torch_geometric.io import fs, read_txt_array from torch_geometric.utils import sort_edge_index class DBP15K(InMemoryDataset): r"""The DBP15K dataset from the `"Cross-lingual Entity Alignment via Joint Attribute-Preserving Embedding" `_ paper, where Chinese, Japanese and French versions of DBpedia were linked to its English version. Node features are given by pre-trained and aligned monolingual word embeddings from the `"Cross-lingual Knowledge Graph Alignment via Graph Matching Neural Network" `_ paper. Args: root (str): Root directory where the dataset should be saved. pair (str): The pair of languages (:obj:`"en_zh"`, :obj:`"en_fr"`, :obj:`"en_ja"`, :obj:`"zh_en"`, :obj:`"fr_en"`, :obj:`"ja_en"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ file_id = '1ggYlYf2_kTyi7oF9g07oTNn3VDhjl7so' def __init__( self, root: str, pair: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert pair in ['en_zh', 'en_fr', 'en_ja', 'zh_en', 'fr_en', 'ja_en'] self.pair = pair super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['en_zh', 'en_fr', 'en_ja', 'zh_en', 'fr_en', 'ja_en'] @property def processed_file_names(self) -> str: return f'{self.pair}.pt' def download(self) -> None: path = download_google_url(self.file_id, self.root, 'data.zip') extract_zip(path, self.root) os.unlink(path) fs.rm(self.raw_dir) os.rename(osp.join(self.root, 'DBP15K'), self.raw_dir) def process(self) -> None: embs = {} with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f: for line in f: info = line.strip().split(' ') if len(info) > 300: embs[info[0]] = torch.tensor([float(x) for x in info[1:]]) else: embs['**UNK**'] = torch.tensor([float(x) for x in info]) g1_path = osp.join(self.raw_dir, self.pair, 'triples_1') x1_path = osp.join(self.raw_dir, self.pair, 'id_features_1') g2_path = osp.join(self.raw_dir, self.pair, 'triples_2') x2_path = osp.join(self.raw_dir, self.pair, 'id_features_2') x1, edge_index1, rel1, assoc1 = self.process_graph( g1_path, x1_path, embs) x2, edge_index2, rel2, assoc2 = self.process_graph( g2_path, x2_path, embs) train_path = osp.join(self.raw_dir, self.pair, 'train.examples.20') train_y = self.process_y(train_path, assoc1, assoc2) test_path = osp.join(self.raw_dir, self.pair, 'test.examples.1000') test_y = self.process_y(test_path, assoc1, assoc2) data = Data(x1=x1, edge_index1=edge_index1, rel1=rel1, x2=x2, edge_index2=edge_index2, rel2=rel2, train_y=train_y, test_y=test_y) self.save([data], self.processed_paths[0]) def process_graph( self, triple_path: str, feature_path: str, embeddings: Dict[str, Tensor], ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: g1 = read_txt_array(triple_path, sep='\t', dtype=torch.long) subj, rel, obj = g1.t() x_dict = {} with open(feature_path) as f: for line in f: info = line.strip().split('\t') info = info if len(info) == 2 else info + ['**UNK**'] seq = info[1].lower().split() hs = [embeddings.get(w, embeddings['**UNK**']) for w in seq] x_dict[int(info[0])] = torch.stack(hs, dim=0) idx = torch.tensor(list(x_dict.keys())) assoc = torch.full((int(idx.max()) + 1, ), -1, dtype=torch.long) assoc[idx] = torch.arange(idx.size(0)) subj, obj = assoc[subj], assoc[obj] edge_index = torch.stack([subj, obj], dim=0) edge_index, rel = sort_edge_index(edge_index, rel) xs = list(x_dict.values()) for i in x_dict.keys(): xs[assoc[i]] = x_dict[i] x = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True) return x, edge_index, rel, assoc def process_y(self, path: str, assoc1: Tensor, assoc2: Tensor) -> Tensor: row, col, mask = read_txt_array(path, sep='\t', dtype=torch.long).t() mask = mask.to(torch.bool) return torch.stack([assoc1[row[mask]], assoc2[col[mask]]], dim=0) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.pair})' ================================================ FILE: torch_geometric/datasets/deezer_europe.py ================================================ from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class DeezerEurope(InMemoryDataset): r"""The Deezer Europe dataset introduced in the `"Characteristic Functions on Graphs: Birds of a Feather, from Statistical Descriptors to Parametric Models" `_ paper. Nodes represent European users of Deezer and edges are mutual follower relationships. It contains 28,281 nodes, 185,504 edges, 128 node features and 2 classes. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://graphmining.ai/datasets/ptg/deezer_europe.npz' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'deezer_europe.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url, self.raw_dir) def process(self) -> None: data = np.load(self.raw_paths[0], 'r', allow_pickle=True) x = torch.from_numpy(data['features']).to(torch.float) y = torch.from_numpy(data['target']).to(torch.long) edge_index = torch.from_numpy(data['edges']).to(torch.long) edge_index = edge_index.t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/dgraph.py ================================================ import os.path as osp from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, extract_zip from torch_geometric.utils import index_to_mask class DGraphFin(InMemoryDataset): r"""The DGraphFin networks from the `"DGraph: A Large-Scale Financial Dataset for Graph Anomaly Detection" `_ paper. It is a directed, unweighted dynamic graph consisting of millions of nodes and edges, representing a realistic user-to-user social network in financial industry. Node represents a Finvolution user, and an edge from one user to another means that the user regards the other user as the emergency contact person. Each edge is associated with a timestamp ranging from 1 to 821 and a type of emergency contact ranging from 0 to 11. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 3,700,550 - 4,300,999 - 17 - 2 """ url = "https://dgraph.xinye.com" def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download '{self.raw_file_names}' from " f"'{self.url}' and move it to '{self.raw_dir}'") @property def raw_file_names(self) -> str: return 'DGraphFin.zip' @property def processed_file_names(self) -> str: return 'data.pt' @property def num_classes(self) -> int: return 2 def process(self) -> None: extract_zip(self.raw_paths[0], self.raw_dir, log=False) path = osp.join(self.raw_dir, "dgraphfin.npz") with np.load(path) as loader: x = torch.from_numpy(loader['x']).to(torch.float) y = torch.from_numpy(loader['y']).to(torch.long) edge_index = torch.from_numpy(loader['edge_index']).to(torch.long) edge_type = torch.from_numpy(loader['edge_type']).to(torch.long) edge_time = torch.from_numpy(loader['edge_timestamp']).to( torch.long) train_nodes = torch.from_numpy(loader['train_mask']).to(torch.long) val_nodes = torch.from_numpy(loader['valid_mask']).to(torch.long) test_nodes = torch.from_numpy(loader['test_mask']).to(torch.long) train_mask = index_to_mask(train_nodes, size=x.size(0)) val_mask = index_to_mask(val_nodes, size=x.size(0)) test_mask = index_to_mask(test_nodes, size=x.size(0)) data = Data(x=x, edge_index=edge_index.t(), edge_type=edge_type, edge_time=edge_time, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/dynamic_faust.py ================================================ from itertools import product from typing import Callable, List, Optional import torch from torch_geometric.data import Data, InMemoryDataset class DynamicFAUST(InMemoryDataset): r"""The dynamic FAUST humans dataset from the `"Dynamic FAUST: Registering Human Bodies in Motion" `_ paper. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. subjects (list, optional): List of subjects to include in the dataset. Can include the subjects :obj:`"50002"`, :obj:`"50004"`, :obj:`"50007"`, :obj:`"50009"`, :obj:`"50020"`, :obj:`"50021"`, :obj:`"50022"`, :obj:`"50025"`, :obj:`"50026"`, :obj:`"50027"`. If set to :obj:`None`, the dataset will contain all subjects. (default: :obj:`None`) categories (list, optional): List of categories to include in the dataset. Can include the categories :obj:`"chicken_wings"`, :obj:`"hips"`, :obj:`"jiggle_on_toes"`, :obj:`"jumping_jacks"`, :obj:`"knees"`, :obj:`"light_hopping_loose"`, :obj:`"light_hopping_stiff"`, :obj:`"one_leg_jump"`, :obj:`"one_leg_loose"`, :obj:`"personal_move"`, :obj:`"punching"`, :obj:`"running_on_spot"`, :obj:`"running_on_spot_bugfix"`, :obj:`"shake_arms"`, :obj:`"shake_hips"`, :obj:`"shoulders"`. If set to :obj:`None`, the dataset will contain all categories. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'http://dfaust.is.tue.mpg.de/' subjects = [ '50002', '50004', '50007', '50009', '50020', '50021', '50022', '50025', '50026', '50027' ] categories = [ 'chicken_wings', 'hips', 'jiggle_on_toes', 'jumping_jacks', 'knees', 'light_hopping_loose', 'light_hopping_stiff', 'one_leg_jump', 'one_leg_loose', 'personal_move', 'punching', 'running_on_spot', 'running_on_spot_bugfix', 'shake_arms', 'shake_hips', 'shake_shoulders' ] def __init__( self, root: str, subjects: Optional[List[str]] = None, categories: Optional[List[str]] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: subjects = self.subjects if subjects is None else subjects subjects = [sid.lower() for sid in subjects] for sid in subjects: assert sid in self.subjects self.subjects = subjects categories = self.categories if categories is None else categories categories = [cat.lower() for cat in categories] for cat in categories: assert cat in self.categories self.categories = categories super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['registrations_m.hdf5', 'registrations_f.hdf5'] @property def processed_file_names(self) -> str: sids = '_'.join([sid[-2:] for sid in self.subjects]) cats = '_'.join([ ''.join([w[0] for w in cat.split('_')]) for cat in self.categories ]) return f'{sids}_{cats}.pt' def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download male registrations " f"'registrations_m.hdf5' and female registrations " f"'registrations_f.hdf5' from '{self.url}' and move it to " f"'{self.raw_dir}'") def process(self) -> None: import h5py fm = h5py.File(self.raw_paths[0], 'r') ff = h5py.File(self.raw_paths[1], 'r') face = torch.from_numpy(fm['faces'][()]).to(torch.long) face = face.t().contiguous() data_list = [] for (sid, cat) in product(self.subjects, self.categories): idx = f'{sid}_{cat}' if idx in fm: pos = torch.from_numpy(fm[idx][()]) elif idx in ff: pos = torch.from_numpy(ff[idx][()]) else: continue pos = pos.permute(2, 0, 1).contiguous() data_list.append(Data(pos=pos, face=face, num_nodes=pos.size(1))) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/elliptic.py ================================================ from typing import Any, Callable, List, Optional, Tuple import torch from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs class EllipticBitcoinDataset(InMemoryDataset): r"""The Elliptic Bitcoin dataset of Bitcoin transactions from the `"Anti-Money Laundering in Bitcoin: Experimenting with Graph Convolutional Networks for Financial Forensics" `_ paper. :class:`EllipticBitcoinDataset` maps Bitcoin transactions to real entities belonging to licit categories (exchanges, wallet providers, miners, licit services, etc.) versus illicit ones (scams, malware, terrorist organizations, ransomware, Ponzi schemes, etc.) There exists 203,769 node transactions and 234,355 directed edge payments flows, with two percent of nodes (4,545) labelled as illicit, and twenty-one percent of nodes (42,019) labelled as licit. The remaining transactions are unknown. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 203,769 - 234,355 - 165 - 2 """ url = 'https://data.pyg.org/datasets/elliptic' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return [ 'elliptic_txs_features.csv', 'elliptic_txs_edgelist.csv', 'elliptic_txs_classes.csv', ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for file_name in self.raw_file_names: fs.cp(f'{self.url}/{file_name}.zip', self.raw_dir, extract=True) def _process_df(self, feat_df: Any, edge_df: Any, class_df: Any) -> Tuple[Any, Any, Any]: return feat_df, edge_df, class_df def process(self) -> None: import pandas as pd feat_df = pd.read_csv(self.raw_paths[0], header=None) edge_df = pd.read_csv(self.raw_paths[1]) class_df = pd.read_csv(self.raw_paths[2]) columns = {0: 'txId', 1: 'time_step'} feat_df = feat_df.rename(columns=columns) feat_df, edge_df, class_df = self._process_df( feat_df, edge_df, class_df, ) x = torch.from_numpy(feat_df.loc[:, 2:].values).to(torch.float) # There exists 3 different classes in the dataset: # 0=licit, 1=illicit, 2=unknown mapping = {'unknown': 2, '1': 1, '2': 0} class_df['class'] = class_df['class'].map(mapping) y = torch.from_numpy(class_df['class'].values) mapping = {idx: i for i, idx in enumerate(feat_df['txId'].values)} edge_df['txId1'] = edge_df['txId1'].map(mapping) edge_df['txId2'] = edge_df['txId2'].map(mapping) edge_index = torch.from_numpy(edge_df.values).t().contiguous() # Timestamp based split: # train_mask: 1 - 34 time_step, test_mask: 35-49 time_step time_step = torch.from_numpy(feat_df['time_step'].values) train_mask = (time_step < 35) & (y != 2) test_mask = (time_step >= 35) & (y != 2) data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) @property def num_classes(self) -> int: return 2 ================================================ FILE: torch_geometric/datasets/elliptic_temporal.py ================================================ from typing import Any, Callable, Optional, Tuple from torch_geometric.datasets import EllipticBitcoinDataset class EllipticBitcoinTemporalDataset(EllipticBitcoinDataset): r"""The time-step aware Elliptic Bitcoin dataset of Bitcoin transactions from the `"Anti-Money Laundering in Bitcoin: Experimenting with Graph Convolutional Networks for Financial Forensics" `_ paper. :class:`EllipticBitcoinTemporalDataset` maps Bitcoin transactions to real entities belonging to licit categories (exchanges, wallet providers, miners, licit services, etc.) versus illicit ones (scams, malware, terrorist organizations, ransomware, Ponzi schemes, etc.) There exists 203,769 node transactions and 234,355 directed edge payments flows, with two percent of nodes (4,545) labelled as illicit, and twenty-one percent of nodes (42,019) labelled as licit. The remaining transactions are unknown. .. note:: In contrast to :class:`EllipticBitcoinDataset`, this dataset returns Bitcoin transactions only for a given timestamp :obj:`t`. Args: root (str): Root directory where the dataset should be saved. t (int): The Timestep for which nodes should be selected (from :obj:`1` to :obj:`49`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 203,769 - 234,355 - 165 - 2 """ def __init__( self, root: str, t: int, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ): if t < 1 or t > 49: raise ValueError("'t' needs to be between 1 and 49") self.t = t super().__init__(root, transform, pre_transform, force_reload=force_reload) @property def processed_file_names(self) -> str: return f'data_t_{self.t}.pt' def _process_df(self, feat_df: Any, edge_df: Any, class_df: Any) -> Tuple[Any, Any, Any]: feat_df = feat_df[feat_df['time_step'] == self.t] mask = edge_df['txId1'].isin(feat_df['txId'].values) edge_df = edge_df[mask] class_df = class_df.merge(feat_df[['txId']], how='right', left_on='txId', right_on='txId') return feat_df, edge_df, class_df ================================================ FILE: torch_geometric/datasets/email_eu_core.py ================================================ import os from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_gz, ) class EmailEUCore(InMemoryDataset): r"""An e-mail communication network of a large European research institution, taken from the `"Local Higher-order Graph Clustering" `_ paper. Nodes indicate members of the institution. An edge between a pair of members indicates that they exchanged at least one email. Node labels indicate membership to one of the 42 departments. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ urls = [ 'https://snap.stanford.edu/data/email-Eu-core.txt.gz', 'https://snap.stanford.edu/data/email-Eu-core-department-labels.txt.gz' ] def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['email-Eu-core.txt', 'email-Eu-core-department-labels.txt'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for url in self.urls: path = download_url(url, self.raw_dir) extract_gz(path, self.raw_dir) os.unlink(path) def process(self) -> None: import pandas as pd edge_index = pd.read_csv(self.raw_paths[0], sep=' ', header=None) edge_index = torch.from_numpy(edge_index.values).t().contiguous() y = pd.read_csv(self.raw_paths[1], sep=' ', header=None, usecols=[1]) y = torch.from_numpy(y.values).view(-1) data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0)) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/entities.py ================================================ import logging import os import os.path as osp from collections import Counter from typing import Any, Callable, List, Optional import torch from torch_geometric.data import ( Data, HeteroData, InMemoryDataset, download_url, extract_tar, ) from torch_geometric.utils import index_sort class Entities(InMemoryDataset): r"""The relational entities networks :obj:`"AIFB"`, :obj:`"MUTAG"`, :obj:`"BGS"` and :obj:`"AM"` from the `"Modeling Relational Data with Graph Convolutional Networks" `_ paper. Training and test splits are given by node indices. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"AIFB"`, :obj:`"MUTAG"`, :obj:`"BGS"`, :obj:`"AM"`). hetero (bool, optional): If set to :obj:`True`, will save the dataset as a :class:`~torch_geometric.data.HeteroData` object. (default: :obj:`False`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - AIFB - 8,285 - 58,086 - 0 - 4 * - AM - 1,666,764 - 11,976,642 - 0 - 11 * - MUTAG - 23,644 - 148,454 - 0 - 2 * - BGS - 333,845 - 1,832,398 - 0 - 2 """ url = 'https://data.dgl.ai/dataset/{}.tgz' def __init__( self, root: str, name: str, hetero: bool = False, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() self.hetero = hetero assert self.name in ['aifb', 'am', 'mutag', 'bgs'] super().__init__(root, transform, pre_transform, force_reload=force_reload) if hetero: self.load(self.processed_paths[0], data_cls=HeteroData) else: self.load(self.processed_paths[0], data_cls=Data) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def num_relations(self) -> int: return int(self._data.edge_type.max()) + 1 # type: ignore @property def num_classes(self) -> int: return int(self._data.train_y.max()) + 1 # type: ignore @property def raw_file_names(self) -> List[str]: return [ f'{self.name}_stripped.nt.gz', 'completeDataset.tsv', 'trainingSet.tsv', 'testSet.tsv', ] @property def processed_file_names(self) -> str: return 'hetero_data.pt' if self.hetero else 'data.pt' def download(self) -> None: path = download_url(self.url.format(self.name), self.root) extract_tar(path, self.raw_dir) os.unlink(path) def process(self) -> None: import gzip import pandas as pd import rdflib as rdf graph_file, task_file, train_file, test_file = self.raw_paths with hide_stdout(): g = rdf.Graph() with gzip.open(graph_file, 'rb') as f: g.parse(file=f, format='nt') # type: ignore freq = Counter(g.predicates()) relations = sorted(set(g.predicates()), key=lambda p: -freq.get(p, 0)) subjects = set(g.subjects()) objects = set(g.objects()) nodes = list(subjects.union(objects)) N = len(nodes) R = 2 * len(relations) relations_dict = {rel: i for i, rel in enumerate(relations)} nodes_dict = {str(node): i for i, node in enumerate(nodes)} edges = [] for s, p, o in g.triples((None, None, None)): src, dst = nodes_dict[str(s)], nodes_dict[str(o)] rel = relations_dict[p] edges.append([src, dst, 2 * rel]) edges.append([dst, src, 2 * rel + 1]) edge = torch.tensor(edges, dtype=torch.long).t().contiguous() _, perm = index_sort(N * R * edge[0] + R * edge[1] + edge[2]) edge = edge[:, perm] edge_index, edge_type = edge[:2], edge[2] if self.name == 'am': label_header = 'label_cateogory' nodes_header = 'proxy' elif self.name == 'aifb': label_header = 'label_affiliation' nodes_header = 'person' elif self.name == 'mutag': label_header = 'label_mutagenic' nodes_header = 'bond' elif self.name == 'bgs': label_header = 'label_lithogenesis' nodes_header = 'rock' labels_df = pd.read_csv(task_file, sep='\t') labels_set = set(labels_df[label_header].values.tolist()) labels_dict = {lab: i for i, lab in enumerate(list(labels_set))} train_labels_df = pd.read_csv(train_file, sep='\t') train_indices, train_labels = [], [] for nod, lab in zip(train_labels_df[nodes_header].values, train_labels_df[label_header].values): train_indices.append(nodes_dict[nod]) train_labels.append(labels_dict[lab]) train_idx = torch.tensor(train_indices, dtype=torch.long) train_y = torch.tensor(train_labels, dtype=torch.long) test_labels_df = pd.read_csv(test_file, sep='\t') test_indices, test_labels = [], [] for nod, lab in zip(test_labels_df[nodes_header].values, test_labels_df[label_header].values): test_indices.append(nodes_dict[nod]) test_labels.append(labels_dict[lab]) test_idx = torch.tensor(test_indices, dtype=torch.long) test_y = torch.tensor(test_labels, dtype=torch.long) data = Data(edge_index=edge_index, edge_type=edge_type, train_idx=train_idx, train_y=train_y, test_idx=test_idx, test_y=test_y, num_nodes=N) if self.hetero: data = data.to_heterogeneous(node_type_names=['v']) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name.upper()}{self.__class__.__name__}()' class hide_stdout: def __enter__(self) -> None: self.level = logging.getLogger().level logging.getLogger().setLevel(logging.ERROR) def __exit__(self, *args: Any) -> None: logging.getLogger().setLevel(self.level) ================================================ FILE: torch_geometric/datasets/explainer_dataset.py ================================================ from typing import Any, Callable, Dict, Optional, Union import torch from torch import Tensor from torch_geometric.data import InMemoryDataset from torch_geometric.datasets.graph_generator import GraphGenerator from torch_geometric.datasets.motif_generator import MotifGenerator from torch_geometric.explain import Explanation class ExplainerDataset(InMemoryDataset): r"""Generates a synthetic dataset for evaluating explainabilty algorithms, as described in the `"GNNExplainer: Generating Explanations for Graph Neural Networks" `__ paper. The :class:`~torch_geometric.datasets.ExplainerDataset` creates synthetic graphs coming from a :class:`~torch_geometric.datasets.graph_generator.GraphGenerator`, and randomly attaches :obj:`num_motifs` many motifs to it coming from a :class:`~torch_geometric.datasets.graph_generator.MotifGenerator`. Ground-truth node-level and edge-level explainabilty masks are given based on whether nodes and edges are part of a certain motif or not. For example, to generate a random Barabasi-Albert (BA) graph with 300 nodes, in which we want to randomly attach 80 :obj:`"house"` motifs, write: .. code-block:: python from torch_geometric.datasets import ExplainerDataset from torch_geometric.datasets.graph_generator import BAGraph dataset = ExplainerDataset( graph_generator=BAGraph(num_nodes=300, num_edges=5), motif_generator='house', num_motifs=80, ) .. note:: For an example of using :class:`ExplainerDataset`, see `examples/explain/gnn_explainer_ba_shapes.py `_. Args: graph_generator (GraphGenerator or str): The graph generator to be used, *e.g.*, :class:`torch.geometric.datasets.graph_generator.BAGraph` (or any string that automatically resolves to it). motif_generator (MotifGenerator): The motif generator to be used, *e.g.*, :class:`torch_geometric.datasets.motif_generator.HouseMotif` (or any string that automatically resolves to it). num_motifs (int): The number of motifs to attach to the graph. num_graphs (int, optional): The number of graphs to generate. (default: :obj:`1`) graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to the respective graph generator module in case it gets automatically resolved. (default: :obj:`None`) motif_generator_kwargs (Dict[str, Any], optional): Arguments passed to the respective motif generator module in case it gets automatically resolved. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) """ def __init__( self, graph_generator: Union[GraphGenerator, str], motif_generator: Union[MotifGenerator, str], num_motifs: int, num_graphs: int = 1, graph_generator_kwargs: Optional[Dict[str, Any]] = None, motif_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None, ): super().__init__(root=None, transform=transform) if num_motifs <= 0: raise ValueError(f"At least one motif needs to be attached to the " f"graph (got {num_motifs})") self.graph_generator = GraphGenerator.resolve( graph_generator, **(graph_generator_kwargs or {}), ) self.motif_generator = MotifGenerator.resolve( motif_generator, **(motif_generator_kwargs or {}), ) self.num_motifs = num_motifs # TODO (matthias) support on-the-fly graph generation. data_list = [self.get_graph() for _ in range(num_graphs)] self.data, self.slices = self.collate(data_list) def get_graph(self) -> Explanation: data = self.graph_generator() assert data.num_nodes is not None assert data.edge_index is not None edge_indices = [data.edge_index] num_nodes = data.num_nodes node_masks = [torch.zeros(data.num_nodes)] edge_masks = [torch.zeros(data.num_edges)] ys = [torch.zeros(num_nodes, dtype=torch.long)] connecting_nodes = torch.randperm(num_nodes)[:self.num_motifs] for i in connecting_nodes.tolist(): motif = self.motif_generator() assert motif.num_nodes is not None assert motif.edge_index is not None # Add motif to the graph. edge_indices.append(motif.edge_index + num_nodes) node_masks.append(torch.ones(motif.num_nodes)) edge_masks.append(torch.ones(motif.num_edges)) # Add random motif connection to the graph. j = int(torch.randint(0, motif.num_nodes, (1, ))) + num_nodes edge_indices.append(torch.tensor([[i, j], [j, i]])) edge_masks.append(torch.zeros(2)) if isinstance(motif.y, Tensor): ys.append(motif.y + 1 if motif.y.min() == 0 else motif.y) else: ys.append(torch.ones(motif.num_nodes, dtype=torch.long)) num_nodes += motif.num_nodes return Explanation( edge_index=torch.cat(edge_indices, dim=1), y=torch.cat(ys, dim=0), edge_mask=torch.cat(edge_masks, dim=0), node_mask=torch.cat(node_masks, dim=0), ) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'graph_generator={self.graph_generator}, ' f'motif_generator={self.motif_generator}, ' f'num_motifs={self.num_motifs})') ================================================ FILE: torch_geometric/datasets/facebook.py ================================================ from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class FacebookPagePage(InMemoryDataset): r"""The Facebook Page-Page network dataset introduced in the `"Multi-scale Attributed Node Embedding" `_ paper. Nodes represent verified pages on Facebook and edges are mutual likes. It contains 22,470 nodes, 342,004 edges, 128 node features and 4 classes. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://graphmining.ai/datasets/ptg/facebook.npz' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'facebook.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url, self.raw_dir) def process(self) -> None: data = np.load(self.raw_paths[0], 'r', allow_pickle=True) x = torch.from_numpy(data['features']).to(torch.float) y = torch.from_numpy(data['target']).to(torch.long) edge_index = torch.from_numpy(data['edges']).to(torch.long) edge_index = edge_index.t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/fake.py ================================================ import random from collections import defaultdict from itertools import product from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData, InMemoryDataset from torch_geometric.utils import coalesce, remove_self_loops, to_undirected class FakeDataset(InMemoryDataset): r"""A fake dataset that returns randomly generated :class:`~torch_geometric.data.Data` objects. Args: num_graphs (int, optional): The number of graphs. (default: :obj:`1`) avg_num_nodes (int, optional): The average number of nodes in a graph. (default: :obj:`1000`) avg_degree (float, optional): The average degree per node. (default: :obj:`10.0`) num_channels (int, optional): The number of node features. (default: :obj:`64`) edge_dim (int, optional): The number of edge features. (default: :obj:`0`) num_classes (int, optional): The number of classes in the dataset. (default: :obj:`10`) task (str, optional): Whether to return node-level or graph-level labels (:obj:`"node"`, :obj:`"graph"`, :obj:`"auto"`). If set to :obj:`"auto"`, will return graph-level labels if :obj:`num_graphs > 1`, and node-level labels other-wise. (default: :obj:`"auto"`) is_undirected (bool, optional): Whether the graphs to generate are undirected. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) **kwargs (optional): Additional attributes and their shapes *e.g.* :obj:`global_features=5`. """ def __init__( self, num_graphs: int = 1, avg_num_nodes: int = 1000, avg_degree: float = 10.0, num_channels: int = 64, edge_dim: int = 0, num_classes: int = 10, task: str = 'auto', is_undirected: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, **kwargs: Union[int, Tuple[int, ...]], ) -> None: super().__init__(None, transform) if task == 'auto': task = 'graph' if num_graphs > 1 else 'node' assert task in ['node', 'graph'] self.avg_num_nodes = max(avg_num_nodes, int(avg_degree)) self.avg_degree = max(avg_degree, 1) self.num_channels = num_channels self.edge_dim = edge_dim self._num_classes = num_classes self.task = task self.is_undirected = is_undirected self.kwargs = kwargs data_list = [self.generate_data() for _ in range(max(num_graphs, 1))] self.data, self.slices = self.collate(data_list) def generate_data(self) -> Data: num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree) data = Data() if self._num_classes > 0 and self.task == 'node': data.y = torch.randint(self._num_classes, (num_nodes, )) elif self._num_classes > 0 and self.task == 'graph': data.y = torch.tensor([random.randint(0, self._num_classes - 1)]) data.edge_index = get_edge_index(num_nodes, num_nodes, self.avg_degree, self.is_undirected, remove_loops=True) if self.num_channels > 0: x = torch.randn(num_nodes, self.num_channels) if self._num_classes > 0 and self.task == 'node': assert isinstance(data.y, Tensor) x = x + data.y.unsqueeze(1) elif self._num_classes > 0 and self.task == 'graph': assert isinstance(data.y, Tensor) x = x + data.y data.x = x else: data.num_nodes = num_nodes if self.edge_dim > 1: data.edge_attr = torch.rand(data.num_edges, self.edge_dim) elif self.edge_dim == 1: data.edge_weight = torch.rand(data.num_edges) for feature_name, feature_shape in self.kwargs.items(): setattr(data, feature_name, torch.randn(feature_shape)) return data class FakeHeteroDataset(InMemoryDataset): r"""A fake dataset that returns randomly generated :class:`~torch_geometric.data.HeteroData` objects. Args: num_graphs (int, optional): The number of graphs. (default: :obj:`1`) num_node_types (int, optional): The number of node types. (default: :obj:`3`) num_edge_types (int, optional): The number of edge types. (default: :obj:`6`) avg_num_nodes (int, optional): The average number of nodes in a graph. (default: :obj:`1000`) avg_degree (float, optional): The average degree per node. (default: :obj:`10.0`) avg_num_channels (int, optional): The average number of node features. (default: :obj:`64`) edge_dim (int, optional): The number of edge features. (default: :obj:`0`) num_classes (int, optional): The number of classes in the dataset. (default: :obj:`10`) task (str, optional): Whether to return node-level or graph-level labels (:obj:`"node"`, :obj:`"graph"`, :obj:`"auto"`). If set to :obj:`"auto"`, will return graph-level labels if :obj:`num_graphs > 1`, and node-level labels other-wise. (default: :obj:`"auto"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) **kwargs (optional): Additional attributes and their shapes *e.g.* :obj:`global_features=5`. """ def __init__( self, num_graphs: int = 1, num_node_types: int = 3, num_edge_types: int = 6, avg_num_nodes: int = 1000, avg_degree: float = 10.0, avg_num_channels: int = 64, edge_dim: int = 0, num_classes: int = 10, task: str = "auto", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, **kwargs: Union[int, Tuple[int, ...]], ) -> None: super().__init__(None, transform) if task == 'auto': task = 'graph' if num_graphs > 1 else 'node' assert task in ['node', 'graph'] self.node_types = [f'v{i}' for i in range(max(num_node_types, 1))] edge_types: List[Tuple[str, str]] = [] edge_type_product = list(product(self.node_types, self.node_types)) while len(edge_types) < max(num_edge_types, 1): edge_types.extend(edge_type_product) random.shuffle(edge_types) self.edge_types: List[Tuple[str, str, str]] = [] count: Dict[Tuple[str, str], int] = defaultdict(int) for edge_type in edge_types[:max(num_edge_types, 1)]: rel = f'e{count[edge_type]}' count[edge_type] += 1 self.edge_types.append((edge_type[0], rel, edge_type[1])) self.avg_num_nodes = max(avg_num_nodes, int(avg_degree)) self.avg_degree = max(avg_degree, 1) self.num_channels = [ get_num_channels(avg_num_channels) for _ in self.node_types ] self.edge_dim = edge_dim self._num_classes = num_classes self.task = task self.kwargs = kwargs data_list = [self.generate_data() for _ in range(max(num_graphs, 1))] self.data, self.slices = self.collate(data_list) def generate_data(self) -> HeteroData: data = HeteroData() iterator = zip(self.node_types, self.num_channels) for i, (node_type, num_channels) in enumerate(iterator): num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree) store = data[node_type] if num_channels > 0: store.x = torch.randn(num_nodes, num_channels) else: store.num_nodes = num_nodes if self._num_classes > 0 and self.task == 'node' and i == 0: store.y = torch.randint(self._num_classes, (num_nodes, )) for (src, rel, dst) in self.edge_types: store = data[(src, rel, dst)] store.edge_index = get_edge_index( data[src].num_nodes, data[dst].num_nodes, self.avg_degree, is_undirected=False, remove_loops=False, ) if self.edge_dim > 1: store.edge_attr = torch.rand(store.num_edges, self.edge_dim) elif self.edge_dim == 1: store.edge_weight = torch.rand(store.num_edges) if self._num_classes > 0 and self.task == 'graph': data.y = torch.tensor([random.randint(0, self._num_classes - 1)]) for feature_name, feature_shape in self.kwargs.items(): setattr(data, feature_name, torch.randn(feature_shape)) return data ############################################################################### def get_num_nodes(avg_num_nodes: int, avg_degree: float) -> int: min_num_nodes = max(3 * avg_num_nodes // 4, int(avg_degree)) max_num_nodes = 5 * avg_num_nodes // 4 return random.randint(min_num_nodes, max_num_nodes) def get_num_channels(num_channels: int) -> int: min_num_channels = 3 * num_channels // 4 max_num_channels = 5 * num_channels // 4 return random.randint(min_num_channels, max_num_channels) def get_edge_index( num_src_nodes: int, num_dst_nodes: int, avg_degree: float, is_undirected: bool = False, remove_loops: bool = False, ) -> Tensor: num_edges = int(num_src_nodes * avg_degree) row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.int64) col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.int64) edge_index = torch.stack([row, col], dim=0) if remove_loops: edge_index, _ = remove_self_loops(edge_index) num_nodes = max(num_src_nodes, num_dst_nodes) if is_undirected: edge_index = to_undirected(edge_index, num_nodes=num_nodes) else: edge_index = coalesce(edge_index, num_nodes=num_nodes) return edge_index ================================================ FILE: torch_geometric/datasets/faust.py ================================================ import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import InMemoryDataset, extract_zip from torch_geometric.io import fs, read_ply class FAUST(InMemoryDataset): r"""The FAUST humans dataset from the `"FAUST: Dataset and Evaluation for 3D Mesh Registration" `_ paper, containing 100 watertight meshes representing 10 different poses for 10 different subjects. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 100 - 6,890 - 41,328 - 3 - 10 """ url = 'http://faust.is.tue.mpg.de/' def __init__( self, root: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> str: return 'MPI-FAUST.zip' @property def processed_file_names(self) -> List[str]: return ['training.pt', 'test.pt'] def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download '{self.raw_file_names}' from " f"'{self.url}' and move it to '{self.raw_dir}'") def process(self) -> None: extract_zip(self.raw_paths[0], self.raw_dir, log=False) path = osp.join(self.raw_dir, 'MPI-FAUST', 'training', 'registrations') path = osp.join(path, 'tr_reg_{0:03d}.ply') data_list = [] for i in range(100): data = read_ply(path.format(i)) data.y = torch.tensor([i % 10], dtype=torch.long) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list[:80], self.processed_paths[0]) self.save(data_list[80:], self.processed_paths[1]) fs.rm(osp.join(self.raw_dir, 'MPI-FAUST')) ================================================ FILE: torch_geometric/datasets/flickr.py ================================================ import json import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_google_url class Flickr(InMemoryDataset): r"""The Flickr dataset from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" `_ paper, containing descriptions and common properties of images. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 89,250 - 899,756 - 500 - 7 """ adj_full_id = '1crmsTbd1-2sEXsGwa2IKnIB7Zd3TmUsy' feats_id = '1join-XdvX3anJU_MLVtick7MgeAQiWIZ' class_map_id = '1uxIkbtg5drHTsKt-PAsZZ4_yJmgFmle9' role_id = '1htXCtuktuCW8TR8KiKfrFDAxUgekQoV7' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz') download_google_url(self.feats_id, self.raw_dir, 'feats.npy') download_google_url(self.class_map_id, self.raw_dir, 'class_map.json') download_google_url(self.role_id, self.raw_dir, 'role.json') def process(self) -> None: import scipy.sparse as sp f = np.load(osp.join(self.raw_dir, 'adj_full.npz')) adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape']) adj = adj.tocoo() row = torch.from_numpy(adj.row).to(torch.long) col = torch.from_numpy(adj.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) x = np.load(osp.join(self.raw_dir, 'feats.npy')) x = torch.from_numpy(x).to(torch.float) ys = [-1] * x.size(0) with open(osp.join(self.raw_dir, 'class_map.json')) as f: class_map = json.load(f) for key, item in class_map.items(): ys[int(key)] = item y = torch.tensor(ys) with open(osp.join(self.raw_dir, 'role.json')) as f: role = json.load(f) train_mask = torch.zeros(x.size(0), dtype=torch.bool) train_mask[torch.tensor(role['tr'])] = True val_mask = torch.zeros(x.size(0), dtype=torch.bool) val_mask[torch.tensor(role['va'])] = True test_mask = torch.zeros(x.size(0), dtype=torch.bool) test_mask[torch.tensor(role['te'])] = True data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/freebase.py ================================================ from typing import Callable, Dict, List, Optional import torch from torch_geometric.data import Data, InMemoryDataset, download_url class FB15k_237(InMemoryDataset): r"""The FB15K237 dataset from the `"Translating Embeddings for Modeling Multi-Relational Data" `_ paper, containing 14,541 entities, 237 relations and 310,116 fact triples. .. note:: The original :class:`FB15k` dataset suffers from major test leakage through inverse relations, where a large number of test triples could be obtained by inverting triples in the training set. In order to create a dataset without this characteristic, the :class:`~torch_geometric.datasets.FB15k_237` describes a subset of :class:`FB15k` where inverse relations are removed. Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://raw.githubusercontent.com/villmow/' 'datasets_knowledge_embedding/master/FB15k-237') def __init__( self, root: str, split: str = "train", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) if split not in {'train', 'val', 'test'}: raise ValueError(f"Invalid 'split' argument (got {split})") path = self.processed_paths[['train', 'val', 'test'].index(split)] self.load(path) @property def raw_file_names(self) -> List[str]: return ['train.txt', 'valid.txt', 'test.txt'] @property def processed_file_names(self) -> List[str]: return ['train_data.pt', 'val_data.pt', 'test_data.pt'] def download(self) -> None: for filename in self.raw_file_names: download_url(f'{self.url}/{filename}', self.raw_dir) def process(self) -> None: data_list: List[Data] = [] node_dict: Dict[str, int] = {} rel_dict: Dict[str, int] = {} for path in self.raw_paths: with open(path) as f: lines = [x.split('\t') for x in f.read().split('\n')[:-1]] edge_index = torch.empty((2, len(lines)), dtype=torch.long) edge_type = torch.empty(len(lines), dtype=torch.long) for i, (src, rel, dst) in enumerate(lines): if src not in node_dict: node_dict[src] = len(node_dict) if dst not in node_dict: node_dict[dst] = len(node_dict) if rel not in rel_dict: rel_dict[rel] = len(rel_dict) edge_index[0, i] = node_dict[src] edge_index[1, i] = node_dict[dst] edge_type[i] = rel_dict[rel] data = Data(edge_index=edge_index, edge_type=edge_type) data_list.append(data) for data, path in zip(data_list, self.processed_paths): data.num_nodes = len(node_dict) self.save([data], path) ================================================ FILE: torch_geometric/datasets/gdelt.py ================================================ from typing import Callable, List, Optional import torch from torch import Tensor from torch_geometric.data import download_url from torch_geometric.datasets.icews import EventDataset from torch_geometric.io import read_txt_array class GDELT(EventDataset): r"""The Global Database of Events, Language, and Tone (GDELT) dataset used in the, *e.g.*, `"Recurrent Event Network for Reasoning over Temporal Knowledge Graphs" `_ paper, consisting of events collected from 1/1/2018 to 1/31/2018 (15 minutes time granularity). Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://github.com/INK-USC/RE-Net/raw/master/data/GDELT' splits = [0, 1734399, 1973164, 2278405] # Train/Val/Test splits. def __init__( self, root: str, split: str = "train", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert split in ['train', 'val', 'test'] super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) idx = self.processed_file_names.index(f'{split}.pt') self.load(self.processed_paths[idx]) @property def num_nodes(self) -> int: return 7691 @property def num_rels(self) -> int: return 240 @property def raw_file_names(self) -> List[str]: return [f'{name}.txt' for name in ['train', 'valid', 'test']] @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: for filename in self.raw_file_names: download_url(f'{self.url}/{filename}', self.raw_dir) def process_events(self) -> Tensor: events = [] for path in self.raw_paths: data = read_txt_array(path, sep='\t', end=4, dtype=torch.long) data[:, 3] = data[:, 3] // 15 events += [data] return torch.cat(events, dim=0) def process(self) -> None: s = self.splits data_list = self._process_data_list() self.save(data_list[s[0]:s[1]], self.processed_paths[0]) self.save(data_list[s[1]:s[2]], self.processed_paths[1]) self.save(data_list[s[2]:s[3]], self.processed_paths[2]) ================================================ FILE: torch_geometric/datasets/gdelt_lite.py ================================================ import os from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class GDELTLite(InMemoryDataset): r"""The (reduced) version of the Global Database of Events, Language, and Tone (GDELT) dataset used in the `"Do We Really Need Complicated Model Architectures for Temporal Networks?" `_ paper, consisting of events collected from 2016 to 2020. Each node (actor) holds a 413-dimensional multi-hot feature vector that represents CAMEO codes attached to the corresponding actor to server. Each edge (event) holds a timestamp and a 186-dimensional multi-hot vector representing CAMEO codes attached to the corresponding event to server. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 8,831 - 1,912,909 - 413 - """ url = 'https://data.pyg.org/datasets/gdelt_lite.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['node_features.pt', 'edges.csv', 'edge_features.pt'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: import pandas as pd x = fs.torch_load(self.raw_paths[0]) df = pd.read_csv(self.raw_paths[1]) edge_attr = fs.torch_load(self.raw_paths[2]) row = torch.from_numpy(df['src'].values) col = torch.from_numpy(df['dst'].values) edge_index = torch.stack([row, col], dim=0) time = torch.from_numpy(df['time'].values).to(torch.long) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, time=time) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/ged_dataset.py ================================================ import glob import os import os.path as osp import pickle from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_google_url, extract_tar, extract_zip, ) from torch_geometric.io import fs from torch_geometric.utils import one_hot, to_undirected class GEDDataset(InMemoryDataset): r"""The GED datasets from the `"Graph Edit Distance Computation via Graph Neural Networks" `_ paper. GEDs can be accessed via the global attributes :obj:`ged` and :obj:`norm_ged` for all train/train graph pairs and all train/test graph pairs: .. code-block:: python dataset = GEDDataset(root, name="LINUX") data1, data2 = dataset[0], dataset[1] ged = dataset.ged[data1.i, data2.i] # GED between `data1` and `data2`. Note that GEDs are not available if both graphs are from the test set. For evaluation, it is recommended to pair up each graph from the test set with each graph in the training set. .. note:: :obj:`ALKANE` is missing GEDs for train/test graph pairs since they are not provided in the `official datasets `_. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (one of :obj:`"AIDS700nef"`, :obj:`"LINUX"`, :obj:`"ALKANE"`, :obj:`"IMDBMulti"`). train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - AIDS700nef - 700 - ~8.9 - ~17.6 - 29 - 0 * - LINUX - 1,000 - ~7.6 - ~13.9 - 0 - 0 * - ALKANE - 150 - ~8.9 - ~15.8 - 0 - 0 * - IMDBMulti - 1,500 - ~13.0 - ~131.9 - 0 - 0 """ datasets = { 'AIDS700nef': { 'id': '10czBPJDEzEDI2tq7Z7mkBjLhj55F-a2z', 'extract': extract_zip, 'pickle': '1OpV4bCHjBkdpqI6H5Mg0-BqlA2ee2eBW', }, 'LINUX': { 'id': '1nw0RRVgyLpit4V4XFQyDy0pI6wUEXSOI', 'extract': extract_tar, 'pickle': '14FDm3NSnrBvB7eNpLeGy5Bz6FjuCSF5v', }, 'ALKANE': { 'id': '1-LmxaWW3KulLh00YqscVEflbqr0g4cXt', 'extract': extract_tar, 'pickle': '15BpvMuHx77-yUGYgM27_sQett02HQNYu', }, 'IMDBMulti': { 'id': '12QxZ7EhYA7pJiF4cO-HuE8szhSOWcfST', 'extract': extract_zip, 'pickle': '1wy9VbZvZodkixxVIOuRllC-Lp-0zdoYZ', }, } # List of atoms contained in the AIDS700nef dataset: types = [ 'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F', 'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb', 'Sb', 'Se', 'Ni', 'Te' ] def __init__( self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name assert self.name in self.datasets.keys() super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) path = osp.join(self.processed_dir, f'{self.name}_ged.pt') self.ged = fs.torch_load(path) path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt') self.norm_ged = fs.torch_load(path) @property def raw_file_names(self) -> List[str]: # Returns, e.g., ['LINUX/train', 'LINUX/test'] return [osp.join(self.name, s) for s in ['train', 'test']] @property def processed_file_names(self) -> List[str]: # Returns, e.g., ['LINUX_training.pt', 'LINUX_test.pt'] return [f'{self.name}_{s}.pt' for s in ['training', 'test']] def download(self) -> None: # Downloads the .tar/.zip file of the graphs and extracts them: id = self.datasets[self.name]['id'] assert isinstance(id, str) path = download_google_url(id, self.raw_dir, 'data') extract_fn = self.datasets[self.name]['extract'] assert callable(extract_fn) extract_fn(path, self.raw_dir) os.unlink(path) # Downloads the pickle file containing pre-computed GEDs: id = self.datasets[self.name]['pickle'] assert isinstance(id, str) path = download_google_url(id, self.raw_dir, 'ged.pickle') def process(self) -> None: import networkx as nx ids, Ns = [], [] # Iterating over paths for raw and processed data (train + test): for r_path, p_path in zip(self.raw_paths, self.processed_paths): # Find the paths of all raw graphs: names = glob.glob(osp.join(r_path, '*.gexf')) # Get sorted graph IDs given filename: 123.gexf -> 123 ids.append(sorted([int(osp.basename(i)[:-5]) for i in names])) data_list = [] # Convert graphs in .gexf format to a NetworkX Graph: for i, idx in enumerate(ids[-1]): i = i if len(ids) == 1 else i + len(ids[0]) # Reading the raw `*.gexf` graph: G = nx.read_gexf(osp.join(r_path, f'{idx}.gexf')) # Mapping of nodes in `G` to a contiguous number: mapping = {name: j for j, name in enumerate(G.nodes())} G = nx.relabel_nodes(G, mapping) Ns.append(G.number_of_nodes()) edge_index = torch.tensor(list(G.edges)).t().contiguous() if edge_index.numel() == 0: edge_index = torch.empty((2, 0), dtype=torch.long) edge_index = to_undirected(edge_index, num_nodes=Ns[-1]) data = Data(edge_index=edge_index, i=i) data.num_nodes = Ns[-1] # Create a one-hot encoded feature matrix denoting the atom # type (for the `AIDS700nef` dataset): if self.name == 'AIDS700nef': assert data.num_nodes is not None x = torch.zeros(data.num_nodes, dtype=torch.long) for node, info in G.nodes(data=True): x[int(node)] = self.types.index(info['type']) data.x = one_hot(x, num_classes=len(self.types)) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, p_path) assoc = {idx: i for i, idx in enumerate(ids[0])} assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])}) # Extracting ground-truth GEDs from the GED pickle file path = osp.join(self.raw_dir, self.name, 'ged.pickle') # Initialize GEDs as float('inf'): mat = torch.full((len(assoc), len(assoc)), float('inf')) with open(path, 'rb') as f: obj = pickle.load(f) xs, ys, gs = [], [], [] for (_x, _y), g in obj.items(): xs += [assoc[_x]] ys += [assoc[_y]] gs += [g] # The pickle file does not contain GEDs for test graph pairs, i.e. # GEDs for (test_graph, test_graph) pairs are still float('inf'): x, y = torch.tensor(xs), torch.tensor(ys) ged = torch.tensor(gs, dtype=torch.float) mat[x, y], mat[y, x] = ged, ged path = osp.join(self.processed_dir, f'{self.name}_ged.pt') torch.save(mat, path) # Calculate the normalized GEDs: N = torch.tensor(Ns, dtype=torch.float) norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1))) path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt') torch.save(norm_mat, path) def __repr__(self) -> str: return f'{self.name}({len(self)})' ================================================ FILE: torch_geometric/datasets/gemsec.py ================================================ import os.path as osp from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class GemsecDeezer(InMemoryDataset): r"""The Deezer User Network datasets introduced in the `"GEMSEC: Graph Embedding with Self Clustering" `_ paper. Nodes represent Deezer user and edges are mutual friendships. The task is multi-label multi-class node classification about the genres liked by the users. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"HU"`, :obj:`"HR"`, :obj:`"RO"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://graphmining.ai/datasets/ptg/gemsec' def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name assert self.name in ['HU', 'HR', 'RO'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.name}.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(osp.join(self.url, self.name + '.npz'), self.raw_dir) def process(self) -> None: data = np.load(self.raw_paths[0], 'r', allow_pickle=True) y = torch.from_numpy(data['target']).to(torch.long) edge_index = torch.from_numpy(data['edges']).to(torch.long) edge_index = edge_index.t().contiguous() data = Data(y=y, edge_index=edge_index) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/geometry.py ================================================ import glob import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import read_off class GeometricShapes(InMemoryDataset): r"""Synthetic dataset of various geometric shapes like cubes, spheres or pyramids. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 80 - ~148.8 - ~859.5 - 3 - 40 """ url = 'https://github.com/Yannick-S/geometric_shapes/raw/master/raw.zip' def __init__( self, root: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> str: return '2d_circle' @property def processed_file_names(self) -> List[str]: return ['training.pt', 'test.pt'] def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) def process(self) -> None: self.save(self.process_set('train'), self.processed_paths[0]) self.save(self.process_set('test'), self.processed_paths[1]) def process_set(self, dataset: str) -> List[Data]: categories = glob.glob(osp.join(self.raw_dir, '*', '')) categories = sorted([x.split(os.sep)[-2] for x in categories]) data_list = [] for target, category in enumerate(categories): folder = osp.join(self.raw_dir, category, dataset) paths = glob.glob(f'{folder}/*.off') for path in paths: data = read_off(path) assert data.pos is not None data.pos = data.pos - data.pos.mean(dim=0, keepdim=True) data.y = torch.tensor([target]) data_list.append(data) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] return data_list ================================================ FILE: torch_geometric/datasets/git_mol_dataset.py ================================================ import sys from typing import Any, Callable, Dict, List, Optional import numpy as np import torch from tqdm import tqdm from torch_geometric.data import ( Data, InMemoryDataset, download_google_url, extract_zip, ) from torch_geometric.io import fs def safe_index(lst: List[Any], e: int) -> int: return lst.index(e) if e in lst else len(lst) - 1 class GitMolDataset(InMemoryDataset): r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text" `_ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) split (int, optional): Datasets split, train/valid/test=0/1/2. (default: :obj:`0`) """ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, split: int = 0, ): from torchvision import transforms self.split = split if self.split == 0: self.img_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: self.img_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl'] @property def processed_file_names(self) -> str: return ['train.pt', 'valid.pt', 'test.pt'][self.split] def download(self) -> None: file_path = download_google_url( self.raw_url_id, self.raw_dir, 'gitmol.zip', ) extract_zip(file_path, self.raw_dir) def process(self) -> None: import pandas as pd from PIL import Image try: from rdkit import Chem, RDLogger RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined] WITH_RDKIT = True except ImportError: WITH_RDKIT = False if not WITH_RDKIT: print(("Using a pre-processed version of the dataset. Please " "install 'rdkit' to alternatively process the raw data."), file=sys.stderr) data_list = fs.torch_load(self.raw_paths[0]) data_list = [Data(**data_dict) for data_dict in data_list] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) return allowable_features: Dict[str, List[Any]] = { 'possible_atomic_num_list': list(range(1, 119)) + ['misc'], 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], 'possible_chirality_list': [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER ], 'possible_hybridization_list': [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc' ], 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], 'possible_is_aromatic_list': [False, True], 'possible_is_in_ring_list': [False, True], 'possible_bond_type_list': [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, Chem.rdchem.BondType.ZERO ], 'possible_bond_dirs': [ # only for double bond stereo information Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT, Chem.rdchem.BondDir.ENDDOWNRIGHT ], 'possible_bond_stereo_list': [ Chem.rdchem.BondStereo.STEREONONE, Chem.rdchem.BondStereo.STEREOZ, Chem.rdchem.BondStereo.STEREOE, Chem.rdchem.BondStereo.STEREOCIS, Chem.rdchem.BondStereo.STEREOTRANS, Chem.rdchem.BondStereo.STEREOANY, ], 'possible_is_conjugated_list': [False, True] } data = pd.read_pickle( f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}') data_list = [] for _, r in tqdm(data.iterrows(), total=data.shape[0]): smiles = r['isosmiles'] mol = Chem.MolFromSmiles(smiles.strip('\n')) if mol is not None: # text summary = r['summary'] # image cid = r['cid'] img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png' img = Image.open(img_file).convert('RGB') img = self.img_transform(img).unsqueeze(0) # graph atom_features_list = [] for atom in mol.GetAtoms(): atom_feature = [ safe_index( allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), allowable_features['possible_chirality_list'].index( atom.GetChiralTag()), safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), safe_index( allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), safe_index( allowable_features[ 'possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), safe_index( allowable_features['possible_hybridization_list'], atom.GetHybridization()), allowable_features['possible_is_aromatic_list'].index( atom.GetIsAromatic()), allowable_features['possible_is_in_ring_list'].index( atom.IsInRing()), ] atom_features_list.append(atom_feature) x = torch.tensor(np.array(atom_features_list), dtype=torch.long) edges_list = [] edge_features_list = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_feature = [ safe_index( allowable_features['possible_bond_type_list'], bond.GetBondType()), allowable_features['possible_bond_stereo_list'].index( bond.GetStereo()), allowable_features['possible_is_conjugated_list']. index(bond.GetIsConjugated()), ] edges_list.append((i, j)) edge_features_list.append(edge_feature) edges_list.append((j, i)) edge_features_list.append(edge_feature) edge_index = torch.tensor( np.array(edges_list).T, dtype=torch.long, ) edge_attr = torch.tensor( np.array(edge_features_list), dtype=torch.long, ) data = Data( x=x, edge_index=edge_index, smiles=smiles, edge_attr=edge_attr, image=img, caption=summary, ) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/github.py ================================================ from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class GitHub(InMemoryDataset): r"""The GitHub Web and ML Developers dataset introduced in the `"Multi-scale Attributed Node Embedding" `_ paper. Nodes represent developers on :obj:`github:`GitHub` and edges are mutual follower relationships. It contains 37,300 nodes, 578,006 edges, 128 node features and 2 classes. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 37,700 - 578,006 - 0 - 2 """ url = 'https://graphmining.ai/datasets/ptg/github.npz' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'github.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url, self.raw_dir) def process(self) -> None: data = np.load(self.raw_paths[0], 'r', allow_pickle=True) x = torch.from_numpy(data['features']).to(torch.float) y = torch.from_numpy(data['target']).to(torch.long) edge_index = torch.from_numpy(data['edges']).to(torch.long) edge_index = edge_index.t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/gnn_benchmark_dataset.py ================================================ import logging import os import os.path as osp import pickle from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs from torch_geometric.utils import remove_self_loops class GNNBenchmarkDataset(InMemoryDataset): r"""A variety of artificially and semi-artificially generated graph datasets from the `"Benchmarking Graph Neural Networks" `_ paper. .. note:: The ZINC dataset is provided via :class:`torch_geometric.datasets.ZINC`. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (one of :obj:`"PATTERN"`, :obj:`"CLUSTER"`, :obj:`"MNIST"`, :obj:`"CIFAR10"`, :obj:`"TSP"`, :obj:`"CSL"`) split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - PATTERN - 14,000 - ~118.9 - ~6,098.9 - 3 - 2 * - CLUSTER - 12,000 - ~117.2 - ~4,303.9 - 7 - 6 * - MNIST - 70,000 - ~70.6 - ~564.5 - 3 - 10 * - CIFAR10 - 60,000 - ~117.6 - ~941.2 - 5 - 10 * - TSP - 12,000 - ~275.4 - ~6,885.0 - 2 - 2 * - CSL - 150 - ~41.0 - ~164.0 - 0 - 10 """ names = ['PATTERN', 'CLUSTER', 'MNIST', 'CIFAR10', 'TSP', 'CSL'] root_url = 'https://data.pyg.org/datasets/benchmarking-gnns' urls = { 'PATTERN': f'{root_url}/PATTERN_v2.zip', 'CLUSTER': f'{root_url}/CLUSTER_v2.zip', 'MNIST': f'{root_url}/MNIST_v2.zip', 'CIFAR10': f'{root_url}/CIFAR10_v2.zip', 'TSP': f'{root_url}/TSP_v2.zip', 'CSL': 'https://www.dropbox.com/s/rnbkp5ubgk82ocu/CSL.zip?dl=1', } def __init__( self, root: str, name: str, split: str = "train", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name assert self.name in self.names if self.name == 'CSL' and split != 'train': split = 'train' logging.warning( "Dataset 'CSL' does not provide a standardized splitting. " "Instead, it is recommended to perform 5-fold cross " "validation with stratifed sampling") super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) if split == 'train': path = self.processed_paths[0] elif split == 'val': path = self.processed_paths[1] elif split == 'test': path = self.processed_paths[2] else: raise ValueError(f"Split '{split}' found, but expected either " f"'train', 'val', or 'test'") self.load(path) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: if self.name == 'CSL': return [ 'graphs_Kary_Deterministic_Graphs.pkl', 'y_Kary_Deterministic_Graphs.pt' ] else: name = self.urls[self.name].split('/')[-1][:-4] return [f'{name}.pt'] @property def processed_file_names(self) -> List[str]: if self.name == 'CSL': return ['data.pt'] else: return ['train_data.pt', 'val_data.pt', 'test_data.pt'] def download(self) -> None: path = download_url(self.urls[self.name], self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: if self.name == 'CSL': data_list = self.process_CSL() self.save(data_list, self.processed_paths[0]) else: inputs = fs.torch_load(self.raw_paths[0]) for i in range(len(inputs)): data_list = [Data(**data_dict) for data_dict in inputs[i]] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[i]) def process_CSL(self) -> List[Data]: with open(self.raw_paths[0], 'rb') as f: adjs = pickle.load(f) ys = fs.torch_load(self.raw_paths[1]).tolist() data_list = [] for adj, y in zip(adjs, ys): row, col = torch.from_numpy(adj.row), torch.from_numpy(adj.col) edge_index = torch.stack([row, col], dim=0).to(torch.long) edge_index, _ = remove_self_loops(edge_index) data = Data(edge_index=edge_index, y=y, num_nodes=adj.shape[0]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) return data_list def __repr__(self) -> str: return f'{self.name}({len(self)})' ================================================ FILE: torch_geometric/datasets/graph_generator/__init__.py ================================================ from .base import GraphGenerator from .ba_graph import BAGraph from .er_graph import ERGraph from .grid_graph import GridGraph from .tree_graph import TreeGraph __all__ = classes = [ 'GraphGenerator', 'BAGraph', 'ERGraph', 'GridGraph', 'TreeGraph', ] ================================================ FILE: torch_geometric/datasets/graph_generator/ba_graph.py ================================================ from torch_geometric.data import Data from torch_geometric.datasets.graph_generator import GraphGenerator from torch_geometric.utils import barabasi_albert_graph class BAGraph(GraphGenerator): r"""Generates random Barabasi-Albert (BA) graphs. See :meth:`~torch_geometric.utils.barabasi_albert_graph` for more information. Args: num_nodes (int): The number of nodes. num_edges (int): The number of edges from a new node to existing nodes. """ def __init__(self, num_nodes: int, num_edges: int): super().__init__() self.num_nodes = num_nodes self.num_edges = num_edges def __call__(self) -> Data: edge_index = barabasi_albert_graph(self.num_nodes, self.num_edges) return Data(num_nodes=self.num_nodes, edge_index=edge_index) def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, ' f'num_edges={self.num_edges})') ================================================ FILE: torch_geometric/datasets/graph_generator/base.py ================================================ from abc import ABC, abstractmethod from typing import Any from torch_geometric.data import Data from torch_geometric.resolver import resolver class GraphGenerator(ABC): r"""An abstract base class for generating synthetic graphs.""" @abstractmethod def __call__(self) -> Data: r"""To be implemented by :class:`GraphGenerator` subclasses.""" raise NotImplementedError @staticmethod def resolve(query: Any, *args: Any, **kwargs: Any) -> 'GraphGenerator': import torch_geometric.datasets.graph_generator as _graph_generators graph_generators = [ gen for gen in vars(_graph_generators).values() if isinstance(gen, type) and issubclass(gen, GraphGenerator) ] return resolver(graph_generators, {}, query, GraphGenerator, 'Graph', *args, **kwargs) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/datasets/graph_generator/er_graph.py ================================================ from torch_geometric.data import Data from torch_geometric.datasets.graph_generator import GraphGenerator from torch_geometric.utils import erdos_renyi_graph class ERGraph(GraphGenerator): r"""Generates random Erdos-Renyi (ER) graphs. See :meth:`~torch_geometric.utils.erdos_renyi_graph` for more information. Args: num_nodes (int): The number of nodes. edge_prob (float): Probability of an edge. """ def __init__(self, num_nodes: int, edge_prob: float): super().__init__() self.num_nodes = num_nodes self.edge_prob = edge_prob def __call__(self) -> Data: edge_index = erdos_renyi_graph(self.num_nodes, self.edge_prob) return Data(num_nodes=self.num_nodes, edge_index=edge_index) def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, ' f'edge_prob={self.edge_prob})') ================================================ FILE: torch_geometric/datasets/graph_generator/grid_graph.py ================================================ from typing import Optional import torch from torch_geometric.data import Data from torch_geometric.datasets.graph_generator import GraphGenerator from torch_geometric.utils import grid class GridGraph(GraphGenerator): r"""Generates two-dimensional grid graphs. See :meth:`~torch_geometric.utils.grid` for more information. Args: height (int): The height of the grid. width (int): The width of the grid. dtype (:obj:`torch.dtype`, optional): The desired data type of the returned position tensor. (default: :obj:`None`) """ def __init__( self, height: int, width: int, dtype: Optional[torch.dtype] = None, ): super().__init__() self.height = height self.width = width self.dtype = dtype def __call__(self) -> Data: edge_index, pos = grid(height=self.height, width=self.width, dtype=self.dtype) return Data(edge_index=edge_index, pos=pos) def __repr__(self) -> str: return (f'{self.__class__.__name__}(height={self.height}, ' f'width={self.width})') ================================================ FILE: torch_geometric/datasets/graph_generator/tree_graph.py ================================================ from typing import List, Optional, Tuple import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.datasets.graph_generator import GraphGenerator from torch_geometric.utils import to_undirected def tree( depth: int, branch: int = 2, undirected: bool = False, device: Optional[torch.device] = None, ) -> Tuple[Tensor, Tensor]: """Generates a tree graph with the given depth and branch size, along with node-level depth indicators. Args: depth (int): The depth of the tree. branch (int, optional): The branch size of the tree. (default: :obj:`2`) undirected (bool, optional): If set to :obj:`True`, the tree graph will be undirected. (default: :obj:`False`) device (torch.device, optional): The desired device of the returned tensors. (default: :obj:`None`) """ edges: List[Tuple[int, int]] = [] depths: List[int] = [0] def add_edges(node: int, current_depth: int) -> None: node_count = len(depths) if current_depth < depth: for i in range(branch): edges.append((node, node_count + i)) depths.append(current_depth + 1) for i in range(branch): add_edges(node=node_count + i, current_depth=current_depth + 1) add_edges(node=0, current_depth=0) edge_index = torch.tensor(edges, device=device).t().contiguous() if undirected: edge_index = to_undirected(edge_index, num_nodes=len(depths)) return edge_index, torch.tensor(depths, device=device) class TreeGraph(GraphGenerator): r"""Generates tree graphs. Args: depth (int): The depth of the tree. branch (int, optional): The branch size of the tree. (default: :obj:`2`) undirected (bool, optional): If set to :obj:`True`, the tree graph will be undirected. (default: :obj:`False`) """ def __init__( self, depth: int, branch: int = 2, undirected: bool = False, ) -> None: super().__init__() self.depth = depth self.branch = branch self.undirected = undirected def __call__(self) -> Data: edge_index, depth = tree(self.depth, self.branch, self.undirected) num_nodes = depth.numel() return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes) def __repr__(self) -> str: return (f'{self.__class__.__name__}(depth={self.depth}, ' f'branch={self.branch}, undirected={self.undirected})') ================================================ FILE: torch_geometric/datasets/heterophilous_graph_dataset.py ================================================ import os.path as osp from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.utils import to_undirected class HeterophilousGraphDataset(InMemoryDataset): r"""The heterophilous graphs :obj:`"Roman-empire"`, :obj:`"Amazon-ratings"`, :obj:`"Minesweeper"`, :obj:`"Tolokers"` and :obj:`"Questions"` from the `"A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress?" `_ paper. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"Roman-empire"`, :obj:`"Amazon-ratings"`, :obj:`"Minesweeper"`, :obj:`"Tolokers"`, :obj:`"Questions"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Roman-empire - 22,662 - 32,927 - 300 - 18 * - Amazon-ratings - 24,492 - 93,050 - 300 - 5 * - Minesweeper - 10,000 - 39,402 - 7 - 2 * - Tolokers - 11,758 - 519,000 - 10 - 2 * - Questions - 48,921 - 153,540 - 301 - 2 """ url = ('https://github.com/yandex-research/heterophilous-graphs/raw/' 'main/data') def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower().replace('-', '_') assert self.name in [ 'roman_empire', 'amazon_ratings', 'minesweeper', 'tolokers', 'questions', ] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.name}.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(f'{self.url}/{self.name}.npz', self.raw_dir) def process(self) -> None: raw = np.load(self.raw_paths[0], 'r') x = torch.from_numpy(raw['node_features']) y = torch.from_numpy(raw['node_labels']) edge_index = torch.from_numpy(raw['edges']).t().contiguous() edge_index = to_undirected(edge_index, num_nodes=x.size(0)) train_mask = torch.from_numpy(raw['train_masks']).t().contiguous() val_mask = torch.from_numpy(raw['val_masks']).t().contiguous() test_mask = torch.from_numpy(raw['test_masks']).t().contiguous() data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}(name={self.name})' ================================================ FILE: torch_geometric/datasets/hgb_dataset.py ================================================ import json import os import os.path as osp from collections import defaultdict from typing import Callable, Dict, List, Optional import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_google_url, extract_zip, ) class HGBDataset(InMemoryDataset): r"""A variety of heterogeneous graph benchmark datasets from the `"Are We Really Making Much Progress? Revisiting, Benchmarking, and Refining Heterogeneous Graph Neural Networks" `_ paper. .. note:: Test labels are randomly given to prevent data leakage issues. If you want to obtain final test performance, you will need to submit your model predictions to the `HGB leaderboard `_. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (one of :obj:`"ACM"`, :obj:`"DBLP"`, :obj:`"Freebase"`, :obj:`"IMDB"`) transform (callable, optional): A function/transform that takes in an :class:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :class:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ names = { 'acm': 'ACM', 'dblp': 'DBLP', 'freebase': 'Freebase', 'imdb': 'IMDB', } file_ids = { 'acm': '1xbJ4QE9pcDJOcALv7dYhHDCPITX2Iddz', 'dblp': '1fLLoy559V7jJaQ_9mQEsC06VKd6Qd3SC', 'freebase': '1vw-uqbroJZfFsWpriC1CWbtHCJMGdWJ7', 'imdb': '18qXmmwKJBrEJxVQaYwKTL3Ny3fPqJeJ2', } def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in set(self.names.keys()) super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: x = ['info.dat', 'node.dat', 'link.dat', 'label.dat', 'label.dat.test'] return [osp.join(self.names[self.name], f) for f in x] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: id = self.file_ids[self.name] path = download_google_url(id, self.raw_dir, 'data.zip') extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: data = HeteroData() # node_types = {0: 'paper', 1, 'author', ...} # edge_types = {0: ('paper', 'cite', 'paper'), ...} if self.name in ['acm', 'dblp', 'imdb']: with open(self.raw_paths[0]) as f: # `info.dat` info = json.load(f) n_types = info['node.dat']['node type'] n_types = {int(k): v for k, v in n_types.items()} e_types = info['link.dat']['link type'] e_types = {int(k): tuple(v.values()) for k, v in e_types.items()} for key, (src, dst, rel) in e_types.items(): src, dst = n_types[int(src)], n_types[int(dst)] rel = rel.split('-')[1] rel = rel if rel != dst and rel[1:] != dst else 'to' e_types[key] = (src, rel, dst) num_classes = len(info['label.dat']['node type']['0']) elif self.name in ['freebase']: with open(self.raw_paths[0]) as f: # `info.dat` info = f.read().split('\n') start = info.index('TYPE\tMEANING') + 1 end = info[start:].index('') n_types = [v.split('\t\t') for v in info[start:start + end]] n_types = {int(k): v.lower() for k, v in n_types} e_types = {} start = info.index('LINK\tSTART\tEND\tMEANING') + 1 end = info[start:].index('') for key, row in enumerate(info[start:start + end]): edge = row.split('\t')[1:] src, dst, rel = (v for v in edge if v != '') src, dst = n_types[int(src)], n_types[int(dst)] rel = rel.split('-')[1] e_types[key] = (src, rel, dst) else: # Link prediction: raise NotImplementedError # Extract node information: mapping_dict = {} # Maps global node indices to local ones. x_dict = defaultdict(list) num_nodes_dict: Dict[str, int] = defaultdict(int) with open(self.raw_paths[1]) as f: # `node.dat` xs = [v.split('\t') for v in f.read().split('\n')[:-1]] for x in xs: n_id, n_type = int(x[0]), n_types[int(x[2])] mapping_dict[n_id] = num_nodes_dict[n_type] num_nodes_dict[n_type] += 1 if len(x) >= 4: # Extract features (in case they are given). x_dict[n_type].append([float(v) for v in x[3].split(',')]) for n_type in n_types.values(): if len(x_dict[n_type]) == 0: data[n_type].num_nodes = num_nodes_dict[n_type] else: data[n_type].x = torch.tensor(x_dict[n_type]) edge_index_dict = defaultdict(list) edge_weight_dict = defaultdict(list) with open(self.raw_paths[2]) as f: # `link.dat` edges = [v.split('\t') for v in f.read().split('\n')[:-1]] for src, dst, rel, weight in edges: e_type = e_types[int(rel)] src, dst = mapping_dict[int(src)], mapping_dict[int(dst)] edge_index_dict[e_type].append([src, dst]) edge_weight_dict[e_type].append(float(weight)) for e_type in e_types.values(): edge_index = torch.tensor(edge_index_dict[e_type]) edge_weight = torch.tensor(edge_weight_dict[e_type]) data[e_type].edge_index = edge_index.t().contiguous() # Only add "weighted" edgel to the graph: if not torch.allclose(edge_weight, torch.ones_like(edge_weight)): data[e_type].edge_weight = edge_weight # Node classification: if self.name in ['acm', 'dblp', 'freebase', 'imdb']: with open(self.raw_paths[3]) as f: # `label.dat` train_ys = [v.split('\t') for v in f.read().split('\n')[:-1]] with open(self.raw_paths[4]) as f: # `label.dat.test` test_ys = [v.split('\t') for v in f.read().split('\n')[:-1]] for y in train_ys: n_id, n_type = mapping_dict[int(y[0])], n_types[int(y[2])] if not hasattr(data[n_type], 'y'): num_nodes = data[n_type].num_nodes if self.name in ['imdb']: # multi-label data[n_type].y = torch.zeros((num_nodes, num_classes)) else: data[n_type].y = torch.full((num_nodes, ), -1).long() data[n_type].train_mask = torch.zeros(num_nodes).bool() data[n_type].test_mask = torch.zeros(num_nodes).bool() if data[n_type].y.dim() > 1: # multi-label for v in y[3].split(','): data[n_type].y[n_id, int(v)] = 1 else: data[n_type].y[n_id] = int(y[3]) data[n_type].train_mask[n_id] = True for y in test_ys: n_id, n_type = mapping_dict[int(y[0])], n_types[int(y[2])] if data[n_type].y.dim() > 1: # multi-label for v in y[3].split(','): data[n_type].y[n_id, int(v)] = 1 else: data[n_type].y[n_id] = int(y[3]) data[n_type].test_mask[n_id] = True else: # Link prediction: raise NotImplementedError if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.names[self.name]}()' ================================================ FILE: torch_geometric/datasets/hm.py ================================================ from typing import Callable, List, Optional import torch from torch_geometric.data import HeteroData, InMemoryDataset class HM(InMemoryDataset): r"""The heterogeneous H&M dataset from the `Kaggle H&M Personalized Fashion Recommendations `_ challenge. The task is to develop product recommendations based on data from previous transactions, as well as from customer and product meta data. Args: root (str): Root directory where the dataset should be saved. use_all_tables_as_node_types (bool, optional): If set to :obj:`True`, will use the transaction table as a distinct node type. (default: :obj:`False`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://www.kaggle.com/competitions/' 'h-and-m-personalized-fashion-recommendations/data') def __init__( self, root: str, use_all_tables_as_node_types: bool = False, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.use_all_tables_as_node_types = use_all_tables_as_node_types super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'customers.csv.zip', 'articles.csv.zip', 'transactions_train.csv.zip' ] @property def processed_file_names(self) -> str: if self.use_all_tables_as_node_types: return 'data.pt' else: return 'data_merged.pt' def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download {self.raw_file_names} from " f"'{self.url}' and move it to '{self.raw_dir}'") def process(self) -> None: import pandas as pd data = HeteroData() # Process customer data ############################################### df = pd.read_csv(self.raw_paths[0], index_col='customer_id') customer_map = {idx: i for i, idx in enumerate(df.index)} xs = [] for name in [ 'Active', 'FN', 'club_member_status', 'fashion_news_frequency' ]: x = pd.get_dummies(df[name]).values xs.append(torch.from_numpy(x).to(torch.float)) x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1) x = x.nan_to_num(nan=x.nanmean()) # type: ignore xs.append(x / x.max()) data['customer'].x = torch.cat(xs, dim=-1) # Process article data ################################################ df = pd.read_csv(self.raw_paths[1], index_col='article_id') article_map = {idx: i for i, idx in enumerate(df.index)} xs = [] for name in [ # We drop a few columns here that are high cardinality. # 'product_code', # Drop. # 'prod_name', # Drop. 'product_type_no', 'product_type_name', 'product_group_name', 'graphical_appearance_no', 'graphical_appearance_name', 'colour_group_code', 'colour_group_name', 'perceived_colour_value_id', 'perceived_colour_value_name', 'perceived_colour_master_id', 'perceived_colour_master_name', # 'department_no', # Drop. # 'department_name', # Drop. 'index_code', 'index_name', 'index_group_no', 'index_group_name', 'section_no', 'section_name', 'garment_group_no', 'garment_group_name', # 'detail_desc', # Drop. ]: x = pd.get_dummies(df[name]).values xs.append(torch.from_numpy(x).to(torch.float)) data['article'].x = torch.cat(xs, dim=-1) # Process transaction data ############################################ df = pd.read_csv(self.raw_paths[2], parse_dates=['t_dat']) x1 = pd.get_dummies(df['sales_channel_id']).values x1 = torch.from_numpy(x1).to(torch.float) x2 = torch.from_numpy(df['price'].values).to(torch.float).view(-1, 1) x = torch.cat([x1, x2], dim=-1) time = torch.from_numpy(df['t_dat'].values.astype(int)) time = time // (60 * 60 * 24 * 10**9) # Convert nanoseconds to days. src = torch.tensor([customer_map[idx] for idx in df['customer_id']]) dst = torch.tensor([article_map[idx] for idx in df['article_id']]) if self.use_all_tables_as_node_types: data['transaction'].x = x data['transaction'].time = time edge_index = torch.stack([src, torch.arange(len(df))], dim=0) data['customer', 'to', 'transaction'].edge_index = edge_index edge_index = edge_index.flip([0]) data['transaction', 'rev_to', 'customer'].edge_index = edge_index edge_index = torch.stack([dst, torch.arange(len(df))], dim=0) data['article', 'to', 'transaction'].edge_index = edge_index edge_index = edge_index.flip([0]) data['transaction', 'rev_to', 'article'].edge_index = edge_index else: edge_index = torch.stack([src, dst], dim=0) data['customer', 'to', 'article'].edge_index = edge_index data['customer', 'to', 'article'].time = time data['customer', 'to', 'article'].edge_attr = x edge_index = edge_index.flip([0]) data['article', 'rev_to', 'customer'].edge_index = edge_index data['article', 'rev_to', 'customer'].time = time data['article', 'rev_to', 'customer'].edge_attr = x if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/hydro_net.py ================================================ import copy import os import os.path as osp from dataclasses import dataclass from functools import cached_property from glob import glob from pathlib import Path from typing import Callable, List, MutableSequence, Optional, Tuple, Union import numpy as np import torch from torch.utils.data import ConcatDataset, Subset from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.data.data import BaseData class HydroNet(InMemoryDataset): r"""The HydroNet dataest from the `"HydroNet: Benchmark Tasks for Preserving Intermolecular Interactions and Structural Motifs in Predictive and Generative Models for Molecular Data" `_ paper, consisting of 5 million water clusters held together by hydrogen bonding networks. This dataset provides atomic coordinates and total energy in kcal/mol for the cluster. Args: root (str): Root directory where the dataset should be saved. name (str, optional): Name of the subset of the full dataset to use: :obj:`"small"` uses 500k graphs sampled from the :obj:`"medium"` dataset, :obj:`"medium"` uses 2.7m graphs with maximum size of 75 nodes. Mutually exclusive option with the clusters argument. (default :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) num_workers (int): Number of multiprocessing workers to use for pre-processing the dataset. (default :obj:`8`) clusters (int or List[int], optional): Select a subset of clusters from the full dataset. If set to :obj:`None`, will select all. (default :obj:`None`) use_processed (bool): Option to use a pre-processed version of the original :obj:`xyz` dataset. (default: :obj:`True`) """ def __init__( self, root: str, name: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, num_workers: int = 8, clusters: Optional[Union[int, List[int]]] = None, use_processed: bool = True, ) -> None: self.name = name self.num_workers = num_workers self.use_processed = use_processed super().__init__(root, transform, pre_transform, force_reload=force_reload) self.select_clusters(clusters) @property def raw_file_names(self) -> List[str]: return [f'W{c}_geoms_all.zip' for c in range(3, 31)] @property def processed_file_names(self) -> List[str]: return [f'W{c}_geoms_all.npz' for c in range(3, 31)] def download(self) -> None: token_file = Path(osp.join(self.raw_dir, 'use_processed')) if self.use_processed and token_file.exists(): return file = RemoteFile.hydronet_splits() file.unpack_to(self.raw_dir) if self.use_processed: file = RemoteFile.processed_dataset() file.unpack_to(self.raw_dir) token_file.touch() return file = RemoteFile.raw_dataset() file.unpack_to(self.raw_dir) folder_name, _ = osp.splitext(file.name) files = glob(osp.join(self.raw_dir, folder_name, '*.zip')) for f in files: dst = osp.join(self.raw_dir, osp.basename(f)) os.rename(f, dst) os.rmdir(osp.join(self.raw_dir, folder_name)) def process(self) -> None: if self.use_processed: return self._unpack_processed() from tqdm.contrib.concurrent import process_map self._partitions = process_map( self._create_partitions, self.raw_paths, max_workers=self.num_workers, position=0, leave=True, ) def _unpack_processed(self) -> None: files = glob(osp.join(self.raw_dir, '*.npz')) for f in files: dst = osp.join(self.processed_dir, osp.basename(f)) os.rename(f, dst) def _create_partitions(self, file: str) -> 'Partition': name = osp.basename(file) name, _ = osp.splitext(name) return Partition(self.root, name, self.transform, self.pre_transform) def select_clusters( self, clusters: Optional[Union[int, List[int]]], ) -> None: if self.name is not None: clusters = self._validate_name(clusters) self._partitions = [self._create_partitions(f) for f in self.raw_paths] if clusters is None: return clusters = [clusters] if isinstance(clusters, int) else clusters def is_valid_cluster(x: Union[int, List[int]]) -> bool: return isinstance(x, int) and x >= 3 and x <= 30 if not all([is_valid_cluster(x) for x in clusters]): raise ValueError( "Selected clusters must be an integer in the range [3, 30]") self._partitions = [self._partitions[c - 3] for c in clusters] def _validate_name( self, clusters: Optional[Union[int, List[int]]], ) -> List[int]: if clusters is not None: raise ValueError("'name' and 'clusters' are mutually exclusive") if self.name not in ['small', 'medium']: raise ValueError(f"Invalid subset name '{self.name}'. " f"Must be either 'small' or 'medium'") return list(range(3, 26)) @cached_property def _dataset(self) -> Union[ConcatDataset, Subset]: dataset: ConcatDataset = ConcatDataset(self._partitions) if self.name == "small": return self._load_small_split(dataset) return dataset def _load_small_split(self, dataset: ConcatDataset) -> Subset: split_file = osp.join(self.processed_dir, 'split_00_small.npz') with np.load(split_file) as split: train_idx = split['train_idx'] val_idx = split['val_idx'] all_idx = np.concatenate([train_idx, val_idx]) return Subset(dataset, all_idx) def len(self) -> int: return len(self._dataset) def get(self, idx: int) -> Data: return self._dataset[idx] def get_num_clusters(filepath: str) -> int: name = osp.basename(filepath) return int(name[1:name.find('_')]) def read_energy(file: str, chunk_size: int) -> np.ndarray: import pandas as pd def skipatoms(i: int) -> bool: return (i - 1) % chunk_size != 0 if chunk_size - 2 == 11 * 3: # Manually handle bad lines in W11 df = pd.read_table(file, header=None, dtype="string", skiprows=skipatoms) df = df[0].str.split().str[-1].astype(np.float32) else: df = pd.read_table(file, sep=r'\s+', names=["label", "energy"], dtype=np.float32, skiprows=skipatoms, usecols=['energy'], memory_map=True) return df.to_numpy().squeeze() def read_atoms(file: str, chunk_size: int) -> Tuple[np.ndarray, np.ndarray]: import pandas as pd def skipheaders(i: int) -> bool: return i % chunk_size == 0 or (i - 1) % chunk_size == 0 dtypes = { 'atom': 'string', 'x': np.float16, 'y': np.float16, 'z': np.float16 } df = pd.read_table(file, sep=r'\s+', names=list(dtypes.keys()), dtype=dtypes, skiprows=skipheaders, memory_map=True) z = np.ones(len(df), dtype=np.int8) z[(df.atom == 'O').to_numpy(dtype=np.bool_)] = 8 pos = df.iloc[:, 1:4].to_numpy() num_nodes = (chunk_size - 2) num_graphs = z.shape[0] // num_nodes z.shape = (num_graphs, num_nodes) pos.shape = (num_graphs, num_nodes, 3) return (z, pos) @dataclass class RemoteFile: url: str name: str def unpack_to(self, dest_folder: str) -> None: file = download_url(self.url, dest_folder, filename=self.name) extract_zip(file, dest_folder) os.unlink(file) @staticmethod def raw_dataset() -> 'RemoteFile': return RemoteFile( url='https://figshare.com/ndownloader/files/38063847', name='W3-W30_all_geoms_TTM2.1-F.zip') @staticmethod def processed_dataset() -> 'RemoteFile': return RemoteFile( url='https://figshare.com/ndownloader/files/38075781', name='W3-W30_pyg_processed.zip') @staticmethod def hydronet_splits() -> 'RemoteFile': return RemoteFile( url="https://figshare.com/ndownloader/files/38075904", name="hydronet_splits.zip") class Partition(InMemoryDataset): def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, ) -> None: self.name = name self.num_clusters = get_num_clusters(name) super().__init__(root, transform, pre_transform, pre_filter=None, log=False) self.is_loaded = False @property def raw_file_names(self) -> List[str]: return [self.name + ".zip"] @property def processed_file_names(self) -> List[str]: return [self.name + '.npz'] def process(self) -> None: num_nodes = self.num_clusters * 3 chunk_size = num_nodes + 2 z, pos = read_atoms(self.raw_paths[0], chunk_size) y = read_energy(self.raw_paths[0], chunk_size) np.savez(self.processed_paths[0], z=z, pos=pos, y=y, num_graphs=z.shape[0]) def _load(self) -> None: if self.is_loaded: return None with np.load(self.processed_paths[0]) as npzfile: self.z = npzfile['z'] self.pos = npzfile['pos'] self.y = npzfile['y'] numel = int(npzfile['num_graphs']) self._data_list: MutableSequence[Optional[BaseData]] = [None] * numel self.is_loaded = True @cached_property def num_graphs(self) -> int: with np.load(self.processed_paths[0]) as npzfile: return int(npzfile['num_graphs']) def len(self) -> int: return self.num_graphs def get(self, idx: int) -> Data: self._load() if self._data_list[idx] is not None: cached_data = self._data_list[idx] assert isinstance(cached_data, Data) return copy.copy(cached_data) data = Data( z=torch.from_numpy(self.z[idx, :]), pos=torch.from_numpy(self.pos[idx, :, :]), y=torch.tensor(self.y[idx]), ) if self.pre_transform is not None: data = self.pre_transform(data) self._data_list[idx] = copy.copy(data) return data ================================================ FILE: torch_geometric/datasets/icews.py ================================================ from typing import Callable, List, Optional import torch from torch import Tensor from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.io import read_txt_array class EventDataset(InMemoryDataset): def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) @property def num_nodes(self) -> int: raise NotImplementedError @property def num_rels(self) -> int: raise NotImplementedError def process_events(self) -> Tensor: raise NotImplementedError def _process_data_list(self) -> List[Data]: events = self.process_events() events = events - events.min(dim=0, keepdim=True)[0] data_list = [] for (sub, rel, obj, t) in events.tolist(): data = Data(sub=sub, rel=rel, obj=obj, t=t) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) return data_list class ICEWS18(EventDataset): r"""The Integrated Crisis Early Warning System (ICEWS) dataset used in the, *e.g.*, `"Recurrent Event Network for Reasoning over Temporal Knowledge Graphs" `_ paper, consisting of events collected from 1/1/2018 to 10/31/2018 (24 hours time granularity). Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://github.com/INK-USC/RE-Net/raw/master/data/ICEWS18' splits = [0, 373018, 419013, 468558] # Train/Val/Test splits. def __init__( self, root: str, split: str = 'train', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert split in ['train', 'val', 'test'] super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) idx = self.processed_file_names.index(f'{split}.pt') self.load(self.processed_paths[idx]) @property def num_nodes(self) -> int: return 23033 @property def num_rels(self) -> int: return 256 @property def raw_file_names(self) -> List[str]: return [f'{name}.txt' for name in ['train', 'valid', 'test']] @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: for filename in self.raw_file_names: download_url(f'{self.url}/{filename}', self.raw_dir) def process_events(self) -> Tensor: events = [] for path in self.raw_paths: data = read_txt_array(path, sep='\t', end=4, dtype=torch.long) data[:, 3] = data[:, 3] // 24 events += [data] return torch.cat(events, dim=0) def process(self) -> None: s = self.splits data_list = self._process_data_list() self.save(data_list[s[0]:s[1]], self.processed_paths[0]) self.save(data_list[s[1]:s[2]], self.processed_paths[1]) self.save(data_list[s[2]:s[3]], self.processed_paths[2]) ================================================ FILE: torch_geometric/datasets/igmc_dataset.py ================================================ import os.path as osp from typing import Callable, Optional import torch from torch import Tensor from torch_geometric.data import HeteroData, InMemoryDataset, download_url class IGMCDataset(InMemoryDataset): r"""The user-item heterogeneous rating datasets :obj:`"Douban"`, :obj:`"Flixster"` and :obj:`"Yahoo-Music"` from the `"Inductive Matrix Completion Based on Graph Neural Networks" `_ paper. Nodes represent users and items. Edges and features between users and items represent a (training) rating of the item given by the user. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"Douban"`, :obj:`"Flixster"`, :obj:`"Yahoo-Music"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://github.com/muhanzhang/IGMC/raw/master/raw_data' def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower().replace('-', '_') assert self.name in ['flixster', 'douban', 'yahoo_music'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return 'training_test_dataset.mat' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = f'{self.url}/{self.name}/training_test_dataset.mat' download_url(path, self.raw_dir) @staticmethod def load_matlab_file(path_file: str, name: str) -> Tensor: import h5py import numpy as np db = h5py.File(path_file, 'r') out = torch.from_numpy(np.asarray(db[name])).to(torch.float).t() db.close() return out def process(self) -> None: data = HeteroData() M = self.load_matlab_file(self.raw_paths[0], 'M') if self.name == 'flixster': user_x = self.load_matlab_file(self.raw_paths[0], 'W_users') item_x = self.load_matlab_file(self.raw_paths[0], 'W_movies') elif self.name == 'douban': user_x = self.load_matlab_file(self.raw_paths[0], 'W_users') item_x = torch.eye(M.size(1)) elif self.name == 'yahoo_music': user_x = torch.eye(M.size(0)) item_x = self.load_matlab_file(self.raw_paths[0], 'W_tracks') data['user'].x = user_x data['item'].x = item_x train_mask = self.load_matlab_file(self.raw_paths[0], 'Otraining') train_mask = train_mask.to(torch.bool) edge_index = train_mask.nonzero().t() rating = M[edge_index[0], edge_index[1]] data['user', 'rates', 'item'].edge_index = edge_index data['user', 'rates', 'item'].rating = rating data['item', 'rated_by', 'user'].edge_index = edge_index.flip([0]) data['item', 'rated_by', 'user'].rating = rating test_mask = self.load_matlab_file(self.raw_paths[0], 'Otest') test_mask = test_mask.to(torch.bool) edge_label_index = test_mask.nonzero().t() edge_label = M[edge_label_index[0], edge_label_index[1]] data['user', 'rates', 'item'].edge_label_index = edge_label_index data['user', 'rates', 'item'].edge_label = edge_label if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}(name={self.name})' ================================================ FILE: torch_geometric/datasets/imdb.py ================================================ import os import os.path as osp from itertools import product from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) class IMDB(InMemoryDataset): r"""A subset of the Internet Movie Database (IMDB), as collected in the `"MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph Embedding" `_ paper. IMDB is a heterogeneous graph containing three types of entities - movies (4,278 nodes), actors (5,257 nodes), and directors (2,081 nodes). The movies are divided into three classes (action, comedy, drama) according to their genre. Movie features correspond to elements of a bag-of-words representation of its plot keywords. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://www.dropbox.com/s/g0btk9ctr1es39x/IMDB_processed.zip?dl=1' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npz', 'labels.npy', 'train_val_test_idx.npz' ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import scipy.sparse as sp data = HeteroData() node_types = ['movie', 'director', 'actor'] for i, node_type in enumerate(node_types): x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz')) data[node_type].x = torch.from_numpy(x.todense()).to(torch.float) y = np.load(osp.join(self.raw_dir, 'labels.npy')) data['movie'].y = torch.from_numpy(y).to(torch.long) split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz')) for name in ['train', 'val', 'test']: idx = split[f'{name}_idx'] idx = torch.from_numpy(idx).to(torch.long) mask = torch.zeros(data['movie'].num_nodes, dtype=torch.bool) mask[idx] = True data['movie'][f'{name}_mask'] = mask s = {} N_m = data['movie'].num_nodes N_d = data['director'].num_nodes N_a = data['actor'].num_nodes s['movie'] = (0, N_m) s['director'] = (N_m, N_m + N_d) s['actor'] = (N_m + N_d, N_m + N_d + N_a) A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')) for src, dst in product(node_types, node_types): A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() if A_sub.nnz > 0: row = torch.from_numpy(A_sub.row).to(torch.long) col = torch.from_numpy(A_sub.col).to(torch.long) data[src, dst].edge_index = torch.stack([row, col], dim=0) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/datasets/infection_dataset.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import torch from torch_geometric.data import InMemoryDataset from torch_geometric.datasets.graph_generator import GraphGenerator from torch_geometric.explain import Explanation from torch_geometric.utils import k_hop_subgraph class InfectionDataset(InMemoryDataset): r"""Generates a synthetic infection dataset for evaluating explainabilty algorithms, as described in the `"Explainability Techniques for Graph Convolutional Networks" `__ paper. The :class:`~torch_geometric.datasets.InfectionDataset` creates synthetic graphs coming from a :class:`~torch_geometric.datasets.graph_generator.GraphGenerator` with :obj:`num_infected` randomly assigned infected nodes. The dataset describes a node classification task of predicting the length of the shortest path to infected nodes, with corresponding ground-truth edge-level masks. For example, to generate a random Erdos-Renyi (ER) infection graph with :obj:`500` nodes and :obj:`0.004` edge probability, write: .. code-block:: python from torch_geometric.datasets import InfectionDataset from torch_geometric.datasets.graph_generator import ERGraph dataset = InfectionDataset( graph_generator=ERGraph(num_nodes=500, edge_prob=0.004), num_infected_nodes=50, max_path_length=3, ) Args: graph_generator (GraphGenerator or str): The graph generator to be used, *e.g.*, :class:`torch.geometric.datasets.graph_generator.BAGraph` (or any string that automatically resolves to it). num_infected_nodes (int or List[int]): The number of randomly selected infected nodes in the graph. If given as a list, will select a different number of infected nodes for different graphs. max_path_length (int, List[int]): The maximum shortest path length to determine whether a node will be infected. If given as a list, will apply different shortest path lengths for different graphs. (default: :obj:`5`) num_graphs (int, optional): The number of graphs to generate. The number of graphs will be automatically determined by :obj:`len(num_infected_nodes)` or :obj:`len(max_path_length)` in case either of them is given as a list, and should only be set in case one wants to create multiple graphs while :obj:`num_infected_nodes` and :obj:`max_path_length` are given as an integer. (default: :obj:`None`) graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to the respective graph generator module in case it gets automatically resolved. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) """ def __init__( self, graph_generator: Union[GraphGenerator, str], num_infected_nodes: Union[int, List[int]], max_path_length: Union[int, List[int]], num_graphs: Optional[int] = None, graph_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None, ): super().__init__(root=None, transform=transform) assert isinstance(num_infected_nodes, (int, list)) assert isinstance(max_path_length, (int, list)) if (num_graphs is None and isinstance(num_infected_nodes, int) and isinstance(max_path_length, int)): num_graphs = 1 if num_graphs is None and isinstance(num_infected_nodes, list): num_graphs = len(num_infected_nodes) if num_graphs is None and isinstance(max_path_length, list): num_graphs = len(max_path_length) assert num_graphs is not None self.graph_generator = GraphGenerator.resolve( graph_generator, **(graph_generator_kwargs or {}), ) self.num_infected_nodes = num_infected_nodes self.max_path_length = max_path_length self.num_graphs = num_graphs if isinstance(num_infected_nodes, int): num_infected_nodes = [num_infected_nodes] * num_graphs if isinstance(max_path_length, int): max_path_length = [max_path_length] * num_graphs if len(num_infected_nodes) != num_graphs: raise ValueError(f"The length of 'num_infected_nodes' " f"(got {len(num_infected_nodes)} does not match " f"with the number of graphs (got {num_graphs})") if len(max_path_length) != num_graphs: raise ValueError(f"The length of 'max_path_length' " f"(got {len(max_path_length)} does not match " f"with the number of graphs (got {num_graphs})") if any(num_infected_nodes) <= 0: raise ValueError(f"'num_infected_nodes' needs to be positive " f"(got {min(num_infected_nodes)})") if any(max_path_length) <= 0: raise ValueError(f"'max_path_length' needs to be positive " f"(got {min(max_path_length)})") data_list: List[Explanation] = [] for N, L in zip(num_infected_nodes, max_path_length): data_list.append(self.get_graph(N, L)) self.data, self.slices = self.collate(data_list) def get_graph(self, num_infected_nodes: int, max_path_length: int) -> Explanation: data = self.graph_generator() assert data.num_nodes is not None perm = torch.randperm(data.num_nodes) x = torch.zeros((data.num_nodes, 2)) x[perm[:num_infected_nodes], 1] = 1 # Infected x[perm[num_infected_nodes:], 0] = 1 # Healthy y = torch.empty(data.num_nodes, dtype=torch.long) y.fill_(max_path_length + 1) y[perm[:num_infected_nodes]] = 0 # Infected nodes have label `0`. assert data.edge_index is not None edge_mask = torch.zeros(data.num_edges, dtype=torch.bool) for num_hops in range(1, max_path_length + 1): sub_node_index, _, _, sub_edge_mask = k_hop_subgraph( perm[:num_infected_nodes], num_hops, data.edge_index, num_nodes=data.num_nodes, flow='target_to_source', directed=True) value = torch.full_like(sub_node_index, fill_value=num_hops) y[sub_node_index] = torch.min(y[sub_node_index], value) edge_mask |= sub_edge_mask return Explanation( x=x, edge_index=data.edge_index, y=y, edge_mask=edge_mask.to(torch.float), ) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'graph_generator={self.graph_generator}, ' f'num_infected_nodes={self.num_infected_nodes}, ' f'max_path_length={self.max_path_length})') ================================================ FILE: torch_geometric/datasets/instruct_mol_dataset.py ================================================ import json import sys from typing import Callable, List, Optional import torch from tqdm import tqdm from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs from torch_geometric.utils import one_hot class InstructMolDataset(InMemoryDataset): r"""The dataset from the `"InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery" `_ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/resolve/main' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ): super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['all_clean.json'] @property def processed_file_names(self) -> List[str]: return ['data.pt'] def download(self) -> None: print('downloading dataset...') fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir) def process(self) -> None: try: from rdkit import Chem from rdkit.Chem.rdchem import BondType as BT WITH_RDKIT = True except ImportError: WITH_RDKIT = False if not WITH_RDKIT: print(("Using a pre-processed version of the dataset. Please " "install 'rdkit' to alternatively process the raw data."), file=sys.stderr) data_list = fs.torch_load(self.raw_paths[0]) data_list = [Data(**data_dict) for data_dict in data_list] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) return # types of atom and bond types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} # load data mols = json.load(open(f'{self.raw_dir}/all_clean.json')) data_list = [] for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)): mol = Chem.MolFromSmiles(smiles) if mol is None: continue x: torch.Tensor = torch.tensor([ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 for atom in mol.GetAtoms() ]) x = one_hot(x, num_classes=len(types), dtype=torch.float) rows, cols, edge_types = [], [], [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_types += [bonds[bond.GetBondType()]] * 2 rows += [i, j] cols += [j, i] edge_index = torch.tensor([rows, cols], dtype=torch.long) edge_type = torch.tensor(edge_types, dtype=torch.long) edge_attr = one_hot(edge_type, num_classes=len(bonds)) for question, answer in qa_pairs: data = Data( x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles, instruction=question, y=answer, ) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/jodie.py ================================================ import os.path as osp from typing import Callable, Optional import torch from torch_geometric.data import InMemoryDataset, TemporalData, download_url class JODIEDataset(InMemoryDataset): r"""The temporal graph datasets from the `"JODIE: Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks" `_ paper. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"Reddit"`, :obj:`"Wikipedia"`, :obj:`"MOOC"`, and :obj:`"LastFM"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Reddit - 6,509 - 25,470 - 172 - 1 * - Wikipedia - 9,227 - 157,474 - 172 - 2 * - MOOC - 7,144 - 411,749 - 4 - 2 * - LastFM - 1,980 - 1,293,103 - 2 - 1 """ url = 'http://snap.stanford.edu/jodie/{}.csv' names = ['reddit', 'wikipedia', 'mooc', 'lastfm'] def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in self.names super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=TemporalData) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.name}.csv' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url.format(self.name), self.raw_dir) def process(self) -> None: import pandas as pd df = pd.read_csv(self.raw_paths[0], skiprows=1, header=None) src = torch.from_numpy(df.iloc[:, 0].values).to(torch.long) dst = torch.from_numpy(df.iloc[:, 1].values).to(torch.long) dst += int(src.max()) + 1 t = torch.from_numpy(df.iloc[:, 2].values).to(torch.long) y = torch.from_numpy(df.iloc[:, 3].values).to(torch.long) msg = torch.from_numpy(df.iloc[:, 4:].values).to(torch.float) data = TemporalData(src=src, dst=dst, t=t, msg=msg, y=y) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name.capitalize()}()' ================================================ FILE: torch_geometric/datasets/karate.py ================================================ from typing import Callable, Optional import torch from torch_geometric.data import Data, InMemoryDataset class KarateClub(InMemoryDataset): r"""Zachary's karate club network from the `"An Information Flow Model for Conflict and Fission in Small Groups" `_ paper, containing 34 nodes, connected by 156 (undirected and unweighted) edges. Every node is labeled by one of four classes obtained via modularity-based clustering, following the `"Semi-supervised Classification with Graph Convolutional Networks" `_ paper. Training is based on a single labeled example per class, *i.e.* a total number of 4 labeled nodes. Args: transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 34 - 156 - 34 - 4 """ def __init__(self, transform: Optional[Callable] = None): super().__init__(None, transform) row = [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 13, 13, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33 ] col = [ 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2, 3, 7, 13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0, 1, 2, 7, 12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1, 2, 3, 0, 2, 30, 32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2, 3, 33, 32, 33, 32, 33, 5, 6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0, 1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23, 24, 33, 2, 31, 33, 23, 26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32, 33, 2, 8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32 ] edge_index = torch.tensor([row, col]) y = torch.tensor([ # Create communities. 1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0 ]) x = torch.eye(y.size(0), dtype=torch.float) # Select a single training node for each community # (we just use the first one). train_mask = torch.zeros(y.size(0), dtype=torch.bool) for i in range(int(y.max()) + 1): train_mask[(y == i).nonzero(as_tuple=False)[0]] = True data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask) self.data, self.slices = self.collate([data]) ================================================ FILE: torch_geometric/datasets/last_fm.py ================================================ import os import os.path as osp from itertools import product from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) class LastFM(InMemoryDataset): r"""A subset of the last.fm music website keeping track of users' listining information from various sources, as collected in the `"MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph Embedding" `_ paper. last.fm is a heterogeneous graph containing three types of entities - users (1,892 nodes), artists (17,632 nodes), and artist tags (1,088 nodes). This dataset can be used for link prediction, and no labels or features are provided. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://www.dropbox.com/s/jvlbs09pz6zwcka/LastFM_processed.zip?dl=1' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'adjM.npz', 'node_types.npy', 'train_val_test_neg_user_artist.npz', 'train_val_test_pos_user_artist.npz' ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import scipy.sparse as sp data = HeteroData() node_type_idx = np.load(osp.join(self.raw_dir, 'node_types.npy')) node_type_idx = torch.from_numpy(node_type_idx).to(torch.long) node_types = ['user', 'artist', 'tag'] for i, node_type in enumerate(node_types): data[node_type].num_nodes = int((node_type_idx == i).sum()) pos_split = np.load( osp.join(self.raw_dir, 'train_val_test_pos_user_artist.npz')) neg_split = np.load( osp.join(self.raw_dir, 'train_val_test_neg_user_artist.npz')) for name in ['train', 'val', 'test']: if name != 'train': edge_index = pos_split[f'{name}_pos_user_artist'] edge_index = torch.from_numpy(edge_index) edge_index = edge_index.t().to(torch.long).contiguous() data['user', 'artist'][f'{name}_pos_edge_index'] = edge_index edge_index = neg_split[f'{name}_neg_user_artist'] edge_index = torch.from_numpy(edge_index) edge_index = edge_index.t().to(torch.long).contiguous() data['user', 'artist'][f'{name}_neg_edge_index'] = edge_index s = {} N_u = data['user'].num_nodes N_a = data['artist'].num_nodes N_t = data['tag'].num_nodes s['user'] = (0, N_u) s['artist'] = (N_u, N_u + N_a) s['tag'] = (N_u + N_a, N_u + N_a + N_t) A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')) for src, dst in product(node_types, node_types): A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() if A_sub.nnz > 0: row = torch.from_numpy(A_sub.row).to(torch.long) col = torch.from_numpy(A_sub.col).to(torch.long) data[src, dst].edge_index = torch.stack([row, col], dim=0) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/datasets/lastfm_asia.py ================================================ from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class LastFMAsia(InMemoryDataset): r"""The LastFM Asia Network dataset introduced in the `"Characteristic Functions on Graphs: Birds of a Feather, from Statistical Descriptors to Parametric Models" `_ paper. Nodes represent LastFM users from Asia and edges are friendships. It contains 7,624 nodes, 55,612 edges, 128 node features and 18 classes. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://graphmining.ai/datasets/ptg/lastfm_asia.npz' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'lastfm_asia.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url, self.raw_dir) def process(self) -> None: data = np.load(self.raw_paths[0], 'r', allow_pickle=True) x = torch.from_numpy(data['features']).to(torch.float) y = torch.from_numpy(data['target']).to(torch.long) edge_index = torch.from_numpy(data['edges']).to(torch.long) edge_index = edge_index.t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/linkx_dataset.py ================================================ import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.io import fs from torch_geometric.utils import one_hot class LINKXDataset(InMemoryDataset): r"""A variety of non-homophilous graph datasets from the `"Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods" `_ paper. .. note:: Some of the datasets provided in :class:`LINKXDataset` are from other sources, but have been updated with new features and/or labels. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"penn94"`, :obj:`"reed98"`, :obj:`"amherst41"`, :obj:`"cornell5"`, :obj:`"johnshopkins55"`, :obj:`"genius"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ github_url = ('https://github.com/CUAI/Non-Homophily-Large-Scale/' 'raw/master/data') gdrive_url = 'https://drive.usercontent.google.com/download?confirm=t' facebook_datasets = [ 'penn94', 'reed98', 'amherst41', 'cornell5', 'johnshopkins55' ] datasets = { 'penn94': { 'data.mat': f'{github_url}/facebook100/Penn94.mat' }, 'reed98': { 'data.mat': f'{github_url}/facebook100/Reed98.mat' }, 'amherst41': { 'data.mat': f'{github_url}/facebook100/Amherst41.mat', }, 'cornell5': { 'data.mat': f'{github_url}/facebook100/Cornell5.mat' }, 'johnshopkins55': { 'data.mat': f'{github_url}/facebook100/Johns%20Hopkins55.mat' }, 'genius': { 'data.mat': f'{github_url}/genius.mat' }, 'wiki': { 'wiki_views2M.pt': f'{gdrive_url}&id=1p5DlVHrnFgYm3VsNIzahSsvCD424AyvP', 'wiki_edges2M.pt': f'{gdrive_url}&id=14X7FlkjrlUgmnsYtPwdh-gGuFla4yb5u', 'wiki_features2M.pt': f'{gdrive_url}&id=1ySNspxbK-snNoAZM7oxiWGvOnTRdSyEK' } } splits = { 'penn94': f'{github_url}/splits/fb100-Penn94-splits.npy', } def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in self.datasets.keys() super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: names = list(self.datasets[self.name].keys()) if self.name in self.splits: names += [self.splits[self.name].split('/')[-1]] return names @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for filename, path in self.datasets[self.name].items(): download_url(path, self.raw_dir, filename=filename) if self.name in self.splits: download_url(self.splits[self.name], self.raw_dir) def _process_wiki(self) -> Data: paths = {x.split('/')[-1]: x for x in self.raw_paths} x = fs.torch_load(paths['wiki_features2M.pt']) edge_index = fs.torch_load(paths['wiki_edges2M.pt']).t().contiguous() y = fs.torch_load(paths['wiki_views2M.pt']) return Data(x=x, edge_index=edge_index, y=y) def _process_facebook(self) -> Data: from scipy.io import loadmat mat = loadmat(self.raw_paths[0]) A = mat['A'].tocsr().tocoo() row = torch.from_numpy(A.row).to(torch.long) col = torch.from_numpy(A.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) metadata = torch.from_numpy(mat['local_info'].astype('int64')) xs = [] y = metadata[:, 1] - 1 # gender label, -1 means unlabeled x = torch.cat([metadata[:, :1], metadata[:, 2:]], dim=-1) for i in range(x.size(1)): _, out = x[:, i].unique(return_inverse=True) xs.append(one_hot(out)) x = torch.cat(xs, dim=-1) data = Data(x=x, edge_index=edge_index, y=y) if self.name in self.splits: splits = np.load(self.raw_paths[1], allow_pickle=True) assert data.num_nodes is not None sizes = (data.num_nodes, len(splits)) data.train_mask = torch.zeros(sizes, dtype=torch.bool) data.val_mask = torch.zeros(sizes, dtype=torch.bool) data.test_mask = torch.zeros(sizes, dtype=torch.bool) for i, split in enumerate(splits): data.train_mask[:, i][torch.tensor(split['train'])] = True data.val_mask[:, i][torch.tensor(split['valid'])] = True data.test_mask[:, i][torch.tensor(split['test'])] = True return data def _process_genius(self) -> Data: from scipy.io import loadmat mat = loadmat(self.raw_paths[0]) edge_index = torch.from_numpy(mat['edge_index']).to(torch.long) x = torch.from_numpy(mat['node_feat']).to(torch.float) y = torch.from_numpy(mat['label']).squeeze().to(torch.long) return Data(x=x, edge_index=edge_index, y=y) def process(self) -> None: if self.name in self.facebook_datasets: data = self._process_facebook() elif self.name == 'genius': data = self._process_genius() elif self.name == 'wiki': data = self._process_wiki() else: raise NotImplementedError( f"chosen dataset '{self.name}' is not implemented") if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name.capitalize()}({len(self)})' ================================================ FILE: torch_geometric/datasets/lrgb.py ================================================ import os import os.path as osp import pickle from typing import Callable, Dict, List, Optional import torch from tqdm import tqdm from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class LRGBDataset(InMemoryDataset): r"""The `"Long Range Graph Benchmark (LRGB)" `_ datasets which is a collection of 5 graph learning datasets with tasks that are based on long-range dependencies in graphs. See the original `source code `_ for more details on the individual datasets. +------------------------+-------------------+----------------------+ | Dataset | Domain | Task | +========================+===================+======================+ | :obj:`PascalVOC-SP` | Computer Vision | Node Classification | +------------------------+-------------------+----------------------+ | :obj:`COCO-SP` | Computer Vision | Node Classification | +------------------------+-------------------+----------------------+ | :obj:`PCQM-Contact` | Quantum Chemistry | Link Prediction | +------------------------+-------------------+----------------------+ | :obj:`Peptides-func` | Chemistry | Graph Classification | +------------------------+-------------------+----------------------+ | :obj:`Peptides-struct` | Chemistry | Graph Regression | +------------------------+-------------------+----------------------+ Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (one of :obj:`"PascalVOC-SP"`, :obj:`"COCO-SP"`, :obj:`"PCQM-Contact"`, :obj:`"Peptides-func"`, :obj:`"Peptides-struct"`) split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 15 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #classes * - PascalVOC-SP - 11,355 - ~479.40 - ~2,710.48 - 21 * - COCO-SP - 123,286 - ~476.88 - ~2,693.67 - 81 * - PCQM-Contact - 529,434 - ~30.14 - ~61.09 - 1 * - Peptides-func - 15,535 - ~150.94 - ~307.30 - 10 * - Peptides-struct - 15,535 - ~150.94 - ~307.30 - 11 """ names = [ 'pascalvoc-sp', 'coco-sp', 'pcqm-contact', 'peptides-func', 'peptides-struct' ] urls = { 'pascalvoc-sp': 'https://www.dropbox.com/s/8x722ai272wqwl4/pascalvocsp.zip?dl=1', 'coco-sp': 'https://www.dropbox.com/s/r6ihg1f4pmyjjy0/cocosp.zip?dl=1', 'pcqm-contact': 'https://www.dropbox.com/s/qdag867u6h6i60y/pcqmcontact.zip?dl=1', 'peptides-func': 'https://www.dropbox.com/s/ycsq37q8sxs1ou8/peptidesfunc.zip?dl=1', 'peptides-struct': 'https://www.dropbox.com/s/zgv4z8fcpmknhs8/peptidesstruct.zip?dl=1' } dwnld_file_name = { 'pascalvoc-sp': 'voc_superpixels_edge_wt_region_boundary', 'coco-sp': 'coco_superpixels_edge_wt_region_boundary', 'pcqm-contact': 'pcqmcontact', 'peptides-func': 'peptidesfunc', 'peptides-struct': 'peptidesstruct' } def __init__( self, root: str, name: str, split: str = "train", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in self.names assert split in ['train', 'val', 'test'] super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = osp.join(self.processed_dir, f'{split}.pt') self.load(path) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: if self.name.split('-')[1] == 'sp': return ['train.pickle', 'val.pickle', 'test.pickle'] else: return ['train.pt', 'val.pt', 'test.pt'] @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: fs.rm(self.raw_dir) path = download_url(self.urls[self.name], self.root) extract_zip(path, self.root) os.rename(osp.join(self.root, self.dwnld_file_name[self.name]), self.raw_dir) os.unlink(path) def process(self) -> None: if self.name == 'pcqm-contact': # PCQM-Contact self.process_pcqm_contact() else: if self.name == 'coco-sp': # Label remapping for coco-sp. # See self.label_remap_coco() func label_map = self.label_remap_coco() for split in ['train', 'val', 'test']: if self.name.split('-')[1] == 'sp': # PascalVOC-SP and COCO-SP with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f: graphs = pickle.load(f) elif self.name.split('-')[0] == 'peptides': # Peptides-func and Peptides-struct graphs = fs.torch_load( osp.join(self.raw_dir, f'{split}.pt')) data_list = [] for graph in tqdm(graphs, desc=f'Processing {split} dataset'): if self.name.split('-')[1] == 'sp': """ PascalVOC-SP and COCO-SP Each `graph` is a tuple (x, edge_attr, edge_index, y) Shape of x : [num_nodes, 14] Shape of edge_attr : [num_edges, 2] Shape of edge_index : [2, num_edges] Shape of y : [num_nodes] """ x = graph[0].to(torch.float) edge_attr = graph[1].to(torch.float) edge_index = graph[2] y = torch.LongTensor(graph[3]) elif self.name.split('-')[0] == 'peptides': """ Peptides-func and Peptides-struct Each `graph` is a tuple (x, edge_attr, edge_index, y) Shape of x : [num_nodes, 9] Shape of edge_attr : [num_edges, 3] Shape of edge_index : [2, num_edges] Shape of y : [1, 10] for Peptides-func, or [1, 11] for Peptides-struct """ x = graph[0] edge_attr = graph[1] edge_index = graph[2] y = graph[3] if self.name == 'coco-sp': for i, label in enumerate(y): y[i] = label_map[label.item()] data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) if self.pre_filter is not None and not self.pre_filter( data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) path = osp.join(self.processed_dir, f'{split}.pt') self.save(data_list, path) def label_remap_coco(self) -> Dict[int, int]: # Util function for name 'COCO-SP' # to remap the labels as the original label idxs are not contiguous original_label_idx = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90 ] label_map = {} for i, key in enumerate(original_label_idx): label_map[key] = i return label_map def process_pcqm_contact(self) -> None: for split in ['train', 'val', 'test']: graphs = fs.torch_load(osp.join(self.raw_dir, f'{split}.pt')) data_list = [] for graph in tqdm(graphs, desc=f'Processing {split} dataset'): """ PCQM-Contact Each `graph` is a tuple (x, edge_attr, edge_index, edge_label_index, edge_label) Shape of x : [num_nodes, 9] Shape of edge_attr : [num_edges, 3] Shape of edge_index : [2, num_edges] Shape of edge_label_index: [2, num_labeled_edges] Shape of edge_label : [num_labeled_edges] where, num_labeled_edges are negative edges and link pred labels, https://github.com/vijaydwivedi75/lrgb/blob/main/graphgps/loader/dataset/pcqm4mv2_contact.py#L192 """ x = graph[0] edge_attr = graph[1] edge_index = graph[2] edge_label_index = graph[3] edge_label = graph[4] data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, edge_label_index=edge_label_index, edge_label=edge_label) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, osp.join(self.processed_dir, f'{split}.pt')) ================================================ FILE: torch_geometric/datasets/malnet_tiny.py ================================================ import os import os.path as osp from typing import Callable, Dict, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_tar, extract_zip, ) from torch_geometric.io import fs class MalNetTiny(InMemoryDataset): r"""The MalNet Tiny dataset from the `"A Large-Scale Database for Graph Representation Learning" `_ paper. :class:`MalNetTiny` contains 5,000 malicious and benign software function call graphs across 5 different types. Each graph contains at most 5k nodes. Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"trainval"`, loads the training and validation dataset. If :obj:`"test"`, loads the test dataset. If :obj:`None`, loads the entire dataset. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ data_url = ('http://malnet.cc.gatech.edu/' 'graph-data/malnet-graphs-tiny.tar.gz') split_url = 'http://malnet.cc.gatech.edu/split-info/split_info_tiny.zip' splits = ['train', 'val', 'test'] def __init__( self, root: str, split: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: if split not in {'train', 'val', 'trainval', 'test', None}: raise ValueError(f'Split "{split}" found, but expected either ' f'"train", "val", "trainval", "test" or None') super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) if split is not None: split_slices = fs.torch_load(self.processed_paths[1]) if split == 'train': self._indices = range(split_slices[0], split_slices[1]) elif split == 'val': self._indices = range(split_slices[1], split_slices[2]) elif split == 'trainval': self._indices = range(split_slices[0], split_slices[2]) elif split == 'test': self._indices = range(split_slices[2], split_slices[3]) @property def raw_file_names(self) -> List[str]: return ['malnet-graphs-tiny', osp.join('split_info_tiny', 'type')] @property def processed_file_names(self) -> List[str]: return ['data.pt', 'split_slices.pt'] def download(self) -> None: path = download_url(self.data_url, self.raw_dir) extract_tar(path, self.raw_dir) os.unlink(path) path = download_url(self.split_url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: y_map: Dict[str, int] = {} data_list = [] split_slices = [0] for split in ['train', 'val', 'test']: with open(osp.join(self.raw_paths[1], f'{split}.txt')) as f: filenames = f.read().split('\n')[:-1] split_slices.append(split_slices[-1] + len(filenames)) for filename in filenames: path = osp.join(self.raw_paths[0], f'{filename}.edgelist') malware_type = filename.split('/')[0] y = y_map.setdefault(malware_type, len(y_map)) with open(path) as f: edges = f.read().split('\n')[5:-1] edge_indices = [[int(s) for s in e.split()] for e in edges] edge_index = torch.tensor(edge_indices).t().contiguous() num_nodes = int(edge_index.max()) + 1 data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes) data_list.append(data) if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] self.save(data_list, self.processed_paths[0]) torch.save(split_slices, self.processed_paths[1]) ================================================ FILE: torch_geometric/datasets/md17.py ================================================ import os import os.path as osp from typing import Callable, List, Optional, Union import numpy as np import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_tar, extract_zip, ) class MD17(InMemoryDataset): r"""A variety of ab-initio molecular dynamics trajectories from the authors of `sGDML `_. This class provides access to the original MD17 datasets, their revised versions, and the CCSD(T) trajectories. For every trajectory, the dataset contains the Cartesian positions of atoms (in Angstrom), their atomic numbers, as well as the total energy (in kcal/mol) and forces (kcal/mol/Angstrom) on each atom. The latter two are the regression targets for this collection. .. note:: Data objects contain no edge indices as these are most commonly constructed via the :obj:`torch_geometric.transforms.RadiusGraph` transform, with its cut-off being a hyperparameter. The `original MD17 dataset `_ contains ten molecule trajectories. This version of the dataset was found to suffer from high numerical noise. The `revised MD17 dataset `_ contains the same molecules, but the energies and forces were recalculated at the PBE/def2-SVP level of theory using very tight SCF convergence and very dense DFT integration grid. The third version of the dataset contains fewer molecules, computed at the CCSD(T) level of theory. The benzene molecule at the DFT FHI-aims level of theory was `released separately `_. Check the table below for detailed information on the molecule, level of theory and number of data points contained in each dataset. Which trajectory is loaded is determined by the :attr:`name` argument. For the coupled cluster trajectories, the dataset comes with pre-defined training and testing splits which are loaded separately via the :attr:`train` argument. +--------------------+--------------------+-------------------------------+-----------+ | Molecule | Level of Theory | Name | #Examples | +====================+====================+===============================+===========+ | Benzene | DFT | :obj:`benzene` | 627,983 | +--------------------+--------------------+-------------------------------+-----------+ | Uracil | DFT | :obj:`uracil` | 133,770 | +--------------------+--------------------+-------------------------------+-----------+ | Naphthalene | DFT | :obj:`naphthalene` | 326,250 | +--------------------+--------------------+-------------------------------+-----------+ | Aspirin | DFT | :obj:`aspirin` | 211,762 | +--------------------+--------------------+-------------------------------+-----------+ | Salicylic acid | DFT | :obj:`salicylic acid` | 320,231 | +--------------------+--------------------+-------------------------------+-----------+ | Malonaldehyde | DFT | :obj:`malonaldehyde` | 993,237 | +--------------------+--------------------+-------------------------------+-----------+ | Ethanol | DFT | :obj:`ethanol` | 555,092 | +--------------------+--------------------+-------------------------------+-----------+ | Toluene | DFT | :obj:`toluene` | 442,790 | +--------------------+--------------------+-------------------------------+-----------+ | Paracetamol | DFT | :obj:`paracetamol` | 106,490 | +--------------------+--------------------+-------------------------------+-----------+ | Azobenzene | DFT | :obj:`azobenzene` | 99,999 | +--------------------+--------------------+-------------------------------+-----------+ | Benzene (R) | DFT (PBE/def2-SVP) | :obj:`revised benzene` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Uracil (R) | DFT (PBE/def2-SVP) | :obj:`revised uracil` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Naphthalene (R) | DFT (PBE/def2-SVP) | :obj:`revised naphthalene` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Aspirin (R) | DFT (PBE/def2-SVP) | :obj:`revised aspirin` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Salicylic acid (R) | DFT (PBE/def2-SVP) | :obj:`revised salicylic acid` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Malonaldehyde (R) | DFT (PBE/def2-SVP) | :obj:`revised malonaldehyde` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Ethanol (R) | DFT (PBE/def2-SVP) | :obj:`revised ethanol` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Toluene (R) | DFT (PBE/def2-SVP) | :obj:`revised toluene` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Paracetamol (R) | DFT (PBE/def2-SVP) | :obj:`revised paracetamol` | 100,000 | +--------------------+--------------------+-------------------------------+-----------+ | Azobenzene (R) | DFT (PBE/def2-SVP) | :obj:`revised azobenzene` | 99,988 | +--------------------+--------------------+-------------------------------+-----------+ | Benzene | CCSD(T) | :obj:`benzene CCSD(T)` | 1,500 | +--------------------+--------------------+-------------------------------+-----------+ | Aspirin | CCSD | :obj:`aspirin CCSD` | 1,500 | +--------------------+--------------------+-------------------------------+-----------+ | Malonaldehyde | CCSD(T) | :obj:`malonaldehyde CCSD(T)` | 1,500 | +--------------------+--------------------+-------------------------------+-----------+ | Ethanol | CCSD(T) | :obj:`ethanol CCSD(T)` | 2,000 | +--------------------+--------------------+-------------------------------+-----------+ | Toluene | CCSD(T) | :obj:`toluene CCSD(T)` | 1,501 | +--------------------+--------------------+-------------------------------+-----------+ | Benzene | DFT FHI-aims | :obj:`benzene FHI-aims` | 49,863 | +--------------------+--------------------+-------------------------------+-----------+ .. warning:: It is advised to not train a model on more than 1,000 samples from the original or revised MD17 dataset. Args: root (str): Root directory where the dataset should be saved. name (str): Keyword of the trajectory that should be loaded. train (bool, optional): Determines whether the train or test split gets loaded for the coupled cluster trajectories. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #tasks * - Benzene - 627,983 - 12 - 0 - 1 - 2 * - Uracil - 133,770 - 12 - 0 - 1 - 2 * - Naphthalene - 326,250 - 10 - 0 - 1 - 2 * - Aspirin - 211,762 - 21 - 0 - 1 - 2 * - Salicylic acid - 320,231 - 16 - 0 - 1 - 2 * - Malonaldehyde - 993,237 - 9 - 0 - 1 - 2 * - Ethanol - 555,092 - 9 - 0 - 1 - 2 * - Toluene - 442,790 - 15 - 0 - 1 - 2 * - Paracetamol - 106,490 - 20 - 0 - 1 - 2 * - Azobenzene - 99,999 - 24 - 0 - 1 - 2 * - Benzene (R) - 100,000 - 12 - 0 - 1 - 2 * - Uracil (R) - 100,000 - 12 - 0 - 1 - 2 * - Naphthalene (R) - 100,000 - 10 - 0 - 1 - 2 * - Aspirin (R) - 100,000 - 21 - 0 - 1 - 2 * - Salicylic acid (R) - 100,000 - 16 - 0 - 1 - 2 * - Malonaldehyde (R) - 100,000 - 9 - 0 - 1 - 2 * - Ethanol (R) - 100,000 - 9 - 0 - 1 - 2 * - Toluene (R) - 100,000 - 15 - 0 - 1 - 2 * - Paracetamol (R) - 100,000 - 20 - 0 - 1 - 2 * - Azobenzene (R) - 99,988 - 24 - 0 - 1 - 2 * - Benzene CCSD-T - 1,500 - 12 - 0 - 1 - 2 * - Aspirin CCSD-T - 1,500 - 21 - 0 - 1 - 2 * - Malonaldehyde CCSD-T - 1,500 - 9 - 0 - 1 - 2 * - Ethanol CCSD-T - 2000 - 9 - 0 - 1 - 2 * - Toluene CCSD-T - 1,501 - 15 - 0 - 1 - 2 * - Benzene FHI-aims - 49,863 - 12 - 0 - 1 - 2 """ # noqa: E501 gdml_url = 'http://quantum-machine.org/gdml/data/npz' revised_url = ('https://archive.materialscloud.org/record/' 'file?filename=rmd17.tar.bz2&record_id=466') file_names = { 'benzene': 'md17_benzene2017.npz', 'uracil': 'md17_uracil.npz', 'naphthalene': 'md17_naphthalene.npz', 'aspirin': 'md17_aspirin.npz', 'salicylic acid': 'md17_salicylic.npz', 'malonaldehyde': 'md17_malonaldehyde.npz', 'ethanol': 'md17_ethanol.npz', 'toluene': 'md17_toluene.npz', 'paracetamol': 'paracetamol_dft.npz', 'azobenzene': 'azobenzene_dft.npz', 'revised benzene': 'rmd17_benzene.npz', 'revised uracil': 'rmd17_uracil.npz', 'revised naphthalene': 'rmd17_naphthalene.npz', 'revised aspirin': 'rmd17_aspirin.npz', 'revised salicylic acid': 'rmd17_salicylic.npz', 'revised malonaldehyde': 'rmd17_malonaldehyde.npz', 'revised ethanol': 'rmd17_ethanol.npz', 'revised toluene': 'rmd17_toluene.npz', 'revised paracetamol': 'rmd17_paracetamol.npz', 'revised azobenzene': 'rmd17_azobenzene.npz', 'benzene CCSD(T)': 'benzene_ccsd_t.zip', 'aspirin CCSD': 'aspirin_ccsd.zip', 'malonaldehyde CCSD(T)': 'malonaldehyde_ccsd_t.zip', 'ethanol CCSD(T)': 'ethanol_ccsd_t.zip', 'toluene CCSD(T)': 'toluene_ccsd_t.zip', 'benzene FHI-aims': 'benzene2018_dft.npz', } def __init__( self, root: str, name: str, train: Optional[bool] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: if name not in self.file_names: raise ValueError(f"Unknown dataset name '{name}'") self.name = name self.revised = 'revised' in name self.ccsd = 'CCSD' in self.name super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) if len(self.processed_file_names) == 1 and train is not None: raise ValueError( f"'{self.name}' dataset does not provide pre-defined splits " f"but the 'train' argument is set to '{train}'") elif len(self.processed_file_names) == 2 and train is None: raise ValueError( f"'{self.name}' dataset does provide pre-defined splits but " f"the 'train' argument was not specified") idx = 0 if train is None or train else 1 self.load(self.processed_paths[idx]) def mean(self) -> float: assert isinstance(self._data, Data) return float(self._data.energy.mean()) @property def raw_dir(self) -> str: if self.revised: return osp.join(self.root, 'raw') return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> Union[str, List[str]]: name = self.file_names[self.name] if self.revised: return osp.join('rmd17', 'npz_data', name) elif self.ccsd: return [name[:-4] + '-train.npz', name[:-4] + '-test.npz'] return name @property def processed_file_names(self) -> List[str]: if self.ccsd: return ['train.pt', 'test.pt'] else: return ['data.pt'] def download(self) -> None: if self.revised: path = download_url(self.revised_url, self.raw_dir) extract_tar(path, self.raw_dir, mode='r:bz2') os.unlink(path) else: url = f'{self.gdml_url}/{self.file_names[self.name]}' path = download_url(url, self.raw_dir) if self.ccsd: extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: it = zip(self.raw_paths, self.processed_paths) for raw_path, processed_path in it: raw_data = np.load(raw_path) if self.revised: z = torch.from_numpy(raw_data['nuclear_charges']).long() pos = torch.from_numpy(raw_data['coords']).float() energy = torch.from_numpy(raw_data['energies']).float() force = torch.from_numpy(raw_data['forces']).float() else: z = torch.from_numpy(raw_data['z']).long() pos = torch.from_numpy(raw_data['R']).float() energy = torch.from_numpy(raw_data['E']).float() force = torch.from_numpy(raw_data['F']).float() data_list = [] for i in range(pos.size(0)): data = Data(z=z, pos=pos[i], energy=energy[i], force=force[i]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, processed_path) def __repr__(self) -> str: return f"{self.__class__.__name__}({len(self)}, name='{self.name}')" ================================================ FILE: torch_geometric/datasets/medshapenet.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset class MedShapeNet(InMemoryDataset): r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale Dataset of 3D Medical Shapes for Computer Vision" `_ paper, containing 8 different type of structures (classes). .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. size (int): Number of invividual 3D structures to download per type (classes). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ def __init__( self, root: str, size: int = 100, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.size = size super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] self.load(path) @property def raw_file_names(self) -> List[str]: return [ '3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY', 'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy' ] @property def processed_file_names(self) -> List[str]: return ['dataset.pt'] @property def raw_paths(self) -> List[str]: r"""The absolute filepaths that must be present in order to skip downloading. """ return [osp.join(self.raw_dir, f) for f in self.raw_file_names] def process(self) -> None: import urllib3 from MedShapeNet import MedShapeNet as msn msn_instance = msn(timeout=120) urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50) list_of_datasets = msn_instance.datasets(False) list_of_datasets = list( filter( lambda x: x not in [ 'medshapenetcore/ASOCA', 'medshapenetcore/AVT', 'medshapenetcore/AutoImplantCraniotomy', 'medshapenetcore/FaceVR' ], list_of_datasets)) subset = [] for dataset in list_of_datasets: parts = dataset.split("/") self.newpath = self.root + '/' + parts[1 if len(parts) > 1 else 0] if not os.path.exists(self.newpath): os.makedirs(self.newpath) stl_files = msn_instance.dataset_files(dataset, '.stl') subset.extend(stl_files[:self.size]) for stl_file in stl_files[:self.size]: msn_instance.download_stl_as_numpy(bucket_name=dataset, stl_file=stl_file, output_dir=self.newpath, print_output=False) class_mapping = { '3DTeethSeg': 0, 'CoronaryArteries': 1, 'FLARE': 2, 'KITS': 3, 'PULMONARY': 4, 'SurgicalInstruments': 5, 'ThoracicAorta_Saitta': 6, 'ToothFairy': 7 } for dataset, path in zip([subset], self.processed_paths): data_list = [] for item in dataset: class_name = item.split("/")[0] item = item.split("stl")[0] target = class_mapping[class_name] file = osp.join(self.root, item + 'npz') data = np.load(file) pre_data_list = Data( pos=torch.tensor(data["vertices"], dtype=torch.float), face=torch.tensor(data["faces"], dtype=torch.long).t().contiguous()) pre_data_list.y = torch.tensor([target], dtype=torch.long) data_list.append(pre_data_list) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, path) ================================================ FILE: torch_geometric/datasets/mixhop_synthetic_dataset.py ================================================ import os.path as osp import pickle from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class MixHopSyntheticDataset(InMemoryDataset): r"""The MixHop synthetic dataset from the `"MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" `_ paper, containing 10 graphs, each with varying degree of homophily (ranging from 0.0 to 0.9). All graphs have 5,000 nodes, where each node corresponds to 1 out of 10 classes. The feature values of the nodes are sampled from a 2D Gaussian distribution, which are distinct for each class. Args: root (str): Root directory where the dataset should be saved. homophily (float): The degree of homophily (one of :obj:`0.0`, :obj:`0.1`, ..., :obj:`0.9`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://raw.githubusercontent.com/samihaija/mixhop/master/data' '/synthetic') def __init__( self, root: str, homophily: float, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.homophily = homophily assert homophily in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'processed') @property def raw_file_names(self) -> List[str]: name = f'ind.n5000-h{self.homophily:0.1f}-c10' return [f'{name}.allx', f'{name}.ally', f'{name}.graph'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for filename in self.raw_file_names: download_url(f'{self.url}/{filename}', self.raw_dir) def process(self) -> None: x = torch.from_numpy(np.load(self.raw_paths[0])) y = torch.from_numpy(np.load(self.raw_paths[1])).argmax(dim=-1) edges = pickle.load(open(self.raw_paths[2], 'rb'), encoding='latin1') row, col = [], [] for k, v in edges.items(): row += [k] * len(v) col += v edge_index = torch.tensor([row, col], dtype=torch.long) N_s = x.size(0) // 3 train_mask = torch.zeros(x.size(0), dtype=torch.bool) train_mask[:N_s] = True val_mask = torch.zeros(x.size(0), dtype=torch.bool) val_mask[N_s:2 * N_s] = True test_mask = torch.zeros(x.size(0), dtype=torch.bool) test_mask[2 * N_s:] = True data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}(homophily={self.homophily:.1f})' ================================================ FILE: torch_geometric/datasets/mnist_superpixels.py ================================================ import os from typing import Callable, List, Optional from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class MNISTSuperpixels(InMemoryDataset): r"""MNIST superpixels dataset from the `"Geometric Deep Learning on Graphs and Manifolds Using Mixture Model CNNs" `_ paper, containing 70,000 graphs with 75 nodes each. Every graph is labeled by one of 10 classes. Args: root (str): Root directory where the dataset should be saved. train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 70,000 - 75 - ~1,393.0 - 1 - 10 """ url = 'https://data.pyg.org/datasets/MNISTSuperpixels.zip' def __init__( self, root: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> str: return 'MNISTSuperpixels.pt' @property def processed_file_names(self) -> List[str]: return ['train_data.pt', 'test_data.pt'] def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: inputs = fs.torch_load(self.raw_paths[0]) for i in range(len(inputs)): data_list = [Data(**data_dict) for data_dict in inputs[i]] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[i]) ================================================ FILE: torch_geometric/datasets/modelnet.py ================================================ import glob import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs, read_off class ModelNet(InMemoryDataset): r"""The ModelNet10/40 datasets from the `"3D ShapeNets: A Deep Representation for Volumetric Shapes" `_ paper, containing CAD models of 10 and 40 categories, respectively. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. name (str, optional): The name of the dataset (:obj:`"10"` for ModelNet10, :obj:`"40"` for ModelNet40). (default: :obj:`"10"`) train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - ModelNet10 - 4,899 - ~9,508.2 - ~37,450.5 - 3 - 10 * - ModelNet40 - 12,311 - ~17,744.4 - ~66,060.9 - 3 - 40 """ urls = { '10': 'http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip', # noqa '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip' } def __init__( self, root: str, name: str = '10', train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert name in ['10', '40'] self.name = name super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> List[str]: return [ 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet' ] @property def processed_file_names(self) -> List[str]: return ['training.pt', 'test.pt'] def download(self) -> None: path = download_url(self.urls[self.name], self.root) extract_zip(path, self.root) os.unlink(path) folder = osp.join(self.root, f'ModelNet{self.name}') fs.rm(self.raw_dir) os.rename(folder, self.raw_dir) # Delete osx metadata generated during compression of ModelNet10 metadata_folder = osp.join(self.root, '__MACOSX') if osp.exists(metadata_folder): fs.rm(metadata_folder) def process(self) -> None: self.save(self.process_set('train'), self.processed_paths[0]) self.save(self.process_set('test'), self.processed_paths[1]) def process_set(self, dataset: str) -> List[Data]: categories = glob.glob(osp.join(self.raw_dir, '*', '')) categories = sorted([x.split(os.sep)[-2] for x in categories]) data_list = [] for target, category in enumerate(categories): folder = osp.join(self.raw_dir, category, dataset) paths = glob.glob(f'{folder}/{category}_*.off') for path in paths: data = read_off(path) data.y = torch.tensor([target]) data_list.append(data) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] return data_list def __repr__(self) -> str: return f'{self.__class__.__name__}{self.name}({len(self)})' ================================================ FILE: torch_geometric/datasets/molecule_gpt_dataset.py ================================================ import gzip import json import multiprocessing import os import sys from collections import defaultdict from multiprocessing import Pool from typing import Callable, List, Optional, Tuple import numpy as np import requests import torch from tqdm import tqdm from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.io import fs from torch_geometric.llm.models import LLM from torch_geometric.utils import one_hot def clean_up_description(description: str) -> str: description = description + " " # extra adj Pure if description.startswith("Pure "): description = description.replace("Pure ", "") # fix typo if description.startswith("Mercurycombines"): description = description.replace("Mercurycombines", "Mercury combines") # a special case description = description.replace( "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ", "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ") # a special case description = description.replace("5-Thymidylic acid. ", "5-Thymidylic acid. is ") # a special case description = description.replace( "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ", "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ") # a special case description = description.replace( ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" " with phosphorothioic acid. "), ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" " with phosphorothioic acid is ")) # a special case description = description.replace("5'-Uridylic acid. ", "5'-Uridylic acid is ") # a special case description = description.replace("5'-Adenylic acid, ", "5'-Adenylic acid is ") # a special case description = description.replace( "Uridine 5'-(tetrahydrogen triphosphate). ", "Uridine 5'-(tetrahydrogen triphosphate). is ") # a special case description = description.replace("Inosine 5'-Monophosphate. ", "Inosine 5'-Monophosphate. is ") # a special case description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ", "Pivaloyloxymethyl butyrate (AN-9) is ") # a special case description = description.replace( "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ", "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ") # a special case description = description.replace( "Cardamonin (also known as Dihydroxymethoxychalcone), ", "Cardamonin (also known as Dihydroxymethoxychalcone) is ") # a special case description = description.replace("Lithium has been used to treat ", "Lithium is ") # a special case description = description.replace("4,4'-Methylenebis ", "4,4'-Methylenebis is ") # a special case description = description.replace( "2,3,7,8-Tetrachlorodibenzo-p-dioxin", "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ") # a special case description = description.replace("Exposure to 2,4,5-trichlorophenol ", "2,4,5-Trichlorophenol exposure ") index = 0 L = len(description) if description.startswith('C.I. '): start_index = len('C.I. ') elif description.startswith('Nectriapyrone. D '): start_index = len('Nectriapyrone. D ') elif description.startswith( 'Salmonella enterica sv. Minnesota LPS core oligosaccharide'): start_index = len( 'Salmonella enterica sv. Minnesota LPS core oligosaccharide') else: start_index = 0 for index in range(start_index, L - 1): if index < L - 2: if description[index] == '.' and description[ index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z': break elif index == L - 2: break first_sentence = description[:index + 1] return first_sentence def extract_name( name_raw: str, description: str, ) -> Tuple[Optional[str], str, str]: first_sentence = clean_up_description(description) splitter = ' -- -- ' if ' are ' in first_sentence or ' were ' in first_sentence: replaced_words = 'These molecules' else: replaced_words = 'This molecule' first_sentence = first_sentence.replace(' is ', splitter) first_sentence = first_sentence.replace(' are ', splitter) first_sentence = first_sentence.replace(' was ', splitter) first_sentence = first_sentence.replace(' were ', splitter) first_sentence = first_sentence.replace(' appears ', splitter) first_sentence = first_sentence.replace(' occurs ', splitter) first_sentence = first_sentence.replace(' stands for ', splitter) first_sentence = first_sentence.replace(' belongs to ', splitter) first_sentence = first_sentence.replace(' exists ', splitter) # only for CID=11443 first_sentence = first_sentence.replace(' has been used in trials ', splitter) first_sentence = first_sentence.replace(' has been investigated ', splitter) first_sentence = first_sentence.replace(' has many uses ', splitter) if splitter in first_sentence: extracted_name = first_sentence.split(splitter, 1)[0] elif first_sentence.startswith(name_raw): extracted_name = name_raw elif name_raw in first_sentence: extracted_name = name_raw extracted_name = None print("=====", name_raw) print("first sentence: ", first_sentence) else: extracted_name = None if extracted_name is not None: extracted_description = description.replace(extracted_name, replaced_words) else: extracted_description = description return extracted_name, extracted_description, first_sentence class MoleculeGPTDataset(InMemoryDataset): r"""The dataset from the `"MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction" `_ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) total_page_num (int, optional): The number of pages from PubChem. (default: :obj:`10`) total_block_num (int, optional): The blocks of SDF files from PubChem. (default: :obj:`1`) num_units (int, optional): Number of units of the sample. (default: :obj:`-1`, which means all units will be used) """ description_url = ( 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/' 'heading/json?heading_type=Compound&heading=Record+Description&page={}' ) compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/' 'CURRENT-Full/SDF') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, total_page_num: int = 10, total_block_num: int = 1, num_units: int = -1, ): self.total_page_num = total_page_num self.total_block_num = total_block_num self.num_units = num_units super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['pubchem.csv'] @property def processed_file_names(self) -> List[str]: return ['data.pt'] def download(self) -> None: # Step 01. Extract description step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description" if not os.path.exists(step1_folder): os.makedirs(step1_folder) valid_CID_set = set() CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict( list) CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict( list) for page_index in tqdm(range(self.total_page_num)): page_num = page_index + 1 f_out = open( f"{step1_folder}/Compound_description_{page_num}.txt", "w") description_data = requests.get( self.description_url.format(page_num)).json() description_data = description_data["Annotations"] assert description_data["Page"] == page_num record_list = description_data["Annotation"] for record in record_list: try: CID = record["LinkedRecords"]["CID"][0] if "Name" in record: name_raw = record["Name"] CID2name_raw[CID].append(name_raw) else: name_raw = None data_list = record["Data"] for data in data_list: description = data["Value"]["StringWithMarkup"][0][ "String"].strip() extracted_name, extracted_description, _ = extract_name( # noqa: E501 name_raw, description) if extracted_name is not None: CID2name_extracted[CID].append(extracted_name) CID2text_raw[CID].append(description) CID2text_extracted[CID].append( extracted_description) valid_CID_set.add(CID) f_out.write(f"{CID}\n") f_out.write(f"{extracted_description}\n\n") except Exception: continue valid_CID_list = sorted(list(valid_CID_set)) print(f"Total CID (with raw name) {len(CID2name_raw)}") print(f"Total CID (with extracted name) {len(CID2name_extracted)}") print(f"Total CID {len(valid_CID_list)}") with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f: json.dump(CID2name_raw, f) with open(f"{self.raw_dir}/CID2name.json", "w") as f: json.dump(CID2name_extracted, f) with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f: json.dump(CID2text_raw, f) with open(f"{self.raw_dir}/CID2text.json", "w") as f: json.dump(CID2text_extracted, f) # Step 02. Download SDF Files step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" if not os.path.exists(step2_folder): for block_id in tqdm(range(self.total_block_num)): block_size = 500000 l_id = block_id * block_size + 1 r_id = (block_id + 1) * block_size compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" download_url(f"{self.compound_url}/{compound_file_name}", step2_folder) def process(self, use_mp: bool = False) -> None: try: from rdkit import Chem from rdkit.Chem.rdchem import BondType as BT WITH_RDKIT = True except ImportError: WITH_RDKIT = False if not WITH_RDKIT: print(("Using a pre-processed version of the dataset. Please " "install 'rdkit' to alternatively process the raw data."), file=sys.stderr) data_list = fs.torch_load(self.raw_paths[0]) data_list = [Data(**data_dict) for data_dict in data_list] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) return # Step 03. Filter out SDF step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered" if not os.path.exists(step3_folder): os.makedirs(step3_folder) with open(f"{self.raw_dir}/CID2text.json") as f: CID2text = json.load(f) target_CID_list = set(CID2text.keys()) block_size = 500000 def extract_one_SDF_file(block_id: int) -> None: valid_mol_count = 0 writer = Chem.SDWriter( f'{step3_folder}/filtered_{block_id}.sdf') l_id = block_id * block_size + 1 r_id = (block_id + 1) * block_size compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}") suppl = Chem.ForwardSDMolSupplier(gzip_loader) for mol in tqdm(suppl): if mol is None: continue cid = mol.GetProp("PUBCHEM_COMPOUND_CID") if cid not in target_CID_list: continue writer.write(mol) valid_mol_count += 1 writer.close() print(f"block id: {block_id}\nfound {valid_mol_count}\n\n") sys.stdout.flush() return if use_mp: num_process = multiprocessing.cpu_count() print(f"{num_process} CPUs") num_process = 8 p = Pool(num_process) block_id_list = np.arange(self.total_block_num) with p: p.map(extract_one_SDF_file, block_id_list) else: for block_id in range(self.total_block_num): extract_one_SDF_file(block_id) # Step 04. Merge SDF with open(f"{self.raw_dir}/CID2text.json") as f: CID2text = json.load(f) target_CID_list = set(CID2text.keys()) print(f'The length of target_CID_list: {len(target_CID_list)}') writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf') found_CID_set = set() for block_id in range(self.total_block_num + 1): compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf" try: suppl = Chem.SDMolSupplier(compound_file_path) for mol in tqdm(suppl): writer.write(mol) cid = mol.GetProp("PUBCHEM_COMPOUND_CID") found_CID_set.add(cid) except Exception: print(f"block id: {block_id} with 0 valid SDF file") continue writer.close() print(f"In total: {len(found_CID_set)} molecules") # Step 05. Convert to PyG data format types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} data_list = [] # Real data CID2text_file = f'{self.raw_dir}/CID2text.json' with open(CID2text_file) as f: CID2text_data = json.load(f) suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf') llm = LLM( model_name='Qwen/Qwen3-0.6B', num_params=1, dtype=torch.bfloat16, sys_prompt='You are an agent, answer my questions.', ) prompt = ("Propose a question regarding the molecule '∼' " "whose answer is: {}:") for mol in tqdm(suppl): if mol.HasProp('PUBCHEM_COMPOUND_CID'): CID = mol.GetProp("PUBCHEM_COMPOUND_CID") CAN_SMILES = mol.GetProp("PUBCHEM_SMILES") m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES) if m is None: continue RDKit_CAN_SMILES = Chem.MolToSmiles(m) ground_truth = CID2text_data[CID][0] instruction = llm.inference([prompt.format(ground_truth)])[0] x: torch.Tensor = torch.tensor([ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 for atom in m.GetAtoms() ]) x = one_hot(x, num_classes=len(types), dtype=torch.float) rows, cols, edge_types = [], [], [] for bond in m.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_types += [bonds[bond.GetBondType()]] * 2 rows += [i, j] cols += [j, i] edge_index = torch.tensor([rows, cols], dtype=torch.long) edge_type = torch.tensor(edge_types, dtype=torch.long) edge_attr = one_hot(edge_type, num_classes=len(bonds)) data = Data( x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=RDKit_CAN_SMILES, instruction=instruction, y=ground_truth, ) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) if self.num_units > 0 and len(data_list) >= self.num_units: break self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/molecule_net.py ================================================ import os import os.path as osp import re import warnings from typing import Callable, Dict, Optional, Tuple, Union import torch from torch_geometric.data import InMemoryDataset, download_url, extract_gz from torch_geometric.utils import from_smiles as _from_smiles class MoleculeNet(InMemoryDataset): r"""The `MoleculeNet `_ benchmark collection from the `"MoleculeNet: A Benchmark for Molecular Machine Learning" `_ paper, containing datasets from physical chemistry, biophysics and physiology. All datasets come with the additional node and edge features introduced by the :ogb:`null` `Open Graph Benchmark `_. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"ESOL"`, :obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`, :obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBBP"`, :obj:`"Tox21"`, :obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) from_smiles (callable, optional): A custom function that takes a SMILES string and outputs a :obj:`~torch_geometric.data.Data` object. If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`. (default: :obj:`None`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - ESOL - 1,128 - ~13.3 - ~27.4 - 9 - 1 * - FreeSolv - 642 - ~8.7 - ~16.8 - 9 - 1 * - Lipophilicity - 4,200 - ~27.0 - ~59.0 - 9 - 1 * - PCBA - 437,929 - ~26.0 - ~56.2 - 9 - 128 * - MUV - 93,087 - ~24.2 - ~52.6 - 9 - 17 * - HIV - 41,127 - ~25.5 - ~54.9 - 9 - 1 * - BACE - 1513 - ~34.1 - ~73.7 - 9 - 1 * - BBBP - 2,050 - ~23.9 - ~51.6 - 9 - 1 * - Tox21 - 7,831 - ~18.6 - ~38.6 - 9 - 12 * - ToxCast - 8,597 - ~18.7 - ~38.4 - 9 - 617 * - SIDER - 1,427 - ~33.6 - ~70.7 - 9 - 27 * - ClinTox - 1,484 - ~26.1 - ~55.5 - 9 - 2 """ url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/{}' # Format: name: (display_name, url_name, csv_name, smiles_idx, y_idx) names: Dict[str, Tuple[str, str, str, int, Union[int, slice]]] = { 'esol': ('ESOL', 'delaney-processed.csv', 'delaney-processed', -1, -2), 'freesolv': ('FreeSolv', 'SAMPL.csv', 'SAMPL', 1, 2), 'lipo': ('Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity', 2, 1), 'pcba': ('PCBA', 'pcba.csv.gz', 'pcba', -1, slice(0, 128)), 'muv': ('MUV', 'muv.csv.gz', 'muv', -1, slice(0, 17)), 'hiv': ('HIV', 'HIV.csv', 'HIV', 0, -1), 'bace': ('BACE', 'bace.csv', 'bace', 0, 2), 'bbbp': ('BBBP', 'BBBP.csv', 'BBBP', -1, -2), 'tox21': ('Tox21', 'tox21.csv.gz', 'tox21', -1, slice(0, 12)), 'toxcast': ('ToxCast', 'toxcast_data.csv.gz', 'toxcast_data', 0, slice(1, 618)), 'sider': ('SIDER', 'sider.csv.gz', 'sider', 0, slice(1, 28)), 'clintox': ('ClinTox', 'clintox.csv.gz', 'clintox', 0, slice(1, 3)), } def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, from_smiles: Optional[Callable] = None, ) -> None: self.name = name.lower() assert self.name in self.names.keys() self.from_smiles = from_smiles or _from_smiles super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.names[self.name][2]}.csv' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: url = self.url.format(self.names[self.name][1]) path = download_url(url, self.raw_dir) if self.names[self.name][1][-2:] == 'gz': extract_gz(path, self.raw_dir) os.unlink(path) def process(self) -> None: with open(self.raw_paths[0]) as f: dataset = f.read().split('\n')[1:-1] dataset = [x for x in dataset if len(x) > 0] # Filter empty lines. data_list = [] for line in dataset: line = re.sub(r'\".*\"', '', line) # Replace ".*" strings. values = line.split(',') smiles = values[self.names[self.name][3]] labels = values[self.names[self.name][4]] labels = labels if isinstance(labels, list) else [labels] ys = [float(y) if len(y) > 0 else float('NaN') for y in labels] y = torch.tensor(ys, dtype=torch.float).view(1, -1) data = self.from_smiles(smiles) data.y = y if data.num_nodes == 0: warnings.warn( f"Skipping molecule '{smiles}' since it " f"resulted in zero atoms", stacklevel=2) continue if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) def __repr__(self) -> str: return f'{self.names[self.name][0]}({len(self)})' ================================================ FILE: torch_geometric/datasets/motif_generator/__init__.py ================================================ from .base import MotifGenerator from .custom import CustomMotif from .house import HouseMotif from .cycle import CycleMotif from .grid import GridMotif __all__ = classes = [ 'MotifGenerator', 'CustomMotif', 'HouseMotif', 'CycleMotif', 'GridMotif', ] ================================================ FILE: torch_geometric/datasets/motif_generator/base.py ================================================ from abc import ABC, abstractmethod from typing import Any from torch_geometric.data import Data from torch_geometric.resolver import resolver class MotifGenerator(ABC): r"""An abstract base class for generating a motif.""" @abstractmethod def __call__(self) -> Data: r"""To be implemented by :class:`Motif` subclasses.""" @staticmethod def resolve(query: Any, *args: Any, **kwargs: Any) -> 'MotifGenerator': import torch_geometric.datasets.motif_generator as _motif_generators motif_generators = [ gen for gen in vars(_motif_generators).values() if isinstance(gen, type) and issubclass(gen, MotifGenerator) ] return resolver(motif_generators, {}, query, MotifGenerator, 'Motif', *args, **kwargs) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/datasets/motif_generator/custom.py ================================================ from typing import Any, Optional from torch_geometric.data import Data from torch_geometric.datasets.motif_generator import MotifGenerator from torch_geometric.utils import from_networkx class CustomMotif(MotifGenerator): r"""Generates a motif based on a custom structure coming from a :class:`torch_geometric.data.Data` or :class:`networkx.Graph` object. Args: structure (torch_geometric.data.Data or networkx.Graph): The structure to use as a motif. """ def __init__(self, structure: Any): super().__init__() self.structure: Optional[Data] = None if isinstance(structure, Data): self.structure = structure else: try: import networkx as nx if isinstance(structure, nx.Graph): self.structure = from_networkx(structure) except ImportError: pass if self.structure is None: raise ValueError(f"Expected a motif structure of type " f"'torch_geometric.data.Data' or 'networkx.Graph'" f"(got {type(structure)})") def __call__(self) -> Data: assert isinstance(self.structure, Data) return self.structure ================================================ FILE: torch_geometric/datasets/motif_generator/cycle.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.datasets.motif_generator import CustomMotif class CycleMotif(CustomMotif): r"""Generates the cycle motif from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" `__ paper. Args: num_nodes (int): The number of nodes in the cycle. """ def __init__(self, num_nodes: int): self.num_nodes = num_nodes row = torch.arange(num_nodes).view(-1, 1).repeat(1, 2).view(-1) col1 = torch.arange(-1, num_nodes - 1) % num_nodes col2 = torch.arange(1, num_nodes + 1) % num_nodes col = torch.stack([col1, col2], dim=1).sort(dim=-1)[0].view(-1) structure = Data( num_nodes=num_nodes, edge_index=torch.stack([row, col], dim=0), ) super().__init__(structure) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num_nodes})' ================================================ FILE: torch_geometric/datasets/motif_generator/grid.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.datasets.motif_generator import CustomMotif class GridMotif(CustomMotif): r"""Generates the grid-structured motif from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" `__ paper. """ def __init__(self) -> None: edge_indices = [ [0, 1], [0, 3], [1, 4], [3, 4], [1, 2], [2, 5], [4, 5], [3, 6], [6, 7], [4, 7], [5, 8], [7, 8], [1, 0], [3, 0], [4, 1], [4, 3], [2, 1], [5, 2], [5, 4], [6, 3], [7, 6], [7, 4], [8, 5], [8, 7], ] structure = Data( num_nodes=9, edge_index=torch.tensor(edge_indices).t().contiguous(), y=torch.tensor([0, 1, 0, 1, 2, 1, 0, 1, 0]), ) super().__init__(structure) ================================================ FILE: torch_geometric/datasets/motif_generator/house.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.datasets.motif_generator import CustomMotif class HouseMotif(CustomMotif): r"""Generates the house-structured motif from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" `__ paper, containing 5 nodes and 6 undirected edges. Nodes are labeled according to their structural role: the top, middle and bottom of the house. """ def __init__(self) -> None: structure = Data( num_nodes=5, edge_index=torch.tensor([ [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4], [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1], ]), y=torch.tensor([0, 0, 1, 1, 2]), ) super().__init__(structure) ================================================ FILE: torch_geometric/datasets/movie_lens.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) class MovieLens(InMemoryDataset): r"""A heterogeneous rating dataset, assembled by GroupLens Research from the `MovieLens web site `_, consisting of nodes of type :obj:`"movie"` and :obj:`"user"`. User ratings for movies are available as ground truth labels for the edges between the users and the movies :obj:`("user", "rates", "movie")`. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) model_name (str): Name of model used to transform movie titles to node features. The model comes from the`Huggingface SentenceTransformer `_. force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, model_name: Optional[str] = 'all-MiniLM-L6-v2', force_reload: bool = False, ) -> None: self.model_name = model_name super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ osp.join('ml-latest-small', 'movies.csv'), osp.join('ml-latest-small', 'ratings.csv'), ] @property def processed_file_names(self) -> str: return f'data_{self.model_name}.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import pandas as pd from sentence_transformers import SentenceTransformer data = HeteroData() df = pd.read_csv(self.raw_paths[0], index_col='movieId') movie_mapping = {idx: i for i, idx in enumerate(df.index)} genres = df['genres'].str.get_dummies('|').values genres = torch.from_numpy(genres).to(torch.float) model = SentenceTransformer(self.model_name) with torch.no_grad(): emb = model.encode(df['title'].values, show_progress_bar=True, convert_to_tensor=True).cpu() data['movie'].x = torch.cat([emb, genres], dim=-1) df = pd.read_csv(self.raw_paths[1]) user_mapping = {idx: i for i, idx in enumerate(df['userId'].unique())} data['user'].num_nodes = len(user_mapping) src = [user_mapping[idx] for idx in df['userId']] dst = [movie_mapping[idx] for idx in df['movieId']] edge_index = torch.tensor([src, dst]) rating = torch.from_numpy(df['rating'].values).to(torch.long) time = torch.from_numpy(df['timestamp'].values).to(torch.long) data['user', 'rates', 'movie'].edge_index = edge_index data['user', 'rates', 'movie'].edge_label = rating data['user', 'rates', 'movie'].time = time if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/movie_lens_100k.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs MOVIE_HEADERS = [ "movieId", "title", "releaseDate", "videoReleaseDate", "IMDb URL", "unknown", "Action", "Adventure", "Animation", "Children's", "Comedy", "Crime", "Documentary", "Drama", "Fantasy", "Film-Noir", "Horror", "Musical", "Mystery", "Romance", "Sci-Fi", "Thriller", "War", "Western" ] USER_HEADERS = ["userId", "age", "gender", "occupation", "zipCode"] RATING_HEADERS = ["userId", "movieId", "rating", "timestamp"] class MovieLens100K(InMemoryDataset): r"""The MovieLens 100K heterogeneous rating dataset, assembled by GroupLens Research from the `MovieLens web site `__, consisting of movies (1,682 nodes) and users (943 nodes) with 100K ratings between them. User ratings for movies are available as ground truth labels. Features of users and movies are encoded according to the `"Inductive Matrix Completion Based on Graph Neural Networks" `__ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 :header-rows: 1 * - Node/Edge Type - #nodes/#edges - #features - #tasks * - Movie - 1,682 - 18 - * - User - 943 - 24 - * - User-Movie - 80,000 - 1 - 1 """ url = 'https://files.grouplens.org/datasets/movielens/ml-100k.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return ['u.item', 'u.user', 'u1.base', 'u1.test'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.remove(path) folder = osp.join(self.root, 'ml-100k') fs.rm(self.raw_dir) os.rename(folder, self.raw_dir) def process(self) -> None: import pandas as pd data = HeteroData() # Process movie data: df = pd.read_csv( self.raw_paths[0], sep='|', header=None, names=MOVIE_HEADERS, index_col='movieId', encoding='ISO-8859-1', ) movie_mapping = {idx: i for i, idx in enumerate(df.index)} x = df[MOVIE_HEADERS[6:]].values data['movie'].x = torch.from_numpy(x).to(torch.float) # Process user data: df = pd.read_csv( self.raw_paths[1], sep='|', header=None, names=USER_HEADERS, index_col='userId', encoding='ISO-8859-1', ) user_mapping = {idx: i for i, idx in enumerate(df.index)} age = df['age'].values / df['age'].values.max() age = torch.from_numpy(age).to(torch.float).view(-1, 1) gender = df['gender'].str.get_dummies().values gender = torch.from_numpy(gender).to(torch.float) occupation = df['occupation'].str.get_dummies().values occupation = torch.from_numpy(occupation).to(torch.float) data['user'].x = torch.cat([age, gender, occupation], dim=-1) # Process rating data for training: df = pd.read_csv( self.raw_paths[2], sep='\t', header=None, names=RATING_HEADERS, ) src = [user_mapping[idx] for idx in df['userId']] dst = [movie_mapping[idx] for idx in df['movieId']] edge_index = torch.tensor([src, dst]) data['user', 'rates', 'movie'].edge_index = edge_index rating = torch.from_numpy(df['rating'].values).to(torch.long) data['user', 'rates', 'movie'].rating = rating time = torch.from_numpy(df['timestamp'].values) data['user', 'rates', 'movie'].time = time data['movie', 'rated_by', 'user'].edge_index = edge_index.flip([0]) data['movie', 'rated_by', 'user'].rating = rating data['movie', 'rated_by', 'user'].time = time # Process rating data for testing: df = pd.read_csv( self.raw_paths[3], sep='\t', header=None, names=RATING_HEADERS, ) src = [user_mapping[idx] for idx in df['userId']] dst = [movie_mapping[idx] for idx in df['movieId']] edge_label_index = torch.tensor([src, dst]) data['user', 'rates', 'movie'].edge_label_index = edge_label_index edge_label = torch.from_numpy(df['rating'].values).to(torch.float) data['user', 'rates', 'movie'].edge_label = edge_label if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/movie_lens_1m.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs MOVIE_HEADERS = ["movieId", "title", "genres"] USER_HEADERS = ["userId", "gender", "age", "occupation", "zipCode"] RATING_HEADERS = ['userId', 'movieId', 'rating', 'timestamp'] class MovieLens1M(InMemoryDataset): r"""The MovieLens 1M heterogeneous rating dataset, assembled by GroupLens Research from the `MovieLens web site `__, consisting of movies (3,883 nodes) and users (6,040 nodes) with approximately 1 million ratings between them. User ratings for movies are available as ground truth labels. Features of users and movies are encoded according to the `"Inductive Matrix Completion Based on Graph Neural Networks" `__ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 :header-rows: 1 * - Node/Edge Type - #nodes/#edges - #features - #tasks * - Movie - 3,883 - 18 - * - User - 6,040 - 30 - * - User-Movie - 1,000,209 - 1 - 1 """ url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return ['movies.dat', 'users.dat', 'ratings.dat'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.remove(path) folder = osp.join(self.root, 'ml-1m') fs.rm(self.raw_dir) os.rename(folder, self.raw_dir) def process(self) -> None: import pandas as pd data = HeteroData() # Process movie data: df = pd.read_csv( self.raw_paths[0], sep='::', header=None, index_col='movieId', names=MOVIE_HEADERS, encoding='ISO-8859-1', engine='python', ) movie_mapping = {idx: i for i, idx in enumerate(df.index)} genres = df['genres'].str.get_dummies('|').values genres = torch.from_numpy(genres).to(torch.float) data['movie'].x = genres # Process user data: df = pd.read_csv( self.raw_paths[1], sep='::', header=None, index_col='userId', names=USER_HEADERS, dtype='str', encoding='ISO-8859-1', engine='python', ) user_mapping = {idx: i for i, idx in enumerate(df.index)} age = df['age'].str.get_dummies().values age = torch.from_numpy(age).to(torch.float) gender = df['gender'].str.get_dummies().values gender = torch.from_numpy(gender).to(torch.float) occupation = df['occupation'].str.get_dummies().values occupation = torch.from_numpy(occupation).to(torch.float) data['user'].x = torch.cat([age, gender, occupation], dim=-1) # Process rating data: df = pd.read_csv( self.raw_paths[2], sep='::', header=None, names=RATING_HEADERS, encoding='ISO-8859-1', engine='python', ) src = [user_mapping[idx] for idx in df['userId']] dst = [movie_mapping[idx] for idx in df['movieId']] edge_index = torch.tensor([src, dst]) data['user', 'rates', 'movie'].edge_index = edge_index rating = torch.from_numpy(df['rating'].values).to(torch.long) data['user', 'rates', 'movie'].rating = rating time = torch.from_numpy(df['timestamp'].values) data['user', 'rates', 'movie'].time = time data['movie', 'rated_by', 'user'].edge_index = edge_index.flip([0]) data['movie', 'rated_by', 'user'].rating = rating data['movie', 'rated_by', 'user'].time = time if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/myket.py ================================================ from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import InMemoryDataset, TemporalData, download_url class MyketDataset(InMemoryDataset): r"""The Myket Android Application Install dataset from the `"Effect of Choosing Loss Function when Using T-Batching for Representation Learning on Dynamic Networks" `_ paper. The dataset contains a temporal graph of application install interactions in an Android application market. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Myket - 17,988 - 694,121 - 33 - 1 """ url = ('https://raw.githubusercontent.com/erfanloghmani/' 'myket-android-application-market-dataset/main/data_int_index') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=TemporalData) @property def raw_file_names(self) -> List[str]: return ['myket.csv', 'app_info_sample.npy'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for file_name in self.raw_file_names: download_url(f'{self.url}/{file_name}', self.raw_dir) def process(self) -> None: import pandas as pd df = pd.read_csv(self.raw_paths[0], skiprows=1, header=None) src = torch.from_numpy(df[0].values) dst = torch.from_numpy(df[1].values) t = torch.from_numpy(df[2].values) x = torch.from_numpy(np.load(self.raw_paths[1])).to(torch.float) msg = x[dst] dst = dst + (int(src.max()) + 1) data = TemporalData(src=src, dst=dst, t=t, msg=msg) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/nell.py ================================================ import os import os.path as osp from typing import Callable, List, Optional from torch_geometric.data import InMemoryDataset, download_url, extract_tar from torch_geometric.io import fs, read_planetoid_data class NELL(InMemoryDataset): r"""The NELL dataset, a knowledge graph from the `"Toward an Architecture for Never-Ending Language Learning" `_ paper. The dataset is processed as in the `"Revisiting Semi-Supervised Learning with Graph Embeddings" `_ paper. .. note:: Entity nodes are described by sparse feature vectors of type :class:`torch.sparse_csr_tensor`. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 65,755 - 251,550 - 61,278 - 186 """ url = 'http://www.cs.cmu.edu/~zhiliny/data/nell_data.tar.gz' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] return [f'ind.nell.0.001.{name}' for name in names] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_tar(path, self.root) os.unlink(path) fs.rm(self.raw_dir) os.rename(osp.join(self.root, 'nell_data'), self.raw_dir) def process(self) -> None: data = read_planetoid_data(self.raw_dir, 'nell.0.001') data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/neurograph.py ================================================ import os import os.path as osp from typing import Callable, List, Optional from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class NeuroGraphDataset(InMemoryDataset): r"""The NeuroGraph benchmark datasets from the `"NeuroGraph: Benchmarks for Graph Machine Learning in Brain Connectomics" `_ paper. :class:`NeuroGraphDataset` holds a collection of five neuroimaging graph learning datasets that span multiple categories of demographics, mental states, and cognitive traits. See the `documentation `_ and the `Github `_ for more details. +--------------------+---------+----------------------+ | Dataset | #Graphs | Task | +====================+=========+======================+ | :obj:`HCPTask` | 7,443 | Graph Classification | +--------------------+---------+----------------------+ | :obj:`HCPGender` | 1,078 | Graph Classification | +--------------------+---------+----------------------+ | :obj:`HCPAge` | 1,065 | Graph Classification | +--------------------+---------+----------------------+ | :obj:`HCPFI` | 1,071 | Graph Regression | +--------------------+---------+----------------------+ | :obj:`HCPWM` | 1,078 | Graph Regression | +--------------------+---------+----------------------+ Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (one of :obj:`"HCPGender"`, :obj:`"HCPTask"`, :obj:`"HCPAge"`, :obj:`"HCPFI"`, :obj:`"HCPWM"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://vanderbilt.box.com/shared/static' filenames = { 'HCPGender': 'r6hlz2arm7yiy6v6981cv2nzq3b0meax.zip', 'HCPTask': '8wzz4y17wpxg2stip7iybtmymnybwvma.zip', 'HCPAge': 'lzzks4472czy9f9vc8aikp7pdbknmtfe.zip', 'HCPWM': 'xtmpa6712fidi94x6kevpsddf9skuoxy.zip', 'HCPFI': 'g2md9h9snh7jh6eeay02k1kr9m4ido9f.zip', } def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert name in self.filenames.keys() self.name = name super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def raw_file_names(self) -> str: return 'data.pt' @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: url = f'{self.url}/{self.filenames[self.name]}' path = download_url(url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) os.rename( osp.join(self.raw_dir, self.name, 'processed', f'{self.name}.pt'), osp.join(self.raw_dir, 'data.pt')) fs.rm(osp.join(self.raw_dir, self.name)) def process(self) -> None: data, slices = fs.torch_load(self.raw_paths[0]) num_samples = slices['x'].size(0) - 1 data_list: List[Data] = [] for i in range(num_samples): x = data.x[slices['x'][i]:slices['x'][i + 1]] start = slices['edge_index'][i] end = slices['edge_index'][i + 1] edge_index = data.edge_index[:, start:end] sample = Data(x=x, edge_index=edge_index, y=data.y[i]) if self.pre_filter is not None and not self.pre_filter(sample): continue if self.pre_transform is not None: sample = self.pre_transform(sample) data_list.append(sample) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/ogb_mag.py ================================================ import os import os.path as osp import shutil from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class OGB_MAG(InMemoryDataset): r"""The :obj:`ogbn-mag` dataset from the `"Open Graph Benchmark: Datasets for Machine Learning on Graphs" `_ paper. :obj:`ogbn-mag` is a heterogeneous graph composed of a subset of the Microsoft Academic Graph (MAG). It contains four types of entities — papers (736,389 nodes), authors (1,134,649 nodes), institutions (8,740 nodes), and fields of study (59,965 nodes) — as well as four types of directed relations connecting two types of entities. Each paper is associated with a 128-dimensional :obj:`word2vec` feature vector, while all other node types are not associated with any input features. The task is to predict the venue (conference or journal) of each paper. In total, there are 349 different venues. Args: root (str): Root directory where the dataset should be saved. preprocess (str, optional): Pre-processes the original dataset by adding structural features (:obj:`"metapath2vec"`, :obj:`"TransE"`) to featureless nodes. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip' urls = { 'metapath2vec': ('https://data.pyg.org/datasets/' 'mag_metapath2vec_emb.zip'), 'transe': ('https://data.pyg.org/datasets/' 'mag_transe_emb.zip'), } def __init__( self, root: str, preprocess: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: preprocess = None if preprocess is None else preprocess.lower() self.preprocess = preprocess assert self.preprocess in [None, 'metapath2vec', 'transe'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def num_classes(self) -> int: assert isinstance(self._data, HeteroData) return int(self._data['paper'].y.max()) + 1 @property def raw_dir(self) -> str: return osp.join(self.root, 'mag', 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, 'mag', 'processed') @property def raw_file_names(self) -> List[str]: file_names = [ 'node-feat', 'node-label', 'relations', 'split', 'num-node-dict.csv.gz' ] if self.preprocess is not None: file_names += [f'mag_{self.preprocess}_emb.pt'] return file_names @property def processed_file_names(self) -> str: if self.preprocess is not None: return f'data_{self.preprocess}.pt' else: return 'data.pt' def download(self) -> None: if not all([osp.exists(f) for f in self.raw_paths[:5]]): path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) for file_name in ['node-feat', 'node-label', 'relations']: path = osp.join(self.raw_dir, 'mag', 'raw', file_name) shutil.move(path, self.raw_dir) path = osp.join(self.raw_dir, 'mag', 'split') shutil.move(path, self.raw_dir) path = osp.join(self.raw_dir, 'mag', 'raw', 'num-node-dict.csv.gz') shutil.move(path, self.raw_dir) fs.rm(osp.join(self.raw_dir, 'mag')) os.remove(osp.join(self.raw_dir, 'mag.zip')) if self.preprocess is not None: path = download_url(self.urls[self.preprocess], self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import pandas as pd data = HeteroData() path = osp.join(self.raw_dir, 'node-feat', 'paper', 'node-feat.csv.gz') x_paper = pd.read_csv(path, compression='gzip', header=None, dtype=np.float32).values data['paper'].x = torch.from_numpy(x_paper) path = osp.join(self.raw_dir, 'node-feat', 'paper', 'node_year.csv.gz') year_paper = pd.read_csv(path, compression='gzip', header=None, dtype=np.int64).values data['paper'].year = torch.from_numpy(year_paper).view(-1) path = osp.join(self.raw_dir, 'node-label', 'paper', 'node-label.csv.gz') y_paper = pd.read_csv(path, compression='gzip', header=None, dtype=np.int64).values.flatten() data['paper'].y = torch.from_numpy(y_paper) if self.preprocess is None: path = osp.join(self.raw_dir, 'num-node-dict.csv.gz') num_nodes_df = pd.read_csv(path, compression='gzip') for node_type in ['author', 'institution', 'field_of_study']: data[node_type].num_nodes = num_nodes_df[node_type].tolist()[0] else: emb_dict = fs.torch_load(self.raw_paths[-1]) for key, value in emb_dict.items(): if key != 'paper': data[key].x = value for edge_type in [('author', 'affiliated_with', 'institution'), ('author', 'writes', 'paper'), ('paper', 'cites', 'paper'), ('paper', 'has_topic', 'field_of_study')]: f = '___'.join(edge_type) path = osp.join(self.raw_dir, 'relations', f, 'edge.csv.gz') edge_index = pd.read_csv(path, compression='gzip', header=None, dtype=np.int64).values edge_index = torch.from_numpy(edge_index).t().contiguous() data[edge_type].edge_index = edge_index for f, v in [('train', 'train'), ('valid', 'val'), ('test', 'test')]: path = osp.join(self.raw_dir, 'split', 'time', 'paper', f'{f}.csv.gz') idx = pd.read_csv(path, compression='gzip', header=None, dtype=np.int64).values.flatten() idx = torch.from_numpy(idx) mask = torch.zeros(data['paper'].num_nodes, dtype=torch.bool) mask[idx] = True data['paper'][f'{v}_mask'] = mask if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return 'ogbn-mag()' ================================================ FILE: torch_geometric/datasets/omdb.py ================================================ import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, extract_tar class OMDB(InMemoryDataset): r"""The `Organic Materials Database (OMDB) `__ of bulk organic crystals. Args: root (str): Root directory where the dataset should be saved. train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://omdb.mathub.io/dataset' def __init__( self, root: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> str: return 'OMDB-GAP1_v1.1.tar.gz' @property def processed_file_names(self) -> List[str]: return ['train_data.pt', 'test_data.pt'] def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download '{self.raw_file_names}' from " f"'{self.url}' and move it to '{self.raw_dir}'") def process(self) -> None: from ase.io import read extract_tar(self.raw_paths[0], self.raw_dir, log=False) materials = read(osp.join(self.raw_dir, 'structures.xyz'), index=':') bandgaps = np.loadtxt(osp.join(self.raw_dir, 'bandgaps.csv')) data_list = [] for material, bandgap in zip(materials, bandgaps): pos = torch.from_numpy(material.get_positions()).to(torch.float) z = torch.from_numpy(material.get_atomic_numbers()).to(torch.int64) y = torch.tensor([float(bandgap)]) data_list.append(Data(z=z, pos=pos, y=y)) train_data = data_list[:10000] test_data = data_list[10000:] if self.pre_filter is not None: train_data = [d for d in train_data if self.pre_filter(d)] test_data = [d for d in test_data if self.pre_filter(d)] if self.pre_transform is not None: train_data = [self.pre_transform(d) for d in train_data] test_data = [self.pre_transform(d) for d in test_data] self.save(train_data, self.processed_paths[0]) self.save(test_data, self.processed_paths[1]) ================================================ FILE: torch_geometric/datasets/opf.py ================================================ import json import os import os.path as osp from typing import Callable, Dict, List, Literal, Optional import torch import tqdm from torch import Tensor from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_tar, ) class OPFDataset(InMemoryDataset): r"""The heterogeneous OPF data from the `"Large-scale Datasets for AC Optimal Power Flow with Topological Perturbations" `_ paper. :class:`OPFDataset` is a large-scale dataset of solved optimal power flow problems, derived from the `pglib-opf `_ dataset. The physical topology of the grid is represented by the :obj:`"bus"` node type, and the connecting AC lines and transformers. Additionally, :obj:`"generator"`, :obj:`"load"`, and :obj:`"shunt"` nodes are connected to :obj:`"bus"` nodes using a dedicated edge type each, *e.g.*, :obj:`"generator_link"`. Edge direction corresponds to the properties of the line, *e.g.*, :obj:`b_fr` is the line charging susceptance at the :obj:`from` (source/sender) bus. Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) case_name (str, optional): The name of the original pglib-opf case. (default: :obj:`"pglib_opf_case14_ieee"`) num_groups (int, optional): The dataset is divided into 20 groups with each group containing 15,000 samples. For large networks, this amount of data can be overwhelming. The :obj:`num_groups` parameters controls the amount of data being downloaded. Allowed values are :obj:`[1, 20]`. (default: :obj:`20`) topological_perturbations (bool, optional): Whether to use the dataset with added topological perturbations. (default: :obj:`False`) transform (callable, optional): A function/transform that takes in a :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in a :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :obj:`torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://storage.googleapis.com/gridopt-dataset' def __init__( self, root: str, split: Literal['train', 'val', 'test'] = 'train', case_name: Literal[ 'pglib_opf_case14_ieee', 'pglib_opf_case30_ieee', 'pglib_opf_case57_ieee', 'pglib_opf_case118_ieee', 'pglib_opf_case500_goc', 'pglib_opf_case2000_goc', 'pglib_opf_case6470_rte', 'pglib_opf_case4661_sdet' 'pglib_opf_case10000_goc', 'pglib_opf_case13659_pegase', ] = 'pglib_opf_case14_ieee', num_groups: int = 20, topological_perturbations: bool = False, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.split = split self.case_name = case_name self.num_groups = num_groups self.topological_perturbations = topological_perturbations self._release = 'dataset_release_1' if topological_perturbations: self._release += '_nminusone' super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) idx = self.processed_file_names.index(f'{split}.pt') self.load(self.processed_paths[idx]) @property def raw_dir(self) -> str: return osp.join(self.root, self._release, self.case_name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self._release, self.case_name, f'processed_{self.num_groups}') @property def raw_file_names(self) -> List[str]: return [f'{self.case_name}_{i}.tar.gz' for i in range(self.num_groups)] @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: for name in self.raw_file_names: url = f'{self.url}/{self._release}/{name}' path = download_url(url, self.raw_dir) extract_tar(path, self.raw_dir) def process(self) -> None: train_data_list = [] val_data_list = [] test_data_list = [] for group in tqdm.tqdm(range(self.num_groups)): tmp_dir = osp.join( self.raw_dir, 'gridopt-dataset-tmp', self._release, self.case_name, f'group_{group}', ) for name in os.listdir(tmp_dir): with open(osp.join(tmp_dir, name)) as f: obj = json.load(f) grid = obj['grid'] solution = obj['solution'] metadata = obj['metadata'] # Graph-level properties: data = HeteroData() data.x = torch.tensor(grid['context']).view(-1) data.objective = torch.tensor(metadata['objective']) # Nodes (only some have a target): data['bus'].x = torch.tensor(grid['nodes']['bus']) data['bus'].y = torch.tensor(solution['nodes']['bus']) data['generator'].x = torch.tensor(grid['nodes']['generator']) data['generator'].y = torch.tensor( solution['nodes']['generator']) data['load'].x = torch.tensor(grid['nodes']['load']) data['shunt'].x = torch.tensor(grid['nodes']['shunt']) # Edges (only ac lines and transformers have features): data['bus', 'ac_line', 'bus'].edge_index = ( # extract_edge_index(obj, 'ac_line')) data['bus', 'ac_line', 'bus'].edge_attr = torch.tensor( grid['edges']['ac_line']['features']) data['bus', 'ac_line', 'bus'].edge_label = torch.tensor( solution['edges']['ac_line']['features']) data['bus', 'transformer', 'bus'].edge_index = ( # extract_edge_index(obj, 'transformer')) data['bus', 'transformer', 'bus'].edge_attr = torch.tensor( grid['edges']['transformer']['features']) data['bus', 'transformer', 'bus'].edge_label = torch.tensor( solution['edges']['transformer']['features']) data['generator', 'generator_link', 'bus'].edge_index = ( # extract_edge_index(obj, 'generator_link')) data['bus', 'generator_link', 'generator'].edge_index = ( # extract_edge_index_rev(obj, 'generator_link')) data['load', 'load_link', 'bus'].edge_index = ( # extract_edge_index(obj, 'load_link')) data['bus', 'load_link', 'load'].edge_index = ( # extract_edge_index_rev(obj, 'load_link')) data['shunt', 'shunt_link', 'bus'].edge_index = ( # extract_edge_index(obj, 'shunt_link')) data['bus', 'shunt_link', 'shunt'].edge_index = ( # extract_edge_index_rev(obj, 'shunt_link')) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) i = int(name.split('.')[0].split('_')[1]) train_limit = int(15_000 * self.num_groups * 0.9) val_limit = train_limit + int(15_000 * self.num_groups * 0.05) if i < train_limit: train_data_list.append(data) elif i < val_limit: val_data_list.append(data) else: test_data_list.append(data) self.save(train_data_list, self.processed_paths[0]) self.save(val_data_list, self.processed_paths[1]) self.save(test_data_list, self.processed_paths[2]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'split={self.split}, ' f'case_name={self.case_name}, ' f'topological_perturbations={self.topological_perturbations})') def extract_edge_index(obj: Dict, edge_name: str) -> Tensor: return torch.tensor([ obj['grid']['edges'][edge_name]['senders'], obj['grid']['edges'][edge_name]['receivers'], ]) def extract_edge_index_rev(obj: Dict, edge_name: str) -> Tensor: return torch.tensor([ obj['grid']['edges'][edge_name]['receivers'], obj['grid']['edges'][edge_name]['senders'], ]) ================================================ FILE: torch_geometric/datasets/ose_gvcs.py ================================================ import json import os from collections import defaultdict from typing import Callable, List, Optional import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_tar, ) class OSE_GVCS(InMemoryDataset): r"""A dataset describing the `Product ecology `_ of the Open Source Ecology's iconoclastic `Global Village Construction Set `_. GVCS is a modular, DIY, low-cost set of blueprints that enables the fabrication of the 50 different industrial machines that it takes to build a small, sustainable civilization with modern comforts. The dataset contains a heterogenous graphs with 50 :obj:`machine` nodes, composing the GVCS, and 290 directed edges, each representing one out of three relationships between machines. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ machines = [ '3D Printer', '3D Scanner', 'Aluminum Extractor', 'Backhoe', 'Bakery Oven', 'Baler', 'Bioplastic Extruder', 'Bulldozer', 'Car', 'CEB Press', 'Cement Mixer', 'Chipper Hammermill', 'CNC Circuit Mill', 'CNC Torch Table', 'Dairy Milker', 'Drill Press', 'Electric Motor Generator', 'Gasifier Burner', 'Hay Cutter', 'Hay Rake', 'Hydraulic Motor', 'Induction Furnace', 'Industrial Robot', 'Ironworker', 'Laser Cutter', 'Metal Roller', 'Microcombine', 'Microtractor', 'Multimachine', 'Nickel-Iron Battery', 'Pelletizer', 'Plasma Cutter', 'Power Cube', 'Press Forge', 'Rod and Wire Mill', 'Rototiller', 'Sawmill', 'Seeder', 'Solar Concentrator', 'Spader', 'Steam Engine', 'Steam Generator', 'Tractor', 'Trencher', 'Truck', 'Universal Power Supply', 'Universal Rotor', 'Welder', 'Well-Drilling Rig', 'Wind Turbine' ] categories = [ 'habitat', 'agriculture', 'industry', 'energy', 'materials', 'transportation' ] relationships = ['from', 'uses', 'enables'] url = 'https://github.com/Wesxdz/ose_gvcs/raw/master/ose_gvcs.tar.gz' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ f"{machine.lower().replace(' ', '_')}.json" for machine in self.machines ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_tar(path, self.raw_dir) os.unlink(path) def process(self) -> None: data = HeteroData() categories = [] edges = defaultdict(list) for path in self.raw_paths: with open(path) as f: product = json.load(f) categories.append(self.categories.index(product['category'])) for interaction in product['ecology']: # NOTE Some ecology items are not GVCS machines or have other # relationship types we don't want included. rt = interaction['relationship'] if rt not in self.relationships: continue dst = interaction['tool'] if dst not in self.machines: continue # Machines are guaranteed to be sorted according to their order # in `self.machines`, so we can use its index for the mapping: src = self.machines.index(product['machine']) dst = self.machines.index(dst) edges[rt].append((src, dst)) data['machine'].num_nodes = len(categories) data['machine'].category = torch.tensor(categories) for rel, edge_indices, in edges.items(): edge_index = torch.tensor(edge_indices).t().contiguous() data['machine', rel, 'machine'].edge_index = edge_index if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/particle.py ================================================ import glob import os.path as osp from typing import Any, Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, Dataset from torch_geometric.utils import index_sort, scatter class TrackingData(Data): def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any: if key == 'y_index': return torch.tensor([value[0].max().item() + 1, self.num_nodes]) else: return super().__inc__(key, value, *args, **kwargs) class TrackMLParticleTrackingDataset(Dataset): r"""The `TrackML Particle Tracking Challenge `_ dataset to reconstruct particle tracks from 3D points left in the silicon detectors. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) """ url = 'https://www.kaggle.com/c/trackml-particle-identification' def __init__( self, root: str, transform: Optional[Callable] = None, ) -> None: super().__init__(root, transform) events = glob.glob(osp.join(self.raw_dir, 'event*-hits.csv')) events = [e.split(osp.sep)[-1].split('-')[0][5:] for e in events] self.events: List[str] = sorted(events) @property def raw_file_names(self) -> List[str]: event_indices = ['000001000'] file_names = [] file_names += [f'event{idx}-cells.csv' for idx in event_indices] file_names += [f'event{idx}-hits.csv' for idx in event_indices] file_names += [f'event{idx}-particles.csv' for idx in event_indices] file_names += [f'event{idx}-truth.csv' for idx in event_indices] return file_names def download(self) -> None: raise RuntimeError( f'Dataset not found. Please download it from {self.url} and move ' f'all *.csv files to {self.raw_dir}') def len(self) -> int: return len(glob.glob(osp.join(self.raw_dir, 'event*-hits.csv'))) def get(self, i: int) -> TrackingData: import pandas as pd idx = self.events[i] # Get hit positions. hits_path = osp.join(self.raw_dir, f'event{idx}-hits.csv') pos = pd.read_csv(hits_path, usecols=['x', 'y', 'z'], dtype=np.float32) pos = torch.from_numpy(pos.values).div_(1000.) # Get hit features. cells_path = osp.join(self.raw_dir, f'event{idx}-cells.csv') cell = pd.read_csv(cells_path, usecols=['hit_id', 'value']) hit_id = torch.from_numpy(cell['hit_id'].values).to(torch.long).sub_(1) value = torch.from_numpy(cell['value'].values).to(torch.float) ones = torch.ones(hit_id.size(0)) num_cells = scatter(ones, hit_id, 0, pos.size(0), 'sum').div_(10.) value = scatter(value, hit_id, 0, pos.size(0), 'sum') x = torch.stack([num_cells, value], dim=-1) # Get ground-truth hit assignments. truth_path = osp.join(self.raw_dir, f'event{idx}-truth.csv') y = pd.read_csv(truth_path, usecols=['hit_id', 'particle_id', 'weight']) hit_id = torch.from_numpy(y['hit_id'].values).to(torch.long).sub_(1) particle_id = torch.from_numpy(y['particle_id'].values).to(torch.long) particle_id = particle_id.unique(return_inverse=True)[1].sub_(1) weight = torch.from_numpy(y['weight'].values).to(torch.float) # Sort. _, perm = index_sort(particle_id * hit_id.size(0) + hit_id) hit_id = hit_id[perm] particle_id = particle_id[perm] weight = weight[perm] # Remove invalid particle ids. mask = particle_id >= 0 hit_id = hit_id[mask] particle_id = particle_id[mask] weight = weight[mask] y_index = torch.stack([particle_id, hit_id], dim=0) return TrackingData(x=x, pos=pos, y_index=y_index, y_weight=weight) ================================================ FILE: torch_geometric/datasets/pascal.py ================================================ import os import os.path as osp from itertools import chain from typing import Callable, Dict, List, Optional from xml.dom import minidom import numpy as np import torch import torch.nn.functional as F from torch import Tensor from torch.utils.data import DataLoader from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_tar, ) from torch_geometric.io import fs class PascalVOCKeypoints(InMemoryDataset): r"""The Pascal VOC 2011 dataset with Berkely annotations of keypoints from the `"Poselets: Body Part Detectors Trained Using 3D Human Pose Annotations" `_ paper, containing 0 to 23 keypoints per example over 20 categories. The dataset is pre-filtered to exclude difficult, occluded and truncated objects. The keypoints contain interpolated features from a pre-trained VGG16 model on ImageNet (:obj:`relu4_2` and :obj:`relu5_1`). Args: root (str): Root directory where the dataset should be saved. category (str): The category of the images (one of :obj:`"Aeroplane"`, :obj:`"Bicycle"`, :obj:`"Bird"`, :obj:`"Boat"`, :obj:`"Bottle"`, :obj:`"Bus"`, :obj:`"Car"`, :obj:`"Cat"`, :obj:`"Chair"`, :obj:`"Diningtable"`, :obj:`"Dog"`, :obj:`"Horse"`, :obj:`"Motorbike"`, :obj:`"Person"`, :obj:`"Pottedplant"`, :obj:`"Sheep"`, :obj:`"Sofa"`, :obj:`"Train"`, :obj:`"TVMonitor"`) train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) device (str or torch.device, optional): The device to use for processing the raw data. If set to :obj:`None`, will utilize GPU-processing if available. (default: :obj:`None`) """ image_url = ('http://host.robots.ox.ac.uk/pascal/VOC/voc2011/' 'VOCtrainval_25-May-2011.tar') annotation_url = ('https://www2.eecs.berkeley.edu/Research/Projects/CS/' 'vision/shape/poselets/voc2011_keypoints_Feb2012.tgz') # annotation_url = 'http://www.roemisch-drei.de/pascal_annotations.tar' # split_url = 'http://cvgl.stanford.edu/projects/ucn/voc2011_pairs.npz' split_url = ('https://github.com/Thinklab-SJTU/PCA-GM/raw/master/data/' 'PascalVOC/voc2011_pairs.npz') categories = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] batch_size = 32 def __init__( self, root: str, category: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, device: Optional[str] = None, ) -> None: if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.category = category.lower() assert self.category in self.categories self.device = device super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_dir(self) -> str: return osp.join(self.root, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.category.capitalize(), 'processed') @property def raw_file_names(self) -> List[str]: return ['images', 'annotations', 'splits.npz'] @property def processed_file_names(self) -> List[str]: return ['training.pt', 'test.pt'] def download(self) -> None: path = download_url(self.image_url, self.raw_dir) extract_tar(path, self.raw_dir, mode='r') os.unlink(path) image_path = osp.join(self.raw_dir, 'TrainVal', 'VOCdevkit', 'VOC2011') os.rename(image_path, osp.join(self.raw_dir, 'images')) fs.rm(osp.join(self.raw_dir, 'TrainVal')) path = download_url(self.annotation_url, self.raw_dir) extract_tar(path, self.raw_dir, mode='r') os.unlink(path) path = download_url(self.split_url, self.raw_dir) os.rename(path, osp.join(self.raw_dir, 'splits.npz')) def process(self) -> None: import torchvision.models as models import torchvision.transforms as T from PIL import Image splits = np.load(osp.join(self.raw_dir, 'splits.npz'), allow_pickle=True) category_idx = self.categories.index(self.category) train_split = list(splits['train'])[category_idx] test_split = list(splits['test'])[category_idx] image_path = osp.join(self.raw_dir, 'images', 'JPEGImages') info_path = osp.join(self.raw_dir, 'images', 'Annotations') annotation_path = osp.join(self.raw_dir, 'annotations') labels: Dict[str, int] = {} vgg16_outputs = [] def hook(module: torch.nn.Module, x: Tensor, y: Tensor) -> None: vgg16_outputs.append(y) vgg16 = models.vgg16(pretrained=True).to(self.device) vgg16.eval() vgg16.features[20].register_forward_hook(hook) # relu4_2 vgg16.features[25].register_forward_hook(hook) # relu5_1 transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_set, test_set = [], [] for i, name in enumerate(chain(train_split, test_split)): filename = '_'.join(name.split('/')[1].split('_')[:-1]) file_idx = int(name.split('_')[-1].split('.')[0]) - 1 path = osp.join(info_path, f'{filename}.xml') obj = minidom.parse(path).getElementsByTagName('object')[file_idx] child = obj.getElementsByTagName('truncated')[0].firstChild assert child is not None trunc = child.data # type: ignore elements = obj.getElementsByTagName('occluded') if len(elements) == 0: occ = '0' else: child = elements[0].firstChild assert child is not None occ = child.data # type: ignore child = obj.getElementsByTagName('difficult')[0].firstChild diff = child.data # type: ignore if bool(int(trunc)) or bool(int(occ)) or bool(int(diff)): continue if self.category == 'person' and int(filename[:4]) > 2008: continue child = obj.getElementsByTagName('xmin')[0].firstChild assert child is not None xmin = int(child.data) # type: ignore child = obj.getElementsByTagName('xmax')[0].firstChild assert child is not None xmax = int(child.data) # type: ignore child = obj.getElementsByTagName('ymin')[0].firstChild assert child is not None ymin = int(child.data) # type: ignore child = obj.getElementsByTagName('ymax')[0].firstChild assert child is not None ymax = int(child.data) # type: ignore box = (xmin, ymin, xmax, ymax) dom = minidom.parse(osp.join(annotation_path, name)) keypoints = dom.getElementsByTagName('keypoint') poss, ys = [], [] for keypoint in keypoints: label = keypoint.attributes['name'].value if label not in labels: labels[label] = len(labels) ys.append(labels[label]) _x = float(keypoint.attributes['x'].value) _y = float(keypoint.attributes['y'].value) poss += [_x, _y] y = torch.tensor(ys, dtype=torch.long) pos = torch.tensor(poss, dtype=torch.float).view(-1, 2) if pos.numel() == 0: continue # These examples do not make any sense anyway... # Add a small offset to the bounding because some keypoints lay # outside the bounding box intervals. box = ( min(int(pos[:, 0].min().floor()), box[0]) - 16, min(int(pos[:, 1].min().floor()), box[1]) - 16, max(int(pos[:, 0].max().ceil()), box[2]) + 16, max(int(pos[:, 1].max().ceil()), box[3]) + 16, ) # Rescale keypoints. pos[:, 0] = (pos[:, 0] - box[0]) * 256.0 / (box[2] - box[0]) pos[:, 1] = (pos[:, 1] - box[1]) * 256.0 / (box[3] - box[1]) path = osp.join(image_path, f'{filename}.jpg') with open(path, 'rb') as f: img = Image.open(f).convert('RGB').crop(box) img = img.resize((256, 256), resample=Image.Resampling.BICUBIC) img = transform(img) data = Data(img=img, pos=pos, y=y, name=filename) if i < len(train_split): train_set.append(data) else: test_set.append(data) data_list = list(chain(train_set, test_set)) imgs = [data.img for data in data_list] loader: DataLoader = DataLoader( dataset=imgs, # type: ignore batch_size=self.batch_size, shuffle=False, ) for i, batch_img in enumerate(loader): vgg16_outputs.clear() with torch.no_grad(): vgg16(batch_img.to(self.device)) out1 = F.interpolate(vgg16_outputs[0], (256, 256), mode='bilinear', align_corners=False) out2 = F.interpolate(vgg16_outputs[1], (256, 256), mode='bilinear', align_corners=False) for j in range(out1.size(0)): data = data_list[i * self.batch_size + j] assert data.pos is not None idx = data.pos.round().long().clamp(0, 255) x_1 = out1[j, :, idx[:, 1], idx[:, 0]].to('cpu') x_2 = out2[j, :, idx[:, 1], idx[:, 0]].to('cpu') data.img = None data.x = torch.cat([x_1.t(), x_2.t()], dim=-1) del out1 del out2 if self.pre_filter is not None: train_set = [data for data in train_set if self.pre_filter(data)] test_set = [data for data in test_set if self.pre_filter(data)] if self.pre_transform is not None: train_set = [self.pre_transform(data) for data in train_set] test_set = [self.pre_transform(data) for data in test_set] self.save(train_set, self.processed_paths[0]) self.save(test_set, self.processed_paths[1]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'category={self.category})') ================================================ FILE: torch_geometric/datasets/pascal_pf.py ================================================ import glob import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class PascalPF(InMemoryDataset): r"""The Pascal-PF dataset from the `"Proposal Flow" `_ paper, containing 4 to 16 keypoints per example over 20 categories. Args: root (str): Root directory where the dataset should be saved. category (str): The category of the images (one of :obj:`"Aeroplane"`, :obj:`"Bicycle"`, :obj:`"Bird"`, :obj:`"Boat"`, :obj:`"Bottle"`, :obj:`"Bus"`, :obj:`"Car"`, :obj:`"Cat"`, :obj:`"Chair"`, :obj:`"Diningtable"`, :obj:`"Dog"`, :obj:`"Horse"`, :obj:`"Motorbike"`, :obj:`"Person"`, :obj:`"Pottedplant"`, :obj:`"Sheep"`, :obj:`"Sofa"`, :obj:`"Train"`, :obj:`"TVMonitor"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://www.di.ens.fr/willow/research/proposalflow/dataset/' 'PF-dataset-PASCAL.zip') categories = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] def __init__( self, root: str, category: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.category = category.lower() assert self.category in self.categories super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) self.pairs = fs.torch_load(self.processed_paths[1]) @property def raw_file_names(self) -> List[str]: return ['Annotations', 'parsePascalVOC.mat'] @property def processed_file_names(self) -> List[str]: return [f'{self.category}.pt', f'{self.category}_pairs.pt'] def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) fs.rm(self.raw_dir) os.rename(osp.join(self.root, 'PF-dataset-PASCAL'), self.raw_dir) def process(self) -> None: from scipy.io import loadmat path = osp.join(self.raw_dir, 'Annotations', self.category, '*.mat') filenames = glob.glob(path) names = [] data_list = [] for filename in filenames: name = osp.basename(filename).split('.')[0] pos = torch.from_numpy(loadmat(filename)['kps']).to(torch.float) mask = ~torch.isnan(pos[:, 0]) pos = pos[mask] # Normalize points to unit sphere. pos = pos - pos.mean(dim=0, keepdim=True) pos = pos / pos.norm(dim=1).max() y = mask.nonzero(as_tuple=False).flatten() data = Data(pos=pos, y=y, name=name) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) names.append(name) data_list.append(data) pairs = loadmat(osp.join(self.raw_dir, 'parsePascalVOC.mat')) pairs = pairs['PascalVOC']['pair'][0, 0][ 0, self.categories.index(self.category)] pairs = [(names.index(x[0][0]), names.index(x[1][0])) for x in pairs] self.save(data_list, self.processed_paths[0]) torch.save(pairs, self.processed_paths[1]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'category={self.category})') ================================================ FILE: torch_geometric/datasets/pcpnet_dataset.py ================================================ import os import os.path as osp from typing import Callable, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import read_txt_array class PCPNetDataset(InMemoryDataset): r"""The PCPNet dataset from the `"PCPNet: Learning Local Shape Properties from Raw Point Clouds" `_ paper, consisting of 30 shapes, each given as a point cloud, densely sampled with 100k points. For each shape, surface normals and local curvatures are given as node features. Args: root (str): Root directory where the dataset should be saved. category (str): The training set category (one of :obj:`"NoNoise"`, :obj:`"Noisy"`, :obj:`"VarDensity"`, :obj:`"NoisyAndVarDensity"` for :obj:`split="train"` or :obj:`split="val"`, or one of :obj:`"All"`, :obj:`"LowNoise"`, :obj:`"MedNoise"`, :obj:`"HighNoise", :obj:`"VarDensityStriped", :obj:`"VarDensityGradient"` for :obj:`split="test"`). split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'http://geometry.cs.ucl.ac.uk/projects/2018/pcpnet/pclouds.zip' category_files_train = { 'NoNoise': 'trainingset_no_noise.txt', 'Noisy': 'trainingset_whitenoise.txt', 'VarDensity': 'trainingset_vardensity.txt', 'NoisyAndVarDensity': 'trainingset_vardensity_whitenoise.txt' } category_files_val = { 'NoNoise': 'validationset_no_noise.txt', 'Noisy': 'validationset_whitenoise.txt', 'VarDensity': 'validationset_vardensity.txt', 'NoisyAndVarDensity': 'validationset_vardensity_whitenoise.txt' } category_files_test = { 'All': 'testset_all.txt', 'NoNoise': 'testset_no_noise.txt', 'LowNoise': 'testset_low_noise.txt', 'MedNoise': 'testset_med_noise.txt', 'HighNoise': 'testset_high_noise.txt', 'VarDensityStriped': 'testset_vardensity_striped.txt', 'VarDensityGradient': 'testset_vardensity_gradient.txt' } def __init__( self, root: str, category: str, split: str = 'train', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert split in ['train', 'val', 'test'] if split == 'train': assert category in self.category_files_train.keys() elif split == 'val': assert category in self.category_files_val.keys() else: assert category in self.category_files_test.keys() self.category = category self.split = split super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: if self.split == 'train': return self.category_files_train[self.category] elif self.split == 'val': return self.category_files_val[self.category] else: return self.category_files_test[self.category] @property def processed_file_names(self) -> str: return self.split + '_' + self.category + '.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: path_file = self.raw_paths with open(path_file[0]) as f: filenames = f.read().split('\n')[:-1] data_list = [] for filename in filenames: pos_path = osp.join(self.raw_dir, filename + '.xyz') normal_path = osp.join(self.raw_dir, filename + '.normals') curv_path = osp.join(self.raw_dir, filename + '.curv') idx_path = osp.join(self.raw_dir, filename + '.pidx') pos = read_txt_array(pos_path) normals = read_txt_array(normal_path) curv = read_txt_array(curv_path) normals_and_curv = torch.cat([normals, curv], dim=1) test_idx = read_txt_array(idx_path, dtype=torch.long) data = Data(pos=pos, x=normals_and_curv) data.test_idx = test_idx if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'category={self.category})') ================================================ FILE: torch_geometric/datasets/pcqm4m.py ================================================ import os import os.path as osp from typing import Any, Callable, Dict, List, Optional import torch from tqdm import tqdm from torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip from torch_geometric.data.data import BaseData from torch_geometric.io import fs from torch_geometric.utils import from_smiles as _from_smiles class PCQM4Mv2(OnDiskDataset): r"""The PCQM4Mv2 dataset from the `"OGB-LSC: A Large-Scale Challenge for Machine Learning on Graphs" `_ paper. :class:`PCQM4Mv2` is a quantum chemistry dataset originally curated under the `PubChemQC project `_. The task is to predict the DFT-calculated HOMO-LUMO energy gap of molecules given their 2D molecular graphs. .. note:: This dataset uses the :class:`OnDiskDataset` base class to load data dynamically from disk. Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. If :obj:`"holdout"`, loads the holdout dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) backend (str): The :class:`Database` backend to use. (default: :obj:`"sqlite"`) from_smiles (callable, optional): A custom function that takes a SMILES string and outputs a :obj:`~torch_geometric.data.Data` object. If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`. (default: :obj:`None`) """ url = ('https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/' 'pcqm4m-v2.zip') split_mapping = { 'train': 'train', 'val': 'valid', 'test': 'test-dev', 'holdout': 'test-challenge', } def __init__( self, root: str, split: str = 'train', transform: Optional[Callable] = None, backend: str = 'sqlite', from_smiles: Optional[Callable] = None, ) -> None: assert split in ['train', 'val', 'test', 'holdout'] schema = { 'x': dict(dtype=torch.int64, size=(-1, 9)), 'edge_index': dict(dtype=torch.int64, size=(2, -1)), 'edge_attr': dict(dtype=torch.int64, size=(-1, 3)), 'smiles': str, 'y': float, } self.from_smiles = from_smiles or _from_smiles super().__init__(root, transform, backend=backend, schema=schema) split_idx = fs.torch_load(self.raw_paths[1]) self._indices = split_idx[self.split_mapping[split]].tolist() @property def raw_file_names(self) -> List[str]: return [ osp.join('pcqm4m-v2', 'raw', 'data.csv.gz'), osp.join('pcqm4m-v2', 'split_dict.pt'), ] def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: import pandas as pd df = pd.read_csv(self.raw_paths[0]) data_list: List[Data] = [] iterator = enumerate(zip(df['smiles'], df['homolumogap'])) for i, (smiles, y) in tqdm(iterator, total=len(df)): data = self.from_smiles(smiles) data.y = y data_list.append(data) if i + 1 == len(df) or (i + 1) % 1000 == 0: # Write batch-wise: self.extend(data_list) data_list = [] def serialize(self, data: BaseData) -> Dict[str, Any]: assert isinstance(data, Data) return dict( x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y, smiles=data.smiles, ) def deserialize(self, data: Dict[str, Any]) -> Data: return Data.from_dict(data) ================================================ FILE: torch_geometric/datasets/planetoid.py ================================================ import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import InMemoryDataset from torch_geometric.io import fs, read_planetoid_data class Planetoid(InMemoryDataset): r"""The citation network datasets :obj:`"Cora"`, :obj:`"CiteSeer"` and :obj:`"PubMed"` from the `"Revisiting Semi-Supervised Learning with Graph Embeddings" `_ paper. Nodes represent documents and edges represent citation links. Training, validation and test splits are given by binary masks. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"Cora"`, :obj:`"CiteSeer"`, :obj:`"PubMed"`). split (str, optional): The type of dataset split (:obj:`"public"`, :obj:`"full"`, :obj:`"geom-gcn"`, :obj:`"random"`). If set to :obj:`"public"`, the split will be the public fixed split from the `"Revisiting Semi-Supervised Learning with Graph Embeddings" `_ paper. If set to :obj:`"full"`, all nodes except those in the validation and test sets will be used for training (as in the `"FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling" `_ paper). If set to :obj:`"geom-gcn"`, the 10 public fixed splits from the `"Geom-GCN: Geometric Graph Convolutional Networks" `_ paper are given. If set to :obj:`"random"`, train, validation, and test sets will be randomly generated, according to :obj:`num_train_per_class`, :obj:`num_val` and :obj:`num_test`. (default: :obj:`"public"`) num_train_per_class (int, optional): The number of training samples per class in case of :obj:`"random"` split. (default: :obj:`20`) num_val (int, optional): The number of validation samples in case of :obj:`"random"` split. (default: :obj:`500`) num_test (int, optional): The number of test samples in case of :obj:`"random"` split. (default: :obj:`1000`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Cora - 2,708 - 10,556 - 1,433 - 7 * - CiteSeer - 3,327 - 9,104 - 3,703 - 6 * - PubMed - 19,717 - 88,648 - 500 - 3 """ url = 'https://github.com/kimiyoung/planetoid/raw/master/data' geom_gcn_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/' 'geom-gcn/master') def __init__( self, root: str, name: str, split: str = "public", num_train_per_class: int = 20, num_val: int = 500, num_test: int = 1000, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name self.split = split.lower() assert self.split in ['public', 'full', 'geom-gcn', 'random'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) if split == 'full': data = self.get(0) data.train_mask.fill_(True) data.train_mask[data.val_mask | data.test_mask] = False self.data, self.slices = self.collate([data]) elif split == 'random': data = self.get(0) data.train_mask.fill_(False) for c in range(self.num_classes): idx = (data.y == c).nonzero(as_tuple=False).view(-1) idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]] data.train_mask[idx] = True remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1) remaining = remaining[torch.randperm(remaining.size(0))] data.val_mask.fill_(False) data.val_mask[remaining[:num_val]] = True data.test_mask.fill_(False) data.test_mask[remaining[num_val:num_val + num_test]] = True self.data, self.slices = self.collate([data]) @property def raw_dir(self) -> str: if self.split == 'geom-gcn': return osp.join(self.root, self.name, 'geom-gcn', 'raw') return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: if self.split == 'geom-gcn': return osp.join(self.root, self.name, 'geom-gcn', 'processed') return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] return [f'ind.{self.name.lower()}.{name}' for name in names] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for name in self.raw_file_names: fs.cp(f'{self.url}/{name}', self.raw_dir) if self.split == 'geom-gcn': for i in range(10): url = f'{self.geom_gcn_url}/splits/{self.name.lower()}' fs.cp(f'{url}_split_0.6_0.2_{i}.npz', self.raw_dir) def process(self) -> None: data = read_planetoid_data(self.raw_dir, self.name) if self.split == 'geom-gcn': train_masks, val_masks, test_masks = [], [], [] for i in range(10): name = f'{self.name.lower()}_split_0.6_0.2_{i}.npz' splits = np.load(osp.join(self.raw_dir, name)) train_masks.append(torch.from_numpy(splits['train_mask'])) val_masks.append(torch.from_numpy(splits['val_mask'])) test_masks.append(torch.from_numpy(splits['test_mask'])) data.train_mask = torch.stack(train_masks, dim=1) data.val_mask = torch.stack(val_masks, dim=1) data.test_mask = torch.stack(test_masks, dim=1) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name}()' ================================================ FILE: torch_geometric/datasets/polblogs.py ================================================ import os from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_tar, ) class PolBlogs(InMemoryDataset): r"""The Political Blogs dataset from the `"The Political Blogosphere and the 2004 US Election: Divided they Blog" `_ paper. :class:`Polblogs` is a graph with 1,490 vertices (representing political blogs) and 19,025 edges (links between blogs). The links are automatically extracted from a crawl of the front page of the blog. Each vertex receives a label indicating the political leaning of the blog: liberal or conservative. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 1,490 - 19,025 - 0 - 2 """ url = 'https://netset.telecom-paris.fr/datasets/polblogs.tar.gz' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['adjacency.tsv', 'labels.tsv'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_tar(path, self.raw_dir) os.unlink(path) def process(self) -> None: import pandas as pd edge_index = pd.read_csv(self.raw_paths[0], header=None, sep='\t', usecols=[0, 1]) edge_index = torch.from_numpy(edge_index.values).t().contiguous() y = pd.read_csv(self.raw_paths[1], header=None, sep='\t') y = torch.from_numpy(y.values).view(-1) data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0)) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/ppi.py ================================================ import json import os import os.path as osp from itertools import product from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.utils import remove_self_loops class PPI(InMemoryDataset): r"""The protein-protein interaction networks from the `"Predicting Multicellular Function through Multi-layer Tissue Networks" `_ paper, containing positional gene sets, motif gene sets and immunological signatures as features (50 in total) and gene ontology sets as labels (121 in total). Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #tasks * - 20 - ~2,245.3 - ~61,318.4 - 50 - 121 """ url = 'https://data.dgl.ai/dataset/ppi.zip' def __init__( self, root: str, split: str = 'train', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert split in ['train', 'val', 'test'] super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) if split == 'train': self.load(self.processed_paths[0]) elif split == 'val': self.load(self.processed_paths[1]) elif split == 'test': self.load(self.processed_paths[2]) @property def raw_file_names(self) -> List[str]: splits = ['train', 'valid', 'test'] files = ['feats.npy', 'graph_id.npy', 'graph.json', 'labels.npy'] return [f'{split}_{name}' for split, name in product(splits, files)] @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: import networkx as nx from networkx.readwrite import json_graph for s, split in enumerate(['train', 'valid', 'test']): path = osp.join(self.raw_dir, f'{split}_graph.json') with open(path) as f: G = nx.DiGraph( json_graph.node_link_graph(json.load(f), edges="links")) x = np.load(osp.join(self.raw_dir, f'{split}_feats.npy')) x = torch.from_numpy(x).to(torch.float) y = np.load(osp.join(self.raw_dir, f'{split}_labels.npy')) y = torch.from_numpy(y).to(torch.float) data_list = [] path = osp.join(self.raw_dir, f'{split}_graph_id.npy') idx = torch.from_numpy(np.load(path)).to(torch.long) idx = idx - idx.min() for i in range(int(idx.max()) + 1): mask = idx == i G_s = G.subgraph( mask.nonzero(as_tuple=False).view(-1).tolist()) edge_index = torch.tensor(list(G_s.edges)).t().contiguous() edge_index = edge_index - edge_index.min() edge_index, _ = remove_self_loops(edge_index) data = Data(edge_index=edge_index, x=x[mask], y=y[mask]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[s]) ================================================ FILE: torch_geometric/datasets/protein_mpnn_dataset.py ================================================ import os import pickle import random from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from tqdm import tqdm from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_tar, ) class ProteinMPNNDataset(InMemoryDataset): r"""The ProteinMPNN dataset from the `"Robust deep learning based protein sequence design using ProteinMPNN" `_ paper. Args: root (str): Root directory where the dataset should be saved. size (str): Size of the PDB information to train the model. If :obj:`"small"`, loads the small dataset (229.4 MB). If :obj:`"large"`, loads the large dataset (64.1 GB). (default: :obj:`"small"`) split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"valid"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) datacut (str, optional): Date cutoff to filter the dataset. (default: :obj:`"2030-01-01"`) rescut (float, optional): PDB resolution cutoff. (default: :obj:`3.5`) homo (float, optional): Homology cutoff. (default: :obj:`0.70`) max_length (int, optional): Maximum length of the protein complex. (default: :obj:`10000`) num_units (int, optional): Number of units of the protein complex. (default: :obj:`150`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ raw_url = { 'small': 'https://files.ipd.uw.edu/pub/training_sets/' 'pdb_2021aug02_sample.tar.gz', 'large': 'https://files.ipd.uw.edu/pub/training_sets/' 'pdb_2021aug02.tar.gz', } splits = { 'train': 1, 'valid': 2, 'test': 3, } def __init__( self, root: str, size: str = 'small', split: str = 'train', datacut: str = '2030-01-01', rescut: float = 3.5, homo: float = 0.70, max_length: int = 10000, num_units: int = 150, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.size = size self.split = split self.datacut = datacut self.rescut = rescut self.homo = homo self.max_length = max_length self.num_units = num_units self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0] super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[self.splits[self.split]]) @property def raw_file_names(self) -> List[str]: return [ f'{self.sub_folder}/{f}' for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt'] ] @property def processed_file_names(self) -> List[str]: return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt'] def download(self) -> None: file_path = download_url(self.raw_url[self.size], self.raw_dir) extract_tar(file_path, self.raw_dir) os.unlink(file_path) def process(self) -> None: alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX')) cluster_ids = self._process_split() total_items = sum(len(items) for items in cluster_ids.values()) data_list = [] with tqdm(total=total_items, desc="Processing") as pbar: for _, items in cluster_ids.items(): for chain_id, _ in items: item = self._process_pdb1(chain_id) if 'label' not in item: pbar.update(1) continue if len(list(np.unique(item['idx']))) >= 352: pbar.update(1) continue my_dict = self._process_pdb2(item) if len(my_dict['seq']) > self.max_length: pbar.update(1) continue bad_chars = set(list( my_dict['seq'])).difference(alphabet_set) if len(bad_chars) > 0: pbar.update(1) continue x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501 my_dict) data = Data( x=x_chain_all, # [seq_len, 4, 3] chain_seq_label=chain_seq_label_all, # [seq_len] mask=mask, # [seq_len] chain_mask_all=chain_mask_all, # [seq_len] residue_idx=residue_idx, # [seq_len] chain_encoding_all=chain_encoding_all, # [seq_len] ) if self.pre_filter is not None and not self.pre_filter( data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) if len(data_list) >= self.num_units: pbar.update(total_items - pbar.n) break pbar.update(1) else: continue break self.save(data_list, self.processed_paths[self.splits[self.split]]) def _process_split(self) -> Dict[int, List[Tuple[str, int]]]: import pandas as pd save_path = self.processed_paths[0] if os.path.exists(save_path): print('Load split') with open(save_path, 'rb') as f: data = pickle.load(f) else: # CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE df = pd.read_csv(self.raw_paths[0]) df = df[(df['RESOLUTION'] <= self.rescut) & (df['DEPOSITION'] <= self.datacut)] val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist() test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist() # compile training and validation sets data = { 'train': defaultdict(list), 'valid': defaultdict(list), 'test': defaultdict(list), } for _, r in tqdm(df.iterrows(), desc='Processing split', total=len(df)): cluster_id = r['CLUSTER'] hash_id = r['HASH'] chain_id = r['CHAINID'] if cluster_id in val_ids: data['valid'][cluster_id].append((chain_id, hash_id)) elif cluster_id in test_ids: data['test'][cluster_id].append((chain_id, hash_id)) else: data['train'][cluster_id].append((chain_id, hash_id)) with open(save_path, 'wb') as f: pickle.dump(data, f) return data[self.split] def _process_pdb1(self, chain_id: str) -> Dict[str, Any]: pdbid, chid = chain_id.split('_') prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}' # load metadata if not os.path.isfile(f'{prefix}.pt'): return {'seq': np.zeros(5)} meta = torch.load(f'{prefix}.pt') asmb_ids = meta['asmb_ids'] asmb_chains = meta['asmb_chains'] chids = np.array(meta['chains']) # find candidate assemblies which contain chid chain asmb_candidates = { a for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',') } # if the chains is missing is missing from all the assemblies # then return this chain alone if len(asmb_candidates) < 1: chain = torch.load(f'{prefix}_{chid}.pt') L = len(chain['seq']) return { 'seq': chain['seq'], 'xyz': chain['xyz'], 'idx': torch.zeros(L).int(), 'masked': torch.Tensor([0]).int(), 'label': chain_id, } # randomly pick one assembly from candidates asmb_i = random.sample(list(asmb_candidates), 1) # indices of selected transforms idx = np.where(np.array(asmb_ids) == asmb_i)[0] # load relevant chains chains = { c: torch.load(f'{prefix}_{c}.pt') for i in idx for c in asmb_chains[i] if c in meta['chains'] } # generate assembly asmb = {} for k in idx: # pick k-th xform xform = meta[f'asmb_xform{k}'] u = xform[:, :3, :3] r = xform[:, :3, 3] # select chains which k-th xform should be applied to s1 = set(meta['chains']) s2 = set(asmb_chains[k].split(',')) chains_k = s1 & s2 # transform selected chains for c in chains_k: try: xyz = chains[c]['xyz'] xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None, None, :] asmb.update({ (c, k, i): xyz_i for i, xyz_i in enumerate(xyz_ru) }) except KeyError: return {'seq': np.zeros(5)} # select chains which share considerable similarity to chid seqid = meta['tm'][chids == chid][0, :, 1] homo = { ch_j for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo } # stack all chains in the assembly together seq: str = '' xyz_all: List[torch.Tensor] = [] idx_all: List[torch.Tensor] = [] masked: List[int] = [] seq_list = [] for counter, (k, v) in enumerate(asmb.items()): seq += chains[k[0]]['seq'] seq_list.append(chains[k[0]]['seq']) xyz_all.append(v) idx_all.append(torch.full((v.shape[0], ), counter)) if k[0] in homo: masked.append(counter) return { 'seq': seq, 'xyz': torch.cat(xyz_all, dim=0), 'idx': torch.cat(idx_all, dim=0), 'masked': torch.Tensor(masked).int(), 'label': chain_id, } def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]: init_alphabet = list( 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') extra_alphabet = [str(item) for item in list(np.arange(300))] chain_alphabet = init_alphabet + extra_alphabet my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {} concat_seq = '' mask_list = [] visible_list = [] for idx in list(np.unique(t['idx'])): letter = chain_alphabet[idx] res = np.argwhere(t['idx'] == idx) initial_sequence = "".join(list( np.array(list(t['seq']))[res][ 0, ])) if initial_sequence[-6:] == "HHHHHH": res = res[:, :-6] if initial_sequence[0:6] == "HHHHHH": res = res[:, 6:] if initial_sequence[-7:-1] == "HHHHHH": res = res[:, :-7] if initial_sequence[-8:-2] == "HHHHHH": res = res[:, :-8] if initial_sequence[-9:-3] == "HHHHHH": res = res[:, :-9] if initial_sequence[-10:-4] == "HHHHHH": res = res[:, :-10] if initial_sequence[1:7] == "HHHHHH": res = res[:, 7:] if initial_sequence[2:8] == "HHHHHH": res = res[:, 8:] if initial_sequence[3:9] == "HHHHHH": res = res[:, 9:] if initial_sequence[4:10] == "HHHHHH": res = res[:, 10:] if res.shape[1] >= 4: chain_seq = "".join(list(np.array(list(t['seq']))[res][0])) my_dict[f'seq_chain_{letter}'] = chain_seq concat_seq += chain_seq if idx in t['masked']: mask_list.append(letter) else: visible_list.append(letter) coords_dict_chain = {} all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3] for i, c in enumerate(['N', 'CA', 'C', 'O']): coords_dict_chain[ f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist() my_dict[f'coords_chain_{letter}'] = coords_dict_chain my_dict['name'] = t['label'] my_dict['masked_list'] = mask_list my_dict['visible_list'] = visible_list my_dict['num_of_chains'] = len(mask_list) + len(visible_list) my_dict['seq'] = concat_seq return my_dict def _process_pdb3( self, b: Dict[str, Any] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: L = len(b['seq']) # residue idx with jumps across chains residue_idx = -100 * np.ones([L], dtype=np.int32) # get the list of masked / visible chains masked_chains, visible_chains = b['masked_list'], b['visible_list'] visible_temp_dict, masked_temp_dict = {}, {} for letter in masked_chains + visible_chains: chain_seq = b[f'seq_chain_{letter}'] if letter in visible_chains: visible_temp_dict[letter] = chain_seq elif letter in masked_chains: masked_temp_dict[letter] = chain_seq # check for duplicate chains (same sequence but different identity) for _, vm in masked_temp_dict.items(): for kv, vv in visible_temp_dict.items(): if vm == vv: if kv not in masked_chains: masked_chains.append(kv) if kv in visible_chains: visible_chains.remove(kv) # build protein data structures all_chains = masked_chains + visible_chains np.random.shuffle(all_chains) x_chain_list = [] chain_mask_list = [] chain_seq_list = [] chain_encoding_list = [] c, l0, l1 = 1, 0, 0 for letter in all_chains: chain_seq = b[f'seq_chain_{letter}'] chain_length = len(chain_seq) chain_coords = b[f'coords_chain_{letter}'] x_chain = np.stack([ chain_coords[c] for c in [ f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}' ] ], 1) # [chain_length, 4, 3] x_chain_list.append(x_chain) chain_seq_list.append(chain_seq) if letter in visible_chains: chain_mask = np.zeros(chain_length) # 0 for visible chains elif letter in masked_chains: chain_mask = np.ones(chain_length) # 1 for masked chains chain_mask_list.append(chain_mask) chain_encoding_list.append(c * np.ones(chain_length)) l1 += chain_length residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1) l0 += chain_length c += 1 x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3] chain_seq_all = "".join(chain_seq_list) # [L,] 1.0 for places that need to be predicted chain_mask_all = np.concatenate(chain_mask_list, 0) chain_encoding_all = np.concatenate(chain_encoding_list, 0) # Convert to labels alphabet = 'ACDEFGHIKLMNPQRSTVWYX' chain_seq_label_all = np.asarray( [alphabet.index(a) for a in chain_seq_all], dtype=np.int32) isnan = np.isnan(x_chain_all) mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32) x_chain_all[isnan] = 0. # Conversion return ( torch.from_numpy(x_chain_all).to(dtype=torch.float32), torch.from_numpy(chain_seq_label_all).to(dtype=torch.long), torch.from_numpy(mask).to(dtype=torch.float32), torch.from_numpy(chain_mask_all).to(dtype=torch.float32), torch.from_numpy(residue_idx).to(dtype=torch.long), torch.from_numpy(chain_encoding_all).to(dtype=torch.long), ) ================================================ FILE: torch_geometric/datasets/qm7.py ================================================ from typing import Callable, Optional import torch from torch_geometric.data import Data, InMemoryDataset, download_url class QM7b(InMemoryDataset): r"""The QM7b dataset from the `"MoleculeNet: A Benchmark for Molecular Machine Learning" `_ paper, consisting of 7,211 molecules with 14 regression targets. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #tasks * - 7,211 - ~15.4 - ~245.0 - 0 - 14 """ url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm7b.mat' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'qm7b.mat' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url, self.raw_dir) def process(self) -> None: from scipy.io import loadmat data = loadmat(self.raw_paths[0]) coulomb_matrix = torch.from_numpy(data['X']) target = torch.from_numpy(data['T']).to(torch.float) data_list = [] for i in range(target.shape[0]): edge_index = coulomb_matrix[i].nonzero( as_tuple=False).t().contiguous() edge_attr = coulomb_matrix[i, edge_index[0], edge_index[1]] y = target[i].view(1, -1) data = Data(edge_index=edge_index, edge_attr=edge_attr, y=y) data.num_nodes = int(edge_index.max()) + 1 data_list.append(data) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/qm9.py ================================================ import os import os.path as osp import sys from typing import Callable, List, Optional import torch from torch import Tensor from tqdm import tqdm from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs from torch_geometric.utils import one_hot, scatter HAR2EV = 27.211386246 KCALMOL2EV = 0.04336414 conversion = torch.tensor([ 1., 1., HAR2EV, HAR2EV, HAR2EV, 1., HAR2EV, HAR2EV, HAR2EV, HAR2EV, HAR2EV, 1., KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, 1., 1., 1. ]) atomrefs = { 6: [0., 0., 0., 0., 0.], 7: [ -13.61312172, -1029.86312267, -1485.30251237, -2042.61123593, -2713.48485589 ], 8: [ -13.5745904, -1029.82456413, -1485.26398105, -2042.5727046, -2713.44632457 ], 9: [ -13.54887564, -1029.79887659, -1485.2382935, -2042.54701705, -2713.42063702 ], 10: [ -13.90303183, -1030.25891228, -1485.71166277, -2043.01812778, -2713.88796536 ], 11: [0., 0., 0., 0., 0.], } class QM9(InMemoryDataset): r"""The QM9 dataset from the `"MoleculeNet: A Benchmark for Molecular Machine Learning" `_ paper, consisting of about 130,000 molecules with 19 regression targets. Each molecule includes complete spatial information for the single low energy conformation of the atoms in the molecule. In addition, we provide the atom features from the `"Neural Message Passing for Quantum Chemistry" `_ paper. +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | Target | Property | Description | Unit | +========+==================================+===================================================================================+=============================================+ | 0 | :math:`\mu` | Dipole moment | :math:`\textrm{D}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 1 | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 2 | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 3 | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 4 | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 5 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 6 | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 7 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 8 | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 9 | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 10 | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 11 | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 12 | :math:`U_0^{\textrm{ATOM}}` | Atomization energy at 0K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 13 | :math:`U^{\textrm{ATOM}}` | Atomization energy at 298.15K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 14 | :math:`H^{\textrm{ATOM}}` | Atomization enthalpy at 298.15K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 15 | :math:`G^{\textrm{ATOM}}` | Atomization free energy at 298.15K | :math:`\textrm{eV}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 16 | :math:`A` | Rotational constant | :math:`\textrm{GHz}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 17 | :math:`B` | Rotational constant | :math:`\textrm{GHz}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ | 18 | :math:`C` | Rotational constant | :math:`\textrm{GHz}` | +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ .. note:: We also provide a pre-processed version of the dataset in case :class:`rdkit` is not installed. The pre-processed version matches with the manually processed version as outlined in :meth:`process`. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #tasks * - 130,831 - ~18.0 - ~37.3 - 11 - 19 """ # noqa: E501 raw_url = ('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/' 'molnet_publish/qm9.zip') raw_url2 = 'https://ndownloader.figshare.com/files/3195404' processed_url = 'https://data.pyg.org/datasets/qm9_v3.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) def mean(self, target: int) -> float: y = torch.cat([self.get(i).y for i in range(len(self))], dim=0) return float(y[:, target].mean()) def std(self, target: int) -> float: y = torch.cat([self.get(i).y for i in range(len(self))], dim=0) return float(y[:, target].std()) def atomref(self, target: int) -> Optional[Tensor]: if target in atomrefs: out = torch.zeros(100) out[torch.tensor([1, 6, 7, 8, 9])] = torch.tensor(atomrefs[target]) return out.view(-1, 1) return None @property def raw_file_names(self) -> List[str]: try: import rdkit # noqa return ['gdb9.sdf', 'gdb9.sdf.csv', 'uncharacterized.txt'] except ImportError: return ['qm9_v3.pt'] @property def processed_file_names(self) -> str: return 'data_v3.pt' def download(self) -> None: try: import rdkit # noqa file_path = download_url(self.raw_url, self.raw_dir) extract_zip(file_path, self.raw_dir) os.unlink(file_path) file_path = download_url(self.raw_url2, self.raw_dir) os.rename(osp.join(self.raw_dir, '3195404'), osp.join(self.raw_dir, 'uncharacterized.txt')) except ImportError: path = download_url(self.processed_url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: try: from rdkit import Chem, RDLogger from rdkit.Chem.rdchem import BondType as BT from rdkit.Chem.rdchem import HybridizationType RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined] WITH_RDKIT = True except ImportError: WITH_RDKIT = False if not WITH_RDKIT: print(("Using a pre-processed version of the dataset. Please " "install 'rdkit' to alternatively process the raw data."), file=sys.stderr) data_list = fs.torch_load(self.raw_paths[0]) data_list = [Data(**data_dict) for data_dict in data_list] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) return types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} with open(self.raw_paths[1]) as f: target = [[float(x) for x in line.split(',')[1:20]] for line in f.read().split('\n')[1:-1]] y = torch.tensor(target, dtype=torch.float) y = torch.cat([y[:, 3:], y[:, :3]], dim=-1) y = y * conversion.view(1, -1) with open(self.raw_paths[2]) as f: skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]] suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False) data_list = [] for i, mol in enumerate(tqdm(suppl)): if i in skip: continue N = mol.GetNumAtoms() conf = mol.GetConformer() pos = conf.GetPositions() pos = torch.tensor(pos, dtype=torch.float) type_idx = [] atomic_number = [] aromatic = [] sp = [] sp2 = [] sp3 = [] num_hs = [] for atom in mol.GetAtoms(): type_idx.append(types[atom.GetSymbol()]) atomic_number.append(atom.GetAtomicNum()) aromatic.append(1 if atom.GetIsAromatic() else 0) hybridization = atom.GetHybridization() sp.append(1 if hybridization == HybridizationType.SP else 0) sp2.append(1 if hybridization == HybridizationType.SP2 else 0) sp3.append(1 if hybridization == HybridizationType.SP3 else 0) z = torch.tensor(atomic_number, dtype=torch.long) rows, cols, edge_types = [], [], [] for bond in mol.GetBonds(): start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() rows += [start, end] cols += [end, start] edge_types += 2 * [bonds[bond.GetBondType()]] edge_index = torch.tensor([rows, cols], dtype=torch.long) edge_type = torch.tensor(edge_types, dtype=torch.long) edge_attr = one_hot(edge_type, num_classes=len(bonds)) perm = (edge_index[0] * N + edge_index[1]).argsort() edge_index = edge_index[:, perm] edge_type = edge_type[perm] edge_attr = edge_attr[perm] row, col = edge_index hs = (z == 1).to(torch.float) num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist() x1 = one_hot(torch.tensor(type_idx), num_classes=len(types)) x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous() x = torch.cat([x1, x2], dim=-1) name = mol.GetProp('_Name') smiles = Chem.MolToSmiles(mol, isomericSmiles=True) data = Data( x=x, z=z, pos=pos, edge_index=edge_index, smiles=smiles, edge_attr=edge_attr, y=y[i].unsqueeze(0), name=name, idx=i, ) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/rcdd.py ================================================ import os from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.utils import index_to_mask class RCDD(InMemoryDataset): r"""The risk commodity detection dataset (RCDD) from the `"Datasets and Interfaces for Benchmarking Heterogeneous Graph Neural Networks" `_ paper. RCDD is an industrial-scale heterogeneous graph dataset based on a real risk detection scenario from Alibaba's e-commerce platform. It consists of 13,806,619 nodes and 157,814,864 edges across 7 node types and 7 edge types, respectively. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/' 'openhgnn/AliRCD_ICDM.zip') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'AliRCD_ICDM_nodes.csv', 'AliRCD_ICDM_edges.csv', 'AliRCD_ICDM_train_labels.csv', 'AliRCD_ICDM_test_labels.csv', ] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) @property def num_classes(self) -> int: return 2 def process(self) -> None: import pandas as pd data = HeteroData() node_df = pd.read_csv( # AliRCD_ICDM_nodes.csv: self.raw_paths[0], header=None, names=['node_id', 'node_type', 'node_feat'], ) # Map global node IDs to local ones for each node type: mapping = torch.empty(len(node_df), dtype=torch.long) for node_type in node_df['node_type'].unique(): mask = node_df['node_type'] == node_type node_id = torch.from_numpy(node_df['node_id'][mask].values) num_nodes = mask.sum() mapping[node_id] = torch.arange(num_nodes) data[node_type].num_nodes = num_nodes x = np.vstack([ np.asarray(f.split(':'), dtype=np.float32) for f in node_df['node_feat'][mask] ]) data[node_type].x = torch.from_numpy(x) edge_df = pd.read_csv( # AliRCD_ICDM_edges.csv: self.raw_paths[1], header=None, names=['src_id', 'dst_id', 'src_type', 'dst_type', 'edge_type'], ) for edge_type in edge_df['edge_type'].unique(): edge_type_df = edge_df[edge_df['edge_type'] == edge_type] src_type = edge_type_df['src_type'].iloc[0] dst_type = edge_type_df['dst_type'].iloc[0] src = mapping[torch.from_numpy(edge_type_df['src_id'].values)] dst = mapping[torch.from_numpy(edge_type_df['dst_id'].values)] edge_index = torch.stack([src, dst], dim=0) data[src_type, edge_type, dst_type].edge_index = edge_index train_df = pd.read_csv( # AliRCD_ICDM_train_labels.csv: self.raw_paths[2], header=None, names=['node_id', 'label'], dtype=int, ) test_df = pd.read_csv( # AliRCD_ICDM_test_labels.csv: self.raw_paths[3], header=None, sep='\t', names=['node_id', 'label'], dtype=int, ) train_idx = mapping[torch.from_numpy(train_df['node_id'].values)] test_idx = mapping[torch.from_numpy(test_df['node_id'].values)] y = torch.full((data['item'].num_nodes, ), -1, dtype=torch.long) y[train_idx] = torch.from_numpy(train_df['label'].values) y[test_idx] = torch.from_numpy(test_df['label'].values) train_mask = index_to_mask(train_idx, data['item'].num_nodes) test_mask = index_to_mask(test_idx, data['item'].num_nodes) data['item'].y = y data['item'].train_mask = train_mask data['item'].test_mask = test_mask if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/reddit.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.utils import coalesce class Reddit(InMemoryDataset): r"""The Reddit dataset from the `"Inductive Representation Learning on Large Graphs" `_ paper, containing Reddit posts belonging to different communities. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 232,965 - 114,615,892 - 602 - 41 """ url = 'https://data.dgl.ai/dataset/reddit.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['reddit_data.npz', 'reddit_graph.npz'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: import scipy.sparse as sp data = np.load(osp.join(self.raw_dir, 'reddit_data.npz')) x = torch.from_numpy(data['feature']).to(torch.float) y = torch.from_numpy(data['label']).to(torch.long) split = torch.from_numpy(data['node_types']) adj = sp.load_npz(osp.join(self.raw_dir, 'reddit_graph.npz')) row = torch.from_numpy(adj.row).to(torch.long) col = torch.from_numpy(adj.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) edge_index = coalesce(edge_index, num_nodes=x.size(0)) data = Data(x=x, edge_index=edge_index, y=y) data.train_mask = split == 1 data.val_mask = split == 2 data.test_mask = split == 3 data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/reddit2.py ================================================ import json import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_google_url class Reddit2(InMemoryDataset): r"""The Reddit dataset from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" `_ paper, containing Reddit posts belonging to different communities. .. note:: This is a sparser version of the original :obj:`~torch_geometric.datasets.Reddit` dataset (~23M edges instead of ~114M edges), and is used in papers such as `SGC `_ and `GraphSAINT `_. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 232,965 - 23,213,838 - 602 - 41 """ adj_full_id = '1sncK996BM5lpuDf75lDFqCiDZyErc1c2' feats_id = '1ZsHaJ0ussP1W722krmEIp_8pwKAoi5b3' class_map_id = '1JF3Pjv9OboMNYs2aXRQGbJbc4t_nDd5u' role_id = '1nJIKd77lcAGU4j-kVNx_AIGEkveIKz3A' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz') download_google_url(self.feats_id, self.raw_dir, 'feats.npy') download_google_url(self.class_map_id, self.raw_dir, 'class_map.json') download_google_url(self.role_id, self.raw_dir, 'role.json') def process(self) -> None: import scipy.sparse as sp f = np.load(osp.join(self.raw_dir, 'adj_full.npz')) adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape']) adj = adj.tocoo() row = torch.from_numpy(adj.row).to(torch.long) col = torch.from_numpy(adj.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) x = np.load(osp.join(self.raw_dir, 'feats.npy')) x = torch.from_numpy(x).to(torch.float) ys = [-1] * x.size(0) with open(osp.join(self.raw_dir, 'class_map.json')) as f: class_map = json.load(f) for key, item in class_map.items(): ys[int(key)] = item y = torch.tensor(ys) with open(osp.join(self.raw_dir, 'role.json')) as f: role = json.load(f) train_mask = torch.zeros(x.size(0), dtype=torch.bool) train_mask[torch.tensor(role['tr'])] = True val_mask = torch.zeros(x.size(0), dtype=torch.bool) val_mask[torch.tensor(role['va'])] = True test_mask = torch.zeros(x.size(0), dtype=torch.bool) test_mask[torch.tensor(role['te'])] = True data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/rel_link_pred_dataset.py ================================================ import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import Data, InMemoryDataset, download_url class RelLinkPredDataset(InMemoryDataset): r"""The relational link prediction datasets from the `"Modeling Relational Data with Graph Convolutional Networks" `_ paper. Training and test splits are given by sets of triplets. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"FB15k-237"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 14,541 - 544,230 - 0 - 0 """ urls = { 'FB15k-237': ('https://raw.githubusercontent.com/MichSchli/' 'RelationPrediction/master/data/FB-Toutanova') } def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name assert name in ['FB15k-237'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def num_relations(self) -> int: return int(self._data.edge_type.max()) + 1 # type: ignore @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def processed_file_names(self) -> str: return 'data.pt' @property def raw_file_names(self) -> List[str]: return [ 'entities.dict', 'relations.dict', 'test.txt', 'train.txt', 'valid.txt' ] def download(self) -> None: for file_name in self.raw_file_names: download_url(f'{self.urls[self.name]}/{file_name}', self.raw_dir) def process(self) -> None: with open(osp.join(self.raw_dir, 'entities.dict')) as f: lines = [row.split('\t') for row in f.read().split('\n')[:-1]] entities_dict = {key: int(value) for value, key in lines} with open(osp.join(self.raw_dir, 'relations.dict')) as f: lines = [row.split('\t') for row in f.read().split('\n')[:-1]] relations_dict = {key: int(value) for value, key in lines} kwargs = {} for split in ['train', 'valid', 'test']: with open(osp.join(self.raw_dir, f'{split}.txt')) as f: lines = [row.split('\t') for row in f.read().split('\n')[:-1]] src = [entities_dict[row[0]] for row in lines] rel = [relations_dict[row[1]] for row in lines] dst = [entities_dict[row[2]] for row in lines] kwargs[f'{split}_edge_index'] = torch.tensor([src, dst]) kwargs[f'{split}_edge_type'] = torch.tensor(rel) # For message passing, we add reverse edges and types to the graph: row, col = kwargs['train_edge_index'] edge_type = kwargs['train_edge_type'] row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) edge_index = torch.stack([row, col], dim=0) edge_type = torch.cat([edge_type, edge_type + len(relations_dict)]) data = Data(num_nodes=len(entities_dict), edge_index=edge_index, edge_type=edge_type, **kwargs) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name}()' ================================================ FILE: torch_geometric/datasets/s3dis.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import torch from torch import Tensor from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class S3DIS(InMemoryDataset): r"""The (pre-processed) Stanford Large-Scale 3D Indoor Spaces dataset from the `"3D Semantic Parsing of Large-Scale Indoor Spaces" `_ paper, containing point clouds of six large-scale indoor parts in three buildings with 12 semantic elements (and one clutter class). Args: root (str): Root directory where the dataset should be saved. test_area (int, optional): Which area to use for testing (1-6). (default: :obj:`6`) train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://shapenet.cs.stanford.edu/media/' 'indoor3d_sem_seg_hdf5_data.zip') # In case `shapenet.cs.stanford.edu` is offline, try to download the data # from here: # https://cvg-data.inf.ethz.ch/s3dis/ def __init__( self, root: str, test_area: int = 6, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert test_area >= 1 and test_area <= 6 self.test_area = test_area super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> List[str]: return ['all_files.txt', 'room_filelist.txt'] @property def processed_file_names(self) -> List[str]: return [f'{split}_{self.test_area}.pt' for split in ['train', 'test']] def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) fs.rm(self.raw_dir) name = self.url.split('/')[-1].split('.')[0] os.rename(osp.join(self.root, name), self.raw_dir) def process(self) -> None: import h5py with open(self.raw_paths[0]) as f: filenames = [x.split('/')[-1] for x in f.read().split('\n')[:-1]] with open(self.raw_paths[1]) as f: rooms = f.read().split('\n')[:-1] xs: List[Tensor] = [] ys: List[Tensor] = [] for filename in filenames: h5 = h5py.File(osp.join(self.raw_dir, filename)) xs += torch.from_numpy(h5['data'][:]).unbind(0) ys += torch.from_numpy(h5['label'][:]).to(torch.long).unbind(0) test_area = f'Area_{self.test_area}' train_data_list, test_data_list = [], [] for i, (x, y) in enumerate(zip(xs, ys)): data = Data(pos=x[:, :3], x=x[:, 3:], y=y) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) if test_area not in rooms[i]: train_data_list.append(data) else: test_data_list.append(data) self.save(train_data_list, self.processed_paths[0]) self.save(test_data_list, self.processed_paths[1]) ================================================ FILE: torch_geometric/datasets/sbm_dataset.py ================================================ import os.path as osp from typing import Any, Callable, List, Optional, Union import numpy as np import torch from torch import Tensor from torch_geometric.data import Data, InMemoryDataset from torch_geometric.utils import stochastic_blockmodel_graph class StochasticBlockModelDataset(InMemoryDataset): r"""A synthetic graph dataset generated by the stochastic block model. The node features of each block are sampled from normal distributions where the centers of clusters are vertices of a hypercube, as computed by the :meth:`sklearn.datasets.make_classification` method. Args: root (str): Root directory where the dataset should be saved. block_sizes ([int] or LongTensor): The sizes of blocks. edge_probs ([[float]] or FloatTensor): The density of edges going from each block to each other block. Must be symmetric if the graph is undirected. num_graphs (int, optional): The number of graphs. (default: :obj:`1`) num_channels (int, optional): The number of node features. If given as :obj:`None`, node features are not generated. (default: :obj:`None`) is_undirected (bool, optional): Whether the graph to generate is undirected. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **kwargs (optional): The keyword arguments that are passed down to the :meth:`sklearn.datasets.make_classification` method for drawing node features. """ def __init__( self, root: str, block_sizes: Union[List[int], Tensor], edge_probs: Union[List[List[float]], Tensor], num_graphs: int = 1, num_channels: Optional[int] = None, is_undirected: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, **kwargs: Any, ) -> None: if not isinstance(block_sizes, torch.Tensor): block_sizes = torch.tensor(block_sizes, dtype=torch.long) if not isinstance(edge_probs, torch.Tensor): edge_probs = torch.tensor(edge_probs, dtype=torch.float) assert num_graphs > 0 self.block_sizes = block_sizes self.edge_probs = edge_probs self.num_graphs = num_graphs self.num_channels = num_channels self.is_undirected = is_undirected self.kwargs = { 'n_informative': num_channels, 'n_redundant': 0, 'flip_y': 0.0, 'shuffle': False, } self.kwargs.update(kwargs) super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def processed_dir(self) -> str: return osp.join(self.root, self.__class__.__name__, 'processed') @property def processed_file_names(self) -> str: block_sizes = self.block_sizes.view(-1).tolist() hash1 = '-'.join([f'{x:.1f}' for x in block_sizes]) edge_probs = self.edge_probs.view(-1).tolist() hash2 = '-'.join([f'{x:.1f}' for x in edge_probs]) return f'data_{self.num_channels}_{hash1}_{hash2}_{self.num_graphs}.pt' def process(self) -> None: from sklearn.datasets import make_classification edge_index = stochastic_blockmodel_graph( self.block_sizes, self.edge_probs, directed=not self.is_undirected) num_samples = int(self.block_sizes.sum()) num_classes = self.block_sizes.size(0) data_list = [] for _ in range(self.num_graphs): x = None if self.num_channels is not None: x, y_not_sorted = make_classification( n_samples=num_samples, n_features=self.num_channels, n_classes=num_classes, weights=self.block_sizes / num_samples, **self.kwargs, ) x = x[np.argsort(y_not_sorted)] x = torch.from_numpy(x).to(torch.float) y = torch.arange(num_classes).repeat_interleave(self.block_sizes) data = Data(x=x, edge_index=edge_index, y=y) if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) class RandomPartitionGraphDataset(StochasticBlockModelDataset): r"""The random partition graph dataset from the `"How to Find Your Friendly Neighborhood: Graph Attention Design with Self-Supervision" `_ paper. This is a synthetic graph of communities controlled by the node homophily and the average degree, and each community is considered as a class. The node features are sampled from normal distributions where the centers of clusters are vertices of a hypercube, as computed by the :meth:`sklearn.datasets.make_classification` method. Args: root (str): Root directory where the dataset should be saved. num_classes (int): The number of classes. num_nodes_per_class (int): The number of nodes per class. node_homophily_ratio (float): The degree of node homophily. average_degree (float): The average degree of the graph. num_graphs (int, optional): The number of graphs. (default: :obj:`1`) num_channels (int, optional): The number of node features. If given as :obj:`None`, node features are not generated. (default: :obj:`None`) is_undirected (bool, optional): Whether the graph to generate is undirected. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) **kwargs (optional): The keyword arguments that are passed down to :meth:`sklearn.datasets.make_classification` method in drawing node features. """ def __init__( self, root: str, num_classes: int, num_nodes_per_class: int, node_homophily_ratio: float, average_degree: float, num_graphs: int = 1, num_channels: Optional[int] = None, is_undirected: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, **kwargs: Any, ) -> None: self._num_classes = num_classes self.num_nodes_per_class = num_nodes_per_class self.node_homophily_ratio = node_homophily_ratio self.average_degree = average_degree # (p_in + (C - 1) * p_out) / C = |E|/|V|^2 # i.e., p_in + (C - 1) * p_out = average_degree / num_nodes_per_class ec_over_v2 = average_degree / num_nodes_per_class p_in = node_homophily_ratio * ec_over_v2 p_out = (ec_over_v2 - p_in) / (num_classes - 1) block_sizes = [num_nodes_per_class for _ in range(num_classes)] edge_probs = [[p_out for _ in range(num_classes)] for _ in range(num_classes)] for r in range(num_classes): edge_probs[r][r] = p_in super().__init__(root, block_sizes, edge_probs, num_graphs, num_channels, is_undirected, transform, pre_transform, **kwargs) @property def processed_file_names(self) -> str: return (f'data_{self.num_channels}_{self._num_classes}_' f'{self.num_nodes_per_class}_{self.node_homophily_ratio:.1f}_' f'{self.average_degree:.1f}_{self.num_graphs}.pt') def process(self) -> None: return super().process() ================================================ FILE: torch_geometric/datasets/shapenet.py ================================================ import json import os import os.path as osp from typing import Callable, List, Optional, Union import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs, read_txt_array class ShapeNet(InMemoryDataset): r"""The ShapeNet part level segmentation dataset from the `"A Scalable Active Framework for Region Annotation in 3D Shape Collections" `_ paper, containing about 17,000 3D shape point clouds from 16 shape categories. Each category is annotated with 2 to 6 parts. Args: root (str): Root directory where the dataset should be saved. categories (str or [str], optional): The category of the CAD models (one or a combination of :obj:`"Airplane"`, :obj:`"Bag"`, :obj:`"Cap"`, :obj:`"Car"`, :obj:`"Chair"`, :obj:`"Earphone"`, :obj:`"Guitar"`, :obj:`"Knife"`, :obj:`"Lamp"`, :obj:`"Laptop"`, :obj:`"Motorbike"`, :obj:`"Mug"`, :obj:`"Pistol"`, :obj:`"Rocket"`, :obj:`"Skateboard"`, :obj:`"Table"`). Can be explicitly set to :obj:`None` to load all categories. (default: :obj:`None`) include_normals (bool, optional): If set to :obj:`False`, will not include normal vectors as input features to :obj:`data.x`. As a result, :obj:`data.x` will be :obj:`None`. (default: :obj:`True`) split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"trainval"`, loads the training and validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"trainval"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 16,881 - ~2,616.2 - 0 - 3 - 50 """ url = ('https://shapenet.cs.stanford.edu/media/' 'shapenetcore_partanno_segmentation_benchmark_v0_normal.zip') # In case `shapenet.cs.stanford.edu` is offline, try to download the data # from Kaggle instead (requires login): # https://www.kaggle.com/datasets/mitkir/shapenet/download?datasetVersionNumber=1 category_ids = { 'Airplane': '02691156', 'Bag': '02773838', 'Cap': '02954340', 'Car': '02958343', 'Chair': '03001627', 'Earphone': '03261776', 'Guitar': '03467517', 'Knife': '03624134', 'Lamp': '03636649', 'Laptop': '03642806', 'Motorbike': '03790512', 'Mug': '03797390', 'Pistol': '03948459', 'Rocket': '04099429', 'Skateboard': '04225987', 'Table': '04379243', } seg_classes = { 'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11], 'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21], 'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40], 'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49], } def __init__( self, root: str, categories: Optional[Union[str, List[str]]] = None, include_normals: bool = True, split: str = 'trainval', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: if categories is None: categories = list(self.category_ids.keys()) if isinstance(categories, str): categories = [categories] assert all(category in self.category_ids for category in categories) self.categories = categories super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) if split == 'train': path = self.processed_paths[0] elif split == 'val': path = self.processed_paths[1] elif split == 'test': path = self.processed_paths[2] elif split == 'trainval': path = self.processed_paths[3] else: raise ValueError(f'Split {split} found, but expected either ' 'train, val, trainval or test') self.load(path) assert isinstance(self._data, Data) self._data.x = self._data.x if include_normals else None self.y_mask = torch.zeros((len(self.seg_classes.keys()), 50), dtype=torch.bool) for i, labels in enumerate(self.seg_classes.values()): self.y_mask[i, labels] = 1 @property def num_classes(self) -> int: return self.y_mask.size(-1) @property def raw_file_names(self) -> List[str]: return list(self.category_ids.values()) + ['train_test_split'] @property def processed_file_names(self) -> List[str]: cats = '_'.join([cat[:3].lower() for cat in self.categories]) return [ osp.join(f'{cats}_{split}.pt') for split in ['train', 'val', 'test', 'trainval'] ] def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) fs.rm(self.raw_dir) name = self.url.split('/')[-1].split('.')[0] os.rename(osp.join(self.root, name), self.raw_dir) def process_filenames(self, filenames: List[str]) -> List[Data]: data_list = [] categories_ids = [self.category_ids[cat] for cat in self.categories] cat_idx = {categories_ids[i]: i for i in range(len(categories_ids))} for name in filenames: cat = name.split(osp.sep)[0] if cat not in categories_ids: continue tensor = read_txt_array(osp.join(self.raw_dir, name)) pos = tensor[:, :3] x = tensor[:, 3:6] y = tensor[:, -1].type(torch.long) data = Data(pos=pos, x=x, y=y, category=cat_idx[cat]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) return data_list def process(self) -> None: trainval = [] for i, split in enumerate(['train', 'val', 'test']): path = osp.join(self.raw_dir, 'train_test_split', f'shuffled_{split}_file_list.json') with open(path) as f: filenames = [ osp.sep.join(name.split('/')[1:]) + '.txt' for name in json.load(f) ] # Removing first directory. data_list = self.process_filenames(filenames) if split == 'train' or split == 'val': trainval += data_list self.save(data_list, self.processed_paths[i]) self.save(trainval, self.processed_paths[3]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'categories={self.categories})') ================================================ FILE: torch_geometric/datasets/shrec2016.py ================================================ import glob import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import InMemoryDataset, download_url, extract_zip from torch_geometric.io import fs, read_off, read_txt_array class SHREC2016(InMemoryDataset): r"""The SHREC 2016 partial matching dataset from the `"SHREC'16: Partial Matching of Deformable Shapes" `_ paper. The reference shape can be referenced via :obj:`dataset.ref`. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. partiality (str): The partiality of the dataset (one of :obj:`"Holes"`, :obj:`"Cuts"`). category (str): The category of the dataset (one of :obj:`"Cat"`, :obj:`"Centaur"`, :obj:`"David"`, :obj:`"Dog"`, :obj:`"Horse"`, :obj:`"Michael"`, :obj:`"Victoria"`, :obj:`"Wolf"`). train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ train_url = ('http://www.dais.unive.it/~shrec2016/data/' 'shrec2016_PartialDeformableShapes.zip') test_url = ('http://www.dais.unive.it/~shrec2016/data/' 'shrec2016_PartialDeformableShapes_TestSet.zip') categories = [ 'cat', 'centaur', 'david', 'dog', 'horse', 'michael', 'victoria', 'wolf' ] partialities = ['holes', 'cuts'] def __init__( self, root: str, partiality: str, category: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert partiality.lower() in self.partialities self.part = partiality.lower() assert category.lower() in self.categories self.cat = category.lower() super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.__ref__ = fs.torch_load(self.processed_paths[0]) path = self.processed_paths[1] if train else self.processed_paths[2] self.load(path) @property def ref(self) -> str: ref = self.__ref__ if self.transform is not None: ref = self.transform(ref) return ref @property def raw_file_names(self) -> List[str]: return ['training', 'test'] @property def processed_file_names(self) -> List[str]: name = f'{self.part}_{self.cat}.pt' return [f'{i}_{name}' for i in ['ref', 'training', 'test']] def download(self) -> None: path = download_url(self.train_url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) path = osp.join(self.raw_dir, 'shrec2016_PartialDeformableShapes') os.rename(path, osp.join(self.raw_dir, 'training')) path = download_url(self.test_url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) path = osp.join(self.raw_dir, 'shrec2016_PartialDeformableShapes_TestSet') os.rename(path, osp.join(self.raw_dir, 'test')) def process(self) -> None: ref_data = read_off( osp.join(self.raw_paths[0], 'null', f'{self.cat}.off')) train_list = [] name = f'{self.part}_{self.cat}_*.off' paths = glob.glob(osp.join(self.raw_paths[0], self.part, name)) paths = [path[:-4] for path in paths] paths = sorted(paths, key=lambda e: (len(e), e)) for path in paths: data = read_off(f'{path}.off') y = read_txt_array(f'{path}.baryc_gt') data.y = y[:, 0].to(torch.long) - 1 data.y_baryc = y[:, 1:] train_list.append(data) test_list = [] name = f'{self.part}_{self.cat}_*.off' paths = glob.glob(osp.join(self.raw_paths[1], self.part, name)) paths = [path[:-4] for path in paths] paths = sorted(paths, key=lambda e: (len(e), e)) for path in paths: test_list.append(read_off(f'{path}.off')) if self.pre_filter is not None: train_list = [d for d in train_list if self.pre_filter(d)] test_list = [d for d in test_list if self.pre_filter(d)] if self.pre_transform is not None: ref_data = self.pre_transform(ref_data) train_list = [self.pre_transform(d) for d in train_list] test_list = [self.pre_transform(d) for d in test_list] torch.save(ref_data, self.processed_paths[0]) self.save(train_list, self.processed_paths[1]) self.save(test_list, self.processed_paths[2]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'partiality={self.part}, category={self.cat})') ================================================ FILE: torch_geometric/datasets/snap_dataset.py ================================================ import os import os.path as osp from typing import Any, Callable, Dict, List, Optional, Union import fsspec import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs from torch_geometric.utils import coalesce class EgoData(Data): def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any: if key == 'circle': return self.num_nodes elif key == 'circle_batch': return int(value.max()) + 1 if value.numel() > 0 else 0 return super().__inc__(key, value, *args, **kwargs) def read_ego(files: List[str], name: str) -> List[EgoData]: import pandas as pd import tqdm files = sorted(files) all_featnames = [] files = [ x for x in files if x.split('.')[-1] in ['circles', 'edges', 'egofeat', 'feat', 'featnames'] ] for i in range(4, len(files), 5): featnames_file = files[i] with fsspec.open(featnames_file, 'r') as f: featnames = f.read().split('\n')[:-1] featnames = [' '.join(x.split(' ')[1:]) for x in featnames] all_featnames += featnames all_featnames = sorted(list(set(all_featnames))) all_featnames_dict = {key: i for i, key in enumerate(all_featnames)} data_list = [] for i in tqdm.tqdm(range(0, len(files), 5)): circles_file = files[i] edges_file = files[i + 1] egofeat_file = files[i + 2] feat_file = files[i + 3] featnames_file = files[i + 4] x = None if name != 'gplus': # Don't read node features on g-plus: x_ego = pd.read_csv(egofeat_file, sep=' ', header=None, dtype=np.float32) x_ego = torch.from_numpy(x_ego.values) x = pd.read_csv(feat_file, sep=' ', header=None, dtype=np.float32) x = torch.from_numpy(x.values)[:, 1:] x_all = torch.cat([x, x_ego], dim=0) # Reorder `x` according to `featnames` ordering. x_all = torch.zeros(x.size(0), len(all_featnames)) with fsspec.open(featnames_file, 'r') as f: featnames = f.read().split('\n')[:-1] featnames = [' '.join(x.split(' ')[1:]) for x in featnames] indices = [all_featnames_dict[featname] for featname in featnames] x_all[:, torch.tensor(indices)] = x x = x_all if x.size(1) > 100_000: x = x.to_sparse_csr() idx = pd.read_csv(feat_file, sep=' ', header=None, dtype=str, usecols=[0]).squeeze() idx_assoc: Dict[str, int] = {} for i, j in enumerate(idx): idx_assoc[j] = i circles: List[int] = [] circles_batch: List[int] = [] with fsspec.open(circles_file, 'r') as f: for i, line in enumerate(f.read().split('\n')[:-1]): circle_indices = [idx_assoc[c] for c in line.split()[1:]] circles += circle_indices circles_batch += [i] * len(circle_indices) circle = torch.tensor(circles) circle_batch = torch.tensor(circles_batch) try: row = pd.read_csv(edges_file, sep=' ', header=None, dtype=str, usecols=[0]).squeeze() col = pd.read_csv(edges_file, sep=' ', header=None, dtype=str, usecols=[1]).squeeze() except Exception: continue row = torch.tensor([idx_assoc[i] for i in row]) col = torch.tensor([idx_assoc[i] for i in col]) N = max(int(row.max()), int(col.max())) + 2 N = x.size(0) if x is not None else N row_ego = torch.full((N - 1, ), N - 1, dtype=torch.long) col_ego = torch.arange(N - 1) # Ego node should be connected to every other node. row = torch.cat([row, row_ego, col_ego], dim=0) col = torch.cat([col, col_ego, row_ego], dim=0) edge_index = torch.stack([row, col], dim=0) edge_index = coalesce(edge_index, num_nodes=int(N)) data = EgoData(x=x, edge_index=edge_index, circle=circle, circle_batch=circle_batch) data_list.append(data) return data_list def read_soc(files: List[str], name: str) -> List[Data]: import pandas as pd skiprows = 4 if name == 'pokec': skiprows = 0 edge_index = pd.read_csv(files[0], sep='\t', header=None, skiprows=skiprows, dtype=np.int64) edge_index = torch.from_numpy(edge_index.values).t() num_nodes = int(edge_index.max()) + 1 edge_index = coalesce(edge_index, num_nodes=num_nodes) return [Data(edge_index=edge_index, num_nodes=num_nodes)] def read_wiki(files: List[str], name: str) -> List[Data]: import pandas as pd edge_index = pd.read_csv(files[0], sep='\t', header=None, skiprows=4, dtype=np.int64) edge_index = torch.from_numpy(edge_index.values).t() idx = torch.unique(edge_index.flatten()) idx_assoc = torch.full( (edge_index.max() + 1, ), # type: ignore -1, dtype=torch.long, ) idx_assoc[idx] = torch.arange(idx.size(0)) edge_index = idx_assoc[edge_index] num_nodes = int(edge_index.max()) + 1 edge_index = coalesce(edge_index, num_nodes=num_nodes) return [Data(edge_index=edge_index, num_nodes=num_nodes)] class SNAPDataset(InMemoryDataset): r"""A variety of graph datasets collected from `SNAP at Stanford University `_. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://snap.stanford.edu/data' available_datasets = { 'ego-facebook': ['facebook.tar.gz'], 'ego-gplus': ['gplus.tar.gz'], 'ego-twitter': ['twitter.tar.gz'], 'soc-ca-astroph': ['ca-AstroPh.txt.gz'], 'soc-ca-grqc': ['ca-GrQc.txt.gz'], 'soc-epinions1': ['soc-Epinions1.txt.gz'], 'soc-livejournal1': ['soc-LiveJournal1.txt.gz'], 'soc-pokec': ['soc-pokec-relationships.txt.gz'], 'soc-slashdot0811': ['soc-Slashdot0811.txt.gz'], 'soc-slashdot0922': ['soc-Slashdot0902.txt.gz'], 'wiki-vote': ['wiki-Vote.txt.gz'], } def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in self.available_datasets.keys() super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def processed_file_names(self) -> str: return 'data.pt' def _download(self) -> None: if osp.isdir(self.raw_dir) and len(os.listdir(self.raw_dir)) > 0: return fs.makedirs(self.raw_dir, exist_ok=True) self.download() def download(self) -> None: for name in self.available_datasets[self.name]: fs.cp(f'{self.url}/{name}', self.raw_dir, extract=True) def process(self) -> None: raw_dir = self.raw_dir filenames = fs.ls(self.raw_dir) if len(filenames) == 1 and fs.isdir(filenames[0]): raw_dir = filenames[0] raw_files = fs.ls(raw_dir) data_list: Union[List[Data], List[EgoData]] if self.name[:4] == 'ego-': data_list = read_ego(raw_files, self.name[4:]) elif self.name[:4] == 'soc-': data_list = read_soc(raw_files, self.name[:4]) elif self.name[:5] == 'wiki-': data_list = read_wiki(raw_files, self.name[5:]) else: raise NotImplementedError if len(data_list) > 1 and self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] self.save(data_list, self.processed_paths[0]) def __repr__(self) -> str: return f'SNAP-{self.name}({len(self)})' ================================================ FILE: torch_geometric/datasets/suite_sparse.py ================================================ import os.path as osp from typing import Callable, Optional import fsspec import torch from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs class SuiteSparseMatrixCollection(InMemoryDataset): r"""A suite of sparse matrix benchmarks known as the `Suite Sparse Matrix Collection `_ collected from a wide range of applications. Args: root (str): Root directory where the dataset should be saved. group (str): The group of the sparse matrix. name (str): The name of the sparse matrix. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://sparse.tamu.edu/mat/{}/{}.mat' def __init__( self, root: str, group: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.group = group self.name = name super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.group, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.group, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.name}.mat' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: fs.cp(self.url.format(self.group, self.name), self.raw_dir) def process(self) -> None: from scipy.io import loadmat with fsspec.open(self.raw_paths[0], 'rb') as f: mat = loadmat(f)['Problem'][0][0][2].tocsr().tocoo() row = torch.from_numpy(mat.row).to(torch.long) col = torch.from_numpy(mat.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) value = torch.from_numpy(mat.data).to(torch.float) edge_attr = None if torch.all(value == 1.0) else value size: Optional[torch.Size] = torch.Size(mat.shape) if mat.shape[0] == mat.shape[1]: size = None num_nodes = mat.shape[0] data = Data(edge_index=edge_index, edge_attr=edge_attr, size=size, num_nodes=num_nodes) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return (f'{self.__class__.__name__}(group={self.group}, ' f'name={self.name})') ================================================ FILE: torch_geometric/datasets/tag_dataset.py ================================================ import csv import os import os.path as osp from collections.abc import Sequence from typing import Dict, List, Optional, Union import numpy as np import torch from torch import Tensor from tqdm import tqdm from torch_geometric.data import InMemoryDataset, download_google_url from torch_geometric.data.data import BaseData from torch_geometric.io import fs try: from pandas import DataFrame, read_csv WITH_PANDAS = True except ImportError: WITH_PANDAS = False IndexType = Union[slice, Tensor, np.ndarray, Sequence] class TAGDataset(InMemoryDataset): r"""The Text Attributed Graph datasets from the `"Learning on Large-scale Text-attributed Graphs via Variational Inference" `_ paper and `"Harnessing Explanations: LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation Learning" `_ paper. This dataset is aiming on transform `ogbn products`, `ogbn arxiv` into Text Attributed Graph that each node in graph is associate with a raw text, LLM prediction and explanation, that dataset can be adapt to DataLoader (for LM training) and NeighborLoader(for GNN training). In addition, this class can be use as a wrapper class by convert a InMemoryDataset with Tokenizer and text into Text Attributed Graph. Args: root (str): Root directory where the dataset should be saved. dataset (InMemoryDataset): The name of the dataset (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`). tokenizer_name (str): The tokenizer name for language model, Be sure to use same tokenizer name as your `model id` of model repo on huggingface.co. text (List[str]): list of raw text associate with node, the order of list should be align with node list split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, for saving split index, it is required that if your dataset doesn't have get_split_idx function tokenize_batch_size (int): batch size of tokenizing text, the tokenizing process will run on cpu, default: 256 token_on_disk (bool): save token as .pt file on disk or not, default: False text_on_disk (bool): save given text(list of str) as dataframe on disk or not, default: False force_reload (bool): default: False .. note:: See `example/llm/glem.py` for example usage """ raw_text_id = { 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3', 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt' } llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds' llm_explanation_id = { 'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ', } def __init__( self, root: str, dataset: InMemoryDataset, tokenizer_name: str, text: Optional[List[str]] = None, split_idx: Optional[Dict[str, Tensor]] = None, tokenize_batch_size: int = 256, token_on_disk: bool = False, text_on_disk: bool = False, force_reload: bool = False, ) -> None: # list the vars you want to pass in before run download & process self.name = dataset.name self.text = text self.llm_prediction_topk = 5 self.tokenizer_name = tokenizer_name from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.dir_name = '_'.join(dataset.name.split('-')) self.root = osp.join(root, self.dir_name) missing_str_list = [] if not WITH_PANDAS: missing_str_list.append('pandas') if len(missing_str_list) > 0: missing_str = ' '.join(missing_str_list) error_out = f"`pip install {missing_str}` to use this dataset." raise ImportError(error_out) if hasattr(dataset, 'get_idx_split'): self.split_idx = dataset.get_idx_split() elif split_idx is not None: self.split_idx = split_idx else: raise ValueError("TAGDataset need split idx for generating " "is_gold mask, please pass splited index " "in format of dictionaty with 'train', 'valid' " "'test' index tensor to 'split_idx'") if text_on_disk: if text is not None: self.save_node_text(text) self.text_on_disk = text_on_disk # init will call download and process super().__init__(self.root, transform=None, pre_transform=None, pre_filter=None, force_reload=force_reload) # after processing and download # Dataset has to have BaseData as _data assert dataset._data is not None self._data = dataset._data # reassign reference assert self._data is not None assert dataset._data.y is not None assert isinstance(self._data, BaseData) assert self._data.num_nodes is not None assert isinstance(dataset._data.num_nodes, int) assert isinstance(self._data.num_nodes, int) self._n_id = torch.arange(self._data.num_nodes) is_good_tensor = self.load_gold_mask() self._is_gold = is_good_tensor.squeeze() self._data['is_gold'] = is_good_tensor if self.text is not None and len(self.text) != self._data.num_nodes: raise ValueError("The number of text sequence in 'text' should be " "equal to number of nodes!") self.token_on_disk = token_on_disk self.tokenize_batch_size = tokenize_batch_size self._token = self.tokenize_graph(self.tokenize_batch_size) self._llm_explanation_token: Dict[str, Tensor] = {} self._all_token: Dict[str, Tensor] = {} if self.name in self.llm_explanation_id: self._llm_explanation_token = self.tokenize_graph( self.tokenize_batch_size, text_type='llm_explanation') self._all_token = self.tokenize_graph(self.tokenize_batch_size, text_type='all') self.__num_classes__ = dataset.num_classes @property def num_classes(self) -> int: return self.__num_classes__ @property def raw_file_names(self) -> List[str]: file_names = [] for _, _, files in os.walk(osp.join(self.root, 'raw')): for file in files: file_names.append(file) return file_names @property def processed_file_names(self) -> List[str]: return [ 'geometric_data_processed.pt', 'pre_filter.pt', 'pre_transformed.pt' ] @property def token(self) -> Dict[str, Tensor]: if self._token is None: # lazy load self._token = self.tokenize_graph() return self._token @property def llm_explanation_token(self) -> Dict[str, Tensor]: if self._llm_explanation_token is None and \ self.name in self.llm_explanation_id: self._llm_explanation_token = self.tokenize_graph( text_type='llm_explanation') return self._llm_explanation_token @property def all_token(self) -> Dict[str, Tensor]: if self._all_token is None and \ self.name in self.llm_explanation_id: self._all_token = self.tokenize_graph(text_type='all') return self._all_token # load is_gold after init @property def is_gold(self) -> Tensor: if self._is_gold is None: print('lazy load is_gold!!') self._is_gold = self.load_gold_mask() return self._is_gold def get_n_id(self, node_idx: IndexType) -> Tensor: if self._n_id is None: assert self._data is not None assert self._data.num_nodes is not None assert isinstance(self._data.num_nodes, int) self._n_id = torch.arange(self._data.num_nodes) return self._n_id[node_idx] def load_gold_mask(self) -> Tensor: r"""Use original train split as gold split, generating is_gold mask for picking ground truth labels and pseudo labels. """ train_split_idx = self.get_idx_split()['train'] assert self._data is not None assert self._data.num_nodes is not None assert isinstance(self._data.num_nodes, int) is_good_tensor = torch.zeros(self._data.num_nodes, dtype=torch.bool).view(-1, 1) is_good_tensor[train_split_idx] = True return is_good_tensor def get_gold(self, node_idx: IndexType) -> Tensor: r"""Get gold mask for given node_idx. Args: node_idx (torch.tensor): a tensor contain node idx """ if self._is_gold is None: self._is_gold = self.is_gold return self._is_gold[node_idx] def get_idx_split(self) -> Dict[str, Tensor]: return self.split_idx def download(self) -> None: print('downloading raw text') raw_text_path = download_google_url(id=self.raw_text_id[self.name], folder=f'{self.root}/raw', filename='node-text.csv.gz', log=True) self.text = list(read_csv(raw_text_path)['text']) if self.name in self.llm_explanation_id: print('downloading llm explanations') llm_explanation_path = download_google_url( id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw', filename='node-gpt-response.csv.gz', log=True) self.llm_explanation = list(read_csv(llm_explanation_path)['text']) print('downloading llm predictions') fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir) def process(self) -> None: # process Title and Abstraction if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')): text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz')) self.text = list(text_df['text']) elif self.name in self.raw_text_id: self.download() else: print('The dataset is not ogbn-products nor ogbn-arxiv,' 'please pass in your raw text string list to `text`') if self.text is None: raise ValueError("The TAGDataset only have ogbn-products and " "ogbn-arxiv raw text in default " "The raw text of each node is not specified" "Please pass in 'text' when convert your dataset " "to Text Attribute Graph Dataset") # process LLM explanation and prediction llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz' llm_prediction_path = f'{self.raw_dir}/{self.name}.csv' if osp.exists(llm_explanation_path) and osp.exists( llm_prediction_path): # load LLM explanation self.llm_explanation = list(read_csv(llm_explanation_path)['text']) # load LLM prediction preds = [] with open(llm_prediction_path) as file: reader = csv.reader(file) for row in reader: inner_list = [] for value in row: inner_list.append(int(value)) preds.append(inner_list) pl = torch.zeros(len(preds), self.llm_prediction_topk, dtype=torch.long) for i, pred in enumerate(preds): pl[i][:len(pred)] = torch.tensor( pred[:self.llm_prediction_topk], dtype=torch.long) + 1 if self.llm_explanation is None or pl is None: raise ValueError( "The TAGDataset only have ogbn-arxiv LLM explanations" "and predictions in default. The llm explanation and" "prediction of each node is not specified.Please pass in" "'llm_explanation' and 'llm_prediction' when" "convert your dataset to Text Attribute Graph Dataset") elif self.name in self.llm_explanation_id: self.download() else: print( 'The dataset is not ogbn-arxiv,' 'please pass in your llm explanation list to `llm_explanation`' 'and llm prediction list to `llm_prediction`') def save_node_text(self, text: List[str]) -> None: node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz') if osp.exists(node_text_path): print(f'The raw text is existed at {node_text_path}') else: print(f'Saving raw text file at {node_text_path}') os.makedirs(f'{self.root}/raw', exist_ok=True) text_df = DataFrame(text, columns=['text']) text_df.to_csv(osp.join(node_text_path), compression='gzip', index=False) def tokenize_graph(self, batch_size: int = 256, text_type: str = 'raw_text') -> Dict[str, Tensor]: r"""Tokenizing the text associate with each node, running in cpu. Args: batch_size (Optional[int]): batch size of list of text for generating emebdding text_type (Optional[str]): type of text Returns: Dict[str, torch.Tensor]: tokenized graph """ assert text_type in ['raw_text', 'llm_explanation', 'all'] if text_type == 'raw_text': _text = self.text elif text_type == 'llm_explanation': _text = self.llm_explanation elif text_type == 'all': if self.text is None or self.llm_explanation is None: raise ValueError("The TAGDataset need text and llm explanation" "for tokenizing all text") _text = [ f'{raw_txt} Explanation: {exp_txt}' for raw_txt, exp_txt in zip(self.text, self.llm_explanation) ] data_len = 0 if _text is not None: data_len = len(_text) else: raise ValueError("The TAGDataset need text for tokenization") token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] path = os.path.join(self.processed_dir, 'token', text_type, self.tokenizer_name) # Check if the .pt files already exist token_files_exist = any( os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys) if token_files_exist and self.token_on_disk: print('Found tokenized file, loading may take several minutes...') all_encoded_token = { k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True) for k in token_keys if os.path.exists(os.path.join(path, f'{k}.pt')) } return all_encoded_token all_encoded_token = {k: [] for k in token_keys} pbar = tqdm(total=data_len) pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}') for i in range(0, data_len, batch_size): end_index = min(data_len, i + batch_size) token = self.tokenizer(_text[i:end_index], padding='max_length', truncation=True, max_length=512, return_tensors="pt") for k in token.keys(): all_encoded_token[k].append(token[k]) pbar.update(end_index - i) pbar.close() all_encoded_token = { k: torch.cat(v) for k, v in all_encoded_token.items() if len(v) > 0 } if self.token_on_disk: os.makedirs(path, exist_ok=True) print('Saving tokens on Disk') for k, tensor in all_encoded_token.items(): torch.save(tensor, os.path.join(path, f'{k}.pt')) print('Token saved:', os.path.join(path, f'{k}.pt')) os.environ["TOKENIZERS_PARALLELISM"] = 'true' # suppressing warning return all_encoded_token def __repr__(self) -> str: return f'{self.__class__.__name__}()' class TextDataset(torch.utils.data.Dataset): r"""This nested dataset provides textual data for each node in the graph. Factory method to create TextDataset from TAGDataset. Args: tag_dataset (TAGDataset): the parent dataset text_type (str): type of text """ def __init__(self, tag_dataset: 'TAGDataset', text_type: str = 'raw_text') -> None: assert text_type in ['raw_text', 'llm_explanation', 'all'] self.tag_dataset = tag_dataset if text_type == 'raw_text': self.token = tag_dataset.token elif text_type == 'llm_explanation': self.token = tag_dataset.llm_explanation_token elif text_type == 'all': self.token = tag_dataset.all_token assert tag_dataset._data is not None self._data = tag_dataset._data assert tag_dataset._data.y is not None self.labels = tag_dataset._data.y def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: r"""This function will be called in __getitem__(). Args: node_idx (IndexType): selected node idx in each batch Returns: items (Dict[str, Tensor]): input for LM """ items = {k: v[node_idx] for k, v in self.token.items()} return items # for LM training def __getitem__( self, node_id: IndexType, ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: r"""This function will override the function in torch.utils.data.Dataset, and will be called when you iterate batch in the dataloader, make sure all following key value pairs are present in the return dict. Args: node_id (List[int]): list of node idx for selecting tokens, labels etc. when iterating data loader for LM Returns: items (dict): input k,v pairs for Language model training and inference """ item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {} item['input'] = self.get_token(node_id) item['labels'] = self.labels[node_id] item['is_gold'] = self.tag_dataset.get_gold(node_id) item['n_id'] = self.tag_dataset.get_n_id(node_id) return item def __len__(self) -> int: assert self._data.num_nodes is not None return self._data.num_nodes def get(self, idx: int) -> BaseData: return self._data def __repr__(self) -> str: return f'{self.__class__.__name__}()' def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset: r"""Factory Build text dataset from Text Attributed Graph Dataset each data point is node's associated text token. """ return TAGDataset.TextDataset(self, text_type) ================================================ FILE: torch_geometric/datasets/taobao.py ================================================ import os from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import ( HeteroData, InMemoryDataset, download_url, extract_zip, ) class Taobao(InMemoryDataset): r"""Taobao is a dataset of user behaviors from Taobao offered by Alibaba, provided by the `Tianchi Alicloud platform `_. Taobao is a heterogeneous graph for recommendation. Nodes represent users with user IDs, items with item IDs, and categories with category ID. Edges between users and items represent different types of user behaviors towards items (alongside with timestamps). Edges between items and categories assign each item to its set of categories. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://alicloud-dev.oss-cn-hangzhou.aliyuncs.com/' 'UserBehavior.csv.zip') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> str: return 'UserBehavior.csv' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import pandas as pd cols = ['userId', 'itemId', 'categoryId', 'behaviorType', 'timestamp'] df = pd.read_csv(self.raw_paths[0], names=cols) # Time representation (YYYY.MM.DD-HH:MM:SS -> Integer) # start: 1511539200 = 2017.11.25-00:00:00 # end: 1512316799 = 2017.12.03-23:59:59 start = 1511539200 end = 1512316799 df = df[(df["timestamp"] >= start) & (df["timestamp"] <= end)] df = df.drop_duplicates() behavior_dict = {'pv': 0, 'cart': 1, 'buy': 2, 'fav': 3} df['behaviorType'] = df['behaviorType'].map(behavior_dict) num_entries = {} for name in ['userId', 'itemId', 'categoryId']: # Map IDs to consecutive integers: value, df[name] = np.unique(df[[name]].values, return_inverse=True) num_entries[name] = value.shape[0] data = HeteroData() data['user'].num_nodes = num_entries['userId'] data['item'].num_nodes = num_entries['itemId'] data['category'].num_nodes = num_entries['categoryId'] row = torch.from_numpy(df['userId'].values) col = torch.from_numpy(df['itemId'].values) data['user', 'item'].edge_index = torch.stack([row, col], dim=0) data['user', 'item'].time = torch.from_numpy(df['timestamp'].values) behavior = torch.from_numpy(df['behaviorType'].values) data['user', 'item'].behavior = behavior df = df[['itemId', 'categoryId']].drop_duplicates() row = torch.from_numpy(df['itemId'].values) col = torch.from_numpy(df['categoryId'].values) data['item', 'category'].edge_index = torch.stack([row, col], dim=0) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/teeth3ds.py ================================================ import json import os import os.path as osp from glob import glob from typing import Callable, Dict, List, Optional import numpy as np import torch from tqdm import tqdm from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) class Teeth3DS(InMemoryDataset): r"""The Teeth3DS+ dataset from the `"An Extended Benchmark for Intra-oral 3D Scans Analysis" `_ paper. This dataset is the first comprehensive public benchmark designed to advance the field of intra-oral 3D scan analysis developed as part of the 3DTeethSeg 2022 and 3DTeethLand 2024 MICCAI challenges, aiming to drive research in teeth identification, segmentation, labeling, 3D modeling, and dental landmark identification. The dataset includes at least 1,800 intra-oral scans (containing 23,999 annotated teeth) collected from 900 patients, covering both upper and lower jaws separately. Args: root (str): Root directory where the dataset should be saved. split (str): The split name (one of :obj:`"Teeth3DS"`, :obj:`"3DTeethSeg22_challenge"` or :obj:`"3DTeethLand_challenge"`). train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) num_samples (int, optional): Number of points to sample from each mesh. (default: :obj:`30000`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ urls = { 'data_part_1.zip': 'https://osf.io/download/qhprs/', 'data_part_2.zip': 'https://osf.io/download/4pwnr/', 'data_part_3.zip': 'https://osf.io/download/frwdp/', 'data_part_4.zip': 'https://osf.io/download/2arn4/', 'data_part_5.zip': 'https://osf.io/download/xrz5f/', 'data_part_6.zip': 'https://osf.io/download/23hgq/', 'data_part_7.zip': 'https://osf.io/download/u83ad/', 'train_test_split': 'https://files.de-1.osf.io/v1/' 'resources/xctdy/providers/osfstorage/?zip=' } sample_url = { 'teeth3ds_sample': 'https://osf.io/download/vr38s/', } landmarks_urls = { '3DTeethLand_landmarks_train.zip': 'https://osf.io/download/k5hbj/', '3DTeethLand_landmarks_test.zip': 'https://osf.io/download/sqw5e/', } def __init__( self, root: str, split: str = 'Teeth3DS', # [3DTeethSeg22_challenge, 3DTeethLand_challenge] train: bool = True, num_samples: int = 30000, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.mode = 'training' if train else 'testing' self.split = split self.num_samples = num_samples super().__init__(root, transform, pre_transform, force_reload=force_reload) @property def processed_dir(self) -> str: return os.path.join(self.root, f'processed_{self.split}_{self.mode}') @property def raw_file_names(self) -> List[str]: return ['license.txt'] @property def processed_file_names(self) -> List[str]: # Directory containing train/test split files: split_subdir = 'teeth3ds_sample' if self.split == 'sample' else '' split_dir = osp.join( self.raw_dir, split_subdir, f'{self.split}_train_test_split', ) split_files = glob(osp.join(split_dir, f'{self.mode}*.txt')) # Collect all file names from the split files: combined_list = [] for file_path in split_files: with open(file_path) as file: combined_list.extend(file.read().splitlines()) # Generate the list of processed file paths: return [f'{file_name}.pt' for file_name in combined_list] def download(self) -> None: if self.split == 'sample': for key, url in self.sample_url.items(): path = download_url(url, self.root, filename=key) extract_zip(path, self.raw_dir) os.unlink(path) else: for key, url in self.urls.items(): path = download_url(url, self.root, filename=key) extract_zip(path, self.raw_dir) os.unlink(path) for key, url in self.landmarks_urls.items(): path = download_url(url, self.root, filename=key) extract_zip(path, self.raw_dir) # Extract each downloaded part os.unlink(path) def process_file(self, file_path: str) -> Optional[Data]: """Processes the input file path to load mesh data, annotations, and prepare the input features for a graph-based deep learning model. """ import trimesh from fpsample import bucket_fps_kdline_sampling mesh = trimesh.load_mesh(file_path) if isinstance(mesh, list): # Handle the case where a list of Geometry objects is returned mesh = mesh[0] vertices = mesh.vertices vertex_normals = mesh.vertex_normals # Perform sampling on mesh vertices: if len(vertices) < self.num_samples: sampled_indices = np.random.choice( len(vertices), self.num_samples, replace=True, ) else: sampled_indices = bucket_fps_kdline_sampling( vertices, self.num_samples, h=5, start_idx=0, ) if len(sampled_indices) != self.num_samples: raise RuntimeError(f"Sampled points mismatch, expected " f"{self.num_samples} points, but got " f"{len(sampled_indices)} for '{file_path}'") # Extract features and annotations for the sampled points: pos = torch.tensor(vertices[sampled_indices], dtype=torch.float) x = torch.tensor(vertex_normals[sampled_indices], dtype=torch.float) # Load segmentation annotations: seg_annotation_path = file_path.replace('.obj', '.json') if osp.exists(seg_annotation_path): with open(seg_annotation_path) as f: seg_annotations = json.load(f) y = torch.tensor( np.asarray(seg_annotations['labels'])[sampled_indices], dtype=torch.float) instances = torch.tensor( np.asarray(seg_annotations['instances'])[sampled_indices], dtype=torch.float) else: y = torch.empty(0, 3) instances = torch.empty(0, 3) # Load landmarks annotations: landmarks_annotation_path = file_path.replace('.obj', '__kpt.json') # Parse keypoint annotations into structured tensors: keypoints_dict: Dict[str, List] = { key: [] for key in [ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint', 'FacialPoint' ] } keypoint_tensors: Dict[str, torch.Tensor] = { key: torch.empty(0, 3) for key in [ 'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint', 'FacialPoint' ] } if osp.exists(landmarks_annotation_path): with open(landmarks_annotation_path) as f: landmarks_annotations = json.load(f) for keypoint in landmarks_annotations['objects']: keypoints_dict[keypoint['class']].extend(keypoint['coord']) keypoint_tensors = { k: torch.tensor(np.asarray(v), dtype=torch.float).reshape(-1, 3) for k, v in keypoints_dict.items() } data = Data( pos=pos, x=x, y=y, instances=instances, jaw=file_path.split('.obj')[0].split('_')[1], mesial=keypoint_tensors['Mesial'], distal=keypoint_tensors['Distal'], cusp=keypoint_tensors['Cusp'], inner_point=keypoint_tensors['InnerPoint'], outer_point=keypoint_tensors['OuterPoint'], facial_point=keypoint_tensors['FacialPoint'], ) if self.pre_transform is not None: data = self.pre_transform(data) return data def process(self) -> None: for file in tqdm(self.processed_file_names): name = file.split('.')[0] path = osp.join(self.raw_dir, '**', '*', name + '.obj') paths = glob(path) if len(paths) == 1: data = self.process_file(paths[0]) torch.save(data, osp.join(self.processed_dir, file)) def len(self) -> int: return len(self.processed_file_names) def get(self, idx: int) -> Data: return torch.load( osp.join(self.processed_dir, self.processed_file_names[idx]), weights_only=False, ) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'mode={self.mode}, split={self.split})') ================================================ FILE: torch_geometric/datasets/tosca.py ================================================ import glob import os import os.path as osp from typing import Callable, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import read_txt_array class TOSCA(InMemoryDataset): r"""The TOSCA dataset from the `"Numerical Geometry of Non-Ridig Shapes" `_ book, containing 80 meshes. Meshes within the same category have the same triangulation and an equal number of vertices numbered in a compatible way. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. categories (list, optional): List of categories to include in the dataset. Can include the categories :obj:`"Cat"`, :obj:`"Centaur"`, :obj:`"David"`, :obj:`"Dog"`, :obj:`"Gorilla"`, :obj:`"Horse"`, :obj:`"Michael"`, :obj:`"Victoria"`, :obj:`"Wolf"`. If set to :obj:`None`, the dataset will contain all categories. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'http://tosca.cs.technion.ac.il/data/toscahires-asci.zip' categories = [ 'cat', 'centaur', 'david', 'dog', 'gorilla', 'horse', 'michael', 'victoria', 'wolf' ] def __init__( self, root: str, categories: Optional[List[str]] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: categories = self.categories if categories is None else categories categories = [cat.lower() for cat in categories] for cat in categories: assert cat in self.categories self.categories = categories super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['cat0.vert', 'cat0.tri'] @property def processed_file_names(self) -> str: name = '_'.join([cat[:2] for cat in self.categories]) return f'{name}.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self) -> None: data_list = [] for cat in self.categories: paths = glob.glob(osp.join(self.raw_dir, f'{cat}*.tri')) paths = [path[:-4] for path in paths] paths = sorted(paths, key=lambda e: (len(e), e)) for path in paths: pos = read_txt_array(f'{path}.vert') face = read_txt_array(f'{path}.tri', dtype=torch.long) face = face - face.min() # Ensure zero-based index. data = Data(pos=pos, face=face.t().contiguous()) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/tu_dataset.py ================================================ import os.path as osp from typing import Callable, List, Optional from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs, read_tu_data class TUDataset(InMemoryDataset): r"""A variety of graph kernel benchmark datasets, *.e.g.*, :obj:`"IMDB-BINARY"`, :obj:`"REDDIT-BINARY"` or :obj:`"PROTEINS"`, collected from the `TU Dortmund University `_. In addition, this dataset wrapper provides `cleaned dataset versions `_ as motivated by the `"Understanding Isomorphism Bias in Graph Data Sets" `_ paper, containing only non-isomorphic graphs. .. note:: Some datasets may not come with any node labels. You can then either make use of the argument :obj:`use_node_attr` to load additional continuous node attributes (if present) or provide synthetic node features using transforms such as :class:`torch_geometric.transforms.Constant` or :class:`torch_geometric.transforms.OneHotDegree`. Args: root (str): Root directory where the dataset should be saved. name (str): The `name `_ of the dataset. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) use_node_attr (bool, optional): If :obj:`True`, the dataset will contain additional continuous node attributes (if present). (default: :obj:`False`) use_edge_attr (bool, optional): If :obj:`True`, the dataset will contain additional continuous edge attributes (if present). (default: :obj:`False`) cleaned (bool, optional): If :obj:`True`, the dataset will contain only non-isomorphic graphs. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - MUTAG - 188 - ~17.9 - ~39.6 - 7 - 2 * - ENZYMES - 600 - ~32.6 - ~124.3 - 3 - 6 * - PROTEINS - 1,113 - ~39.1 - ~145.6 - 3 - 2 * - COLLAB - 5,000 - ~74.5 - ~4914.4 - 0 - 3 * - IMDB-BINARY - 1,000 - ~19.8 - ~193.1 - 0 - 2 * - REDDIT-BINARY - 2,000 - ~429.6 - ~995.5 - 0 - 2 * - ... - - - - - """ url = 'https://www.chrsmrrs.com/graphkerneldatasets' cleaned_url = ('https://raw.githubusercontent.com/nd7141/' 'graph_datasets/master/datasets') def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, use_node_attr: bool = False, use_edge_attr: bool = False, cleaned: bool = False, ) -> None: self.name = name self.cleaned = cleaned super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) out = fs.torch_load(self.processed_paths[0]) if not isinstance(out, tuple) or len(out) < 3: raise RuntimeError( "The 'data' object was created by an older version of PyG. " "If this error occurred while loading an already existing " "dataset, remove the 'processed/' directory in the dataset's " "root folder and try again.") assert len(out) == 3 or len(out) == 4 if len(out) == 3: # Backward compatibility. data, self.slices, self.sizes = out data_cls = Data else: data, self.slices, self.sizes, data_cls = out if not isinstance(data, dict): # Backward compatibility. self.data = data else: self.data = data_cls.from_dict(data) assert isinstance(self._data, Data) if self._data.x is not None and not use_node_attr: num_node_attributes = self.num_node_attributes self._data.x = self._data.x[:, num_node_attributes:] if self._data.edge_attr is not None and not use_edge_attr: num_edge_attrs = self.num_edge_attributes self._data.edge_attr = self._data.edge_attr[:, num_edge_attrs:] @property def raw_dir(self) -> str: name = f'raw{"_cleaned" if self.cleaned else ""}' return osp.join(self.root, self.name, name) @property def processed_dir(self) -> str: name = f'processed{"_cleaned" if self.cleaned else ""}' return osp.join(self.root, self.name, name) @property def num_node_labels(self) -> int: return self.sizes['num_node_labels'] @property def num_node_attributes(self) -> int: return self.sizes['num_node_attributes'] @property def num_edge_labels(self) -> int: return self.sizes['num_edge_labels'] @property def num_edge_attributes(self) -> int: return self.sizes['num_edge_attributes'] @property def raw_file_names(self) -> List[str]: names = ['A', 'graph_indicator'] return [f'{self.name}_{name}.txt' for name in names] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: url = self.cleaned_url if self.cleaned else self.url fs.cp(f'{url}/{self.name}.zip', self.raw_dir, extract=True) for filename in fs.ls(osp.join(self.raw_dir, self.name)): fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename))) fs.rm(osp.join(self.raw_dir, self.name)) def process(self) -> None: self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name) if self.pre_filter is not None or self.pre_transform is not None: data_list = [self.get(idx) for idx in range(len(self))] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.data, self.slices = self.collate(data_list) self._data_list = None # Reset cache. assert isinstance(self._data, Data) fs.torch_save( (self._data.to_dict(), self.slices, sizes, self._data.__class__), self.processed_paths[0], ) def __repr__(self) -> str: return f'{self.name}({len(self)})' ================================================ FILE: torch_geometric/datasets/twitch.py ================================================ import os.path as osp from typing import Callable, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url class Twitch(InMemoryDataset): r"""The Twitch Gamer networks introduced in the `"Multi-scale Attributed Node Embedding" `_ paper. Nodes represent gamers on Twitch and edges are followerships between them. Node features represent embeddings of games played by the Twitch users. The task is to predict whether a user streams mature content. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"DE"`, :obj:`"EN"`, :obj:`"ES"`, :obj:`"FR"`, :obj:`"PT"`, :obj:`"RU"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - DE - 9,498 - 315,774 - 128 - 2 * - EN - 7,126 - 77,774 - 128 - 2 * - ES - 4,648 - 123,412 - 128 - 2 * - FR - 6,551 - 231,883 - 128 - 2 * - PT - 1,912 - 64,510 - 128 - 2 * - RU - 4,385 - 78,993 - 128 - 2 """ url = 'https://graphmining.ai/datasets/ptg/twitch' def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name assert self.name in ['DE', 'EN', 'ES', 'FR', 'PT', 'RU'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.name}.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(f'{self.url}/{self.name}.npz', self.raw_dir) def process(self) -> None: data = np.load(self.raw_paths[0], 'r', allow_pickle=True) x = torch.from_numpy(data['features']).to(torch.float) y = torch.from_numpy(data['target']).to(torch.long) edge_index = torch.from_numpy(data['edges']).to(torch.long) edge_index = edge_index.t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/upfd.py ================================================ import os import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import ( Data, InMemoryDataset, download_google_url, extract_zip, ) from torch_geometric.io import read_txt_array from torch_geometric.utils import coalesce, cumsum class UPFD(InMemoryDataset): r"""The tree-structured fake news propagation graph classification dataset from the `"User Preference-aware Fake News Detection" `_ paper. It includes two sets of tree-structured fake & real news propagation graphs extracted from Twitter. For a single graph, the root node represents the source news, and leaf nodes represent Twitter users who retweeted the same root news. A user node has an edge to the news node if and only if the user retweeted the root news directly. Two user nodes have an edge if and only if one user retweeted the root news from the other user. Four different node features are encoded using different encoders. Please refer to `GNN-FakeNews `_ repo for more details. .. note:: For an example of using UPFD, see `examples/upfd.py `_. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the graph set (:obj:`"politifact"`, :obj:`"gossipcop"`). feature (str): The node feature type (:obj:`"profile"`, :obj:`"spacy"`, :obj:`"bert"`, :obj:`"content"`). If set to :obj:`"profile"`, the 10-dimensional node feature is composed of ten Twitter user profile attributes. If set to :obj:`"spacy"`, the 300-dimensional node feature is composed of Twitter user historical tweets encoded by the `spaCy word2vec encoder `_. If set to :obj:`"bert"`, the 768-dimensional node feature is composed of Twitter user historical tweets encoded by the `bert-as-service `_. If set to :obj:`"content"`, the 310-dimensional node feature is composed of a 300-dimensional "spacy" vector plus a 10-dimensional "profile" vector. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ file_ids = { 'politifact': '1toou2GO0agoY_OS54LaCWEECQfe93nuq', 'gossipcop': '1DkMAzC7XUUciAxsSujRJt3sq1MqaVI3g', } def __init__( self, root: str, name: str, feature: str, split: str = "train", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: assert name in ['politifact', 'gossipcop'] assert split in ['train', 'val', 'test'] self.root = root self.name = name self.feature = feature super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[['train', 'val', 'test'].index(split)] self.load(path) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed', self.feature) @property def raw_file_names(self) -> List[str]: return [ 'node_graph_id.npy', 'graph_labels.npy', 'A.txt', 'train_idx.npy', 'val_idx.npy', 'test_idx.npy', f'new_{self.feature}_feature.npz' ] @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: id = self.file_ids[self.name] path = download_google_url(id, self.raw_dir, 'data.zip') extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import scipy.sparse as sp x = sp.load_npz( osp.join(self.raw_dir, f'new_{self.feature}_feature.npz')) x = torch.from_numpy(x.todense()).to(torch.float) edge_index = read_txt_array(osp.join(self.raw_dir, 'A.txt'), sep=',', dtype=torch.long).t() edge_index = coalesce(edge_index, num_nodes=x.size(0)) y = np.load(osp.join(self.raw_dir, 'graph_labels.npy')) y = torch.from_numpy(y).to(torch.long) _, y = y.unique(sorted=True, return_inverse=True) batch = np.load(osp.join(self.raw_dir, 'node_graph_id.npy')) batch = torch.from_numpy(batch).to(torch.long) node_slice = cumsum(batch.bincount()) edge_slice = cumsum(batch[edge_index[0]].bincount()) graph_slice = torch.arange(y.size(0) + 1) self.slices = { 'x': node_slice, 'edge_index': edge_slice, 'y': graph_slice } edge_index -= node_slice[batch[edge_index[0]]].view(1, -1) self.data = Data(x=x, edge_index=edge_index, y=y) for path, split in zip(self.processed_paths, ['train', 'val', 'test']): idx = np.load(osp.join(self.raw_dir, f'{split}_idx.npy')).tolist() data_list = [self.get(i) for i in idx] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, path) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, name={self.name}, ' f'feature={self.feature})') ================================================ FILE: torch_geometric/datasets/utils/__init__.py ================================================ from .cheatsheet import paper_link, has_stats, get_stat, get_children, get_type __all__ = [ 'paper_link', 'has_stats', 'get_stat', 'get_children', 'get_type', ] ================================================ FILE: torch_geometric/datasets/utils/cheatsheet.py ================================================ import importlib import inspect import re from typing import Any, List, Optional def paper_link(cls: str) -> Optional[str]: cls = importlib.import_module('torch_geometric.datasets').__dict__[cls] doc = inspect.getdoc(cls) assert doc is not None match = re.search('<.+?>', doc, flags=re.DOTALL) return None if match is None else match.group().replace('\n', ' ')[1:-1] def get_stats_table(cls: str) -> str: cls = importlib.import_module('torch_geometric.datasets').__dict__[cls] doc = inspect.getdoc(cls) assert doc is not None match = re.search(r'\*\*STATS:\*\*\n.*$', doc, flags=re.DOTALL) return '' if match is None else match.group() def has_stats(cls: str) -> bool: return len(get_stats_table(cls)) > 0 def get_type(cls: str) -> str: return 'Edge' if '-' in cls else 'Node' def get_stat(cls: str, name: str, child: Optional[str] = None, default: Any = None) -> str: if child is None and len(get_children(cls)) > 0: return '' stats_table = get_stats_table(cls) if len(stats_table) > 0: stats_table = '\n'.join(stats_table.split('\n')[2:]) match = re.search(f'^.*- {name}', stats_table, flags=re.DOTALL) if match is None: return default column = match.group().count(' -') if child is not None: child = child.replace('(', r'\(').replace(')', r'\)') match = re.search(f'[*] - {child}\n.*$', stats_table, flags=re.DOTALL) assert match is not None stats_row = match.group() else: stats_row = '*' + stats_table.split('*')[2] return stats_row.split(' -')[column].split('\n')[0].strip() def get_children(cls: str) -> List[str]: matches = re.findall('[*] -.*', get_stats_table(cls)) return [match[4:] for match in matches[1:]] if len(matches) > 2 else [] ================================================ FILE: torch_geometric/datasets/web_qsp_dataset.py ================================================ # Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630 import gc import os from itertools import chain from typing import Any, Dict, Iterator, List, Optional import torch from tqdm import tqdm from torch_geometric.data import InMemoryDataset from torch_geometric.llm.large_graph_indexer import ( EDGE_RELATION, LargeGraphIndexer, TripletLike, get_features_for_triplets_groups, ) from torch_geometric.llm.models import SentenceTransformer from torch_geometric.llm.utils.backend_utils import ( preprocess_triplet, retrieval_via_pcst, ) class KGQABaseDataset(InMemoryDataset): r"""Base class for the 2 KGQA datasets used in `"Reasoning on Graphs: Faithful and Interpretable Large Language Model Reasoning" `_ paper. Args: dataset_name (str): HuggingFace `dataset` name. root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) verbose (bool, optional): Whether to print output. Defaults to False. use_pcst (bool, optional): Whether to preprocess the dataset's graph with PCST or return the full graphs. (default: :obj:`True`) load_dataset_kwargs (dict, optional): Keyword arguments for the `datasets.load_dataset` function. (default: :obj:`{}`) retrieval_kwargs (dict, optional): Keyword arguments for the `get_features_for_triplets_groups` function. (default: :obj:`{}`) """ def __init__( self, dataset_name: str, root: str, split: str = "train", force_reload: bool = False, verbose: bool = False, use_pcst: bool = True, load_dataset_kwargs: Optional[Dict[str, Any]] = None, retrieval_kwargs: Optional[Dict[str, Any]] = None, ) -> None: self.split = split self.dataset_name = dataset_name self.use_pcst = use_pcst self.load_dataset_kwargs = load_dataset_kwargs or {} """ NOTE: If running into memory issues, try reducing this batch size for the LargeGraphIndexer used to build our KG. Example: self.retrieval_kwargs = {"batch_size": 64} """ self.retrieval_kwargs = retrieval_kwargs or {} # Caching custom subsets of the dataset results in unsupported behavior if 'split' in self.load_dataset_kwargs: print("WARNING: Caching custom subsets of the dataset \ results in unsupported behavior.\ Please specify a separate root directory for each split,\ or set force_reload=True on subsequent instantiations\ of the dataset.") self.required_splits = ['train', 'validation', 'test'] self.verbose = verbose self.force_reload = force_reload super().__init__(root, force_reload=force_reload) """ NOTE: Current behavior is to process the entire dataset, and only return the split specified by the user. """ if f'{split}_data.pt' not in set(self.processed_file_names): raise ValueError(f"Invalid 'split' argument (got {split})") if split == 'val': split = 'validation' self.load(self.processed_paths[self.required_splits.index(split)]) @property def raw_file_names(self) -> List[str]: return ["raw.pt"] @property def processed_file_names(self) -> List[str]: return ["train_data.pt", "val_data.pt", "test_data.pt"] def download(self) -> None: import datasets # HF Load Dataset by dataset name if no path is specified self.load_dataset_kwargs['path'] = self.load_dataset_kwargs.get( 'path', self.dataset_name) raw_dataset = datasets.load_dataset(**self.load_dataset_kwargs) # Assert that the dataset contains the required splits assert all(split in raw_dataset for split in self.required_splits), \ f"Dataset '{self.dataset_name}' is missing required splits: \ {self.required_splits}" raw_dataset.save_to_disk(self.raw_paths[0]) def _get_trips(self) -> Iterator[TripletLike]: # Iterate over each element's graph in each split of the dataset # Using chain to lazily iterate without storing all trips in memory split_iterators = [] for split in self.required_splits: # Create an iterator for each element's graph in the current split split_graphs = (element['graph'] for element in self.raw_dataset[split]) split_iterators.append(chain.from_iterable(split_graphs)) # Chain all split iterators together return chain.from_iterable(split_iterators) def _build_graph(self) -> None: print("Encoding graph...") trips = self._get_trips() self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets( trips, pre_transform=preprocess_triplet) # Nodes: print("\tEncoding nodes...") nodes = self.indexer.get_unique_node_features() x = self.model.encode(nodes, batch_size=256, output_device='cpu') self.indexer.add_node_feature(new_feature_name="x", new_feature_vals=x) # Edges: print("\tEncoding edges...") edges = self.indexer.get_unique_edge_features( feature_name=EDGE_RELATION) edge_attr = self.model.encode(edges, batch_size=256, output_device='cpu') self.indexer.add_edge_feature( new_feature_name="edge_attr", new_feature_vals=edge_attr, map_from_feature=EDGE_RELATION, ) print("\tSaving graph...") self.indexer.save(self.indexer_path) def _retrieve_subgraphs(self) -> None: raw_splits = [ self.raw_dataset[split] for split in self.required_splits ] zipped = zip( self.required_splits, raw_splits, # noqa self.processed_paths, ) for split_name, dataset, path in zipped: print(f"Processing {split_name} split...") print("\tEncoding questions...") split_questions = [str(element['question']) for element in dataset] split_q_embs = self.model.encode(split_questions, batch_size=256, output_device='cpu') print("\tRetrieving subgraphs...") results_graphs = [] retrieval_kwargs = { **self.retrieval_kwargs, **{ 'pre_transform': preprocess_triplet, 'verbose': self.verbose, } } graph_gen = get_features_for_triplets_groups( self.indexer, (element['graph'] for element in dataset), **retrieval_kwargs) for index in tqdm(range(len(dataset)), disable=not self.verbose): data_i = dataset[index] graph = next(graph_gen) textual_nodes = self.textual_nodes.iloc[ graph["node_idx"]].reset_index() textual_edges = self.textual_edges.iloc[ graph["edge_idx"]].reset_index() if self.use_pcst and len(textual_nodes) > 0 and len( textual_edges) > 0: subgraph, desc = retrieval_via_pcst( graph, split_q_embs[index], textual_nodes, textual_edges, ) else: desc = textual_nodes.to_csv( index=False) + "\n" + textual_edges.to_csv( index=False, columns=["src", "edge_attr", "dst"], ) subgraph = graph question = f"Question: {data_i['question']}\nAnswer: " label = ("|").join(data_i["answer"]).lower() subgraph["question"] = question subgraph["label"] = label subgraph["desc"] = desc results_graphs.append(subgraph.to("cpu")) print("\tSaving subgraphs...") self.save(results_graphs, path) def process(self) -> None: import datasets from pandas import DataFrame self.raw_dataset = datasets.load_from_disk(self.raw_paths[0]) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_name = 'sentence-transformers/all-roberta-large-v1' self.model: SentenceTransformer = SentenceTransformer(model_name).to( device) self.model.eval() self.indexer_path = os.path.join(self.processed_dir, "large_graph_indexer") if self.force_reload or not os.path.exists(self.indexer_path): self._build_graph() else: print("Loading graph...") self.indexer = LargeGraphIndexer.from_disk(self.indexer_path) self.textual_nodes = DataFrame.from_dict( {"node_attr": self.indexer.get_node_features()}) self.textual_nodes["node_id"] = self.textual_nodes.index self.textual_nodes = self.textual_nodes[["node_id", "node_attr"]] self.textual_edges = DataFrame(self.indexer.get_edge_features(), columns=["src", "edge_attr", "dst"]) self.textual_edges["src"] = [ self.indexer._nodes[h] for h in self.textual_edges["src"] ] self.textual_edges["dst"] = [ self.indexer._nodes[h] for h in self.textual_edges["dst"] ] self._retrieve_subgraphs() gc.collect() torch.cuda.empty_cache() class WebQSPDataset(KGQABaseDataset): r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse Labeling for Knowledge Base Question Answering" `_ paper. Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) verbose (bool, optional): Whether to print output. Defaults to False. use_pcst (bool, optional): Whether to preprocess the dataset's graph with PCST or return the full graphs. (default: :obj:`True`) load_dataset_kwargs (dict, optional): Keyword arguments for the `datasets.load_dataset` function. (default: :obj:`{}`) retrieval_kwargs (dict, optional): Keyword arguments for the `get_features_for_triplets_groups` function. (default: :obj:`{}`) """ def __init__( self, root: str, split: str = "train", force_reload: bool = False, verbose: bool = False, use_pcst: bool = True, load_dataset_kwargs: Optional[Dict[str, Any]] = None, retrieval_kwargs: Optional[Dict[str, Any]] = None, ) -> None: load_dataset_kwargs = load_dataset_kwargs or {} retrieval_kwargs = retrieval_kwargs or {} # Modify these paramters if running into memory/compute issues default_retrieval_kwargs = { 'max_batch_size': 250, # Lower batch size to reduce memory usage 'num_workers': None, # Use all available workers, or set to number of threads } retrieval_kwargs = {**default_retrieval_kwargs, **retrieval_kwargs} dataset_name = 'rmanluo/RoG-webqsp' super().__init__(dataset_name, root, split, force_reload, verbose, use_pcst, load_dataset_kwargs=load_dataset_kwargs, retrieval_kwargs=retrieval_kwargs) class CWQDataset(KGQABaseDataset): r"""The ComplexWebQuestions (CWQ) dataset of the `"The Web as a Knowledge-base forAnswering Complex Questions" `_ paper. Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) verbose (bool, optional): Whether to print output. Defaults to False. use_pcst (bool, optional): Whether to preprocess the dataset's graph with PCST or return the full graphs. (default: :obj:`True`) load_dataset_kwargs (dict, optional): Keyword arguments for the `datasets.load_dataset` function. (default: :obj:`{}`) retrieval_kwargs (dict, optional): Keyword arguments for the `get_features_for_triplets_groups` function. (default: :obj:`{}`) """ def __init__( self, root: str, split: str = "train", force_reload: bool = False, verbose: bool = False, use_pcst: bool = True, load_dataset_kwargs: Optional[Dict[str, Any]] = None, retrieval_kwargs: Optional[Dict[str, Any]] = None, ) -> None: load_dataset_kwargs = load_dataset_kwargs or {} retrieval_kwargs = retrieval_kwargs or {} dataset_name = 'rmanluo/RoG-cwq' super().__init__(dataset_name, root, split, force_reload, verbose, use_pcst, load_dataset_kwargs=load_dataset_kwargs, retrieval_kwargs=retrieval_kwargs) ================================================ FILE: torch_geometric/datasets/webkb.py ================================================ import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.utils import coalesce class WebKB(InMemoryDataset): r"""The WebKB datasets used in the `"Geom-GCN: Geometric Graph Convolutional Networks" `_ paper. Nodes represent web pages and edges represent hyperlinks between them. Node features are the bag-of-words representation of web pages. The task is to classify the nodes into one of the five categories, student, project, course, staff, and faculty. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"Cornell"`, :obj:`"Texas"`, :obj:`"Wisconsin"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Name - #nodes - #edges - #features - #classes * - Cornell - 183 - 298 - 1,703 - 5 * - Texas - 183 - 325 - 1,703 - 5 * - Wisconsin - 251 - 515 - 1,703 - 5 """ url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() assert self.name in ['cornell', 'texas', 'wisconsin'] super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: out = ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] out += [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)] return out @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for f in self.raw_file_names[:2]: download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir) for f in self.raw_file_names[2:]: download_url(f'{self.url}/splits/{f}', self.raw_dir) def process(self) -> None: with open(self.raw_paths[0]) as f: lines = f.read().split('\n')[1:-1] xs = [[float(value) for value in line.split('\t')[1].split(',')] for line in lines] x = torch.tensor(xs, dtype=torch.float) ys = [int(line.split('\t')[2]) for line in lines] y = torch.tensor(ys, dtype=torch.long) with open(self.raw_paths[1]) as f: lines = f.read().split('\n')[1:-1] edge_indices = [[int(value) for value in line.split('\t')] for line in lines] edge_index = torch.tensor(edge_indices).t().contiguous() edge_index = coalesce(edge_index, num_nodes=x.size(0)) train_masks, val_masks, test_masks = [], [], [] for path in self.raw_paths[2:]: tmp = np.load(path) train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)] val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)] test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)] train_mask = torch.stack(train_masks, dim=1) val_mask = torch.stack(val_masks, dim=1) test_mask = torch.stack(test_masks, dim=1) data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name}()' ================================================ FILE: torch_geometric/datasets/wikics.py ================================================ import json import warnings from itertools import chain from typing import Callable, List, Optional import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.utils import to_undirected class WikiCS(InMemoryDataset): r"""The semi-supervised Wikipedia-based dataset from the `"Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks" `_ paper, containing 11,701 nodes, 216,123 edges, 10 classes and 20 different training splits. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) is_undirected (bool, optional): Whether the graph is undirected. (default: :obj:`True`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://github.com/pmernyei/wiki-cs-dataset/raw/master/dataset' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, is_undirected: Optional[bool] = None, force_reload: bool = False, ) -> None: if is_undirected is None: warnings.warn( f"The {self.__class__.__name__} dataset now returns an " f"undirected graph by default. Please explicitly specify " f"'is_undirected=False' to restore the old behavior.", stacklevel=2) is_undirected = True self.is_undirected = is_undirected super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['data.json'] @property def processed_file_names(self) -> str: return 'data_undirected.pt' if self.is_undirected else 'data.pt' def download(self) -> None: for name in self.raw_file_names: download_url(f'{self.url}/{name}', self.raw_dir) def process(self) -> None: with open(self.raw_paths[0]) as f: data = json.load(f) x = torch.tensor(data['features'], dtype=torch.float) y = torch.tensor(data['labels'], dtype=torch.long) edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])] edges = list(chain(*edges)) # type: ignore edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() if self.is_undirected: edge_index = to_undirected(edge_index, num_nodes=x.size(0)) train_mask = torch.tensor(data['train_masks'], dtype=torch.bool) train_mask = train_mask.t().contiguous() val_mask = torch.tensor(data['val_masks'], dtype=torch.bool) val_mask = val_mask.t().contiguous() test_mask = torch.tensor(data['test_mask'], dtype=torch.bool) stopping_mask = torch.tensor(data['stopping_masks'], dtype=torch.bool) stopping_mask = stopping_mask.t().contiguous() data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, stopping_mask=stopping_mask) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/wikidata.py ================================================ import os import os.path as osp from typing import Callable, Dict, List, Optional import torch from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_tar, ) from torch_geometric.io import fs class Wikidata5M(InMemoryDataset): r"""The Wikidata-5M dataset from the `"KEPLER: A Unified Model for Knowledge Embedding and Pre-trained Language Representation" `_ paper, containing 4,594,485 entities, 822 relations, 20,614,279 train triples, 5,163 validation triples, and 5,133 test triples. `Wikidata-5M `_ is a large-scale knowledge graph dataset with aligned corpus extracted form Wikidata. Args: root (str): Root directory where the dataset should be saved. setting (str, optional): If :obj:`"transductive"`, loads the transductive dataset. If :obj:`"inductive"`, loads the inductive dataset. (default: :obj:`"transductive"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ def __init__( self, root: str, setting: str = 'transductive', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: if setting not in {'transductive', 'inductive'}: raise ValueError(f"Invalid 'setting' argument (got '{setting}')") self.setting = setting self.urls = [ ('https://www.dropbox.com/s/7jp4ib8zo3i6m10/' 'wikidata5m_text.txt.gz?dl=1'), 'https://uni-bielefeld.sciebo.de/s/yuBKzBxsEc9j3hy/download', ] if self.setting == 'inductive': self.urls.append('https://www.dropbox.com/s/csed3cgal3m7rzo/' 'wikidata5m_inductive.tar.gz?dl=1') else: self.urls.append('https://www.dropbox.com/s/6sbhm0rwo4l73jq/' 'wikidata5m_transductive.tar.gz?dl=1') super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return [ 'wikidata5m_text.txt.gz', 'download', f'wikidata5m_{self.setting}_train.txt', f'wikidata5m_{self.setting}_valid.txt', f'wikidata5m_{self.setting}_test.txt', ] @property def processed_file_names(self) -> str: return f'{self.setting}_data.pt' def download(self) -> None: for url in self.urls: download_url(url, self.raw_dir) path = osp.join(self.raw_dir, f'wikidata5m_{self.setting}.tar.gz') extract_tar(path, self.raw_dir) os.remove(path) def process(self) -> None: import gzip entity_to_id: Dict[str, int] = {} with gzip.open(self.raw_paths[0], 'rt') as f: for i, line in enumerate(f): values = line.strip().split('\t') entity_to_id[values[0]] = i x = fs.torch_load(self.raw_paths[1]) edge_indices = [] edge_types = [] split_indices = [] rel_to_id: Dict[str, int] = {} for split, path in enumerate(self.raw_paths[2:]): with open(path) as f: for line in f: head, rel, tail = line[:-1].split('\t') src = entity_to_id[head] dst = entity_to_id[tail] edge_indices.append([src, dst]) if rel not in rel_to_id: rel_to_id[rel] = len(rel_to_id) edge_types.append(rel_to_id[rel]) split_indices.append(split) edge_index = torch.tensor(edge_indices).t().contiguous() edge_type = torch.tensor(edge_types) split_index = torch.tensor(split_indices) data = Data( x=x, edge_index=edge_index, edge_type=edge_type, train_mask=split_index == 0, val_mask=split_index == 1, test_mask=split_index == 2, ) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/wikipedia_network.py ================================================ import os.path as osp from typing import Callable, List, Optional, Union import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.utils import coalesce class WikipediaNetwork(InMemoryDataset): r"""The Wikipedia networks introduced in the `"Multi-scale Attributed Node Embedding" `_ paper. Nodes represent web pages and edges represent hyperlinks between them. Node features represent several informative nouns in the Wikipedia pages. The task is to predict the average daily traffic of the web page. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"chameleon"`, :obj:`"crocodile"`, :obj:`"squirrel"`). geom_gcn_preprocess (bool): If set to :obj:`True`, will load the pre-processed data as introduced in the `"Geom-GCN: Geometric Graph Convolutional Networks" _`, in which the average monthly traffic of the web page is converted into five categories to predict. If set to :obj:`True`, the dataset :obj:`"crocodile"` is not available. If set to :obj:`True`, train/validation/test splits will be available as masks for multiple splits with shape :obj:`[num_nodes, num_splits]`. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ raw_url = 'https://graphmining.ai/datasets/ptg/wiki' processed_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/' 'geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f') def __init__( self, root: str, name: str, geom_gcn_preprocess: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.name = name.lower() self.geom_gcn_preprocess = geom_gcn_preprocess assert self.name in ['chameleon', 'crocodile', 'squirrel'] if geom_gcn_preprocess and self.name == 'crocodile': raise AttributeError("The dataset 'crocodile' is not available in " "case 'geom_gcn_preprocess=True'") super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: if self.geom_gcn_preprocess: return osp.join(self.root, self.name, 'geom_gcn', 'raw') else: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: if self.geom_gcn_preprocess: return osp.join(self.root, self.name, 'geom_gcn', 'processed') else: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> Union[List[str], str]: if self.geom_gcn_preprocess: return (['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)]) else: return f'{self.name}.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: if self.geom_gcn_preprocess: for filename in self.raw_file_names[:2]: url = f'{self.processed_url}/new_data/{self.name}/{filename}' download_url(url, self.raw_dir) for filename in self.raw_file_names[2:]: url = f'{self.processed_url}/splits/{filename}' download_url(url, self.raw_dir) else: download_url(f'{self.raw_url}/{self.name}.npz', self.raw_dir) def process(self) -> None: if self.geom_gcn_preprocess: with open(self.raw_paths[0]) as f: lines = f.read().split('\n')[1:-1] xs = [[float(value) for value in line.split('\t')[1].split(',')] for line in lines] x = torch.tensor(xs, dtype=torch.float) ys = [int(line.split('\t')[2]) for line in lines] y = torch.tensor(ys, dtype=torch.long) with open(self.raw_paths[1]) as f: lines = f.read().split('\n')[1:-1] edge_indices = [[int(value) for value in line.split('\t')] for line in lines] edge_index = torch.tensor(edge_indices).t().contiguous() edge_index = coalesce(edge_index, num_nodes=x.size(0)) train_masks, val_masks, test_masks = [], [], [] for filepath in self.raw_paths[2:]: masks = np.load(filepath) train_masks += [torch.from_numpy(masks['train_mask'])] val_masks += [torch.from_numpy(masks['val_mask'])] test_masks += [torch.from_numpy(masks['test_mask'])] train_mask = torch.stack(train_masks, dim=1).to(torch.bool) val_mask = torch.stack(val_masks, dim=1).to(torch.bool) test_mask = torch.stack(test_masks, dim=1).to(torch.bool) data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) else: raw_data = np.load(self.raw_paths[0], 'r', allow_pickle=True) x = torch.from_numpy(raw_data['features']).to(torch.float) edge_index = torch.from_numpy(raw_data['edges']).to(torch.long) edge_index = edge_index.t().contiguous() edge_index = coalesce(edge_index, num_nodes=x.size(0)) y = torch.from_numpy(raw_data['target']).to(torch.float) data = Data(x=x, edge_index=edge_index, y=y) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/willow_object_class.py ================================================ import glob import os import os.path as osp from typing import Callable, List, Optional import torch import torch.nn.functional as F from torch import Tensor from torch.utils.data import DataLoader from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class WILLOWObjectClass(InMemoryDataset): r"""The WILLOW-ObjectClass dataset from the `"Learning Graphs to Match" `_ paper, containing 10 equal keypoints of at least 40 images in each category. The keypoints contain interpolated features from a pre-trained VGG16 model on ImageNet (:obj:`relu4_2` and :obj:`relu5_1`). Args: root (str): Root directory where the dataset should be saved. category (str): The category of the images (one of :obj:`"Car"`, :obj:`"Duck"`, :obj:`"Face"`, :obj:`"Motorbike"`, :obj:`"Winebottle"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) device (str or torch.device, optional): The device to use for processing the raw data. If set to :obj:`None`, will utilize GPU-processing if available. (default: :obj:`None`) """ url = ('http://www.di.ens.fr/willow/research/graphlearning/' 'WILLOW-ObjectClass_dataset.zip') categories = ['face', 'motorbike', 'car', 'duck', 'winebottle'] batch_size = 32 def __init__( self, root: str, category: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, device: Optional[str] = None, ) -> None: if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' assert category.lower() in self.categories self.category = category self.device = device super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.category.capitalize(), 'processed') @property def raw_file_names(self) -> List[str]: return [category.capitalize() for category in self.categories] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) os.unlink(osp.join(self.root, 'README')) os.unlink(osp.join(self.root, 'demo_showAnno.m')) fs.rm(self.raw_dir) os.rename(osp.join(self.root, 'WILLOW-ObjectClass'), self.raw_dir) def process(self) -> None: import torchvision.models as models import torchvision.transforms as T from PIL import Image from scipy.io import loadmat category = self.category.capitalize() names = glob.glob(osp.join(self.raw_dir, category, '*.png')) names = sorted([name[:-4] for name in names]) vgg16_outputs = [] def hook(module: torch.nn.Module, x: Tensor, y: Tensor) -> None: vgg16_outputs.append(y.to('cpu')) vgg16 = models.vgg16(pretrained=True).to(self.device) vgg16.eval() vgg16.features[20].register_forward_hook(hook) # relu4_2 vgg16.features[25].register_forward_hook(hook) # relu5_1 transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) data_list = [] for name in names: pos = loadmat(f'{name}.mat')['pts_coord'] x, y = torch.from_numpy(pos).to(torch.float) pos = torch.stack([x, y], dim=1) # The "face" category contains a single image with less than 10 # keypoints, so we need to skip it. if pos.size(0) != 10: continue with open(f'{name}.png', 'rb') as f: img = Image.open(f).convert('RGB') # Rescale keypoints. pos[:, 0] = pos[:, 0] * 256.0 / (img.size[0]) pos[:, 1] = pos[:, 1] * 256.0 / (img.size[1]) img = img.resize((256, 256), resample=Image.Resampling.BICUBIC) img = transform(img) data = Data(img=img, pos=pos, name=name) data_list.append(data) imgs = [data.img for data in data_list] loader = DataLoader( dataset=imgs, # type: ignore batch_size=self.batch_size, shuffle=False, ) for i, batch_img in enumerate(loader): vgg16_outputs.clear() with torch.no_grad(): vgg16(batch_img.to(self.device)) out1 = F.interpolate(vgg16_outputs[0], (256, 256), mode='bilinear', align_corners=False) out2 = F.interpolate(vgg16_outputs[1], (256, 256), mode='bilinear', align_corners=False) for j in range(out1.size(0)): data = data_list[i * self.batch_size + j] assert data.pos is not None idx = data.pos.round().long().clamp(0, 255) x_1 = out1[j, :, idx[:, 1], idx[:, 0]].to('cpu') x_2 = out2[j, :, idx[:, 1], idx[:, 0]].to('cpu') data.img = None data.x = torch.cat([x_1.t(), x_2.t()], dim=-1) del out1 del out2 if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] self.save(data_list, self.processed_paths[0]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'category={self.category})') ================================================ FILE: torch_geometric/datasets/word_net.py ================================================ from itertools import chain from typing import Callable, List, Optional import torch from torch_geometric.data import Data, InMemoryDataset, download_url from torch_geometric.utils import index_sort class WordNet18(InMemoryDataset): r"""The WordNet18 dataset from the `"Translating Embeddings for Modeling Multi-Relational Data" `_ paper, containing 40,943 entities, 18 relations and 151,442 fact triplets, *e.g.*, furniture includes bed. .. note:: The original :obj:`WordNet18` dataset suffers from test leakage, *i.e.* more than 80% of test triplets can be found in the training set with another relation type. Therefore, it should not be used for research evaluation anymore. We recommend to use its cleaned version :class:`~torch_geometric.datasets.WordNet18RR` instead. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://raw.githubusercontent.com/villmow/' 'datasets_knowledge_embedding/master/WN18/original') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['train.txt', 'valid.txt', 'test.txt'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for filename in self.raw_file_names: download_url(f'{self.url}/{filename}', self.raw_dir) def process(self) -> None: srcs, dsts, edge_types = [], [], [] for path in self.raw_paths: with open(path) as f: edges = [int(x) for x in f.read().split()[1:]] edge = torch.tensor(edges, dtype=torch.long) srcs.append(edge[::3]) dsts.append(edge[1::3]) edge_types.append(edge[2::3]) src = torch.cat(srcs, dim=0) dst = torch.cat(dsts, dim=0) edge_type = torch.cat(edge_types, dim=0) train_mask = torch.zeros(src.size(0), dtype=torch.bool) train_mask[:srcs[0].size(0)] = True val_mask = torch.zeros(src.size(0), dtype=torch.bool) val_mask[srcs[0].size(0):srcs[0].size(0) + srcs[1].size(0)] = True test_mask = torch.zeros(src.size(0), dtype=torch.bool) test_mask[srcs[0].size(0) + srcs[1].size(0):] = True num_nodes = max(int(src.max()), int(dst.max())) + 1 _, perm = index_sort(num_nodes * src + dst) edge_index = torch.stack([src[perm], dst[perm]], dim=0) edge_type = edge_type[perm] train_mask = train_mask[perm] val_mask = val_mask[perm] test_mask = test_mask[perm] data = Data( edge_index=edge_index, edge_type=edge_type, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, num_nodes=num_nodes, ) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) class WordNet18RR(InMemoryDataset): r"""The WordNet18RR dataset from the `"Convolutional 2D Knowledge Graph Embeddings" `_ paper, containing 40,943 entities, 11 relations and 93,003 fact triplets. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://raw.githubusercontent.com/villmow/' 'datasets_knowledge_embedding/master/WN18RR/original') edge2id = { '_also_see': 0, '_derivationally_related_form': 1, '_has_part': 2, '_hypernym': 3, '_instance_hypernym': 4, '_member_meronym': 5, '_member_of_domain_region': 6, '_member_of_domain_usage': 7, '_similar_to': 8, '_synset_domain_topic_of': 9, '_verb_group': 10, } def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['train.txt', 'valid.txt', 'test.txt'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: for filename in self.raw_file_names: download_url(f'{self.url}/{filename}', self.raw_dir) def process(self) -> None: node2id, idx = {}, 0 srcs, dsts, edge_types = [], [], [] for path in self.raw_paths: with open(path) as f: edges = f.read().split() _src = edges[::3] _dst = edges[2::3] _edge_type = edges[1::3] for i in chain(_src, _dst): if i not in node2id: node2id[i] = idx idx += 1 srcs.append(torch.tensor([node2id[i] for i in _src])) dsts.append(torch.tensor([node2id[i] for i in _dst])) edge_types.append( torch.tensor([self.edge2id[i] for i in _edge_type])) src = torch.cat(srcs, dim=0) dst = torch.cat(dsts, dim=0) edge_type = torch.cat(edge_types, dim=0) train_mask = torch.zeros(src.size(0), dtype=torch.bool) train_mask[:srcs[0].size(0)] = True val_mask = torch.zeros(src.size(0), dtype=torch.bool) val_mask[srcs[0].size(0):srcs[0].size(0) + srcs[1].size(0)] = True test_mask = torch.zeros(src.size(0), dtype=torch.bool) test_mask[srcs[0].size(0) + srcs[1].size(0):] = True num_nodes = max(int(src.max()), int(dst.max())) + 1 _, perm = index_sort(num_nodes * src + dst) edge_index = torch.stack([src[perm], dst[perm]], dim=0) edge_type = edge_type[perm] train_mask = train_mask[perm] val_mask = val_mask[perm] test_mask = test_mask[perm] data = Data(edge_index=edge_index, edge_type=edge_type, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, num_nodes=num_nodes) if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/yelp.py ================================================ import json import os.path as osp from typing import Callable, List, Optional import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset, download_google_url class Yelp(InMemoryDataset): r"""The Yelp dataset from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" `_ paper, containing customer reviewers and their friendship. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #tasks * - 716,847 - 13,954,819 - 300 - 100 """ adj_full_id = '1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA' feats_id = '1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM' class_map_id = '1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU' role_id = '1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz') download_google_url(self.feats_id, self.raw_dir, 'feats.npy') download_google_url(self.class_map_id, self.raw_dir, 'class_map.json') download_google_url(self.role_id, self.raw_dir, 'role.json') def process(self) -> None: import scipy.sparse as sp f = np.load(osp.join(self.raw_dir, 'adj_full.npz')) adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape']) adj = adj.tocoo() row = torch.from_numpy(adj.row).to(torch.long) col = torch.from_numpy(adj.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) x = np.load(osp.join(self.raw_dir, 'feats.npy')) x = torch.from_numpy(x).to(torch.float) ys = [-1] * x.size(0) with open(osp.join(self.raw_dir, 'class_map.json')) as f: class_map = json.load(f) for key, item in class_map.items(): ys[int(key)] = item y = torch.tensor(ys) with open(osp.join(self.raw_dir, 'role.json')) as f: role = json.load(f) train_mask = torch.zeros(x.size(0), dtype=torch.bool) train_mask[torch.tensor(role['tr'])] = True val_mask = torch.zeros(x.size(0), dtype=torch.bool) val_mask[torch.tensor(role['va'])] = True test_mask = torch.zeros(x.size(0), dtype=torch.bool) test_mask[torch.tensor(role['te'])] = True data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) ================================================ FILE: torch_geometric/datasets/zinc.py ================================================ import os import os.path as osp import pickle from typing import Callable, List, Optional import torch from tqdm import tqdm from torch_geometric.data import ( Data, InMemoryDataset, download_url, extract_zip, ) from torch_geometric.io import fs class ZINC(InMemoryDataset): r"""The ZINC dataset from the `ZINC database `_ and the `"Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules" `_ paper, containing about 250,000 molecular graphs with up to 38 heavy atoms. The task is to regress the penalized :obj:`logP` (also called constrained solubility in some works), given by :obj:`y = logP - SAS - cycles`, where :obj:`logP` is the water-octanol partition coefficient, :obj:`SAS` is the synthetic accessibility score, and :obj:`cycles` denotes the number of cycles with more than six atoms. Penalized :obj:`logP` is a score commonly used for training molecular generation models, see, *e.g.*, the `"Junction Tree Variational Autoencoder for Molecular Graph Generation" `_ and `"Grammar Variational Autoencoder" `_ papers. Args: root (str): Root directory where the dataset should be saved. subset (bool, optional): If set to :obj:`True`, will only load a subset of the dataset (12,000 molecular graphs), following the `"Benchmarking Graph Neural Networks" `_ paper. (default: :obj:`False`) split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - ZINC Full - 249,456 - ~23.2 - ~49.8 - 1 - 1 * - ZINC Subset - 12,000 - ~23.2 - ~49.8 - 1 - 1 """ url = 'https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1' split_url = ('https://raw.githubusercontent.com/graphdeeplearning/' 'benchmarking-gnns/master/data/molecules/{}.index') def __init__( self, root: str, subset: bool = False, split: str = 'train', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.subset = subset assert split in ['train', 'val', 'test'] super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = osp.join(self.processed_dir, f'{split}.pt') self.load(path) @property def raw_file_names(self) -> List[str]: return [ 'train.pickle', 'val.pickle', 'test.pickle', 'train.index', 'val.index', 'test.index' ] @property def processed_dir(self) -> str: name = 'subset' if self.subset else 'full' return osp.join(self.root, name, 'processed') @property def processed_file_names(self) -> List[str]: return ['train.pt', 'val.pt', 'test.pt'] def download(self) -> None: fs.rm(self.raw_dir) path = download_url(self.url, self.root) extract_zip(path, self.root) os.rename(osp.join(self.root, 'molecules'), self.raw_dir) os.unlink(path) for split in ['train', 'val', 'test']: download_url(self.split_url.format(split), self.raw_dir) def process(self) -> None: for split in ['train', 'val', 'test']: with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f: mols = pickle.load(f) indices = list(range(len(mols))) if self.subset: with open(osp.join(self.raw_dir, f'{split}.index')) as f: indices = [int(x) for x in f.read()[:-1].split(',')] pbar = tqdm(total=len(indices)) pbar.set_description(f'Processing {split} dataset') data_list = [] for idx in indices: mol = mols[idx] x = mol['atom_type'].to(torch.long).view(-1, 1) y = mol['logP_SA_cycle_normalized'].to(torch.float) adj = mol['bond_type'] edge_index = adj.nonzero(as_tuple=False).t().contiguous() edge_attr = adj[edge_index[0], edge_index[1]].to(torch.long) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) pbar.update(1) pbar.close() self.save(data_list, osp.join(self.processed_dir, f'{split}.pt')) ================================================ FILE: torch_geometric/debug.py ================================================ from typing import Any __debug_flag__ = {'enabled': False} def is_debug_enabled() -> bool: r"""Returns :obj:`True` if the debug mode is enabled.""" return __debug_flag__['enabled'] def set_debug_enabled(mode: bool) -> None: __debug_flag__['enabled'] = mode class debug: r"""Context-manager that enables the debug mode to help track down errors and separate usage errors from real bugs. .. code-block:: python with torch_geometric.debug(): out = model(data.x, data.edge_index) """ def __init__(self) -> None: self.prev = is_debug_enabled() def __enter__(self) -> None: set_debug_enabled(True) def __exit__(self, *args: Any) -> None: set_debug_enabled(self.prev) class set_debug: r"""Context-manager that sets the debug mode on or off. :class:`set_debug` will enable or disable the debug mode based on its argument :attr:`mode`. It can be used as a context-manager or as a function. See :class:`debug` above for more details. """ def __init__(self, mode: bool) -> None: self.prev = is_debug_enabled() set_debug_enabled(mode) def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> None: set_debug_enabled(self.prev) ================================================ FILE: torch_geometric/deprecation.py ================================================ import functools import inspect import warnings from typing import Any, Callable, Optional def deprecated( details: Optional[str] = None, func_name: Optional[str] = None, ) -> Callable: def decorator(func: Callable) -> Callable: name = func_name or func.__name__ if inspect.isclass(func): cls = type(func.__name__, (func, ), {}) cls.__init__ = deprecated(details, name)( # type: ignore func.__init__) cls.__doc__ = func.__doc__ return cls @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: out = f"'{name}' is deprecated" if details is not None: out += f", {details}" warnings.warn(out, stacklevel=2) return func(*args, **kwargs) return wrapper return decorator ================================================ FILE: torch_geometric/device.py ================================================ from typing import Any import torch def is_mps_available() -> bool: r"""Returns a bool indicating if MPS is currently available.""" if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): try: # Github CI may not have access to MPS hardware. Confirm: torch.empty(1, device='mps') return True except Exception: return False return False def is_xpu_available() -> bool: r"""Returns a bool indicating if XPU is currently available.""" if hasattr(torch, 'xpu') and torch.xpu.is_available(): return True try: import intel_extension_for_pytorch as ipex return ipex.xpu.is_available() except ImportError: return False def device(device: Any) -> torch.device: r"""Returns a :class:`torch.device`. If :obj:`"auto"` is specified, returns the optimal device depending on available hardware. """ if device != 'auto': return torch.device(device) if torch.cuda.is_available(): return torch.device('cuda') if is_mps_available(): return torch.device('mps') if is_xpu_available(): return torch.device('xpu') return torch.device('cpu') ================================================ FILE: torch_geometric/distributed/__init__.py ================================================ from .dist_context import DistContext from .local_feature_store import LocalFeatureStore from .local_graph_store import LocalGraphStore from .partition import Partitioner from .dist_neighbor_sampler import DistNeighborSampler from .dist_loader import DistLoader from .dist_neighbor_loader import DistNeighborLoader from .dist_link_neighbor_loader import DistLinkNeighborLoader __all__ = classes = [ 'DistContext', 'LocalFeatureStore', 'LocalGraphStore', 'Partitioner', 'DistNeighborSampler', 'DistLoader', 'DistNeighborLoader', 'DistLinkNeighborLoader', ] ================================================ FILE: torch_geometric/distributed/dist_context.py ================================================ from dataclasses import dataclass from enum import Enum class DistRole(Enum): WORKER = 1 @dataclass class DistContext: r"""Context information of the current process.""" rank: int global_rank: int world_size: int global_world_size: int group_name: str role: DistRole = DistRole.WORKER @property def worker_name(self) -> str: return f'{self.group_name}-{self.rank}' ================================================ FILE: torch_geometric/distributed/dist_link_neighbor_loader.py ================================================ from typing import Callable, Dict, List, Optional, Tuple, Union from warnings import warn import torch from torch_geometric.distributed import ( DistContext, DistLoader, DistNeighborSampler, LocalFeatureStore, LocalGraphStore, ) from torch_geometric.loader import LinkLoader from torch_geometric.sampler.base import NegativeSampling, SubgraphType from torch_geometric.typing import EdgeType, InputEdges, OptTensor class DistLinkNeighborLoader(LinkLoader, DistLoader): r"""A distributed loader that performs sampling from edges. Args: data (tuple): A (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to :obj:`-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. master_addr (str): RPC address for distributed loader communication, *i.e.* the IP address of the master node. master_port (Union[int, str]): Open port for RPC communication with the master node. current_ctx (DistContext): Distributed context information of the current process. concurrency (int, optional): RPC concurrency used for defining the maximum size of the asynchronous processing queue. (default: :obj:`1`) All other arguments follow the interface of :class:`torch_geometric.loader.LinkNeighborLoader`. """ def __init__( self, data: Tuple[LocalFeatureStore, LocalGraphStore], num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], master_addr: str, master_port: Union[int, str], current_ctx: DistContext, edge_label_index: InputEdges = None, edge_label: OptTensor = None, edge_label_time: OptTensor = None, dist_sampler: Optional[DistNeighborSampler] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = "directional", disjoint: bool = False, temporal_strategy: str = "uniform", neg_sampling: Optional[NegativeSampling] = None, neg_sampling_ratio: Optional[Union[int, float]] = None, time_attr: Optional[str] = None, transform: Optional[Callable] = None, concurrency: int = 1, num_rpc_threads: int = 16, filter_per_worker: Optional[bool] = False, async_sampling: bool = True, device: Optional[torch.device] = None, **kwargs, ): assert isinstance(data[0], LocalFeatureStore) assert isinstance(data[1], LocalGraphStore) assert concurrency >= 1, "RPC concurrency must be greater than 1" if (edge_label_time is not None) != (time_attr is not None): raise ValueError( f"Received conflicting 'edge_label_time' and 'time_attr' " f"arguments: 'edge_label_time' is " f"{'set' if edge_label_time is not None else 'not set'} " f"while 'time_attr' is " f"{'set' if time_attr is not None else 'not set'}. " f"Both arguments must be provided for temporal sampling.") channel = torch.multiprocessing.Queue() if async_sampling else None if dist_sampler is None: dist_sampler = DistNeighborSampler( data=data, current_ctx=current_ctx, num_neighbors=num_neighbors, replace=replace, subgraph_type=subgraph_type, disjoint=disjoint, temporal_strategy=temporal_strategy, time_attr=time_attr, device=device, channel=channel, concurrency=concurrency, ) else: warn( # noqa: B028 "`torch_geometric.distributed` has been deprecated since 2.7.0 and will " # noqa: E501 "no longer be maintained. For distributed training, refer to our " # noqa: E501 "tutorials on distributed training at " "https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html " # noqa: E501 "or cuGraph examples at " "https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples", # noqa: E501 stack_level=2) DistLoader.__init__( self, channel=channel, master_addr=master_addr, master_port=master_port, current_ctx=current_ctx, dist_sampler=dist_sampler, num_rpc_threads=num_rpc_threads, **kwargs, ) LinkLoader.__init__( self, data=data, link_sampler=dist_sampler, edge_label_index=edge_label_index, edge_label=edge_label, edge_label_time=edge_label_time, neg_sampling=neg_sampling, neg_sampling_ratio=neg_sampling_ratio, transform=transform, filter_per_worker=filter_per_worker, worker_init_fn=self.worker_init_fn, transform_sampler_output=self.channel_get if channel else None, **kwargs, ) def __repr__(self) -> str: return DistLoader.__repr__(self) ================================================ FILE: torch_geometric/distributed/dist_loader.py ================================================ import atexit import logging import os from typing import Any, Optional, Union import torch.distributed import torch.multiprocessing as mp from torch_geometric.distributed import DistNeighborSampler from torch_geometric.distributed.dist_context import DistContext from torch_geometric.distributed.rpc import ( global_barrier, init_rpc, shutdown_rpc, ) from torch_geometric.loader.base import DataLoaderIterator class DistLoader: r"""A base class for creating distributed data loading routines. Args: current_ctx (DistContext): Distributed context info of the current process. master_addr (str, optional): RPC address for distributed loader communication. Refers to the IP address of the master node. (default: :obj:`None`) master_port (int or str, optional): The open port for RPC communication with the master node. (default: :obj:`None`) channel (mp.Queue, optional): A communication channel for messages. (default: :obj:`None`) num_rpc_threads (int, optional): The number of threads in the thread-pool used by :class:`~torch.distributed.rpc.TensorPipeAgent` to execute requests. (default: :obj:`16`) rpc_timeout (int, optional): The default timeout in seconds for RPC requests. If the RPC has not completed in this timeframe, an exception will be raised. Callers can override this timeout for individual RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and :meth:`~torch.distributed.rpc.rpc_async` if necessary. (default: :obj:`180`) """ def __init__( self, current_ctx: DistContext, master_addr: Optional[str] = None, master_port: Optional[Union[int, str]] = None, channel: Optional[mp.Queue] = None, num_rpc_threads: int = 16, rpc_timeout: int = 180, dist_sampler: DistNeighborSampler = None, **kwargs, ): if master_addr is None and os.environ.get('MASTER_ADDR') is not None: master_addr = os.environ['MASTER_ADDR'] if master_addr is None: raise ValueError(f"Missing master address for RPC communication " f"in '{self.__class__.__name__}'. Try to provide " f"it or set it via the 'MASTER_ADDR' environment " f"variable.") if master_port is None and os.environ.get('MASTER_PORT') is not None: # Select next port to MASTER_PORT used for DDP. # If multiple loaders are launched in the same script, please # provide distinct ports for each. master_port = int(os.environ['MASTER_PORT']) + 1 if master_port is None: raise ValueError(f"Missing master port for RPC communication in " f"'{self.__class__.__name__}'. Try to provide it " f"or set it via the 'MASTER_ADDR' environment " f"variable.") assert num_rpc_threads > 0 assert rpc_timeout > 0 self.dist_sampler = dist_sampler self.current_ctx = current_ctx self.master_addr = master_addr self.master_port = master_port self.channel = channel self.pid = mp.current_process().pid self.num_rpc_threads = num_rpc_threads self.rpc_timeout = rpc_timeout self.num_workers = kwargs.get('num_workers', 0) logging.info(f"[{self}] MASTER_ADDR={master_addr}, " f"MASTER_PORT={master_port}") if self.num_workers == 0: # Initialize RPC in main process: self.worker_init_fn(0) def channel_get(self, out: Any) -> Any: if self.channel: out = self.channel.get() logging.debug(f"[{self}] Retrieved message") return out def reset_channel(self, channel=None): # clean remaining queue items and restart new queue logging.debug(f'{self} Resetting msg channel') while not self.channel.empty(): self.channel.get_nowait() torch.distributed.barrier() self.channel = channel or mp.Queue() self.dist_sampler.channel = self.channel def worker_init_fn(self, worker_id: int): try: num_sampler_proc = self.num_workers if self.num_workers > 0 else 1 self.current_ctx_worker = DistContext( world_size=self.current_ctx.world_size * num_sampler_proc, rank=self.current_ctx.rank * num_sampler_proc + worker_id, global_world_size=self.current_ctx.world_size * num_sampler_proc, global_rank=self.current_ctx.rank * num_sampler_proc + worker_id, group_name='mp_sampling_worker', ) init_rpc( current_ctx=self.current_ctx_worker, master_addr=self.master_addr, master_port=self.master_port, num_rpc_threads=self.num_rpc_threads, rpc_timeout=self.rpc_timeout, ) logging.info( f"RPC initiated in worker-{worker_id} " f"(current_ctx_worker={self.current_ctx_worker.worker_name})") self.dist_sampler.init_sampler_instance() self.dist_sampler.register_sampler_rpc() global_barrier(timeout=10) # Wait for all workers to initialize. # close RPC & worker group at exit: atexit.register(shutdown_rpc, self.current_ctx_worker.worker_name) except RuntimeError as e: raise RuntimeError(f"`{self}.init_fn()` could not initialize the " f"worker loop of the neighbor sampler") from e def __repr__(self) -> str: return f'{self.__class__.__name__}(pid={self.pid})' def __enter__(self) -> DataLoaderIterator: # fetch a single batch for init self._prefetch_old = self.prefetch_factor self.prefetch_factor = 1 self._iterator = self._get_iterator() return self._iterator def __exit__(self, *args) -> None: if self.channel: self.reset_channel() if self._iterator: del self._iterator torch.distributed.barrier() self._iterator = None self.prefetch_factor = self._prefetch_old ================================================ FILE: torch_geometric/distributed/dist_neighbor_loader.py ================================================ from typing import Callable, Dict, List, Optional, Tuple, Union from warnings import warn import torch from torch_geometric.distributed import ( DistContext, DistLoader, DistNeighborSampler, LocalFeatureStore, LocalGraphStore, ) from torch_geometric.loader import NodeLoader from torch_geometric.sampler.base import SubgraphType from torch_geometric.typing import EdgeType, InputNodes, OptTensor class DistNeighborLoader(NodeLoader, DistLoader): r"""A distributed loader that performs sampling from nodes. Args: data (tuple): A (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to :obj:`-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. master_addr (str): RPC address for distributed loader communication, *i.e.* the IP address of the master node. master_port (Union[int, str]): Open port for RPC communication with the master node. current_ctx (DistContext): Distributed context information of the current process. concurrency (int, optional): RPC concurrency used for defining the maximum size of the asynchronous processing queue. (default: :obj:`1`) All other arguments follow the interface of :class:`torch_geometric.loader.NeighborLoader`. """ def __init__( self, data: Tuple[LocalFeatureStore, LocalGraphStore], num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], master_addr: str, master_port: Union[int, str], current_ctx: DistContext, input_nodes: InputNodes = None, input_time: OptTensor = None, dist_sampler: Optional[DistNeighborSampler] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = "directional", disjoint: bool = False, temporal_strategy: str = "uniform", time_attr: Optional[str] = None, transform: Optional[Callable] = None, concurrency: int = 1, num_rpc_threads: int = 16, filter_per_worker: Optional[bool] = False, async_sampling: bool = True, device: Optional[torch.device] = None, **kwargs, ): assert isinstance(data[0], LocalFeatureStore) assert isinstance(data[1], LocalGraphStore) assert concurrency >= 1, "RPC concurrency must be greater than 1" if input_time is not None and time_attr is None: raise ValueError("Received conflicting 'input_time' and " "'time_attr' arguments: 'input_time' is set " "while 'time_attr' is not set.") channel = torch.multiprocessing.Queue() if async_sampling else None if dist_sampler is None: dist_sampler = DistNeighborSampler( data=data, current_ctx=current_ctx, num_neighbors=num_neighbors, replace=replace, subgraph_type=subgraph_type, disjoint=disjoint, temporal_strategy=temporal_strategy, time_attr=time_attr, device=device, channel=channel, concurrency=concurrency, ) else: warn( # noqa: B028 "`torch_geometric.distributed` has been deprecated since 2.7.0 and will " # noqa: E501 "no longer be maintained. For distributed training, refer to our " # noqa: E501 "tutorials on distributed training at " "https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html " # noqa: E501 "or cuGraph examples at " "https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples", # noqa: E501 stack_level=2) DistLoader.__init__( self, channel=channel, master_addr=master_addr, master_port=master_port, current_ctx=current_ctx, dist_sampler=dist_sampler, num_rpc_threads=num_rpc_threads, **kwargs, ) NodeLoader.__init__( self, data=data, node_sampler=dist_sampler, input_nodes=input_nodes, input_time=input_time, transform=transform, filter_per_worker=filter_per_worker, transform_sampler_output=self.channel_get if channel else None, worker_init_fn=self.worker_init_fn, **kwargs, ) def __repr__(self) -> str: return DistLoader.__repr__(self) ================================================ FILE: torch_geometric/distributed/dist_neighbor_sampler.py ================================================ import itertools import logging import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn import numpy as np import torch import torch.multiprocessing as mp from torch import Tensor from torch_geometric.distributed import ( DistContext, LocalFeatureStore, LocalGraphStore, ) from torch_geometric.distributed.event_loop import ( ConcurrentEventLoop, to_asyncio_future, ) from torch_geometric.distributed.rpc import ( RPCCallBase, RPCRouter, rpc_async, rpc_partition_to_workers, rpc_register, ) from torch_geometric.distributed.utils import ( BatchDict, DistEdgeHeteroSamplerInput, NodeDict, remove_duplicates, ) from torch_geometric.sampler import ( EdgeSamplerInput, HeteroSamplerOutput, NegativeSampling, NeighborSampler, NodeSamplerInput, SamplerOutput, ) from torch_geometric.sampler.base import NumNeighbors, SubgraphType from torch_geometric.sampler.neighbor_sampler import neg_sample from torch_geometric.sampler.utils import remap_keys from torch_geometric.typing import EdgeType, NodeType NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]] class RPCSamplingCallee(RPCCallBase): r"""A wrapper for RPC callee that will perform RPC sampling from remote processes. """ def __init__(self, sampler: NeighborSampler): super().__init__() self.sampler = sampler def rpc_async(self, *args, **kwargs) -> Any: return self.sampler._sample_one_hop(*args, **kwargs) def rpc_sync(self, *args, **kwargs) -> Any: pass class DistNeighborSampler: r"""An implementation of a distributed and asynchronised neighbor sampler used by :class:`~torch_geometric.distributed.DistNeighborLoader` and :class:`~torch_geometric.distributed.DistLinkNeighborLoader`. """ def __init__( self, current_ctx: DistContext, data: Tuple[LocalFeatureStore, LocalGraphStore], num_neighbors: NumNeighborsType, channel: Optional[mp.Queue] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, concurrency: int = 1, device: Optional[torch.device] = None, **kwargs, ): warn( # noqa: B028 "`torch_geometric.distributed` has been deprecated since 2.7.0 and will " # noqa: E501 "no longer be maintained. For distributed training, refer to our " # noqa: E501 "tutorials on distributed training at " "https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html " # noqa: E501 "or cuGraph examples at " "https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples", # noqa: E501 stack_level=2) self.current_ctx = current_ctx self.feature_store, self.graph_store = data assert isinstance(self.graph_store, LocalGraphStore) assert isinstance(self.feature_store, LocalFeatureStore) self.is_hetero = self.graph_store.meta['is_hetero'] self.num_neighbors = num_neighbors self.channel = channel self.concurrency = concurrency self.device = device self.event_loop = None self.replace = replace self.subgraph_type = SubgraphType(subgraph_type) self.disjoint = disjoint self.temporal_strategy = temporal_strategy self.time_attr = time_attr self.temporal = time_attr is not None self.with_edge_attr = self.feature_store.has_edge_attr() self.csc = True def init_sampler_instance(self): self._sampler = NeighborSampler( data=(self.feature_store, self.graph_store), num_neighbors=self.num_neighbors, subgraph_type=self.subgraph_type, replace=self.replace, disjoint=self.disjoint, temporal_strategy=self.temporal_strategy, time_attr=self.time_attr, ) self.num_hops = self._sampler.num_neighbors.num_hops self.node_types = self._sampler.node_types self.edge_types = self._sampler.edge_types self.node_time = self._sampler.node_time self.edge_time = self._sampler.edge_time def register_sampler_rpc(self) -> None: partition2workers = rpc_partition_to_workers( current_ctx=self.current_ctx, num_partitions=self.graph_store.num_partitions, current_partition_idx=self.graph_store.partition_idx, ) self.rpc_router = RPCRouter(partition2workers) self.feature_store.set_rpc_router(self.rpc_router) rpc_sample_callee = RPCSamplingCallee(self) self.rpc_sample_callee_id = rpc_register(rpc_sample_callee) def init_event_loop(self) -> None: if self.event_loop is None: self.event_loop = ConcurrentEventLoop(self.concurrency) self.event_loop.start_loop() logging.info(f'{self} uses {self.event_loop}') # Node-based distributed sampling ######################################### def sample_from_nodes( self, inputs: NodeSamplerInput, **kwargs, ) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]: self.init_event_loop() inputs = NodeSamplerInput.cast(inputs) if self.channel is None: # synchronous sampling return self.event_loop.run_task( coro=self._sample_from(self.node_sample, inputs)) # asynchronous sampling cb = kwargs.get("callback", None) self.event_loop.add_task( coro=self._sample_from(self.node_sample, inputs), callback=cb) return None # Edge-based distributed sampling ######################################### def sample_from_edges( self, inputs: EdgeSamplerInput, neg_sampling: Optional[NegativeSampling] = None, **kwargs, ) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]: self.init_event_loop() if self.channel is None: # synchronous sampling return self.event_loop.run_task(coro=self._sample_from( self.edge_sample, inputs, self.node_sample, self._sampler. num_nodes, self.disjoint, self.node_time, neg_sampling)) # asynchronous sampling cb = kwargs.get("callback", None) self.event_loop.add_task( coro=self._sample_from(self.edge_sample, inputs, self.node_sample, self._sampler.num_nodes, self.disjoint, self.node_time, neg_sampling), callback=cb) return None async def _sample_from( self, async_func, *args, **kwargs, ) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]: sampler_output = await async_func(*args, **kwargs) if self.subgraph_type == SubgraphType.bidirectional: sampler_output = sampler_output.to_bidirectional() res = await self._collate_fn(sampler_output) if self.channel is None: return res self.channel.put(res) return None async def node_sample( self, inputs: Union[NodeSamplerInput, DistEdgeHeteroSamplerInput], ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Performs layer-by-layer distributed sampling from a :class:`NodeSamplerInput` or :class:`DistEdgeHeteroSamplerInput` and returns the output of the sampling procedure. .. note:: In case of distributed training it is required to synchronize the results between machines after each layer. """ input_type = inputs.input_type self.input_type = input_type if isinstance(inputs, NodeSamplerInput): seed = inputs.node.to(self.device) batch_size = len(inputs.node) seed_batch = torch.arange(batch_size) if self.disjoint else None metadata = (inputs.input_id, inputs.time, batch_size) seed_time: Optional[Tensor] = None if self.temporal: if inputs.time is not None: seed_time = inputs.time.to(self.device) elif self.node_time is not None: if not self.is_hetero: seed_time = self.node_time[seed] else: seed_time = self.node_time[input_type][seed] else: raise ValueError("Seed time needs to be specified") else: # `DistEdgeHeteroSamplerInput`: metadata = None # Metadata is added during `edge_sample`. # Heterogeneous Neighborhood Sampling ################################# if self.is_hetero: if input_type is None: raise ValueError("Input type should be defined") node_dict = NodeDict(self.node_types, self.num_hops) batch_dict = BatchDict(self.node_types, self.num_hops) if isinstance(inputs, NodeSamplerInput): seed_dict: Dict[NodeType, Tensor] = {input_type: seed} if self.temporal: node_dict.seed_time[input_type][0] = seed_time.clone() else: # `DistEdgeHeteroSamplerInput`: seed_dict = inputs.node_dict if self.temporal: for k, v in inputs.node_dict.items(): if inputs.time_dict is not None: node_dict.seed_time[k][0] = inputs.time_dict[k] elif self.node_time is not None: node_dict.seed_time[k][0] = self.node_time[k][v] else: raise ValueError("Seed time needs to be specified") edge_dict: Dict[EdgeType, Tensor] = { k: torch.empty(0, dtype=torch.int64) for k in self.edge_types } sampled_nbrs_per_node_dict: Dict[EdgeType, List[List]] = { k: [[] for _ in range(self.num_hops)] for k in self.edge_types } num_sampled_edges_dict: Dict[EdgeType, List[int]] = { k: [] for k in self.edge_types } num_sampled_nodes_dict: Dict[NodeType, List[int]] = { k: [0] for k in self.node_types } # Fill in node_dict and batch_dict with input data: batch_len = 0 for k, v in seed_dict.items(): node_dict.src[k][0] = v node_dict.out[k] = v num_sampled_nodes_dict[k][0] = len(v) if self.disjoint: src_batch = torch.arange(batch_len, batch_len + len(v)) batch_dict.src[k][0] = src_batch batch_dict.out[k] = src_batch batch_len = len(src_batch) # Loop over the layers: for i in range(self.num_hops): # Sample neighbors per edge type: for edge_type in self.edge_types: # `src` is a destination node type of a given edge. src = edge_type[0] if not self.csc else edge_type[2] if node_dict.src[src][i].numel() == 0: # No source nodes of this type in the current layer. num_sampled_edges_dict[edge_type].append(0) continue if isinstance(self.num_neighbors, list): one_hop_num = self.num_neighbors[i] else: one_hop_num = self.num_neighbors[edge_type][i] # Sample neighbors: out = await self.sample_one_hop( node_dict.src[src][i], one_hop_num, node_dict.seed_time[src][i], batch_dict.src[src][i], edge_type, ) if out.node.numel() == 0: # No neighbors were sampled. num_sampled_edges_dict[edge_type].append(0) continue # `dst` is a destination node type of a given edge. dst = edge_type[2] if not self.csc else edge_type[0] # Remove duplicates: ( src_node, node_dict.out[dst], src_batch, batch_dict.out[dst], ) = remove_duplicates( out, node_dict.out[dst], batch_dict.out[dst], self.disjoint, ) # Create src nodes for the next layer: node_dict.src[dst][i + 1] = torch.cat( [node_dict.src[dst][i + 1], src_node]) if self.disjoint: batch_dict.src[dst][i + 1] = torch.cat( [batch_dict.src[dst][i + 1], src_batch]) # Save sampled nodes with duplicates to be able to create # local edge indices: node_dict.with_dupl[dst] = torch.cat( [node_dict.with_dupl[dst], out.node]) edge_dict[edge_type] = torch.cat( [edge_dict[edge_type], out.edge]) if self.disjoint: batch_dict.with_dupl[dst] = torch.cat( [batch_dict.with_dupl[dst], out.batch]) if self.temporal and i < self.num_hops - 1: # Assign seed time based on source node subgraph ID: if isinstance(inputs, NodeSamplerInput): src_seed_time = [ seed_time[(seed_batch == batch_idx).nonzero()] for batch_idx in src_batch ] src_seed_time = torch.as_tensor( src_seed_time, dtype=torch.int64) else: # `DistEdgeHeteroSamplerInput`: src_seed_time = torch.empty(0, dtype=torch.int64) for k, v in batch_dict.src.items(): time = [ node_dict.seed_time[k][0][( v[0] == batch_idx).nonzero()] for batch_idx in src_batch ] try: time = torch.as_tensor( time, dtype=torch.int64) src_seed_time = torch.cat( [src_seed_time, time]) except Exception: # `time` may be an empty tensors, because # no nodes of this type were sampled. pass node_dict.seed_time[dst][i + 1] = torch.cat( [node_dict.seed_time[dst][i + 1], src_seed_time]) # Collect sampled neighbors per node for each layer: sampled_nbrs_per_node_dict[edge_type][i] += out.metadata[0] num_sampled_edges_dict[edge_type].append(len(out.node)) for node_type in self.node_types: num_sampled_nodes_dict[node_type].append( len(node_dict.src[node_type][i + 1])) sampled_nbrs_per_node_dict = remap_keys(sampled_nbrs_per_node_dict, self._sampler.to_rel_type) # Create local edge indices for a batch: row_dict, col_dict = torch.ops.pyg.hetero_relabel_neighborhood( self.node_types, self.edge_types, seed_dict, node_dict.with_dupl, sampled_nbrs_per_node_dict, self._sampler.num_nodes, batch_dict.with_dupl, self.csc, self.disjoint, ) sampler_output = HeteroSamplerOutput( node=node_dict.out, row=remap_keys(row_dict, self._sampler.to_edge_type), col=remap_keys(col_dict, self._sampler.to_edge_type), edge=edge_dict, batch=batch_dict.out if self.disjoint else None, num_sampled_nodes=num_sampled_nodes_dict, num_sampled_edges=num_sampled_edges_dict, metadata=metadata, ) # Homogeneous Neighborhood Sampling ################################### else: src = seed node = src.clone() src_batch = seed_batch.clone() if self.disjoint else None batch = seed_batch.clone() if self.disjoint else None src_seed_time = seed_time.clone() if self.temporal else None node_with_dupl = [torch.empty(0, dtype=torch.int64)] batch_with_dupl = [torch.empty(0, dtype=torch.int64)] edge = [torch.empty(0, dtype=torch.int64)] sampled_nbrs_per_node = [] num_sampled_nodes = [seed.numel()] num_sampled_edges = [] # Loop over the layers: for i, one_hop_num in enumerate(self.num_neighbors): out = await self.sample_one_hop(src, one_hop_num, src_seed_time, src_batch) if out.node.numel() == 0: # No neighbors were sampled: num_zero_layers = self.num_hops - i num_sampled_nodes += num_zero_layers * [0] num_sampled_edges += num_zero_layers * [0] break # Remove duplicates: src, node, src_batch, batch = remove_duplicates( out, node, batch, self.disjoint) node_with_dupl.append(out.node) edge.append(out.edge) if self.disjoint: batch_with_dupl.append(out.batch) if self.temporal and i < self.num_hops - 1: # Assign seed time based on src nodes subgraph IDs. src_seed_time = [ seed_time[(seed_batch == batch_idx).nonzero()] for batch_idx in src_batch ] src_seed_time = torch.as_tensor(src_seed_time, dtype=torch.int64) num_sampled_nodes.append(len(src)) num_sampled_edges.append(len(out.node)) sampled_nbrs_per_node += out.metadata[0] row, col = torch.ops.pyg.relabel_neighborhood( seed, torch.cat(node_with_dupl), sampled_nbrs_per_node, self._sampler.num_nodes, torch.cat(batch_with_dupl) if self.disjoint else None, self.csc, self.disjoint, ) sampler_output = SamplerOutput( node=node, row=row, col=col, edge=torch.cat(edge), batch=batch if self.disjoint else None, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, metadata=metadata, ) return sampler_output async def edge_sample( self, inputs: EdgeSamplerInput, sample_fn: Callable, num_nodes: Union[int, Dict[NodeType, int]], disjoint: bool, node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Performs layer-by-layer distributed sampling from an :class:`EdgeSamplerInput` and returns the output of the sampling procedure. .. note:: In case of distributed training it is required to synchronize the results between machines after each layer. """ input_id = inputs.input_id src = inputs.row dst = inputs.col edge_label = inputs.label edge_label_time = inputs.time input_type = inputs.input_type src_time = dst_time = edge_label_time assert edge_label_time is None or disjoint assert isinstance(num_nodes, (dict, int)) if not isinstance(num_nodes, dict): num_src_nodes = num_dst_nodes = num_nodes else: num_src_nodes = num_nodes[input_type[0]] num_dst_nodes = num_nodes[input_type[-1]] num_pos = src.numel() num_neg = 0 # Negative Sampling ################################################### if neg_sampling is not None: # When we are doing negative sampling, we append negative # information of nodes/edges to `src`, `dst`, `src_time`, # `dst_time`. Later on, we can easily reconstruct what belongs to # positive and negative examples by slicing via `num_pos`. num_neg = math.ceil(num_pos * neg_sampling.amount) if neg_sampling.is_binary(): # In the "binary" case, we randomly sample negative pairs of # nodes. if isinstance(node_time, dict): src_node_time = node_time.get(input_type[0]) else: src_node_time = node_time src_neg = neg_sample(src, neg_sampling, num_src_nodes, src_time, src_node_time) src = torch.cat([src, src_neg], dim=0) if isinstance(node_time, dict): dst_node_time = node_time.get(input_type[-1]) else: dst_node_time = node_time dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time, dst_node_time) dst = torch.cat([dst, dst_neg], dim=0) if edge_label is None: edge_label = torch.ones(num_pos) size = (num_neg, ) + edge_label.size()[1:] edge_neg_label = edge_label.new_zeros(size) edge_label = torch.cat([edge_label, edge_neg_label]) if edge_label_time is not None: src_time = dst_time = edge_label_time.repeat( 1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg] elif neg_sampling.is_triplet(): # In the "triplet" case, we randomly sample negative # destinations. if isinstance(node_time, dict): dst_node_time = node_time.get(input_type[-1]) else: dst_node_time = node_time dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time, dst_node_time) dst = torch.cat([dst, dst_neg], dim=0) assert edge_label is None if edge_label_time is not None: dst_time = edge_label_time.repeat(1 + neg_sampling.amount) # Heterogeneus Neighborhood Sampling ################################## if input_type is not None: if input_type[0] != input_type[-1]: # Two distinct node types: if not disjoint: src, inverse_src = src.unique(return_inverse=True) dst, inverse_dst = dst.unique(return_inverse=True) seed_dict = {input_type[0]: src, input_type[-1]: dst} seed_time_dict = None if edge_label_time is not None: # Always disjoint. seed_time_dict = { input_type[0]: src_time, input_type[-1]: dst_time, } out = await sample_fn( DistEdgeHeteroSamplerInput( input_id=inputs.input_id, node_dict=seed_dict, time_dict=seed_time_dict, input_type=input_type, )) else: # Only a single node type: Merge both source and destination. seed = torch.cat([src, dst], dim=0) if not disjoint: seed, inverse_seed = seed.unique(return_inverse=True) seed_dict = {input_type[0]: seed} seed_time = None if edge_label_time is not None: # Always disjoint. seed_time = torch.cat([src_time, dst_time], dim=0) out = await sample_fn( NodeSamplerInput( input_id=inputs.input_id, node=seed, time=seed_time, input_type=input_type[0], )) # Enhance `out` by label information ############################## if disjoint: for key, batch in out.batch.items(): out.batch[key] = batch % num_pos if neg_sampling is None or neg_sampling.is_binary(): if disjoint: if input_type[0] != input_type[-1]: edge_label_index = torch.arange(num_pos + num_neg) edge_label_index = edge_label_index.repeat(2) edge_label_index = edge_label_index.view(2, -1) else: num_labels = num_pos + num_neg edge_label_index = torch.arange(2 * (num_labels)) edge_label_index = edge_label_index.view(2, -1) else: if input_type[0] != input_type[-1]: edge_label_index = torch.stack([ inverse_src, inverse_dst, ], dim=0) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = (input_id, edge_label_index, edge_label, src_time) elif neg_sampling.is_triplet(): if disjoint: src_index = torch.arange(num_pos) if input_type[0] != input_type[-1]: dst_pos_index = torch.arange(num_pos) # `dst_neg_index` needs to be offset such that indices # with offset `num_pos` belong to the same triplet: dst_neg_index = torch.arange( num_pos, seed_dict[input_type[-1]].numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: dst_pos_index = torch.arange(num_pos, 2 * num_pos) dst_neg_index = torch.arange( 2 * num_pos, seed_dict[input_type[-1]].numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: if input_type[0] != input_type[-1]: src_index = inverse_src dst_pos_index = inverse_dst[:num_pos] dst_neg_index = inverse_dst[num_pos:] else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = ( input_id, src_index, dst_pos_index, dst_neg_index, src_time, ) # Homogeneous Neighborhood Sampling ################################### else: seed = torch.cat([src, dst], dim=0) seed_time = None if not disjoint: seed, inverse_seed = seed.unique(return_inverse=True) if edge_label_time is not None: # Always disjoint. seed_time = torch.cat([src_time, dst_time]) out = await sample_fn( NodeSamplerInput( input_id=inputs.input_id, node=seed, time=seed_time, input_type=None, )) # Enhance `out` by label information ############################## if neg_sampling is None or neg_sampling.is_binary(): if disjoint: out.batch = out.batch % num_pos edge_label_index = torch.arange(seed.numel()).view(2, -1) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = (input_id, edge_label_index, edge_label, src_time) elif neg_sampling.is_triplet(): if disjoint: out.batch = out.batch % num_pos src_index = torch.arange(num_pos) dst_pos_index = torch.arange(num_pos, 2 * num_pos) # `dst_neg_index` needs to be offset such that indices with # offset `num_pos` belong to the same triplet: dst_neg_index = torch.arange(2 * num_pos, seed.numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = ( input_id, src_index, dst_pos_index, dst_neg_index, src_time, ) return out def _get_sampler_output( self, outputs: List[SamplerOutput], seed_size: int, p_id: int, src_batch: Optional[Tensor] = None, ) -> SamplerOutput: r"""Used when seed nodes belongs to one partition. It's purpose is to remove seed nodes from sampled nodes and calculates how many neighbors were sampled by each src node based on the :obj:`cumsum_neighbors_per_node`. Returns updated sampler output. """ cumsum_neighbors_per_node = outputs[p_id].metadata[0] # do not include seed outputs[p_id].node = outputs[p_id].node[seed_size:] begin = np.array(cumsum_neighbors_per_node[1:]) end = np.array(cumsum_neighbors_per_node[:-1]) sampled_nbrs_per_node = list(np.subtract(begin, end)) outputs[p_id].metadata = (sampled_nbrs_per_node, ) if self.disjoint: batch = [[src_batch[i]] * nbrs_per_node for i, nbrs_per_node in enumerate(sampled_nbrs_per_node)] outputs[p_id].batch = Tensor( list(itertools.chain.from_iterable(batch))).type(torch.int64) return outputs[p_id] def _merge_sampler_outputs( self, partition_ids: Tensor, partition_orders: Tensor, outputs: List[SamplerOutput], one_hop_num: int, src_batch: Optional[Tensor] = None, ) -> SamplerOutput: r"""Merges samplers outputs from different partitions, so that they are sorted according to the sampling order. Removes seed nodes from sampled nodes and calculates how many neighbors were sampled by each src node based on the :obj:`cumsum_neighbors_per_node`. Leverages the :obj:`pyg-lib` :meth:`merge_sampler_outputs` function. Args: partition_ids (torch.Tensor): Contains information on which partition seeds nodes are located on. partition_orders (torch.Tensor): Contains information about the order of seed nodes in each partition. outputs (List[SamplerOutput]): List of all samplers outputs. one_hop_num (int): Max number of neighbors sampled in the current layer. src_batch (torch.Tensor, optional): The batch assignment of seed nodes. (default: :obj:`None`) Returns :obj:`SamplerOutput` containing all merged outputs. """ sampled_nodes_with_dupl = [ o.node if o is not None else torch.empty(0, dtype=torch.int64) for o in outputs ] edge_ids = [ o.edge if o is not None else torch.empty(0, dtype=torch.int64) for o in outputs ] cumm_sampled_nbrs_per_node = [ o.metadata[0] if o is not None else [] for o in outputs ] partition_ids = partition_ids.tolist() partition_orders = partition_orders.tolist() partitions_num = self.graph_store.meta["num_parts"] out = torch.ops.pyg.merge_sampler_outputs( sampled_nodes_with_dupl, edge_ids, cumm_sampled_nbrs_per_node, partition_ids, partition_orders, partitions_num, one_hop_num, src_batch, self.disjoint, ) ( out_node_with_dupl, out_edge, out_batch, out_sampled_nbrs_per_node, ) = out return SamplerOutput( out_node_with_dupl, None, None, out_edge, out_batch if self.disjoint else None, metadata=(out_sampled_nbrs_per_node, ), ) async def sample_one_hop( self, srcs: Tensor, one_hop_num: int, seed_time: Optional[Tensor] = None, src_batch: Optional[Tensor] = None, edge_type: Optional[EdgeType] = None, ) -> SamplerOutput: r"""Samples one-hop neighbors for a set of seed nodes in :obj:`srcs`. If seed nodes are located on a local partition, evaluates the sampling function on the current machine. If seed nodes are from a remote partition, sends a request to a remote machine that contains this partition. """ src_node_type = None if not self.is_hetero else edge_type[2] partition_ids = self.graph_store.get_partition_ids_from_nids( srcs, src_node_type) partition_orders = torch.zeros(len(partition_ids), dtype=torch.long) p_outputs: List[SamplerOutput] = [ None ] * self.graph_store.meta["num_parts"] futs: List[torch.futures.Future] = [] local_only = True single_partition = len(set(partition_ids.tolist())) == 1 for i in range(self.graph_store.num_partitions): p_id = (self.graph_store.partition_idx + i) % self.graph_store.num_partitions p_mask = partition_ids == p_id p_srcs = torch.masked_select(srcs, p_mask) p_seed_time = (torch.masked_select(seed_time, p_mask) if self.temporal else None) p_indices = torch.arange(len(p_srcs), dtype=torch.long) partition_orders[p_mask] = p_indices if p_srcs.shape[0] > 0: if p_id == self.graph_store.partition_idx: # Sample for one hop on a local machine: p_nbr_out = self._sample_one_hop(p_srcs, one_hop_num, p_seed_time, edge_type) p_outputs.pop(p_id) p_outputs.insert(p_id, p_nbr_out) else: # Sample on a remote machine: local_only = False to_worker = self.rpc_router.get_to_worker(p_id) futs.append( rpc_async( to_worker, self.rpc_sample_callee_id, args=(p_srcs, one_hop_num, p_seed_time, edge_type), )) if not local_only: # Src nodes are remote res_fut_list = await to_asyncio_future( torch.futures.collect_all(futs)) for i, res_fut in enumerate(res_fut_list): p_id = (self.graph_store.partition_idx + i + 1) % self.graph_store.num_partitions p_outputs.pop(p_id) p_outputs.insert(p_id, res_fut.wait()) # All src nodes are in the same partition if single_partition: return self._get_sampler_output(p_outputs, len(srcs), partition_ids[0], src_batch) return self._merge_sampler_outputs(partition_ids, partition_orders, p_outputs, one_hop_num, src_batch) def _sample_one_hop( self, input_nodes: Tensor, num_neighbors: int, seed_time: Optional[Tensor] = None, edge_type: Optional[EdgeType] = None, ) -> SamplerOutput: r"""Implements one-hop neighbor sampling for a set of input nodes for a specific edge type. """ if not self.is_hetero: colptr = self._sampler.colptr row = self._sampler.row node_time = self.node_time edge_time = self.edge_time else: # Given edge type, get input data and evaluate sample function: rel_type = '__'.join(edge_type) colptr = self._sampler.colptr_dict[rel_type] row = self._sampler.row_dict[rel_type] # `node_time` is a destination node time: node_time = (self.node_time or {}).get(edge_type[0], None) edge_time = (self.edge_time or {}).get(edge_type, None) out = torch.ops.pyg.dist_neighbor_sample( colptr, row, input_nodes.to(colptr.dtype), num_neighbors, node_time, edge_time, seed_time, None, # TODO: edge_weight True, # csc self.replace, self.subgraph_type != SubgraphType.induced, self.disjoint and self.temporal, self.temporal_strategy, ) node, edge, cumsum_neighbors_per_node = out if self.disjoint and self.temporal: # We create a batch during the step of merging sampler outputs. _, node = node.t().contiguous() return SamplerOutput( node=node, row=None, col=None, edge=edge, batch=None, metadata=(cumsum_neighbors_per_node, ), ) async def _collate_fn( self, output: Union[SamplerOutput, HeteroSamplerOutput] ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Collect labels and features for the sampled subgrarph if necessary, and put them into a sample message. """ if self.is_hetero: labels = {} nfeats = {} efeats = {} labels = self.feature_store.labels if labels is not None: if isinstance(self.input_type, tuple): # Edge labels. labels = { self.input_type: labels[output.edge[self.input_type]] } else: # Node labels. labels = { self.input_type: labels[output.node[self.input_type]] } # Collect node features. if output.node is not None: for ntype in output.node.keys(): if output.node[ntype].numel() > 0: fut = self.feature_store.lookup_features( is_node_feat=True, index=output.node[ntype], input_type=ntype, ) nfeat = await to_asyncio_future(fut) nfeat = nfeat.to(torch.device("cpu")) nfeats[ntype] = nfeat else: nfeats[ntype] = None # Collect edge features if output.edge is not None and self.with_edge_attr: for edge_type in output.edge.keys(): if output.edge[edge_type].numel() > 0: fut = self.feature_store.lookup_features( is_node_feat=False, index=output.edge[edge_type], input_type=edge_type, ) efeat = await to_asyncio_future(fut) efeat = efeat.to(torch.device("cpu")) efeats[edge_type] = efeat else: efeats[edge_type] = None else: # Homogeneous: # Collect node labels. if self.feature_store.labels is not None: labels = self.feature_store.labels[output.node] else: labels = None # Collect node features. if output.node is not None: fut = self.feature_store.lookup_features( is_node_feat=True, index=output.node) nfeats = await to_asyncio_future(fut) nfeats = nfeats.to(torch.device("cpu")) else: nfeats = None # Collect edge features. if output.edge is not None and self.with_edge_attr: fut = self.feature_store.lookup_features( is_node_feat=False, index=output.edge) efeats = await to_asyncio_future(fut) efeats = efeats.to(torch.device("cpu")) else: efeats = None output.metadata = (*output.metadata, nfeats, labels, efeats) return output @property def edge_permutation(self) -> None: return None def __repr__(self) -> str: return f'{self.__class__.__name__}(pid={mp.current_process().pid})' ================================================ FILE: torch_geometric/distributed/event_loop.py ================================================ import asyncio import atexit import logging from threading import BoundedSemaphore, Thread from typing import Callable, Optional import torch # Based on graphlearn-for-pytorch repository python/distributed/event_loop.py # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/ # LICENSE: Apache v2 def to_asyncio_future(future: torch.futures.Future) -> asyncio.futures.Future: r"""Convert a :class:`torch.futures.Future` to a :obj:`asyncio` future.""" loop = asyncio.get_event_loop() asyncio_future = loop.create_future() def on_done(*_): try: result = future.wait() except Exception as e: loop.call_soon_threadsafe(asyncio_future.set_exception, e) else: loop.call_soon_threadsafe(asyncio_future.set_result, result) future.add_done_callback(on_done) return asyncio_future class ConcurrentEventLoop: r"""Concurrent event loop context. Args: concurrency: max processing concurrency. """ def __init__(self, concurrency: int): self._concurrency = concurrency self._sem = BoundedSemaphore(concurrency) self._loop = asyncio.new_event_loop() self._runner_t = Thread(target=self._run_loop) self._runner_t.daemon = True def cleanup(): for _ in range(self._concurrency): self._sem.acquire() for _ in range(self._concurrency): self._sem.release() if self._runner_t.is_alive(): self._loop.stop() self._runner_t.join(timeout=1) logging.debug(f'{self}: Closed `ConcurrentEventLoop`') atexit.register(cleanup) def start_loop(self): if not self._runner_t.is_alive(): self._runner_t.start() def wait_all(self): r"""Wait for all pending tasks to be finished.""" for _ in range(self._concurrency): self._sem.acquire() for _ in range(self._concurrency): self._sem.release() def add_task(self, coro, callback: Optional[Callable] = None): r"""Adds an asynchronized coroutine task to run. Args: coro: The asynchronous coroutine function. callback (callable, optional): The callback function applied on the returned results after the coroutine task is finished. (default: :obj:`None`) Note that any result returned by :obj:`callback` will be ignored. """ def on_done(f: asyncio.futures.Future): try: res = f.result() if callback is not None: callback(res) except Exception as e: logging.error(f"Coroutine task failed with error: {e}") self._sem.release() self._sem.acquire() fut = asyncio.run_coroutine_threadsafe(coro, self._loop) fut.add_done_callback(on_done) def run_task(self, coro): r"""Runs a coroutine task synchronously. Args: coro: The synchronous coroutine function. """ with self._sem: fut = asyncio.run_coroutine_threadsafe(coro, self._loop) return fut.result() def _run_loop(self): self._loop.run_forever() ================================================ FILE: torch_geometric/distributed/local_feature_store.py ================================================ import copy import os.path as osp from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import FeatureStore, TensorAttr from torch_geometric.data.feature_store import _FieldStatus from torch_geometric.distributed.partition import load_partition_info from torch_geometric.distributed.rpc import ( RPCCallBase, RPCRouter, rpc_async, rpc_register, ) from torch_geometric.io import fs from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType class RPCCallFeatureLookup(RPCCallBase): r"""A wrapper for RPC calls to the feature store.""" def __init__(self, dist_feature: FeatureStore): super().__init__() self.dist_feature = dist_feature def rpc_async(self, *args, **kwargs): return self.dist_feature._rpc_local_feature_get(*args, **kwargs) def rpc_sync(self, *args, **kwargs): raise NotImplementedError @dataclass class LocalTensorAttr(TensorAttr): r"""Tensor attribute for storing features without :obj:`index`.""" def __init__( self, group_name: Optional[Union[NodeType, EdgeType]] = _FieldStatus.UNSET, attr_name: Optional[str] = _FieldStatus.UNSET, index=None, ): super().__init__(group_name, attr_name, index) class LocalFeatureStore(FeatureStore): r"""Implements the :class:`~torch_geometric.data.FeatureStore` interface to act as a local feature store for distributed training. """ def __init__(self): super().__init__(tensor_attr_cls=LocalTensorAttr) self._feat: Dict[Tuple[Union[NodeType, EdgeType], str], Tensor] = {} # Save the global node/edge IDs: self._global_id: Dict[Union[NodeType, EdgeType], Tensor] = {} # Save the mapping from global node/edge IDs to indices in `_feat`: self._global_id_to_index: Dict[Union[NodeType, EdgeType], Tensor] = {} # For partition/RPC info related to distributed features: self.num_partitions: int = 1 self.partition_idx: int = 0 # Mapping between node ID and partition ID: self.node_feat_pb: Union[Tensor, Dict[NodeType, Tensor]] # Mapping between edge ID and partition ID: self.edge_feat_pb: Union[Tensor, Dict[EdgeType, Tensor]] # Node labels: self.labels: Optional[Tensor] = None self.local_only: bool = False self.rpc_router: Optional[RPCRouter] = None self.meta: Optional[Dict] = None self.rpc_call_id: Optional[int] = None @staticmethod def key(attr: TensorAttr) -> Tuple[str, str]: return (attr.group_name, attr.attr_name) def put_global_id( self, global_id: Tensor, group_name: Union[NodeType, EdgeType], ) -> bool: self._global_id[group_name] = global_id self._set_global_id_to_index(group_name) return True def get_global_id( self, group_name: Union[NodeType, EdgeType], ) -> Optional[Tensor]: return self._global_id.get(group_name) def remove_global_id(self, group_name: Union[NodeType, EdgeType]) -> bool: return self._global_id.pop(group_name) is not None def _set_global_id_to_index(self, group_name: Union[NodeType, EdgeType]): global_id = self.get_global_id(group_name) if global_id is None: return # TODO Compute this mapping without materializing a full-sized tensor: global_id_to_index = global_id.new_full((int(global_id.max()) + 1, ), fill_value=-1) global_id_to_index[global_id] = torch.arange(global_id.numel()) self._global_id_to_index[group_name] = global_id_to_index def _put_tensor(self, tensor: Tensor, attr: TensorAttr) -> bool: assert attr.index is None self._feat[self.key(attr)] = tensor return True def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]: tensor = self._feat.get(self.key(attr)) if tensor is None: return None if attr.index is None: # Empty indices return the full tensor: return tensor return tensor[attr.index] def _remove_tensor(self, attr: TensorAttr) -> bool: assert attr.index is None return self._feat.pop(self.key(attr), None) is not None def get_tensor_from_global_id(self, *args, **kwargs) -> Optional[Tensor]: attr = self._tensor_attr_cls.cast(*args, **kwargs) assert attr.index is not None attr = copy.copy(attr) attr.index = self._global_id_to_index[attr.group_name][attr.index] return self.get_tensor(attr) def _get_tensor_size(self, attr: TensorAttr) -> Tuple[int, ...]: return self._get_tensor(attr).size() def get_all_tensor_attrs(self) -> List[LocalTensorAttr]: return [self._tensor_attr_cls.cast(*key) for key in self._feat.keys()] def set_rpc_router(self, rpc_router: RPCRouter): self.rpc_router = rpc_router if not self.local_only: if self.rpc_router is None: raise ValueError("An RPC router must be provided") rpc_call = RPCCallFeatureLookup(self) self.rpc_call_id = rpc_register(rpc_call) else: self.rpc_call_id = None def has_edge_attr(self) -> bool: has_edge_attr = False for k in [key for key in self._feat.keys() if 'edge_attr' in key]: try: self.get_tensor(k[0], 'edge_attr') has_edge_attr = True except KeyError: pass return has_edge_attr def lookup_features( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[NodeOrEdgeType] = None, ) -> torch.futures.Future: r"""Lookup of local/remote features.""" remote_fut = self._remote_lookup_features(index, is_node_feat, input_type) local_feature = self._local_lookup_features(index, is_node_feat, input_type) res_fut = torch.futures.Future() def when_finish(*_): try: remote_feature_list = remote_fut.wait() # combine the feature from remote and local result = torch.zeros( index.size(0), local_feature[0].size(1), dtype=local_feature[0].dtype, ) result[local_feature[1]] = local_feature[0] for remote in remote_feature_list: result[remote[1]] = remote[0] except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) remote_fut.add_done_callback(when_finish) return res_fut def _local_lookup_features( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[Union[NodeType, EdgeType]] = None, ) -> Tuple[Tensor, Tensor]: r"""Lookup the features in local nodes based on node/edge IDs.""" pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb input_order = torch.arange(index.size(0), dtype=torch.long) if self.meta['is_hetero']: partition_ids = pb[input_type][index] else: partition_ids = pb[index] local_mask = partition_ids == self.partition_idx local_ids = torch.masked_select(index, local_mask) local_index = torch.masked_select(input_order, local_mask) if self.meta['is_hetero']: if is_node_feat: kwargs = dict(group_name=input_type, attr_name='x') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) else: kwargs = dict(group_name=input_type, attr_name='edge_attr') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) else: if is_node_feat: kwargs = dict(group_name=None, attr_name='x') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) else: kwargs = dict(group_name=(None, None), attr_name='edge_attr') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) return ret_feat, local_index def _remote_lookup_features( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[Union[NodeType, EdgeType]] = None, ) -> torch.futures.Future: r"""Fetch the remote features with the remote node/edge IDs.""" pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb input_order = torch.arange(index.size(0), dtype=torch.long) if self.meta['is_hetero']: partition_ids = pb[input_type][index] else: partition_ids = pb[index] futs, indexes = [], [] for pidx in range(0, self.num_partitions): if pidx == self.partition_idx: continue remote_mask = partition_ids == pidx remote_ids = index[remote_mask] if remote_ids.shape[0] > 0: to_worker = self.rpc_router.get_to_worker(pidx) futs.append( rpc_async( to_worker, self.rpc_call_id, args=(remote_ids.cpu(), is_node_feat, input_type), )) indexes.append(torch.masked_select(input_order, remote_mask)) collect_fut = torch.futures.collect_all(futs) res_fut = torch.futures.Future() def when_finish(*_): try: fut_list = collect_fut.wait() result = [] for i, fut in enumerate(fut_list): result.append((fut.wait(), indexes[i])) except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) collect_fut.add_done_callback(when_finish) return res_fut def _rpc_local_feature_get( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[Union[NodeType, EdgeType]] = None, ) -> Tensor: r"""Lookup of features in remote nodes.""" if self.meta['is_hetero']: feat = self if is_node_feat: kwargs = dict(group_name=input_type, attr_name='x') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) else: kwargs = dict(group_name=input_type, attr_name='edge_attr') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) else: feat = self if is_node_feat: kwargs = dict(group_name=None, attr_name='x') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) else: kwargs = dict(group_name=(None, None), attr_name='edge_attr') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) return ret_feat # Initialization ########################################################## @classmethod def from_data( cls, node_id: Tensor, x: Optional[Tensor] = None, y: Optional[Tensor] = None, edge_id: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, ) -> 'LocalFeatureStore': r"""Creates a local feature store from homogeneous :pyg:`PyG` tensors. Args: node_id (torch.Tensor): The global identifier for every local node. x (torch.Tensor, optional): The node features. (default: :obj:`None`) y (torch.Tensor, optional): The node labels. (default: :obj:`None`) edge_id (torch.Tensor, optional): The global identifier for every local edge. (default: :obj:`None`) edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) """ feat_store = cls() feat_store.put_global_id(node_id, group_name=None) if x is not None: feat_store.put_tensor(x, group_name=None, attr_name='x') if y is not None: feat_store.put_tensor(y, group_name=None, attr_name='y') if edge_id is not None: feat_store.put_global_id(edge_id, group_name=(None, None)) if edge_attr is not None: if edge_id is None: raise ValueError("'edge_id' needs to be present in case " "'edge_attr' is passed") feat_store.put_tensor(edge_attr, group_name=(None, None), attr_name='edge_attr') return feat_store @classmethod def from_hetero_data( cls, node_id_dict: Dict[NodeType, Tensor], x_dict: Optional[Dict[NodeType, Tensor]] = None, y_dict: Optional[Dict[NodeType, Tensor]] = None, edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, ) -> 'LocalFeatureStore': r"""Creates a local graph store from heterogeneous :pyg:`PyG` tensors. Args: node_id_dict (Dict[NodeType, torch.Tensor]): The global identifier for every local node of every node type. x_dict (Dict[NodeType, torch.Tensor], optional): The node features of every node type. (default: :obj:`None`) y_dict (Dict[NodeType, torch.Tensor], optional): The node labels of every node type. (default: :obj:`None`) edge_id_dict (Dict[EdgeType, torch.Tensor], optional): The global identifier for every local edge of every edge types. (default: :obj:`None`) edge_attr_dict (Dict[EdgeType, torch.Tensor], optional): The edge features of every edge type. (default: :obj:`None`) """ feat_store = cls() for node_type, node_id in node_id_dict.items(): feat_store.put_global_id(node_id, group_name=node_type) if x_dict is not None: for node_type, x in x_dict.items(): feat_store.put_tensor(x, group_name=node_type, attr_name='x') if y_dict is not None: for node_type, y in y_dict.items(): feat_store.put_tensor(y, group_name=node_type, attr_name='y') if edge_id_dict is not None: for edge_type, edge_id in edge_id_dict.items(): feat_store.put_global_id(edge_id, group_name=edge_type) if edge_attr_dict is not None: for edge_type, edge_attr in edge_attr_dict.items(): if edge_id_dict is None or edge_type not in edge_id_dict: raise ValueError("'edge_id' needs to be present in case " "'edge_attr' is passed") feat_store.put_tensor(edge_attr, group_name=edge_type, attr_name='edge_attr') return feat_store @classmethod def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore': part_dir = osp.join(root, f'part_{pid}') assert osp.exists(part_dir) feat_store = cls() ( meta, num_partitions, partition_idx, node_pb, edge_pb, ) = load_partition_info(root, pid) feat_store.num_partitions = num_partitions feat_store.partition_idx = partition_idx feat_store.node_feat_pb = node_pb feat_store.edge_feat_pb = edge_pb feat_store.meta = meta node_feats: Optional[Dict[str, Any]] = None if osp.exists(osp.join(part_dir, 'node_feats.pt')): node_feats = fs.torch_load(osp.join(part_dir, 'node_feats.pt')) edge_feats: Optional[Dict[str, Any]] = None if osp.exists(osp.join(part_dir, 'edge_feats.pt')): edge_feats = fs.torch_load(osp.join(part_dir, 'edge_feats.pt')) if not meta['is_hetero'] and node_feats is not None: feat_store.put_global_id(node_feats['global_id'], group_name=None) for key, value in node_feats['feats'].items(): feat_store.put_tensor(value, group_name=None, attr_name=key) if 'time' in node_feats: feat_store.put_tensor(node_feats['time'], group_name=None, attr_name='time') if not meta['is_hetero'] and edge_feats is not None: if 'global_id' in edge_feats: feat_store.put_global_id(edge_feats['global_id'], group_name=(None, None)) if 'feats' in edge_feats: for key, value in edge_feats['feats'].items(): feat_store.put_tensor(value, group_name=(None, None), attr_name=key) if 'edge_time' in edge_feats: feat_store.put_tensor(edge_feats['edge_time'], group_name=(None, None), attr_name='edge_time') if meta['is_hetero'] and node_feats is not None: for node_type, node_feat in node_feats.items(): feat_store.put_global_id(node_feat['global_id'], group_name=node_type) for key, value in node_feat['feats'].items(): feat_store.put_tensor(value, group_name=node_type, attr_name=key) if 'time' in node_feat: feat_store.put_tensor(node_feat['time'], group_name=node_type, attr_name='time') if meta['is_hetero'] and edge_feats is not None: for edge_type, edge_feat in edge_feats.items(): if 'global_id' in edge_feat: feat_store.put_global_id(edge_feat['global_id'], group_name=edge_type) if 'feats' in edge_feat: for key, value in edge_feat['feats'].items(): feat_store.put_tensor(value, group_name=edge_type, attr_name=key) if 'edge_time' in edge_feat: feat_store.put_tensor(edge_feat['edge_time'], group_name=edge_type, attr_name='edge_time') return feat_store ================================================ FILE: torch_geometric/distributed/local_graph_store.py ================================================ import os.path as osp from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import EdgeAttr, GraphStore from torch_geometric.distributed.partition import load_partition_info from torch_geometric.io import fs from torch_geometric.typing import EdgeTensorType, EdgeType, NodeType from torch_geometric.utils import sort_edge_index class LocalGraphStore(GraphStore): r"""Implements the :class:`~torch_geometric.data.GraphStore` interface to act as a local graph store for distributed training. """ def __init__(self): super().__init__() self._edge_index: Dict[Tuple, EdgeTensorType] = {} self._edge_attr: Dict[Tuple, EdgeAttr] = {} self._edge_id: Dict[Tuple, Tensor] = {} self.num_partitions = 1 self.partition_idx = 0 # Mapping between node ID and partition ID self.node_pb: Union[Tensor, Dict[NodeType, Tensor]] = None # Mapping between edge ID and partition ID self.edge_pb: Union[Tensor, Dict[EdgeType, Tensor]] = None # Meta information related to partition and graph store info self.meta: Optional[Dict[Any, Any]] = None # If data is sorted based on destination nodes (CSC format): self.is_sorted: Optional[bool] = None @staticmethod def key(attr: EdgeAttr) -> Tuple: return (attr.edge_type, attr.layout.value) def get_partition_ids_from_nids( self, ids: torch.Tensor, node_type: Optional[NodeType] = None, ) -> Tensor: r"""Returns the partition IDs of node IDs for a specific node type.""" if self.meta['is_hetero']: return self.node_pb[node_type][ids] else: return self.node_pb[ids] def get_partition_ids_from_eids(self, eids: torch.Tensor, edge_type: Optional[EdgeType] = None): r"""Returns the partition IDs of edge IDs for a specific edge type.""" if self.meta['is_hetero']: return self.edge_pb[edge_type][eids] else: return self.edge_pb[eids] def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool: edge_attr = self._edge_attr_cls.cast(*args, **kwargs) self._edge_id[self.key(edge_attr)] = edge_id return True def get_edge_id(self, *args, **kwargs) -> Optional[EdgeTensorType]: edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._edge_id.get(self.key(edge_attr)) def remove_edge_id(self, *args, **kwargs) -> bool: edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._edge_id.pop(self.key(edge_attr), None) is not None def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: self._edge_index[self.key(edge_attr)] = edge_index self._edge_attr[self.key(edge_attr)] = edge_attr return True def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: return self._edge_index.get(self.key(edge_attr), None) def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: self._edge_attr.pop(self.key(edge_attr), None) return self._edge_index.pop(self.key(edge_attr), None) is not None def get_all_edge_attrs(self) -> List[EdgeAttr]: return [self._edge_attr[key] for key in self._edge_index.keys()] # Initialization ########################################################## @classmethod def from_data( cls, edge_id: Tensor, edge_index: Tensor, num_nodes: int, is_sorted: bool = False, ) -> 'LocalGraphStore': r"""Creates a local graph store from a homogeneous or heterogenous :pyg:`PyG` graph. Args: edge_id (torch.Tensor): The global identifier for every local edge. edge_index (torch.Tensor): The local edge indices. num_nodes (int): The number of nodes in the local graph. is_sorted (bool): Whether edges are sorted by column/destination nodes (CSC format). (default: :obj:`False`) """ graph_store = cls() graph_store.meta = {'is_hetero': False} if not is_sorted: edge_index, edge_id = sort_edge_index( edge_index, edge_id, sort_by_row=False, ) attr = dict( edge_type=None, layout='coo', size=(num_nodes, num_nodes), is_sorted=True, ) graph_store.put_edge_index(edge_index, **attr) graph_store.put_edge_id(edge_id, **attr) return graph_store @classmethod def from_hetero_data( cls, edge_id_dict: Dict[EdgeType, Tensor], edge_index_dict: Dict[EdgeType, Tensor], num_nodes_dict: Dict[NodeType, int], is_sorted: bool = False, ) -> "LocalGraphStore": r"""Creates a local graph store from a heterogeneous :pyg:`PyG` graph. Args: edge_id_dict (Dict[EdgeType, torch.Tensor]): The global identifier for every local edge of every edge type. edge_index_dict (Dict[EdgeType, torch.Tensor]): The local edge indices of every edge type. num_nodes_dict: (Dict[str, int]): The number of nodes for every node type. is_sorted (bool): Whether edges are sorted by column/destination nodes (CSC format). (default: :obj:`False`) """ graph_store = cls() graph_store.meta = {'is_hetero': True} for edge_type, edge_index in edge_index_dict.items(): src, _, dst = edge_type attr = dict( edge_type=edge_type, layout='coo', size=(num_nodes_dict[src], num_nodes_dict[dst]), is_sorted=True, ) edge_id = edge_id_dict[edge_type] if not is_sorted: edge_index, edge_id = sort_edge_index( edge_index, edge_id, sort_by_row=False, ) graph_store.put_edge_index(edge_index, **attr) graph_store.put_edge_id(edge_id, **attr) return graph_store @classmethod def from_partition(cls, root: str, pid: int) -> 'LocalGraphStore': part_dir = osp.join(root, f'part_{pid}') assert osp.exists(part_dir) graph_store = cls() ( meta, num_partitions, partition_idx, node_pb, edge_pb, ) = load_partition_info(root, pid) graph_store.num_partitions = num_partitions graph_store.partition_idx = partition_idx graph_store.node_pb = node_pb graph_store.edge_pb = edge_pb graph_store.meta = meta graph_data = fs.torch_load(osp.join(part_dir, 'graph.pt')) graph_store.is_sorted = meta['is_sorted'] if not meta['is_hetero']: edge_index = torch.stack((graph_data['row'], graph_data['col']), dim=0) edge_id = graph_data['edge_id'] if not graph_store.is_sorted: edge_index, edge_id = sort_edge_index(edge_index, edge_id, sort_by_row=False) attr = dict( edge_type=None, layout='coo', size=graph_data['size'], is_sorted=True, ) graph_store.put_edge_index(edge_index, **attr) graph_store.put_edge_id(edge_id, **attr) if meta['is_hetero']: for edge_type, data in graph_data.items(): attr = dict( edge_type=edge_type, layout='coo', size=data['size'], is_sorted=True, ) edge_index = torch.stack((data['row'], data['col']), dim=0) edge_id = data['edge_id'] if not graph_store.is_sorted: edge_index, edge_id = sort_edge_index( edge_index, edge_id, sort_by_row=False) graph_store.put_edge_index(edge_index, **attr) graph_store.put_edge_id(edge_id, **attr) return graph_store ================================================ FILE: torch_geometric/distributed/partition.py ================================================ import json import logging import os import os.path as osp from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union import torch import torch_geometric.distributed as pyg_dist from torch_geometric.data import Data, HeteroData from torch_geometric.io import fs from torch_geometric.loader.cluster import ClusterData from torch_geometric.sampler.utils import sort_csc from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType class Partitioner: r"""Partitions the graph and its features of a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object. Partitioned data output will be structured as shown below. **Homogeneous graphs:** .. code-block:: none root/ |-- META.json |-- node_map.pt |-- edge_map.pt |-- part0/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt |-- part1/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt **Heterogeneous graphs:** .. code-block:: none root/ |-- META.json |-- node_map/ |-- ntype1.pt |-- ntype2.pt |-- edge_map/ |-- etype1.pt |-- etype2.pt |-- part0/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt |-- part1/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt Args: data (Data or HeteroData): The data object. num_parts (int): The number of partitions. recursive (bool, optional): If set to :obj:`True`, will use multilevel recursive bisection instead of multilevel k-way partitioning. (default: :obj:`False`) root (str): Root directory where the partitioned dataset should be saved. """ def __init__( self, data: Union[Data, HeteroData], num_parts: int, root: str, recursive: bool = False, ): assert num_parts > 1 self.data = data self.num_parts = num_parts self.root = root self.recursive = recursive @property def is_hetero(self) -> bool: return isinstance(self.data, HeteroData) @property def is_node_level_time(self) -> bool: if 'time' not in self.data: return False if self.is_hetero: return any(['time' in store for store in self.data.node_stores]) return self.data.is_node_attr('time') @property def is_edge_level_time(self) -> bool: if 'edge_time' in self.data: return True if 'time' not in self.data: return False if self.is_hetero: return any(['time' in store for store in self.data.edge_stores]) return self.data.is_edge_attr('time') @property def node_types(self) -> Optional[List[NodeType]]: return self.data.node_types if self.is_hetero else None @property def edge_types(self) -> Optional[List[EdgeType]]: return self.data.edge_types if self.is_hetero else None def generate_partition(self): r"""Generates the partitions.""" os.makedirs(self.root, exist_ok=True) if self.is_hetero and self.is_node_level_time: time_data = { # Get temporal information before converting data: node_type: self.data[node_type].time for node_type in self.data.node_types } data = self.data.to_homogeneous() if self.is_hetero else self.data cluster_data = ClusterData( data, num_parts=self.num_parts, recursive=self.recursive, log=True, keep_inter_cluster_edges=True, sparse_format='csc', ) node_perm = cluster_data.partition.node_perm partptr = cluster_data.partition.partptr edge_perm = cluster_data.partition.edge_perm node_map = torch.empty(data.num_nodes, dtype=torch.int64) edge_map = torch.empty(data.num_edges, dtype=torch.int64) node_offset, edge_offset = {}, {} if self.is_hetero: offset = 0 for node_type in self.node_types: node_offset[node_type] = offset offset += self.data[node_type].num_nodes offset = 0 for edge_name in self.edge_types: edge_offset[edge_name] = offset offset += self.data.num_edges_dict[edge_name] edge_start = 0 for pid in range(self.num_parts): logging.info(f'Saving graph partition {pid}') path = osp.join(self.root, f'part_{pid}') os.makedirs(path, exist_ok=True) part_data = cluster_data[pid] start, end = int(partptr[pid]), int(partptr[pid + 1]) num_edges = part_data.num_edges edge_id = edge_perm[edge_start:edge_start + num_edges] edge_map[edge_id] = pid edge_start += num_edges node_id = node_perm[start:end] node_map[node_id] = pid graph = {} efeat = defaultdict(dict) for i, edge_type in enumerate(self.edge_types): # Row vector refers to source nodes. # Column vector refers to destination nodes. src, _, dst = edge_type size = (self.data[src].num_nodes, self.data[dst].num_nodes) mask = part_data.edge_type == i row = part_data.edge_index[0, mask] col = part_data.edge_index[1, mask] global_col = node_id[col] global_row = node_perm[row] edge_time = src_node_time = None if self.is_edge_level_time: if 'edge_time' in part_data: edge_time = part_data.edge_time[mask] elif 'time' in part_data: edge_time = part_data.time[mask] elif self.is_node_level_time: src_node_time = time_data[src] offsetted_row = global_row - node_offset[src] offsetted_col = global_col - node_offset[dst] # Sort by column to avoid keeping track of permutations in # `NeighborSampler` when converting to CSC format: offsetted_row, offsetted_col, perm = sort_csc( offsetted_row, offsetted_col, src_node_time, edge_time) global_eid = edge_id[mask][perm] assert torch.equal( data.edge_index[:, global_eid], torch.stack((offsetted_row + node_offset[src], offsetted_col + node_offset[dst]), dim=0), ) offsetted_eid = global_eid - edge_offset[edge_type] assert torch.equal( self.data[edge_type].edge_index[:, offsetted_eid], torch.stack(( offsetted_row, offsetted_col, ), dim=0), ) graph[edge_type] = { 'edge_id': global_eid, 'row': offsetted_row, 'col': offsetted_col, 'size': size, } if 'edge_attr' in part_data: edge_attr = part_data.edge_attr[mask][perm] efeat[edge_type].update({ 'global_id': offsetted_eid, 'feats': dict(edge_attr=edge_attr), }) if self.is_edge_level_time: efeat[edge_type].update({'edge_time': edge_time[perm]}) torch.save(efeat, osp.join(path, 'edge_feats.pt')) torch.save(graph, osp.join(path, 'graph.pt')) nfeat = {} for i, node_type in enumerate(self.node_types): mask = part_data.node_type == i x = part_data.x[mask] if 'x' in part_data else None nfeat[node_type] = { 'global_id': node_id[mask], 'id': node_id[mask] - node_offset[node_type], 'feats': dict(x=x), } if self.is_node_level_time: nfeat[node_type].update({'time': time_data[node_type]}) torch.save(nfeat, osp.join(path, 'node_feats.pt')) logging.info('Saving partition mapping info') path = osp.join(self.root, 'node_map') os.makedirs(path, exist_ok=True) for i, node_type in enumerate(self.node_types): mask = data.node_type == i torch.save(node_map[mask], osp.join(path, f'{node_type}.pt')) path = osp.join(self.root, 'edge_map') os.makedirs(path, exist_ok=True) for i, edge_type in enumerate(self.edge_types): mask = data.edge_type == i torch.save( edge_map[mask], osp.join(path, f'{EdgeTypeStr(edge_type)}.pt'), ) else: # `if not self.is_hetero:` edge_start = 0 for pid in range(self.num_parts): logging.info(f'Saving graph partition {pid}') path = osp.join(self.root, f'part_{pid}') os.makedirs(path, exist_ok=True) part_data = cluster_data[pid] start, end = int(partptr[pid]), int(partptr[pid + 1]) num_edges = part_data.num_edges edge_id = edge_perm[edge_start:edge_start + num_edges] edge_map[edge_id] = pid edge_start += num_edges node_id = node_perm[start:end] # global node_ids node_map[node_id] = pid # 0 or 1 row = part_data.edge_index[0] col = part_data.edge_index[1] global_col = node_id[col] # part_ids -> global global_row = node_perm[row] edge_time = node_time = None if self.is_edge_level_time: if 'edge_time' in part_data: edge_time = part_data.edge_time elif 'time' in part_data: edge_time = part_data.time elif self.is_node_level_time: node_time = data.time # Sort by column to avoid keeping track of permutations in # `NeighborSampler` when converting to CSC format: global_row, global_col, perm = sort_csc( global_row, global_col, node_time, edge_time) edge_id = edge_id[perm] assert torch.equal( self.data.edge_index[:, edge_id], torch.stack((global_row, global_col)), ) if 'edge_attr' in part_data: edge_attr = part_data.edge_attr[perm] assert torch.equal(self.data.edge_attr[edge_id, :], edge_attr) torch.save( { 'edge_id': edge_id, 'row': global_row, 'col': global_col, 'size': (data.num_nodes, data.num_nodes), }, osp.join(path, 'graph.pt')) nfeat = { 'global_id': node_id, 'feats': dict(x=part_data.x), } if self.is_node_level_time: nfeat.update({'time': data.time}) torch.save(nfeat, osp.join(path, 'node_feats.pt')) efeat = defaultdict() if 'edge_attr' in part_data: efeat.update({ 'global_id': edge_id, 'feats': dict(edge_attr=part_data.edge_attr[perm]), }) if self.is_edge_level_time: efeat.update({'edge_time': edge_time[perm]}) torch.save(efeat, osp.join(path, 'edge_feats.pt')) logging.info('Saving partition mapping info') torch.save(node_map, osp.join(self.root, 'node_map.pt')) torch.save(edge_map, osp.join(self.root, 'edge_map.pt')) logging.info('Saving metadata') meta = { 'num_parts': self.num_parts, 'node_types': self.node_types, 'edge_types': self.edge_types, 'node_offset': list(node_offset.values()) if node_offset else None, 'is_hetero': self.is_hetero, 'is_sorted': True, # Based on column/destination. } with open(osp.join(self.root, 'META.json'), 'w') as f: json.dump(meta, f) def load_partition_info( root_dir: str, partition_idx: int, ) -> Tuple[Dict, int, int, torch.Tensor, torch.Tensor]: # load the partition with PyG format (graphstore/featurestore) with open(osp.join(root_dir, 'META.json'), 'rb') as infile: meta = json.load(infile) num_partitions = meta['num_parts'] assert partition_idx >= 0 assert partition_idx < num_partitions partition_dir = osp.join(root_dir, f'part_{partition_idx}') assert osp.exists(partition_dir) if meta['is_hetero'] is False: node_pb = fs.torch_load(osp.join(root_dir, 'node_map.pt')) edge_pb = fs.torch_load(osp.join(root_dir, 'edge_map.pt')) return (meta, num_partitions, partition_idx, node_pb, edge_pb) else: node_pb_dict = {} node_pb_dir = osp.join(root_dir, 'node_map') for ntype in meta['node_types']: node_pb_dict[ntype] = fs.torch_load( osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt')) edge_pb_dict = {} edge_pb_dir = osp.join(root_dir, 'edge_map') for etype in meta['edge_types']: edge_pb_dict[tuple(etype)] = fs.torch_load( osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt')) return (meta, num_partitions, partition_idx, node_pb_dict, edge_pb_dict) ================================================ FILE: torch_geometric/distributed/rpc.py ================================================ import logging import threading from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional from torch.distributed import rpc from torch_geometric.distributed.dist_context import DistContext, DistRole try: from torch._C._distributed_rpc import _is_current_rpc_agent_set except Exception: def _is_current_rpc_agent_set() -> bool: return False _rpc_init_lock = threading.RLock() def rpc_is_initialized() -> bool: return _is_current_rpc_agent_set() def rpc_require_initialized(func: Callable) -> Callable: if hasattr(rpc, 'api'): return rpc.api._require_initialized(func) return func @rpc_require_initialized def global_all_gather(obj, timeout: Optional[int] = None) -> Any: r"""Gathers objects from all groups in a list.""" if timeout is None: return rpc.api._all_gather(obj) return rpc.api._all_gather(obj, timeout=timeout) @rpc_require_initialized def global_barrier(timeout: Optional[int] = None) -> None: r"""Block until all local and remote RPC processes.""" try: global_all_gather(obj=None, timeout=timeout) except RuntimeError: logging.error('Failed to respond to global barrier') def init_rpc( current_ctx: DistContext, master_addr: str, master_port: int, num_rpc_threads: int = 16, rpc_timeout: float = 240.0, rpc_worker_names: Optional[Dict[DistRole, List[str]]] = None, ): with _rpc_init_lock: if rpc_is_initialized(): return if current_ctx is None: raise RuntimeError("'dist_context' has not been set in 'init_rpc'") options = rpc.TensorPipeRpcBackendOptions( _transports=['ibv', 'uv'], _channels=['mpt_uv', 'basic'], num_worker_threads=num_rpc_threads, rpc_timeout=rpc_timeout, init_method=f'tcp://{master_addr}:{master_port}', ) rpc.init_rpc( name=current_ctx.worker_name, rank=current_ctx.global_rank, world_size=current_ctx.global_world_size, rpc_backend_options=options, ) global_barrier(timeout=rpc_timeout) def shutdown_rpc(id: str = None, graceful: bool = True, timeout: float = 240.0): with _rpc_init_lock: if rpc_is_initialized(): logging.info(f"Shutdown RPC in {id}" f"{' gracefully' if graceful else ''}") rpc.shutdown(graceful, timeout) else: logging.info(f'RPC in {id} not initialized.') class RPCRouter: r"""A router to get the worker based on the partition ID.""" def __init__(self, partition_to_workers: List[List[str]]): for rpc_worker_list in partition_to_workers: if len(rpc_worker_list) == 0: raise ValueError('No RPC worker is in worker list') self.partition_to_workers = partition_to_workers self.rpc_worker_indices = [0 for _ in range(len(partition_to_workers))] def get_to_worker(self, partition_idx: int) -> str: rpc_worker_list = self.partition_to_workers[partition_idx] worker_idx = self.rpc_worker_indices[partition_idx] router_worker = rpc_worker_list[worker_idx] self.rpc_worker_indices[partition_idx] = ((worker_idx + 1) % len(rpc_worker_list)) return router_worker @rpc_require_initialized def rpc_partition_to_workers( current_ctx: DistContext, num_partitions: int, current_partition_idx: int, ): r"""Performs an :obj:`all_gather` to get the mapping between partition and workers. """ ctx = current_ctx partition_to_workers = [[] for _ in range(num_partitions)] gathered_results = global_all_gather( (ctx.role, num_partitions, current_partition_idx)) for worker_name, (_, _, idx) in gathered_results.items(): partition_to_workers[idx].append(worker_name) return partition_to_workers class RPCCallBase(ABC): r"""A wrapper base class for RPC calls in remote processes.""" @abstractmethod def rpc_sync(self, *args, **kwargs): pass @abstractmethod def rpc_async(self, *args, **kwargs): pass _rpc_call_lock = threading.RLock() _rpc_call_id: int = 0 _rpc_call_pool: Dict[int, RPCCallBase] = {} @rpc_require_initialized def rpc_register(call: RPCCallBase) -> int: r"""Registers a call for RPC requests.""" global _rpc_call_id with _rpc_call_lock: call_id = _rpc_call_id _rpc_call_id += 1 if call_id in _rpc_call_pool: raise RuntimeError("Registered function twice in 'rpc_register'") _rpc_call_pool[call_id] = call return call_id def _rpc_async_call(call_id: int, *args, **kwargs): r"""Entry point for RPC requests.""" return _rpc_call_pool.get(call_id).rpc_async(*args, **kwargs) @rpc_require_initialized def rpc_async(worker_name: str, call_id: int, args=None, kwargs=None): r"""Performs an asynchronous RPC request and returns a future.""" return rpc.rpc_async( to=worker_name, func=_rpc_async_call, args=(call_id, *args), kwargs=kwargs, ) def _rpc_sync_call(call_id: int, *args, **kwargs): r"""Entry point for synchronous RPC requests.""" return _rpc_call_pool.get(call_id).rpc_sync(*args, **kwargs) @rpc_require_initialized def rpc_sync(worker_name: str, call_id: int, args=None, kwargs=None): r"""Performs a synchronous RPC request and returns a future.""" future = rpc.rpc_async( to=worker_name, func=_rpc_sync_call, args=(call_id, *args), kwargs=kwargs, ) return future.wait() ================================================ FILE: torch_geometric/distributed/utils.py ================================================ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch from torch import Tensor from torch_geometric.data import HeteroData from torch_geometric.distributed.local_feature_store import LocalFeatureStore from torch_geometric.distributed.local_graph_store import LocalGraphStore from torch_geometric.sampler import SamplerOutput from torch_geometric.typing import EdgeType, NodeType @dataclass class DistEdgeHeteroSamplerInput: r"""The sampling input of :meth:`~torch_geometric.dstributed.DistNeighborSampler.node_sample` used during distributed heterogeneous link sampling when source and target node types of an input edge are different. Args: input_id (torch.Tensor, optional): The indices of the data loader input of the current mini-batch. node_dict (Dict[NodeType, torch.Tensor]): The indices of seed nodes of a given node types to start sampling from. time_dict (Dict[NodeType, torch.Tensor], optional): The timestamp for the seed nodes of a given node types. (default: :obj:`None`) input_type (str, optional): The input node type. (default: :obj:`None`) """ input_id: Optional[Tensor] node_dict: Dict[NodeType, Tensor] time_dict: Optional[Dict[NodeType, Tensor]] = None input_type: Optional[EdgeType] = None class NodeDict: r"""Class used during heterogeneous sampling. 1) The nodes to serve as source nodes in the next layer. 2) The nodes with duplicates that are further needed to create COO output. 3) The output nodes without duplicates. """ def __init__(self, node_types, num_hops): self.src: Dict[NodeType, List[Tensor]] = { k: (num_hops + 1) * [torch.empty(0, dtype=torch.int64)] for k in node_types } self.with_dupl: Dict[NodeType, Tensor] = { k: torch.empty(0, dtype=torch.int64) for k in node_types } self.out: Dict[NodeType, Tensor] = { k: torch.empty(0, dtype=torch.int64) for k in node_types } self.seed_time: Dict[NodeType, List[Tensor]] = { k: num_hops * [torch.empty(0, dtype=torch.int64)] for k in node_types } class BatchDict: r"""Class used during disjoint heterogeneous sampling. 1) The batch to serve as initial subgraph IDs for source nodes in the next layer. 2) The subgraph IDs with duplicates that are further needed to create COO output. 3) The output subgraph IDs without duplicates. """ def __init__(self, node_types, num_hops): self.src: Dict[NodeType, List[Tensor]] = { k: (num_hops + 1) * [torch.empty(0, dtype=torch.int64)] for k in node_types } self.with_dupl: Dict[NodeType, Tensor] = { k: torch.empty(0, dtype=torch.int64) for k in node_types } self.out: Dict[NodeType, Tensor] = { k: torch.empty(0, dtype=torch.int64) for k in node_types } def remove_duplicates( out: SamplerOutput, node: Tensor, batch: Optional[Tensor] = None, disjoint: bool = False, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: num_nodes = node.numel() node_combined = torch.cat([node, out.node]) if not disjoint: _, idx = np.unique(node_combined.cpu().numpy(), return_index=True) idx = torch.from_numpy(idx).to(node.device).sort().values node = node_combined[idx] src = node[num_nodes:] return (src, node, None, None) else: batch_combined = torch.cat([batch, out.batch]) node_batch = torch.stack([batch_combined, node_combined], dim=0) _, idx = np.unique(node_batch.cpu().numpy(), axis=1, return_index=True) idx = torch.from_numpy(idx).to(node.device).sort().values batch = batch_combined[idx] node = node_combined[idx] src_batch = batch[num_nodes:] src = node[num_nodes:] return (src, node, src_batch, batch) def filter_dist_store( feature_store: LocalFeatureStore, graph_store: LocalGraphStore, node_dict: Dict[str, Tensor], row_dict: Dict[str, Tensor], col_dict: Dict[str, Tensor], edge_dict: Dict[str, Optional[Tensor]], custom_cls: Optional[HeteroData] = None, meta: Optional[Dict[str, Tensor]] = None, input_type: str = None, ) -> HeteroData: r"""Constructs a :class:`HeteroData` object from a feature store that only holds nodes in `node` end edges in `edge` for each node and edge type, respectively. Sorted attribute values are provided as metadata from :class:`DistNeighborSampler`. """ # Construct a new `HeteroData` object: data = custom_cls() if custom_cls is not None else HeteroData() nfeats, labels, efeats = meta[-3:] # Filter edge storage: required_edge_attrs = [] for attr in graph_store.get_all_edge_attrs(): key = attr.edge_type if key in row_dict and key in col_dict: required_edge_attrs.append(attr) edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0) data[attr.edge_type].edge_index = edge_index # Filter node storage: required_node_attrs = [] for attr in feature_store.get_all_tensor_attrs(): if attr.group_name in node_dict: attr.index = node_dict[attr.group_name] required_node_attrs.append(attr) data[attr.group_name].num_nodes = attr.index.size(0) if nfeats: for attr in required_node_attrs: if nfeats[attr.group_name] is not None: data[attr.group_name][attr.attr_name] = nfeats[attr.group_name] if efeats: for attr in required_edge_attrs: if efeats[attr.edge_type] is not None: data[attr.edge_type].edge_attr = efeats[attr.edge_type] if labels: data[input_type].y = labels[input_type] return data def as_str(inputs: Union[NodeType, EdgeType]) -> str: if isinstance(inputs, NodeType): return inputs elif isinstance(inputs, (list, tuple)) and len(inputs) == 3: return '__'.join(inputs) return '' def reverse_edge_type(etype: EdgeType) -> EdgeType: src, rel, dst = etype if src != dst: if rel.split('_', 1)[0] == 'rev': # undirected edge with `rev_` prefix. rel = rel.split('_', 1)[1] else: rel = 'rev_' + rel return dst, rel, src ================================================ FILE: torch_geometric/edge_index.py ================================================ import functools from enum import Enum from typing import ( Any, Callable, Dict, Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Type, Union, get_args, overload, ) import numpy as np import torch import torch.utils._pytree as pytree from torch import Tensor import torch_geometric.typing from torch_geometric import Index, is_compiling from torch_geometric.index import index2ptr, ptr2index from torch_geometric.typing import INDEX_DTYPES, SparseTensor aten = torch.ops.aten HANDLED_FUNCTIONS: Dict[Callable, Callable] = {} ReduceType = Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max'] PYG_REDUCE: Dict[ReduceType, ReduceType] = { 'add': 'sum', 'amin': 'min', 'amax': 'max' } TORCH_REDUCE: Dict[ReduceType, ReduceType] = { 'add': 'sum', 'min': 'amin', 'max': 'amax' } class SortOrder(Enum): ROW = 'row' COL = 'col' class CatMetadata(NamedTuple): nnz: List[int] sparse_size: List[Tuple[Optional[int], Optional[int]]] sort_order: List[Optional[SortOrder]] is_undirected: List[bool] def implements(torch_function: Callable) -> Callable: r"""Registers a :pytorch:`PyTorch` function override.""" @functools.wraps(torch_function) def decorator(my_function: Callable) -> Callable: HANDLED_FUNCTIONS[torch_function] = my_function return my_function return decorator def set_tuple_item( values: Tuple[Any, ...], dim: int, value: Any, ) -> Tuple[Any, ...]: if dim < -len(values) or dim >= len(values): raise IndexError("tuple index out of range") dim = dim + len(values) if dim < 0 else dim return values[:dim] + (value, ) + values[dim + 1:] def maybe_add( value: Sequence[Optional[int]], other: Union[int, Sequence[Optional[int]]], alpha: int = 1, ) -> Tuple[Optional[int], ...]: if isinstance(other, int): return tuple(v + alpha * other if v is not None else None for v in value) assert len(value) == len(other) return tuple(v + alpha * o if v is not None and o is not None else None for v, o in zip(value, other)) def maybe_sub( value: Sequence[Optional[int]], other: Union[int, Sequence[Optional[int]]], alpha: int = 1, ) -> Tuple[Optional[int], ...]: if isinstance(other, int): return tuple(v - alpha * other if v is not None else None for v in value) assert len(value) == len(other) return tuple(v - alpha * o if v is not None and o is not None else None for v, o in zip(value, other)) def assert_valid_dtype(tensor: Tensor) -> None: if tensor.dtype not in INDEX_DTYPES: raise ValueError(f"'EdgeIndex' holds an unsupported data type " f"(got '{tensor.dtype}', but expected one of " f"{INDEX_DTYPES})") def assert_two_dimensional(tensor: Tensor) -> None: if tensor.dim() != 2: raise ValueError(f"'EdgeIndex' needs to be two-dimensional " f"(got {tensor.dim()} dimensions)") if not torch.jit.is_tracing() and tensor.size(0) != 2: raise ValueError(f"'EdgeIndex' needs to have a shape of " f"[2, *] (got {list(tensor.size())})") def assert_contiguous(tensor: Tensor) -> None: if not tensor[0].is_contiguous() or not tensor[1].is_contiguous(): raise ValueError("'EdgeIndex' needs to be contiguous. Please call " "`edge_index.contiguous()` before proceeding.") def assert_symmetric(size: Tuple[Optional[int], Optional[int]]) -> None: if (not torch.jit.is_tracing() and size[0] is not None and size[1] is not None and size[0] != size[1]): raise ValueError(f"'EdgeIndex' is undirected but received a " f"non-symmetric size (got {list(size)})") def assert_sorted(func: Callable) -> Callable: @functools.wraps(func) def wrapper(self: 'EdgeIndex', *args: Any, **kwargs: Any) -> Any: if not self.is_sorted: cls_name = self.__class__.__name__ raise ValueError( f"Cannot call '{func.__name__}' since '{cls_name}' is not " f"sorted. Please call `{cls_name}.sort_by(...)` first.") return func(self, *args, **kwargs) return wrapper class EdgeIndex(Tensor): r"""A COO :obj:`edge_index` tensor with additional (meta)data attached. :class:`EdgeIndex` is a :pytorch:`null` :class:`torch.Tensor`, that holds an :obj:`edge_index` representation of shape :obj:`[2, num_edges]`. Edges are given as pairwise source and destination node indices in sparse COO format. While :class:`EdgeIndex` sub-classes a general :pytorch:`null` :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*: * :obj:`sparse_size`: The underlying sparse matrix size * :obj:`sort_order`: The sort order (if present), either by row or column. * :obj:`is_undirected`: Whether edges are bidirectional. Additionally, :class:`EdgeIndex` caches data for fast CSR or CSC conversion in case its representation is sorted, such as its :obj:`rowptr` or :obj:`colptr`, or the permutation vector for going from CSR to CSC or vice versa. Caches are filled based on demand (*e.g.*, when calling :meth:`EdgeIndex.sort_by`), or when explicitly requested via :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`). This representation ensures optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG` workflows. .. code-block:: python from torch_geometric import EdgeIndex edge_index = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), sort_order='row', is_undirected=True, device='cpu', ) >>> EdgeIndex([[0, 1, 1, 2], ... [1, 0, 2, 1]]) assert edge_index.is_sorted_by_row assert edge_index.is_undirected # Flipping order: edge_index = edge_index.flip(0) >>> EdgeIndex([[1, 0, 2, 1], ... [0, 1, 1, 2]]) assert edge_index.is_sorted_by_col assert edge_index.is_undirected # Filtering: mask = torch.tensor([True, True, True, False]) edge_index = edge_index[:, mask] >>> EdgeIndex([[1, 0, 2], ... [0, 1, 1]]) assert edge_index.is_sorted_by_col assert not edge_index.is_undirected # Sparse-Dense Matrix Multiplication: out = edge_index.flip(0) @ torch.randn(3, 16) assert out.size() == (3, 16) """ # See "https://pytorch.org/docs/stable/notes/extending.html" # for a basic tutorial on how to subclass `torch.Tensor`. # The underlying tensor representation: _data: Tensor # The size of the underlying sparse matrix: _sparse_size: Tuple[Optional[int], Optional[int]] = (None, None) # Whether the `edge_index` representation is non-sorted (`None`), or sorted # based on row or column values. _sort_order: Optional[SortOrder] = None # Whether the `edge_index` is undirected: # NOTE `is_undirected` allows us to assume symmetric adjacency matrix size # and to share compressed pointer representations, however, it does not # allow us get rid of CSR/CSC permutation vectors since ordering within # neighborhoods is not necessarily deterministic. _is_undirected: bool = False # A cache for its compressed representation: _indptr: Optional[Tensor] = None # A cache for its transposed representation: _T_perm: Optional[Tensor] = None _T_index: Tuple[Optional[Tensor], Optional[Tensor]] = (None, None) _T_indptr: Optional[Tensor] = None # A cached "1"-value vector for `torch.sparse` matrix multiplication: _value: Optional[Tensor] = None # Whenever we perform a concatenation of edge indices, we cache the # original metadata to be able to reconstruct individual edge indices: _cat_metadata: Optional[CatMetadata] = None @staticmethod def __new__( cls: Type, data: Any, *args: Any, sparse_size: Optional[Tuple[Optional[int], Optional[int]]] = None, sort_order: Optional[Union[str, SortOrder]] = None, is_undirected: bool = False, **kwargs: Any, ) -> 'EdgeIndex': if not isinstance(data, Tensor): data = torch.tensor(data, *args, **kwargs) elif len(args) > 0: raise TypeError( f"new() received an invalid combination of arguments - got " f"(Tensor, {', '.join(str(type(arg)) for arg in args)})") elif len(kwargs) > 0: raise TypeError(f"new() received invalid keyword arguments - got " f"{set(kwargs.keys())})") assert isinstance(data, Tensor) indptr: Optional[Tensor] = None if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata: indptr = data._indptr sparse_size = sparse_size or data.sparse_size() sort_order = sort_order or data.sort_order is_undirected = is_undirected or data.is_undirected # Convert `torch.sparse` tensors to `EdgeIndex` representation: if data.layout == torch.sparse_coo: sort_order = SortOrder.ROW sparse_size = sparse_size or (data.size(0), data.size(1)) data = data.indices() if data.layout == torch.sparse_csr: indptr = data.crow_indices() col = data.col_indices() assert isinstance(indptr, Tensor) row = ptr2index(indptr, output_size=col.numel()) sort_order = SortOrder.ROW sparse_size = sparse_size or (data.size(0), data.size(1)) if sparse_size[0] is not None and sparse_size[0] != data.size(0): indptr = None data = torch.stack([row, col], dim=0) if data.layout == torch.sparse_csc: row = data.row_indices() indptr = data.ccol_indices() assert isinstance(indptr, Tensor) col = ptr2index(indptr, output_size=row.numel()) sort_order = SortOrder.COL sparse_size = sparse_size or (data.size(0), data.size(1)) if sparse_size[1] is not None and sparse_size[1] != data.size(1): indptr = None data = torch.stack([row, col], dim=0) assert_valid_dtype(data) assert_two_dimensional(data) assert_contiguous(data) if sparse_size is None: sparse_size = (None, None) if is_undirected: assert_symmetric(sparse_size) if sparse_size[0] is not None and sparse_size[1] is None: sparse_size = (sparse_size[0], sparse_size[0]) elif sparse_size[0] is None and sparse_size[1] is not None: sparse_size = (sparse_size[1], sparse_size[1]) out = Tensor._make_wrapper_subclass( cls, size=data.size(), strides=data.stride(), dtype=data.dtype, device=data.device, layout=data.layout, requires_grad=False, ) assert isinstance(out, EdgeIndex) # Attach metadata: out._data = data out._sparse_size = sparse_size out._sort_order = None if sort_order is None else SortOrder(sort_order) out._is_undirected = is_undirected out._indptr = indptr if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata: out._data = data._data out._T_perm = data._T_perm out._T_index = data._T_index out._T_indptr = data._T_indptr out._value = out._value # Reset metadata if cache is invalidated: num_rows = sparse_size[0] if num_rows is not None and num_rows != data.sparse_size(0): out._indptr = None num_cols = sparse_size[1] if num_cols is not None and num_cols != data.sparse_size(1): out._T_indptr = None return out # Validation ############################################################## def validate(self) -> 'EdgeIndex': r"""Validates the :class:`EdgeIndex` representation. In particular, it ensures that * it only holds valid indices. * the sort order is correctly set. * indices are bidirectional in case it is specified as undirected. """ assert_valid_dtype(self._data) assert_two_dimensional(self._data) assert_contiguous(self._data) if self.is_undirected: assert_symmetric(self.sparse_size()) if self.numel() > 0 and self._data.min() < 0: raise ValueError(f"'{self.__class__.__name__}' contains negative " f"indices (got {int(self.min())})") if (self.numel() > 0 and self.num_rows is not None and self._data[0].max() >= self.num_rows): raise ValueError(f"'{self.__class__.__name__}' contains larger " f"indices than its number of rows " f"(got {int(self._data[0].max())}, but expected " f"values smaller than {self.num_rows})") if (self.numel() > 0 and self.num_cols is not None and self._data[1].max() >= self.num_cols): raise ValueError(f"'{self.__class__.__name__}' contains larger " f"indices than its number of columns " f"(got {int(self._data[1].max())}, but expected " f"values smaller than {self.num_cols})") if self.is_sorted_by_row and (self._data[0].diff() < 0).any(): raise ValueError(f"'{self.__class__.__name__}' is not sorted by " f"row indices") if self.is_sorted_by_col and (self._data[1].diff() < 0).any(): raise ValueError(f"'{self.__class__.__name__}' is not sorted by " f"column indices") if self.is_undirected: flat_index1 = self._data[0] * self.get_num_rows() + self._data[1] flat_index1 = flat_index1.sort()[0] flat_index2 = self._data[1] * self.get_num_cols() + self._data[0] flat_index2 = flat_index2.sort()[0] if not torch.equal(flat_index1, flat_index2): raise ValueError(f"'{self.__class__.__name__}' is not " f"undirected") return self # Properties ############################################################## @overload def sparse_size(self) -> Tuple[Optional[int], Optional[int]]: pass @overload def sparse_size(self, dim: int) -> Optional[int]: pass def sparse_size( self, dim: Optional[int] = None, ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: r"""The size of the underlying sparse matrix. If :obj:`dim` is specified, returns an integer holding the size of that sparse dimension. Args: dim (int, optional): The dimension for which to retrieve the size. (default: :obj:`None`) """ if dim is not None: return self._sparse_size[dim] return self._sparse_size @property def num_rows(self) -> Optional[int]: r"""The number of rows of the underlying sparse matrix.""" return self._sparse_size[0] @property def num_cols(self) -> Optional[int]: r"""The number of columns of the underlying sparse matrix.""" return self._sparse_size[1] @property def sort_order(self) -> Optional[str]: r"""The sort order of indices, either :obj:`"row"`, :obj:`"col"` or :obj:`None`. """ return None if self._sort_order is None else self._sort_order.value @property def is_sorted(self) -> bool: r"""Returns whether indices are either sorted by rows or columns.""" return self._sort_order is not None @property def is_sorted_by_row(self) -> bool: r"""Returns whether indices are sorted by rows.""" return self._sort_order == SortOrder.ROW @property def is_sorted_by_col(self) -> bool: r"""Returns whether indices are sorted by columns.""" return self._sort_order == SortOrder.COL @property def is_undirected(self) -> bool: r"""Returns whether indices are bidirectional.""" return self._is_undirected @property def dtype(self) -> torch.dtype: # type: ignore # TODO Remove once PyTorch does not override `dtype` in `DataLoader`. return self._data.dtype # Cache Interface ######################################################### @overload def get_sparse_size(self) -> torch.Size: pass @overload def get_sparse_size(self, dim: int) -> int: pass def get_sparse_size( self, dim: Optional[int] = None, ) -> Union[torch.Size, int]: r"""The size of the underlying sparse matrix. Automatically computed and cached when not explicitly set. If :obj:`dim` is specified, returns an integer holding the size of that sparse dimension. Args: dim (int, optional): The dimension for which to retrieve the size. (default: :obj:`None`) """ if dim is not None: size = self._sparse_size[dim] if size is not None: return size if self.is_undirected: size = int(self._data.max()) + 1 if self.numel() > 0 else 0 self._sparse_size = (size, size) return size size = int(self._data[dim].max()) + 1 if self.numel() > 0 else 0 self._sparse_size = set_tuple_item(self._sparse_size, dim, size) return size return torch.Size((self.get_sparse_size(0), self.get_sparse_size(1))) def sparse_resize_( # type: ignore self, num_rows: Optional[int], num_cols: Optional[int], ) -> 'EdgeIndex': r"""Assigns or re-assigns the size of the underlying sparse matrix. Args: num_rows (int, optional): The number of rows. num_cols (int, optional): The number of columns. """ if self.is_undirected: if num_rows is not None and num_cols is None: num_cols = num_rows elif num_cols is not None and num_rows is None: num_rows = num_cols if num_rows is not None and num_rows != num_cols: raise ValueError(f"'EdgeIndex' is undirected but received a " f"non-symmetric size " f"(got [{num_rows}, {num_cols}])") def _modify_ptr( ptr: Optional[Tensor], size: Optional[int], ) -> Optional[Tensor]: if ptr is None or size is None: return None if ptr.numel() - 1 >= size: return ptr[:size + 1] fill_value = ptr.new_full( (size - ptr.numel() + 1, ), fill_value=ptr[-1], # type: ignore ) return torch.cat([ptr, fill_value], dim=0) if self.is_sorted_by_row: self._indptr = _modify_ptr(self._indptr, num_rows) self._T_indptr = _modify_ptr(self._T_indptr, num_cols) if self.is_sorted_by_col: self._indptr = _modify_ptr(self._indptr, num_cols) self._T_indptr = _modify_ptr(self._T_indptr, num_rows) self._sparse_size = (num_rows, num_cols) return self def get_num_rows(self) -> int: r"""The number of rows of the underlying sparse matrix. Automatically computed and cached when not explicitly set. """ return self.get_sparse_size(0) def get_num_cols(self) -> int: r"""The number of columns of the underlying sparse matrix. Automatically computed and cached when not explicitly set. """ return self.get_sparse_size(1) @assert_sorted def get_indptr(self) -> Tensor: r"""Returns the compressed index representation in case :class:`EdgeIndex` is sorted. """ if self._indptr is not None: return self._indptr if self.is_undirected and self._T_indptr is not None: return self._T_indptr dim = 0 if self.is_sorted_by_row else 1 self._indptr = index2ptr(self._data[dim], self.get_sparse_size(dim)) return self._indptr @assert_sorted def _sort_by_transpose(self) -> Tuple[Tuple[Tensor, Tensor], Tensor]: from torch_geometric.utils import index_sort dim = 1 if self.is_sorted_by_row else 0 if self._T_perm is None: max_index = self.get_sparse_size(dim) index, perm = index_sort(self._data[dim], max_index) self._T_index = set_tuple_item(self._T_index, dim, index) self._T_perm = perm.to(self.dtype) if self._T_index[1 - dim] is None: self._T_index = set_tuple_item( # self._T_index, 1 - dim, self._data[1 - dim][self._T_perm]) row, col = self._T_index assert row is not None and col is not None return (row, col), self._T_perm @assert_sorted def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]: r"""Returns the compressed CSR representation :obj:`(rowptr, col), perm` in case :class:`EdgeIndex` is sorted. """ if self.is_sorted_by_row: return (self.get_indptr(), self._data[1]), None assert self.is_sorted_by_col (row, col), perm = self._sort_by_transpose() if self._T_indptr is not None: rowptr = self._T_indptr elif self.is_undirected and self._indptr is not None: rowptr = self._indptr else: rowptr = self._T_indptr = index2ptr(row, self.get_num_rows()) return (rowptr, col), perm @assert_sorted def get_csc(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]: r"""Returns the compressed CSC representation :obj:`(colptr, row), perm` in case :class:`EdgeIndex` is sorted. """ if self.is_sorted_by_col: return (self.get_indptr(), self._data[0]), None assert self.is_sorted_by_row (row, col), perm = self._sort_by_transpose() if self._T_indptr is not None: colptr = self._T_indptr elif self.is_undirected and self._indptr is not None: colptr = self._indptr else: colptr = self._T_indptr = index2ptr(col, self.get_num_cols()) return (colptr, row), perm def _get_value(self, dtype: Optional[torch.dtype] = None) -> Tensor: if self._value is not None: if (dtype or torch.get_default_dtype()) == self._value.dtype: return self._value # Expanded tensors are not yet supported in all PyTorch code paths :( # value = torch.ones(1, dtype=dtype, device=self.device) # value = value.expand(self.size(1)) self._value = torch.ones(self.size(1), dtype=dtype, device=self.device) return self._value def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex': r"""Fills the cache with (meta)data information. Args: no_transpose (bool, optional): If set to :obj:`True`, will not fill the cache with information about the transposed :class:`EdgeIndex`. (default: :obj:`False`) """ self.get_sparse_size() if self.is_sorted_by_row: self.get_csr() if not no_transpose: self.get_csc() elif self.is_sorted_by_col: self.get_csc() if not no_transpose: self.get_csr() return self # Methods ################################################################# def share_memory_(self) -> 'EdgeIndex': """""" # noqa: D419 self._data.share_memory_() if self._indptr is not None: self._indptr.share_memory_() if self._T_perm is not None: self._T_perm.share_memory_() if self._T_index[0] is not None: self._T_index[0].share_memory_() if self._T_index[1] is not None: self._T_index[1].share_memory_() if self._T_indptr is not None: self._T_indptr.share_memory_() if self._value is not None: self._value.share_memory_() return self def is_shared(self) -> bool: """""" # noqa: D419 return self._data.is_shared() def as_tensor(self) -> Tensor: r"""Zero-copies the :class:`EdgeIndex` representation back to a :class:`torch.Tensor` representation. """ return self._data def sort_by( self, sort_order: Union[str, SortOrder], stable: bool = False, ) -> 'SortReturnType': r"""Sorts the elements by row or column indices. Args: sort_order (str): The sort order, either :obj:`"row"` or :obj:`"col"`. stable (bool, optional): Makes the sorting routine stable, which guarantees that the order of equivalent elements is preserved. (default: :obj:`False`) """ from torch_geometric.utils import index_sort sort_order = SortOrder(sort_order) if self._sort_order == sort_order: # Nothing to do. return SortReturnType(self, None) if self.is_sorted: (row, col), perm = self._sort_by_transpose() edge_index = torch.stack([row, col], dim=0) # Otherwise, perform sorting: elif sort_order == SortOrder.ROW: row, perm = index_sort(self._data[0], self.get_num_rows(), stable) edge_index = torch.stack([row, self._data[1][perm]], dim=0) else: col, perm = index_sort(self._data[1], self.get_num_cols(), stable) edge_index = torch.stack([self._data[0][perm], col], dim=0) out = self.__class__(edge_index) # We can inherit metadata and (mostly) cache: out._sparse_size = self.sparse_size() out._sort_order = sort_order out._is_undirected = self.is_undirected out._indptr = self._indptr out._T_indptr = self._T_indptr # NOTE We cannot copy CSR<>CSC permutations since we don't require that # local neighborhoods are sorted, and thus they may run out of sync. out._value = self._value return SortReturnType(out, perm) def to_dense( # type: ignore self, value: Optional[Tensor] = None, fill_value: float = 0.0, dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a dense :class:`torch.Tensor`. .. warning:: In case of duplicated edges, the behavior is non-deterministic (one of the values from :obj:`value` will be picked arbitrarily). For deterministic behavior, consider calling :meth:`~torch_geometric.utils.coalesce` beforehand. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) fill_value (float, optional): The fill value for remaining elements in the dense matrix. (default: :obj:`0.0`) dtype (torch.dtype, optional): The data type of the returned tensor. (default: :obj:`None`) """ dtype = value.dtype if value is not None else dtype size = self.get_sparse_size() if value is not None and value.dim() > 1: size = size + value.size()[1:] out = torch.full(size, fill_value, dtype=dtype, device=self.device) out[self._data[0], self._data[1]] = value if value is not None else 1 return out def to_sparse_coo(self, value: Optional[Tensor] = None) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse_coo_tensor`. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ value = self._get_value() if value is None else value if not torch_geometric.typing.WITH_PT21: out = torch.sparse_coo_tensor( indices=self._data, values=value, size=self.get_sparse_size(), device=self.device, requires_grad=value.requires_grad, ) if self.is_sorted_by_row: out = out._coalesced_(True) return out return torch.sparse_coo_tensor( indices=self._data, values=value, size=self.get_sparse_size(), device=self.device, requires_grad=value.requires_grad, is_coalesced=True if self.is_sorted_by_row else None, ) def to_sparse_csr( # type: ignore self, value: Optional[Tensor] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse_csr_tensor`. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ (rowptr, col), perm = self.get_csr() if value is not None and perm is not None: value = value[perm] elif value is None: value = self._get_value() return torch.sparse_csr_tensor( crow_indices=rowptr, col_indices=col, values=value, size=self.get_sparse_size(), device=self.device, requires_grad=value.requires_grad, ) def to_sparse_csc( # type: ignore self, value: Optional[Tensor] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse_csc_tensor`. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ (colptr, row), perm = self.get_csc() if value is not None and perm is not None: value = value[perm] elif value is None: value = self._get_value() return torch.sparse_csc_tensor( ccol_indices=colptr, row_indices=row, values=value, size=self.get_sparse_size(), device=self.device, requires_grad=value.requires_grad, ) def to_sparse( # type: ignore self, *, layout: torch.layout = torch.sparse_coo, value: Optional[Tensor] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse` tensor. Args: layout (torch.layout, optional): The desired sparse layout. One of :obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`, or :obj:`torch.sparse_csc`. (default: :obj:`torch.sparse_coo`) value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ if layout is None or layout == torch.sparse_coo: return self.to_sparse_coo(value) if layout == torch.sparse_csr: return self.to_sparse_csr(value) if layout == torch.sparse_csc: return self.to_sparse_csc(value) raise ValueError(f"Unexpected tensor layout (got '{layout}')") def to_sparse_tensor( self, value: Optional[Tensor] = None, ) -> SparseTensor: r"""Converts :class:`EdgeIndex` into a :class:`torch_sparse.SparseTensor`. Requires that :obj:`torch-sparse` is installed. Args: value (torch.Tensor, optional): The values for non-zero elements. (default: :obj:`None`) """ return SparseTensor( row=self._data[0], col=self._data[1], rowptr=self._indptr if self.is_sorted_by_row else None, value=value, sparse_sizes=self.get_sparse_size(), is_sorted=self.is_sorted_by_row, trust_data=True, ) # TODO Investigate how to avoid overlapping return types here. @overload def matmul( # type: ignore self, other: 'EdgeIndex', input_value: Optional[Tensor] = None, other_value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tuple['EdgeIndex', Tensor]: pass @overload def matmul( self, other: Tensor, input_value: Optional[Tensor] = None, other_value: None = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: pass def matmul( self, other: Union[Tensor, 'EdgeIndex'], input_value: Optional[Tensor] = None, other_value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Union[Tensor, Tuple['EdgeIndex', Tensor]]: r"""Performs a matrix multiplication of the matrices :obj:`input` and :obj:`other`. If :obj:`input` is a :math:`(n \times m)` matrix and :obj:`other` is a :math:`(m \times p)` tensor, then the output will be a :math:`(n \times p)` tensor. See :meth:`torch.matmul` for more information. :obj:`input` is a sparse matrix as denoted by the indices in :class:`EdgeIndex`, and :obj:`input_value` corresponds to the values of non-zero elements in :obj:`input`. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. :obj:`other` can either be a dense :class:`torch.Tensor` or a sparse :class:`EdgeIndex`. if :obj:`other` is a sparse :class:`EdgeIndex`, then :obj:`other_value` corresponds to the values of its non-zero elements. This function additionally accepts an optional :obj:`reduce` argument that allows specification of an optional reduction operation. See :meth:`torch.sparse.mm` for more information. Lastly, the :obj:`transpose` option allows to perform matrix multiplication where :obj:`input` will be first transposed, *i.e.*: .. math:: \textrm{input}^{\top} \cdot \textrm{other} Args: other (torch.Tensor or EdgeIndex): The second matrix to be multiplied, which can be sparse or dense. input_value (torch.Tensor, optional): The values for non-zero elements of :obj:`input`. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) other_value (torch.Tensor, optional): The values for non-zero elements of :obj:`other` in case it is sparse. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) reduce (str, optional): The reduce operation, one of :obj:`"sum"`/:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`/:obj:`amin` or :obj:`"max"`/:obj:`amax`. (default: :obj:`"sum"`) transpose (bool, optional): If set to :obj:`True`, will perform matrix multiplication based on the transposed :obj:`input`. (default: :obj:`False`) """ return matmul(self, other, input_value, other_value, reduce, transpose) def sparse_narrow( self, dim: int, start: Union[int, Tensor], length: int, ) -> 'EdgeIndex': r"""Returns a new :class:`EdgeIndex` that is a narrowed version of itself. Narrowing is performed by interpreting :class:`EdgeIndex` as a sparse matrix of shape :obj:`(num_rows, num_cols)`. In contrast to :meth:`torch.narrow`, the returned tensor does not share the same underlying storage anymore. Args: dim (int): The dimension along which to narrow. start (int or torch.Tensor): Index of the element to start the narrowed dimension from. length (int): Length of the narrowed dimension. """ dim = dim + 2 if dim < 0 else dim if dim != 0 and dim != 1: raise ValueError(f"Expected dimension to be 0 or 1 (got {dim})") if start < 0: raise ValueError(f"Expected 'start' value to be positive " f"(got {start})") if dim == 0: if self.is_sorted_by_row: (rowptr, col), _ = self.get_csr() rowptr = rowptr.narrow(0, start, length + 1) if rowptr.numel() < 2: row, col = self._data[0, :0], self._data[1, :0] rowptr = None num_rows = 0 else: col = col[rowptr[0]:rowptr[-1]] rowptr = rowptr - rowptr[0] num_rows = rowptr.numel() - 1 row = torch.arange( num_rows, dtype=col.dtype, device=col.device, ).repeat_interleave( rowptr.diff(), output_size=col.numel(), ) edge_index = EdgeIndex( torch.stack([row, col], dim=0), sparse_size=(num_rows, self.sparse_size(1)), sort_order='row', ) edge_index._indptr = rowptr return edge_index else: mask = self._data[0] >= start mask &= self._data[0] < (start + length) offset = torch.tensor([[start], [0]], device=self.device) edge_index = self[:, mask].sub_(offset) # type: ignore edge_index._sparse_size = (length, edge_index._sparse_size[1]) return edge_index else: assert dim == 1 if self.is_sorted_by_col: (colptr, row), _ = self.get_csc() colptr = colptr.narrow(0, start, length + 1) if colptr.numel() < 2: row, col = self._data[0, :0], self._data[1, :0] colptr = None num_cols = 0 else: row = row[colptr[0]:colptr[-1]] colptr = colptr - colptr[0] num_cols = colptr.numel() - 1 col = torch.arange( num_cols, dtype=row.dtype, device=row.device, ).repeat_interleave( colptr.diff(), output_size=row.numel(), ) edge_index = EdgeIndex( torch.stack([row, col], dim=0), sparse_size=(self.sparse_size(0), num_cols), sort_order='col', ) edge_index._indptr = colptr return edge_index else: mask = self._data[1] >= start mask &= self._data[1] < (start + length) offset = torch.tensor([[0], [start]], device=self.device) edge_index = self[:, mask].sub_(offset) # type: ignore edge_index._sparse_size = (edge_index._sparse_size[0], length) return edge_index def to_vector(self) -> Tensor: r"""Converts :class:`EdgeIndex` into a one-dimensional index vector representation. """ num_rows, num_cols = self.get_sparse_size() if num_rows * num_cols > torch_geometric.typing.MAX_INT64: raise ValueError("'to_vector()' will result in an overflow") return self._data[0] * num_rows + self._data[1] # PyTorch/Python builtins ################################################# def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: attrs = ['_data'] if self._indptr is not None: attrs.append('_indptr') if self._T_perm is not None: attrs.append('_T_perm') # TODO We cannot save `_T_index` for now since it is stored as tuple. if self._T_indptr is not None: attrs.append('_T_indptr') ctx = ( self._sparse_size, self._sort_order, self._is_undirected, self._cat_metadata, ) return attrs, ctx @staticmethod def __tensor_unflatten__( inner_tensors: Dict[str, Any], ctx: Tuple[Any, ...], outer_size: Tuple[int, ...], outer_stride: Tuple[int, ...], ) -> 'EdgeIndex': edge_index = EdgeIndex( inner_tensors['_data'], sparse_size=ctx[0], sort_order=ctx[1], is_undirected=ctx[2], ) edge_index._indptr = inner_tensors.get('_indptr', None) edge_index._T_perm = inner_tensors.get('_T_perm', None) edge_index._T_indptr = inner_tensors.get('_T_indptr', None) edge_index._cat_metadata = ctx[3] return edge_index # Prevent auto-wrapping outputs back into the proper subclass type: __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore @classmethod def __torch_dispatch__( # type: ignore cls: Type, func: Callable[..., Any], types: Iterable[Type[Any]], args: Iterable[Tuple[Any, ...]] = (), kwargs: Optional[Dict[Any, Any]] = None, ) -> Any: # `EdgeIndex` should be treated as a regular PyTorch tensor for all # standard PyTorch functionalities. However, # * some of its metadata can be transferred to new functions, e.g., # `torch.cat(dim=1)` can inherit the sparse matrix size, or # `torch.narrow(dim=1)` can inherit cached pointers. # * not all operations lead to valid `EdgeIndex` tensors again, e.g., # `torch.sum()` does not yield a `EdgeIndex` as its output, or # `torch.cat(dim=0) violates the [2, *] shape assumption. # To account for this, we hold a number of `HANDLED_FUNCTIONS` that # implement specific functions for valid `EdgeIndex` routines. if func in HANDLED_FUNCTIONS: return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) # For all other PyTorch functions, we treat them as vanilla tensors. args = pytree.tree_map_only(EdgeIndex, lambda x: x._data, args) if kwargs is not None: kwargs = pytree.tree_map_only(EdgeIndex, lambda x: x._data, kwargs) return func(*args, **(kwargs or {})) def __repr__(self) -> str: # type: ignore prefix = f'{self.__class__.__name__}(' indent = len(prefix) tensor_str = torch._tensor_str._tensor_str(self._data, indent) suffixes = [] num_rows, num_cols = self.sparse_size() if num_rows is not None or num_cols is not None: size_repr = f"({num_rows or '?'}, {num_cols or '?'})" suffixes.append(f'sparse_size={size_repr}') suffixes.append(f'nnz={self._data.size(1)}') if (self.device.type != torch._C._get_default_device() or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index) or (self.device.type == 'mps')): suffixes.append(f"device='{self.device}'") if self.dtype != torch.int64: suffixes.append(f'dtype={self.dtype}') if self.is_sorted: suffixes.append(f'sort_order={self.sort_order}') if self.is_undirected: suffixes.append('is_undirected=True') return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=False) def tolist(self) -> List[Any]: """""" # noqa: D419 return self._data.tolist() def numpy(self, *, force: bool = False) -> np.ndarray: """""" # noqa: D419 return self._data.numpy(force=force) # Helpers ################################################################# def _shallow_copy(self) -> 'EdgeIndex': out = EdgeIndex(self._data) out._sparse_size = self._sparse_size out._sort_order = self._sort_order out._is_undirected = self._is_undirected out._indptr = self._indptr out._T_perm = self._T_perm out._T_index = self._T_index out._T_indptr = self._T_indptr out._value = self._value out._cat_metadata = self._cat_metadata return out def _clear_metadata(self) -> 'EdgeIndex': self._sparse_size = (None, None) self._sort_order = None self._is_undirected = False self._indptr = None self._T_perm = None self._T_index = (None, None) self._T_indptr = None self._value = None self._cat_metadata = None return self class SortReturnType(NamedTuple): values: EdgeIndex indices: Optional[Tensor] def apply_( tensor: EdgeIndex, fn: Callable, *args: Any, **kwargs: Any, ) -> Union[EdgeIndex, Tensor]: data = fn(tensor._data, *args, **kwargs) if data.dtype not in INDEX_DTYPES: return data if tensor._data.data_ptr() != data.data_ptr(): out = EdgeIndex(data) else: # In-place: tensor._data = data out = tensor # Copy metadata: out._sparse_size = tensor._sparse_size out._sort_order = tensor._sort_order out._is_undirected = tensor._is_undirected out._cat_metadata = tensor._cat_metadata # Convert cache (but do not consider `_value`): if tensor._indptr is not None: out._indptr = fn(tensor._indptr, *args, **kwargs) if tensor._T_perm is not None: out._T_perm = fn(tensor._T_perm, *args, **kwargs) _T_row, _T_col = tensor._T_index if _T_row is not None: _T_row = fn(_T_row, *args, **kwargs) if _T_col is not None: _T_col = fn(_T_col, *args, **kwargs) out._T_index = (_T_row, _T_col) if tensor._T_indptr is not None: out._T_indptr = fn(tensor._T_indptr, *args, **kwargs) return out @implements(aten.clone.default) def _clone( tensor: EdgeIndex, *, memory_format: torch.memory_format = torch.preserve_format, ) -> EdgeIndex: out = apply_(tensor, aten.clone.default, memory_format=memory_format) assert isinstance(out, EdgeIndex) return out @implements(aten._to_copy.default) def _to_copy( tensor: EdgeIndex, *, dtype: Optional[torch.dtype] = None, layout: Optional[torch.layout] = None, device: Optional[torch.device] = None, pin_memory: bool = False, non_blocking: bool = False, memory_format: Optional[torch.memory_format] = None, ) -> Union[EdgeIndex, Tensor]: return apply_( tensor, aten._to_copy.default, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, non_blocking=non_blocking, memory_format=memory_format, ) @implements(aten.alias.default) def _alias(tensor: EdgeIndex) -> EdgeIndex: return tensor._shallow_copy() @implements(aten._pin_memory.default) def _pin_memory(tensor: EdgeIndex) -> EdgeIndex: out = apply_(tensor, aten._pin_memory.default) assert isinstance(out, EdgeIndex) return out @implements(aten.cat.default) def _cat( tensors: List[Union[EdgeIndex, Tensor]], dim: int = 0, ) -> Union[EdgeIndex, Tensor]: data_list = pytree.tree_map_only(EdgeIndex, lambda x: x._data, tensors) data = aten.cat.default(data_list, dim=dim) if dim != 1 and dim != -1: # No valid `EdgeIndex` anymore. return data if any([not isinstance(tensor, EdgeIndex) for tensor in tensors]): return data out = EdgeIndex(data) nnz_list = [t.size(1) for t in tensors] sparse_size_list = [t.sparse_size() for t in tensors] # type: ignore sort_order_list = [t._sort_order for t in tensors] # type: ignore is_undirected_list = [t.is_undirected for t in tensors] # type: ignore # Post-process `sparse_size`: total_num_rows: Optional[int] = 0 for num_rows, _ in sparse_size_list: if num_rows is None: total_num_rows = None break assert isinstance(total_num_rows, int) total_num_rows = max(num_rows, total_num_rows) total_num_cols: Optional[int] = 0 for _, num_cols in sparse_size_list: if num_cols is None: total_num_cols = None break assert isinstance(total_num_cols, int) total_num_cols = max(num_cols, total_num_cols) out._sparse_size = (total_num_rows, total_num_cols) # Post-process `is_undirected`: out._is_undirected = all(is_undirected_list) out._cat_metadata = CatMetadata( nnz=nnz_list, sparse_size=sparse_size_list, sort_order=sort_order_list, is_undirected=is_undirected_list, ) return out @implements(aten.flip.default) def _flip( input: EdgeIndex, dims: Union[List[int], Tuple[int, ...]], ) -> EdgeIndex: data = aten.flip.default(input._data, dims) out = EdgeIndex(data) out._value = input._value out._is_undirected = input.is_undirected # Flip metadata and cache: if 0 in dims or -2 in dims: out._sparse_size = input.sparse_size()[::-1] if len(dims) == 1 and (dims[0] == 0 or dims[0] == -2): if input.is_sorted_by_row: out._sort_order = SortOrder.COL elif input.is_sorted_by_col: out._sort_order = SortOrder.ROW out._indptr = input._T_indptr out._T_perm = input._T_perm out._T_index = input._T_index[::-1] out._T_indptr = input._indptr return out @implements(aten.index_select.default) def _index_select( input: EdgeIndex, dim: int, index: Tensor, ) -> Union[EdgeIndex, Tensor]: out = aten.index_select.default(input._data, dim, index) if dim == 1 or dim == -1: out = EdgeIndex(out) out._sparse_size = input.sparse_size() return out @implements(aten.slice.Tensor) def _slice( input: EdgeIndex, dim: int, start: Optional[int] = None, end: Optional[int] = None, step: int = 1, ) -> Union[EdgeIndex, Tensor]: if ((start is None or start == 0 or start <= -input.size(dim)) and (end is None or end > input.size(dim)) and step == 1): return input._shallow_copy() # No-op. out = aten.slice.Tensor(input._data, dim, start, end, step) if dim == 1 or dim == -1: if step != 1: out = out.contiguous() out = EdgeIndex(out) out._sparse_size = input.sparse_size() # NOTE We could potentially maintain `rowptr`/`colptr` attributes here, # but it is not really clear if this is worth it. The most important # information, the sort order, needs to be maintained though: if step >= 0: out._sort_order = input._sort_order else: if input._sort_order == SortOrder.ROW: out._sort_order = SortOrder.COL elif input._sort_order == SortOrder.COL: out._sort_order = SortOrder.ROW return out @implements(aten.index.Tensor) def _index( input: Union[EdgeIndex, Tensor], indices: List[Optional[Union[Tensor, EdgeIndex]]], ) -> Union[EdgeIndex, Tensor]: if not isinstance(input, EdgeIndex): indices = pytree.tree_map_only(EdgeIndex, lambda x: x._data, indices) return aten.index.Tensor(input, indices) out = aten.index.Tensor(input._data, indices) if len(indices) != 2 or indices[0] is not None: return out index = indices[1] assert isinstance(index, Tensor) out = EdgeIndex(out) # 1. `edge_index[:, mask]` or `edge_index[..., mask]`. if index.dtype in (torch.bool, torch.uint8): out._sparse_size = input.sparse_size() out._sort_order = input._sort_order else: # 2. `edge_index[:, index]` or `edge_index[..., index]`. out._sparse_size = input.sparse_size() return out @implements(aten.select.int) def _select(input: EdgeIndex, dim: int, index: int) -> Union[Tensor, Index]: out = aten.select.int(input._data, dim, index) if dim == 0 or dim == -2: out = Index(out) if index == 0 or index == -2: # Row-select: out._dim_size = input.sparse_size(0) out._is_sorted = input.is_sorted_by_row if input.is_sorted_by_row: out._indptr = input._indptr else: # Col-select: assert index == 1 or index == -1 out._dim_size = input.sparse_size(1) out._is_sorted = input.is_sorted_by_col if input.is_sorted_by_col: out._indptr = input._indptr return out @implements(aten.unbind.int) def _unbind( input: EdgeIndex, dim: int = 0, ) -> Union[List[Index], List[Tensor]]: if dim == 0 or dim == -2: row = input[0] assert isinstance(row, Index) col = input[1] assert isinstance(col, Index) return [row, col] return aten.unbind.int(input._data, dim) @implements(aten.add.Tensor) def _add( input: EdgeIndex, other: Union[int, Tensor, EdgeIndex], *, alpha: int = 1, ) -> Union[EdgeIndex, Tensor]: out = aten.add.Tensor( input._data, other._data if isinstance(other, EdgeIndex) else other, alpha=alpha, ) if out.dtype not in INDEX_DTYPES: return out if out.dim() != 2 or out.size(0) != 2: return out out = EdgeIndex(out) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): size = maybe_add(input._sparse_size, other, alpha) assert len(size) == 2 out._sparse_size = size out._sort_order = input._sort_order out._is_undirected = input.is_undirected out._T_perm = input._T_perm elif isinstance(other, Tensor) and other.size() == (2, 1): size = maybe_add(input._sparse_size, other.view(-1).tolist(), alpha) assert len(size) == 2 out._sparse_size = size out._sort_order = input._sort_order if torch.equal(other[0], other[1]): out._is_undirected = input.is_undirected out._T_perm = input._T_perm elif isinstance(other, EdgeIndex): size = maybe_add(input._sparse_size, other._sparse_size, alpha) assert len(size) == 2 out._sparse_size = size return out @implements(aten.add_.Tensor) def add_( input: EdgeIndex, other: Union[int, Tensor, EdgeIndex], *, alpha: int = 1, ) -> EdgeIndex: sparse_size = input._sparse_size sort_order = input._sort_order is_undirected = input._is_undirected T_perm = input._T_perm input._clear_metadata() aten.add_.Tensor( input._data, other._data if isinstance(other, EdgeIndex) else other, alpha=alpha, ) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): size = maybe_add(sparse_size, other, alpha) assert len(size) == 2 input._sparse_size = size input._sort_order = sort_order input._is_undirected = is_undirected input._T_perm = T_perm elif isinstance(other, Tensor) and other.size() == (2, 1): size = maybe_add(sparse_size, other.view(-1).tolist(), alpha) assert len(size) == 2 input._sparse_size = size input._sort_order = sort_order if torch.equal(other[0], other[1]): input._is_undirected = is_undirected input._T_perm = T_perm elif isinstance(other, EdgeIndex): size = maybe_add(sparse_size, other._sparse_size, alpha) assert len(size) == 2 input._sparse_size = size return input @implements(aten.sub.Tensor) def _sub( input: EdgeIndex, other: Union[int, Tensor, EdgeIndex], *, alpha: int = 1, ) -> Union[EdgeIndex, Tensor]: out = aten.sub.Tensor( input._data, other._data if isinstance(other, EdgeIndex) else other, alpha=alpha, ) if out.dtype not in INDEX_DTYPES: return out if out.dim() != 2 or out.size(0) != 2: return out out = EdgeIndex(out) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): size = maybe_sub(input._sparse_size, other, alpha) assert len(size) == 2 out._sparse_size = size out._sort_order = input._sort_order out._is_undirected = input.is_undirected out._T_perm = input._T_perm elif isinstance(other, Tensor) and other.size() == (2, 1): size = maybe_sub(input._sparse_size, other.view(-1).tolist(), alpha) assert len(size) == 2 out._sparse_size = size out._sort_order = input._sort_order if torch.equal(other[0], other[1]): out._is_undirected = input.is_undirected out._T_perm = input._T_perm return out @implements(aten.sub_.Tensor) def sub_( input: EdgeIndex, other: Union[int, Tensor, EdgeIndex], *, alpha: int = 1, ) -> EdgeIndex: sparse_size = input._sparse_size sort_order = input._sort_order is_undirected = input._is_undirected T_perm = input._T_perm input._clear_metadata() aten.sub_.Tensor( input._data, other._data if isinstance(other, EdgeIndex) else other, alpha=alpha, ) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): size = maybe_sub(sparse_size, other, alpha) assert len(size) == 2 input._sparse_size = size input._sort_order = sort_order input._is_undirected = is_undirected input._T_perm = T_perm elif isinstance(other, Tensor) and other.size() == (2, 1): size = maybe_sub(sparse_size, other.view(-1).tolist(), alpha) assert len(size) == 2 input._sparse_size = size input._sort_order = sort_order if torch.equal(other[0], other[1]): input._is_undirected = is_undirected input._T_perm = T_perm return input # Sparse-Dense Matrix Multiplication ########################################## def _torch_sparse_spmm( input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: # `torch-sparse` still provides a faster sparse-dense matrix multiplication # code path on GPUs (after all these years...): assert torch_geometric.typing.WITH_TORCH_SPARSE reduce = PYG_REDUCE[reduce] if reduce in PYG_REDUCE else reduce # Optional arguments for backpropagation: colptr: Optional[Tensor] = None perm: Optional[Tensor] = None if not transpose: assert input.is_sorted_by_row (rowptr, col), _ = input.get_csr() row = input._data[0] if other.requires_grad and reduce in ['sum', 'mean']: (colptr, _), perm = input.get_csc() else: assert input.is_sorted_by_col (rowptr, col), _ = input.get_csc() row = input._data[1] if other.requires_grad and reduce in ['sum', 'mean']: (colptr, _), perm = input.get_csr() if reduce == 'sum': return torch.ops.torch_sparse.spmm_sum( # row, rowptr, col, value, colptr, perm, other) if reduce == 'mean': rowcount = rowptr.diff() if other.requires_grad else None return torch.ops.torch_sparse.spmm_mean( # row, rowptr, col, value, rowcount, colptr, perm, other) if reduce == 'min': return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)[0] if reduce == 'max': return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)[0] raise NotImplementedError class _TorchSPMM(torch.autograd.Function): @staticmethod def forward( ctx: Any, input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: reduce = TORCH_REDUCE[reduce] if reduce in TORCH_REDUCE else reduce value = value.detach() if value is not None else value if other.requires_grad: other = other.detach() ctx.save_for_backward(input, value) ctx.reduce = reduce ctx.transpose = transpose if not transpose: assert input.is_sorted_by_row adj = input.to_sparse_csr(value) else: assert input.is_sorted_by_col adj = input.to_sparse_csc(value).t() if torch_geometric.typing.WITH_PT20 and not other.is_cuda: return torch.sparse.mm(adj, other, reduce) else: # pragma: no cover assert reduce == 'sum' return adj @ other @staticmethod def backward( ctx: Any, *grad_outputs: Any, ) -> Tuple[None, Optional[Tensor], None, None, None]: grad_out, = grad_outputs other_grad: Optional[Tensor] = None if ctx.needs_input_grad[1]: input, value = ctx.saved_tensors assert ctx.reduce == 'sum' if not ctx.transpose: if value is None and input.is_undirected: adj = input.to_sparse_csr(value) else: (colptr, row), perm = input.get_csc() if value is not None and perm is not None: value = value[perm] else: value = input._get_value() adj = torch.sparse_csr_tensor( crow_indices=colptr, col_indices=row, values=value, size=input.get_sparse_size()[::-1], device=input.device, ) else: if value is None and input.is_undirected: adj = input.to_sparse_csc(value).t() else: (rowptr, col), perm = input.get_csr() if value is not None and perm is not None: value = value[perm] else: value = input._get_value() adj = torch.sparse_csr_tensor( crow_indices=rowptr, col_indices=col, values=value, size=input.get_sparse_size()[::-1], device=input.device, ) other_grad = adj @ grad_out if ctx.needs_input_grad[2]: raise NotImplementedError("Gradient computation for 'value' not " "yet supported") return None, other_grad, None, None, None def _scatter_spmm( input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: from torch_geometric.utils import scatter if not transpose: other_j = other[input._data[1]] index = input._data[0] dim_size = input.get_sparse_size(0) else: other_j = other[input._data[0]] index = input._data[1] dim_size = input.get_sparse_size(1) other_j = other_j * value.view(-1, 1) if value is not None else other_j return scatter(other_j, index, 0, dim_size=dim_size, reduce=reduce) def _spmm( input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: if reduce not in get_args(ReduceType): raise ValueError(f"`reduce='{reduce}'` is not a valid reduction") if not transpose and not input.is_sorted_by_row: cls_name = input.__class__.__name__ raise ValueError(f"'matmul(..., transpose=False)' requires " f"'{cls_name}' to be sorted by rows") if transpose and not input.is_sorted_by_col: cls_name = input.__class__.__name__ raise ValueError(f"'matmul(..., transpose=True)' requires " f"'{cls_name}' to be sorted by columns") if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling() and other.is_cuda): # pragma: no cover return _torch_sparse_spmm(input, other, value, reduce, transpose) if value is not None and value.requires_grad: if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling(): return _torch_sparse_spmm(input, other, value, reduce, transpose) return _scatter_spmm(input, other, value, reduce, transpose) if torch_geometric.typing.WITH_PT20: if reduce == 'sum' or reduce == 'add': return _TorchSPMM.apply(input, other, value, 'sum', transpose) if reduce == 'mean': out = _TorchSPMM.apply(input, other, value, 'sum', transpose) count = input.get_indptr().diff() return out / count.clamp_(min=1).to(out.dtype).view(-1, 1) if not other.is_cuda and not other.requires_grad: return _TorchSPMM.apply(input, other, value, reduce, transpose) if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling(): return _torch_sparse_spmm(input, other, value, reduce, transpose) return _scatter_spmm(input, other, value, reduce, transpose) def matmul( input: EdgeIndex, other: Union[Tensor, EdgeIndex], input_value: Optional[Tensor] = None, other_value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]: if not isinstance(other, EdgeIndex): if other_value is not None: raise ValueError("'other_value' not supported for sparse-dense " "matrix multiplication") return _spmm(input, other, input_value, reduce, transpose) if reduce not in ['sum', 'add']: raise NotImplementedError(f"`reduce='{reduce}'` not yet supported for " f"sparse-sparse matrix multiplication") transpose &= not input.is_undirected or input_value is not None if torch_geometric.typing.NO_MKL: # pragma: no cover sparse_input = input.to_sparse_coo(input_value) elif input.is_sorted_by_col: sparse_input = input.to_sparse_csc(input_value) else: sparse_input = input.to_sparse_csr(input_value) if transpose: sparse_input = sparse_input.t() if torch_geometric.typing.NO_MKL: # pragma: no cover other = other.to_sparse_coo(other_value) elif other.is_sorted_by_col: other = other.to_sparse_csc(other_value) else: other = other.to_sparse_csr(other_value) out = torch.matmul(sparse_input, other) rowptr: Optional[Tensor] = None if out.layout == torch.sparse_csr: rowptr = out.crow_indices().to(input.dtype) col = out.col_indices().to(input.dtype) edge_index = torch._convert_indices_from_csr_to_coo( rowptr, col, out_int32=rowptr.dtype != torch.int64) elif out.layout == torch.sparse_coo: # pragma: no cover out = out.coalesce() edge_index = out.indices() else: raise NotImplementedError edge_index = EdgeIndex(edge_index) edge_index._sort_order = SortOrder.ROW edge_index._sparse_size = (out.size(0), out.size(1)) edge_index._indptr = rowptr return edge_index, out.values() @implements(aten.mm.default) def _mm( input: EdgeIndex, other: Union[Tensor, EdgeIndex], ) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]: return matmul(input, other) @implements(aten._sparse_addmm.default) def _addmm( input: Tensor, mat1: EdgeIndex, mat2: Tensor, beta: float = 1.0, alpha: float = 1.0, ) -> Tensor: assert input.abs().sum() == 0.0 out = matmul(mat1, mat2) assert isinstance(out, Tensor) return alpha * out if alpha != 1.0 else out if hasattr(aten, '_sparse_mm_reduce_impl'): @implements(aten._sparse_mm_reduce_impl.default) def _mm_reduce( mat1: EdgeIndex, mat2: Tensor, reduce: ReduceType = 'sum', ) -> Tuple[Tensor, Tensor]: out = matmul(mat1, mat2, reduce=reduce) assert isinstance(out, Tensor) return out, out # We return a dummy tensor for `argout` for now. ================================================ FILE: torch_geometric/experimental.py ================================================ import functools import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch # TODO (matthias) This file currently requires manual imports to let # TorchScript work on decorated functions. Not totally sure why :( from torch_geometric.utils import * # noqa __experimental_flag__: Dict[str, bool] = { 'disable_dynamic_shapes': False, } Options = Optional[Union[str, List[str]]] def get_options(options: Options) -> List[str]: if options is None: options = list(__experimental_flag__.keys()) if isinstance(options, str): options = [options] return options def is_experimental_mode_enabled(options: Options = None) -> bool: r"""Returns :obj:`True` if the experimental mode is enabled. See :class:`torch_geometric.experimental_mode` for a list of (optional) options. """ if torch.jit.is_scripting() or torch.jit.is_tracing(): return False options = get_options(options) return all([__experimental_flag__[option] for option in options]) def set_experimental_mode_enabled(mode: bool, options: Options = None) -> None: for option in get_options(options): __experimental_flag__[option] = mode class experimental_mode: r"""Context-manager that enables the experimental mode to test new but potentially unstable features. .. code-block:: python with torch_geometric.experimental_mode(): out = model(data.x, data.edge_index) Args: options (str or list, optional): Currently there are no experimental features. """ def __init__(self, options: Options = None) -> None: self.options = get_options(options) self.previous_state = { option: __experimental_flag__[option] for option in self.options } def __enter__(self) -> None: set_experimental_mode_enabled(True, self.options) def __exit__(self, *args: Any) -> None: for option, value in self.previous_state.items(): __experimental_flag__[option] = value class set_experimental_mode: r"""Context-manager that sets the experimental mode on or off. :class:`set_experimental_mode` will enable or disable the experimental mode based on its argument :attr:`mode`. It can be used as a context-manager or as a function. See :class:`experimental_mode` above for more details. """ def __init__(self, mode: bool, options: Options = None) -> None: self.options = get_options(options) self.previous_state = { option: __experimental_flag__[option] for option in self.options } set_experimental_mode_enabled(mode, self.options) def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> None: for option, value in self.previous_state.items(): __experimental_flag__[option] = value def disable_dynamic_shapes(required_args: List[str]) -> Callable: r"""A decorator that disables the usage of dynamic shapes for the given arguments, i.e., it will raise an error in case :obj:`required_args` are not passed and needs to be automatically inferred. """ def decorator(func: Callable) -> Callable: spec = inspect.getfullargspec(func) required_args_pos: Dict[str, int] = {} for arg_name in required_args: if arg_name not in spec.args: raise ValueError(f"The function '{func}' does not have a " f"'{arg_name}' argument") required_args_pos[arg_name] = spec.args.index(arg_name) num_args = len(spec.args) num_default_args = 0 if spec.defaults is None else len(spec.defaults) num_positional_args = num_args - num_default_args @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: if not is_experimental_mode_enabled('disable_dynamic_shapes'): return func(*args, **kwargs) for required_arg in required_args: index = required_args_pos[required_arg] value: Optional[Any] = None if index < len(args): value = args[index] elif required_arg in kwargs: value = kwargs[required_arg] elif num_default_args > 0: assert spec.defaults is not None value = spec.defaults[index - num_positional_args] if value is None: raise ValueError(f"Dynamic shapes disabled. Argument " f"'{required_arg}' needs to be set") return func(*args, **kwargs) return wrapper return decorator ================================================ FILE: torch_geometric/explain/__init__.py ================================================ from .config import ExplainerConfig, ModelConfig, ThresholdConfig from .explanation import Explanation, HeteroExplanation from .algorithm import * # noqa from .explainer import Explainer from .metric import * # noqa __all__ = [ 'ExplainerConfig', 'ModelConfig', 'ThresholdConfig', 'Explanation', 'HeteroExplanation', 'Explainer', ] ================================================ FILE: torch_geometric/explain/algorithm/__init__.py ================================================ from .base import ExplainerAlgorithm from .dummy_explainer import DummyExplainer from .gnn_explainer import GNNExplainer from .captum_explainer import CaptumExplainer from .pg_explainer import PGExplainer from .attention_explainer import AttentionExplainer from .graphmask_explainer import GraphMaskExplainer __all__ = classes = [ 'ExplainerAlgorithm', 'DummyExplainer', 'GNNExplainer', 'CaptumExplainer', 'PGExplainer', 'AttentionExplainer', 'GraphMaskExplainer', ] ================================================ FILE: torch_geometric/explain/algorithm/attention_explainer.py ================================================ import logging from typing import Dict, List, Optional, Union, overload import torch from torch import Tensor from torch_geometric.explain import Explanation, HeteroExplanation from torch_geometric.explain.algorithm import ExplainerAlgorithm from torch_geometric.explain.config import ExplanationType, ModelTaskLevel from torch_geometric.nn.conv.message_passing import MessagePassing from torch_geometric.typing import EdgeType, NodeType class AttentionExplainer(ExplainerAlgorithm): r"""An explainer that uses the attention coefficients produced by an attention-based GNN (*e.g.*, :class:`~torch_geometric.nn.conv.GATConv`, :class:`~torch_geometric.nn.conv.GATv2Conv`, or :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation. Attention scores across layers and heads will be aggregated according to the :obj:`reduce` argument. Args: reduce (str, optional): The method to reduce the attention scores across layers and heads. (default: :obj:`"max"`) """ def __init__(self, reduce: str = 'max'): super().__init__() self.reduce = reduce self.is_hetero = False @overload def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: ... @overload def forward( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> HeteroExplanation: ... def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: """Generate explanations based on attention coefficients.""" self.is_hetero = isinstance(x, dict) # Collect attention coefficients alphas_dict = self._collect_attention_coefficients( model, x, edge_index, **kwargs) # Process attention coefficients if self.is_hetero: return self._create_hetero_explanation(model, alphas_dict, edge_index, index, x) else: return self._create_homo_explanation(model, alphas_dict, edge_index, index, x) @overload def _collect_attention_coefficients( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, **kwargs, ) -> List[Tensor]: ... @overload def _collect_attention_coefficients( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], **kwargs, ) -> Dict[EdgeType, List[Tensor]]: ... def _collect_attention_coefficients( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], **kwargs, ) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]: """Collect attention coefficients from model layers.""" if self.is_hetero: # For heterogeneous graphs, store alphas by edge type alphas_dict: Dict[EdgeType, List[Tensor]] = {} # Get list of edge types edge_types = list(edge_index.keys()) # Hook function to capture attention coefficients by edge type def hook(module, msg_kwargs, out): # Find edge type from the module's full name module_name = getattr(module, '_name', None) if module_name is None: return edge_type = None for edge_tuple in edge_types: src_type, edge_name, dst_type = edge_tuple # Check if all components appear in the module name in # order try: src_idx = module_name.index(src_type) edge_idx = module_name.index(edge_name, src_idx) dst_idx = module_name.index(dst_type, edge_idx) if src_idx < edge_idx < dst_idx: edge_type = edge_tuple break except ValueError: # Component not found continue if edge_type is None: return if edge_type not in alphas_dict: alphas_dict[edge_type] = [] # Extract alpha from message kwargs or module if 'alpha' in msg_kwargs[0]: alphas_dict[edge_type].append( msg_kwargs[0]['alpha'].detach()) elif getattr(module, '_alpha', None) is not None: alphas_dict[edge_type].append(module._alpha.detach()) else: # For homogeneous graphs, store all alphas in a list alphas: List[Tensor] = [] def hook(module, msg_kwargs, out): if 'alpha' in msg_kwargs[0]: alphas.append(msg_kwargs[0]['alpha'].detach()) elif getattr(module, '_alpha', None) is not None: alphas.append(module._alpha.detach()) # Register hooks for all message passing modules hook_handles = [] for name, module in model.named_modules(): if isinstance(module, MessagePassing) and module.explain is not False: # Store name for hetero graph lookup in the hook if self.is_hetero: module._name = name hook_handles.append(module.register_message_forward_hook(hook)) # Forward pass to collect attention coefficients. model(x, edge_index, **kwargs) # Remove hooks for handle in hook_handles: handle.remove() # Check if we collected any attention coefficients. if self.is_hetero: if not alphas_dict: raise ValueError( "Could not collect any attention coefficients. " "Please ensure that your model is using " "attention-based GNN layers.") return alphas_dict else: if not alphas: raise ValueError( "Could not collect any attention coefficients. " "Please ensure that your model is using " "attention-based GNN layers.") return alphas def _process_attention_coefficients( self, alphas: List[Tensor], edge_index_size: int, ) -> Tensor: """Process collected attention coefficients into a single mask.""" for i, alpha in enumerate(alphas): # Ensure alpha doesn't exceed edge_index size alpha = alpha[:edge_index_size] # Reduce multi-head attention if alpha.dim() == 2: alpha = getattr(torch, self.reduce)(alpha, dim=-1) if isinstance(alpha, tuple): # Handle torch.max output alpha = alpha[0] elif alpha.dim() > 2: raise ValueError(f"Cannot reduce attention coefficients of " f"shape {list(alpha.size())}") alphas[i] = alpha # Combine attention coefficients across layers if len(alphas) > 1: alpha = torch.stack(alphas, dim=-1) alpha = getattr(torch, self.reduce)(alpha, dim=-1) if isinstance(alpha, tuple): # Handle torch.max output alpha = alpha[0] else: alpha = alphas[0] return alpha def _create_homo_explanation( self, model: torch.nn.Module, alphas: List[Tensor], edge_index: Tensor, index: Optional[Union[int, Tensor]], x: Tensor, ) -> Explanation: """Create explanation for homogeneous graph.""" # Get hard edge mask for node-level tasks hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, num_nodes=x.size(0)) # Process attention coefficients alpha = self._process_attention_coefficients(alphas, edge_index.size(1)) # Post-process mask with hard edge mask if needed alpha = self._post_process_mask(alpha, hard_edge_mask, apply_sigmoid=False) return Explanation(edge_mask=alpha) def _create_hetero_explanation( self, model: torch.nn.Module, alphas_dict: Dict[EdgeType, List[Tensor]], edge_index: Dict[EdgeType, Tensor], index: Optional[Union[int, Tensor]], x: Dict[NodeType, Tensor], ) -> HeteroExplanation: """Create explanation for heterogeneous graph.""" edge_masks_dict = {} # Process each edge type separately for edge_type, alphas in alphas_dict.items(): if not alphas: continue # Get hard edge mask for node-level tasks hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: src_type, _, dst_type = edge_type _, hard_edge_mask = self._get_hard_masks( model, index, edge_index[edge_type], num_nodes=max(x[src_type].size(0), x[dst_type].size(0))) # Process attention coefficients for this edge type alpha = self._process_attention_coefficients( alphas, edge_index[edge_type].size(1)) # Apply hard mask if available edge_masks_dict[edge_type] = self._post_process_mask( alpha, hard_edge_mask, apply_sigmoid=False) # Create heterogeneous explanation explanation = HeteroExplanation() explanation.set_value_dict('edge_mask', edge_masks_dict) return explanation def supports(self) -> bool: explanation_type = self.explainer_config.explanation_type if explanation_type != ExplanationType.model: logging.error(f"'{self.__class__.__name__}' only supports " f"model explanations " f"got (`explanation_type={explanation_type.value}`)") return False node_mask_type = self.explainer_config.node_mask_type if node_mask_type is not None: logging.error(f"'{self.__class__.__name__}' does not support " f"explaining input node features " f"got (`node_mask_type={node_mask_type.value}`)") return False return True ================================================ FILE: torch_geometric/explain/algorithm/base.py ================================================ from abc import abstractmethod from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.explain import Explanation, HeteroExplanation from torch_geometric.explain.config import ( ExplainerConfig, ModelConfig, ModelReturnType, ) from torch_geometric.nn import MessagePassing from torch_geometric.typing import EdgeType, NodeType from torch_geometric.utils import k_hop_subgraph class ExplainerAlgorithm(torch.nn.Module): r"""An abstract base class for implementing explainer algorithms.""" @abstractmethod def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: r"""Computes the explanation. Args: model (torch.nn.Module): The model to explain. x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input node features of a homogeneous or heterogeneous graph. edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input edge indices of a homogeneous or heterogeneous graph. target (torch.Tensor): The target of the model. index (Union[int, Tensor], optional): The index of the model output to explain. Can be a single index or a tensor of indices. (default: :obj:`None`) **kwargs (optional): Additional keyword arguments passed to :obj:`model`. """ @abstractmethod def supports(self) -> bool: r"""Checks if the explainer supports the user-defined settings provided in :obj:`self.explainer_config`, :obj:`self.model_config`. """ ########################################################################### @property def explainer_config(self) -> ExplainerConfig: r"""Returns the connected explainer configuration.""" if not hasattr(self, '_explainer_config'): raise ValueError( f"The explanation algorithm '{self.__class__.__name__}' is " f"not yet connected to any explainer configuration. Please " f"call `{self.__class__.__name__}.connect(...)` before " f"proceeding.") return self._explainer_config @property def model_config(self) -> ModelConfig: r"""Returns the connected model configuration.""" if not hasattr(self, '_model_config'): raise ValueError( f"The explanation algorithm '{self.__class__.__name__}' is " f"not yet connected to any model configuration. Please call " f"`{self.__class__.__name__}.connect(...)` before " f"proceeding.") return self._model_config def connect( self, explainer_config: ExplainerConfig, model_config: ModelConfig, ): r"""Connects an explainer and model configuration to the explainer algorithm. """ self._explainer_config = ExplainerConfig.cast(explainer_config) self._model_config = ModelConfig.cast(model_config) if not self.supports(): raise ValueError( f"The explanation algorithm '{self.__class__.__name__}' does " f"not support the given explanation settings.") # Helper functions ######################################################## @staticmethod def _post_process_mask( mask: Optional[Tensor], hard_mask: Optional[Tensor] = None, apply_sigmoid: bool = True, ) -> Optional[Tensor]: r""""Post processes any mask to not include any attributions of elements not involved during message passing. """ if mask is None: return mask mask = mask.detach() if apply_sigmoid: mask = mask.sigmoid() if hard_mask is not None and mask.size(0) == hard_mask.size(0): mask[~hard_mask] = 0. return mask @staticmethod def _get_hard_masks( model: torch.nn.Module, node_index: Optional[Union[int, Tensor]], edge_index: Tensor, num_nodes: int, ) -> Tuple[Optional[Tensor], Optional[Tensor]]: r"""Returns hard node and edge masks that only include the nodes and edges visited during message passing. """ if node_index is None: return None, None # Consider all nodes and edges. index, _, _, edge_mask = k_hop_subgraph( node_index, num_hops=ExplainerAlgorithm._num_hops(model), edge_index=edge_index, num_nodes=num_nodes, flow=ExplainerAlgorithm._flow(model), ) node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool) node_mask[index] = True return node_mask, edge_mask @staticmethod def _num_hops(model: torch.nn.Module) -> int: r"""Returns the number of hops the :obj:`model` is aggregating information from. """ num_hops = 0 for module in model.modules(): if isinstance(module, MessagePassing): num_hops += 1 return num_hops @staticmethod def _flow(model: torch.nn.Module) -> str: r"""Determines the message passing flow of the :obj:`model`.""" for module in model.modules(): if isinstance(module, MessagePassing): return module.flow return 'source_to_target' def _loss_binary_classification(self, y_hat: Tensor, y: Tensor) -> Tensor: if self.model_config.return_type == ModelReturnType.raw: loss_fn = F.binary_cross_entropy_with_logits elif self.model_config.return_type == ModelReturnType.probs: loss_fn = F.binary_cross_entropy else: raise AssertionError() return loss_fn(y_hat.view_as(y), y.float()) def _loss_multiclass_classification( self, y_hat: Tensor, y: Tensor, ) -> Tensor: if self.model_config.return_type == ModelReturnType.raw: loss_fn = F.cross_entropy elif self.model_config.return_type == ModelReturnType.probs: loss_fn = F.nll_loss y_hat = y_hat.log() elif self.model_config.return_type == ModelReturnType.log_probs: loss_fn = F.nll_loss else: raise AssertionError() return loss_fn(y_hat, y) def _loss_regression(self, y_hat: Tensor, y: Tensor) -> Tensor: assert self.model_config.return_type == ModelReturnType.raw return F.mse_loss(y_hat, y) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/explain/algorithm/captum.py ================================================ from enum import Enum from typing import Dict, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.explain.algorithm.utils import ( clear_masks, set_hetero_masks, set_masks, ) from torch_geometric.explain.config import ( ModelConfig, ModelMode, ModelReturnType, ) from torch_geometric.typing import EdgeType, Metadata, NodeType class MaskLevelType(Enum): """Enum class for the mask level type.""" node = 'node' edge = 'edge' node_and_edge = 'node_and_edge' @property def with_edge(self) -> bool: return self in [MaskLevelType.edge, MaskLevelType.node_and_edge] class CaptumModel(torch.nn.Module): def __init__( self, model: torch.nn.Module, mask_type: Union[str, MaskLevelType], output_idx: Optional[Union[int, Tensor]] = None, model_config: Optional[ModelConfig] = None, ): super().__init__() self.mask_type = MaskLevelType(mask_type) self.model = model self.output_idx = output_idx self.model_config = model_config def forward(self, mask, *args): """""" # noqa: D419 # The mask tensor, which comes from Captum's attribution methods, # contains the number of samples in dimension 0. Since we are # working with only one sample, we squeeze the tensors below. assert mask.shape[0] == 1, "Dimension 0 of input should be 1" if self.mask_type == MaskLevelType.edge: assert len(args) >= 2, "Expects at least x and edge_index as args." if self.mask_type == MaskLevelType.node: assert len(args) >= 1, "Expects at least edge_index as args." if self.mask_type == MaskLevelType.node_and_edge: assert args[0].shape[0] == 1, "Dimension 0 of input should be 1" assert len(args[1:]) >= 1, "Expects at least edge_index as args." # Set edge mask: if self.mask_type == MaskLevelType.edge: set_masks(self.model, mask.squeeze(0), args[1], apply_sigmoid=False) elif self.mask_type == MaskLevelType.node_and_edge: set_masks(self.model, args[0].squeeze(0), args[1], apply_sigmoid=False) args = args[1:] if self.mask_type == MaskLevelType.edge: x = self.model(*args) else: x = self.model(mask.squeeze(0), *args) return self.postprocess(x) def postprocess(self, x: Tensor) -> Tensor: if self.mask_type.with_edge: clear_masks(self.model) if self.output_idx is not None: # Filter by output index: x = x[self.output_idx] if (isinstance(self.output_idx, int) or self.output_idx.dim() == 0): x = x.unsqueeze(0) # Convert binary classification to multi-class classification: if (self.model_config is not None and self.model_config.mode == ModelMode.binary_classification): assert self.model_config.return_type == ModelReturnType.probs x = x.view(-1, 1) x = torch.cat([1 - x, x], dim=-1) return x # TODO(jinu) Is there any point of inheriting from `CaptumModel` class CaptumHeteroModel(CaptumModel): def __init__( self, model: torch.nn.Module, mask_type: Union[str, MaskLevelType], output_idx: Optional[Union[int, Tensor]], metadata: Metadata, model_config: Optional[ModelConfig] = None, ): super().__init__(model, mask_type, output_idx, model_config) self.node_types = metadata[0] self.edge_types = metadata[1] self.num_node_types = len(self.node_types) self.num_edge_types = len(self.edge_types) def _captum_data_to_hetero_data( self, *args ) -> Tuple[Dict[NodeType, Tensor], Dict[EdgeType, Tensor], Optional[Dict[ EdgeType, Tensor]]]: """Converts tuple of tensors to `x_dict`, `edge_index_dict` and `edge_mask_dict`. """ if self.mask_type == MaskLevelType.node: node_tensors = args[:self.num_node_types] node_tensors = [mask.squeeze(0) for mask in node_tensors] x_dict = dict(zip(self.node_types, node_tensors)) edge_index_dict = args[self.num_node_types] elif self.mask_type == MaskLevelType.edge: edge_mask_tensors = args[:self.num_edge_types] x_dict = args[self.num_edge_types] edge_index_dict = args[self.num_edge_types + 1] else: node_tensors = args[:self.num_node_types] node_tensors = [mask.squeeze(0) for mask in node_tensors] x_dict = dict(zip(self.node_types, node_tensors)) edge_mask_tensors = args[self.num_node_types:self.num_node_types + self.num_edge_types] edge_index_dict = args[self.num_node_types + self.num_edge_types] if self.mask_type.with_edge: edge_mask_tensors = [mask.squeeze(0) for mask in edge_mask_tensors] edge_mask_dict = dict(zip(self.edge_types, edge_mask_tensors)) else: edge_mask_dict = None return x_dict, edge_index_dict, edge_mask_dict def forward(self, *args): # Validate args: if self.mask_type == MaskLevelType.node: assert len(args) >= self.num_node_types + 1 len_remaining_args = len(args) - (self.num_node_types + 1) elif self.mask_type == MaskLevelType.edge: assert len(args) >= self.num_edge_types + 2 len_remaining_args = len(args) - (self.num_edge_types + 2) else: assert len(args) >= self.num_node_types + self.num_edge_types + 1 len_remaining_args = len(args) - (self.num_node_types + self.num_edge_types + 1) # Get main args: (x_dict, edge_index_dict, edge_mask_dict) = self._captum_data_to_hetero_data(*args) if self.mask_type.with_edge: set_hetero_masks(self.model, edge_mask_dict, edge_index_dict) if len_remaining_args > 0: # If there are args other than `x_dict` and `edge_index_dict` x = self.model(x_dict, edge_index_dict, *args[-len_remaining_args:]) else: x = self.model(x_dict, edge_index_dict) return self.postprocess(x) def _to_edge_mask(edge_index: Tensor) -> Tensor: num_edges = edge_index.shape[1] return torch.ones(num_edges, requires_grad=True, device=edge_index.device) def to_captum_input( x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], mask_type: Union[str, MaskLevelType], *args, ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: r"""Given :obj:`x`, :obj:`edge_index` and :obj:`mask_type`, converts it to a format to use in `Captum `_ attribution methods. Returns :obj:`inputs` and :obj:`additional_forward_args` required for :captum:`Captum's` :obj:`attribute` functions. See :meth:`~torch_geometric.nn.models.to_captum_model` for example usage. Args: x (torch.Tensor or Dict[NodeType, torch.Tensor]): The node features. For heterogeneous graphs this is a dictionary holding node features for each node type. edge_index(torch.Tensor or Dict[EdgeType, torch.Tensor]): The edge indices. For heterogeneous graphs this is a dictionary holding the :obj:`edge index` for each edge type. mask_type (str): Denotes the type of mask to be created with a Captum explainer. Valid inputs are :obj:`"edge"`, :obj:`"node"`, and :obj:`"node_and_edge"`. *args: Additional forward arguments of the model being explained which will be added to :obj:`additional_forward_args`. """ mask_type = MaskLevelType(mask_type) additional_forward_args = [] if isinstance(x, Tensor) and isinstance(edge_index, Tensor): if mask_type == MaskLevelType.node: inputs = [x.unsqueeze(0)] elif mask_type == MaskLevelType.edge: inputs = [_to_edge_mask(edge_index).unsqueeze(0)] additional_forward_args.append(x) else: inputs = [x.unsqueeze(0), _to_edge_mask(edge_index).unsqueeze(0)] additional_forward_args.append(edge_index) elif isinstance(x, Dict) and isinstance(edge_index, Dict): node_types = x.keys() edge_types = edge_index.keys() inputs = [] if mask_type == MaskLevelType.node: for key in node_types: inputs.append(x[key].unsqueeze(0)) elif mask_type == MaskLevelType.edge: for key in edge_types: inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0)) additional_forward_args.append(x) else: for key in node_types: inputs.append(x[key].unsqueeze(0)) for key in edge_types: inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0)) additional_forward_args.append(edge_index) else: raise ValueError( "'x' and 'edge_index' need to be either" f"'Dict' or 'Tensor' got({type(x)}, {type(edge_index)})") additional_forward_args.extend(args) return tuple(inputs), tuple(additional_forward_args) def captum_output_to_dicts( captum_attrs: Tuple[Tensor, ...], mask_type: Union[str, MaskLevelType], metadata: Metadata, ) -> Tuple[Optional[Dict[NodeType, Tensor]], Optional[Dict[EdgeType, Tensor]]]: r"""Convert the output of `Captum `_ attribution methods which is a tuple of attributions to two dictionaries with node and edge attribution tensors. This function is used while explaining :class:`~torch_geometric.data.HeteroData` objects. See :meth:`~torch_geometric.nn.models.to_captum_model` for example usage. Args: captum_attrs (tuple[torch.Tensor]): The output of attribution methods. mask_type (str): Denotes the type of mask to be created with a Captum explainer. Valid inputs are :obj:`"edge"`, :obj:`"node"`, and :obj:`"node_and_edge"`: 1. :obj:`"edge"`: :obj:`captum_attrs` contains only edge attributions. The returned tuple has no node attributions, and an edge attribution dictionary edge types as keys and edge mask tensors of shape :obj:`[num_edges]` as values. 2. :obj:`"node"`: :obj:`captum_attrs` contains only node attributions. The returned tuple has a node attribution dictionary with node types as keys and node mask tensors of shape :obj:`[num_nodes, num_features]` as values, and no edge attributions. 3. :obj:`"node_and_edge"`: :obj:`captum_attrs` contains node and edge attributions. metadata (Metadata): The metadata of the heterogeneous graph. """ mask_type = MaskLevelType(mask_type) node_types = metadata[0] edge_types = metadata[1] x_attr_dict, edge_attr_dict = None, None captum_attrs = [captum_attr.squeeze(0) for captum_attr in captum_attrs] if mask_type == MaskLevelType.node: assert len(node_types) == len(captum_attrs) x_attr_dict = dict(zip(node_types, captum_attrs)) elif mask_type == MaskLevelType.edge: assert len(edge_types) == len(captum_attrs) edge_attr_dict = dict(zip(edge_types, captum_attrs)) elif mask_type == MaskLevelType.node_and_edge: assert len(edge_types) + len(node_types) == len(captum_attrs) x_attr_dict = dict(zip(node_types, captum_attrs[:len(node_types)])) edge_attr_dict = dict(zip(edge_types, captum_attrs[len(node_types):])) return x_attr_dict, edge_attr_dict def convert_captum_output( captum_attrs: Tuple[Tensor, ...], mask_type: Union[str, MaskLevelType], metadata: Optional[Metadata] = None, ): r"""Convert the output of `Captum.ai `_ attribution methods which is a tuple of attributions to either :obj:`(node_mask, edge_mask)` or :obj:`(node_mask_dict, edge_mask_dict)`. """ mask_type = MaskLevelType(mask_type) if metadata is not None: return captum_output_to_dicts(captum_attrs, mask_type, metadata) node_mask = edge_mask = None if mask_type == MaskLevelType.edge: edge_mask = captum_attrs[0].squeeze(0) elif mask_type == MaskLevelType.node: node_mask = captum_attrs[0].squeeze(0) else: node_mask = captum_attrs[0].squeeze(0) edge_mask = captum_attrs[1].squeeze(0) return node_mask, edge_mask ================================================ FILE: torch_geometric/explain/algorithm/captum_explainer.py ================================================ import inspect import logging import warnings from typing import Any, Dict, Optional, Union import torch from torch import Tensor from torch_geometric.explain import Explanation, HeteroExplanation from torch_geometric.explain.algorithm import ExplainerAlgorithm from torch_geometric.explain.algorithm.captum import ( CaptumHeteroModel, CaptumModel, MaskLevelType, convert_captum_output, to_captum_input, ) from torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType from torch_geometric.typing import EdgeType, NodeType class CaptumExplainer(ExplainerAlgorithm): """A `Captum `__-based explainer for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. This explainer algorithm uses :captum:`null` `Captum `_ to compute attributions. Currently, the following attribution methods are supported: * :class:`captum.attr.IntegratedGradients` * :class:`captum.attr.Saliency` * :class:`captum.attr.InputXGradient` * :class:`captum.attr.Deconvolution` * :class:`captum.attr.ShapleyValueSampling` * :class:`captum.attr.GuidedBackprop` Args: attribution_method (Attribution or str): The Captum attribution method to use. Can be a string or a :class:`captum.attr` method. **kwargs: Additional arguments for the Captum attribution method. """ SUPPORTED_METHODS = [ # TODO: Add support for more methods. 'IntegratedGradients', 'Saliency', 'InputXGradient', 'Deconvolution', 'ShapleyValueSampling', 'GuidedBackprop', ] def __init__( self, attribution_method: Union[str, Any], **kwargs, ): super().__init__() import captum.attr if isinstance(attribution_method, str): self.attribution_method_class = getattr( captum.attr, attribution_method, ) else: self.attribution_method_class = attribution_method if not self._is_supported_attribution_method(): raise ValueError(f"{self.__class__.__name__} does not support " f"attribution method " f"{self.attribution_method_class.__name__}") if kwargs.get('internal_batch_size', 1) != 1: warnings.warn("Overriding 'internal_batch_size' to 1", stacklevel=2) if 'internal_batch_size' in self._get_attribute_parameters(): kwargs['internal_batch_size'] = 1 self.kwargs = kwargs def _get_mask_type(self) -> MaskLevelType: r"""Based on the explainer config, return the mask type.""" node_mask_type = self.explainer_config.node_mask_type edge_mask_type = self.explainer_config.edge_mask_type if node_mask_type is not None and edge_mask_type is not None: mask_type = MaskLevelType.node_and_edge elif node_mask_type is not None: mask_type = MaskLevelType.node elif edge_mask_type is not None: mask_type = MaskLevelType.edge else: raise ValueError("Neither node mask type nor " "edge mask type is specified.") return mask_type def _get_attribute_parameters(self) -> Dict[str, Any]: r"""Returns the attribute arguments.""" signature = inspect.signature(self.attribution_method_class.attribute) return signature.parameters def _needs_baseline(self) -> bool: r"""Checks if the method needs a baseline.""" parameters = self._get_attribute_parameters() if 'baselines' in parameters: param = parameters['baselines'] if param.default is inspect.Parameter.empty: return True return False def _is_supported_attribution_method(self) -> bool: r"""Returns :obj:`True` if `self.attribution_method` is supported.""" # This is redundant for now since all supported methods need a baseline if self._needs_baseline(): return False elif self.attribution_method_class.__name__ in self.SUPPORTED_METHODS: return True return False def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: mask_type = self._get_mask_type() inputs, add_forward_args = to_captum_input( x, edge_index, mask_type, *kwargs.values(), ) if isinstance(x, dict): # Heterogeneous GNN: metadata = (list(x.keys()), list(edge_index.keys())) captum_model = CaptumHeteroModel( model, mask_type, index, metadata, self.model_config, ) else: # Homogeneous GNN: metadata = None captum_model = CaptumModel( model, mask_type, index, self.model_config, ) self.attribution_method_instance = self.attribution_method_class( captum_model) # In Captum, the target is the class index for which the attribution is # computed. Within CaptumModel, we transform the binary classification # into a multi-class classification task. if self.model_config.mode == ModelMode.regression: target = None elif index is not None: target = target[index] attributions = self.attribution_method_instance.attribute( inputs=inputs, target=target, additional_forward_args=add_forward_args, **self.kwargs, ) node_mask, edge_mask = convert_captum_output( attributions, mask_type, metadata, ) if not isinstance(x, dict): return Explanation(node_mask=node_mask, edge_mask=edge_mask) explanation = HeteroExplanation() explanation.set_value_dict('node_mask', node_mask) explanation.set_value_dict('edge_mask', edge_mask) return explanation def supports(self) -> bool: node_mask_type = self.explainer_config.node_mask_type if node_mask_type not in [None, MaskType.attributes]: logging.error(f"'{self.__class__.__name__}' expects " f"'node_mask_type' to be 'None' or 'attributes' " f"(got '{node_mask_type.value}')") return False return_type = self.model_config.return_type if (self.model_config.mode == ModelMode.binary_classification and return_type != ModelReturnType.probs): logging.error(f"'{self.__class__.__name__}' expects " f"'return_type' to be 'probs' for binary " f"classification tasks (got '{return_type.value}')") return False # TODO (ramona) Confirm that output type is valid. return True ================================================ FILE: torch_geometric/explain/algorithm/dummy_explainer.py ================================================ from collections import defaultdict from typing import Dict, Optional, Union import torch from torch import Tensor from torch_geometric.explain import Explanation, HeteroExplanation from torch_geometric.explain.algorithm import ExplainerAlgorithm from torch_geometric.explain.config import MaskType from torch_geometric.typing import EdgeType, NodeType class DummyExplainer(ExplainerAlgorithm): r"""A dummy explainer that returns random explanations (useful for testing purposes). """ def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], edge_attr: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: assert isinstance(x, (Tensor, dict)) node_mask_type = self.explainer_config.node_mask_type edge_mask_type = self.explainer_config.edge_mask_type if isinstance(x, Tensor): # Homogeneous graph. assert isinstance(edge_index, Tensor) node_mask = None if node_mask_type == MaskType.object: node_mask = torch.rand(x.size(0), 1, device=x.device) elif node_mask_type == MaskType.common_attributes: node_mask = torch.rand(1, x.size(1), device=x.device) elif node_mask_type == MaskType.attributes: node_mask = torch.rand_like(x) edge_mask = None if edge_mask_type == MaskType.object: edge_mask = torch.rand(edge_index.size(1), device=x.device) return Explanation(node_mask=node_mask, edge_mask=edge_mask) else: # isinstance(x, dict): # Heterogeneous graph. assert isinstance(edge_index, dict) node_dict = defaultdict(dict) for k, v in x.items(): node_mask = None if node_mask_type == MaskType.object: node_mask = torch.rand(v.size(0), 1, device=v.device) elif node_mask_type == MaskType.common_attributes: node_mask = torch.rand(1, v.size(1), device=v.device) elif node_mask_type == MaskType.attributes: node_mask = torch.rand_like(v) if node_mask is not None: node_dict[k]['node_mask'] = node_mask edge_dict = defaultdict(dict) for k, v in edge_index.items(): edge_mask = None if edge_mask_type == MaskType.object: edge_mask = torch.rand(v.size(1), device=v.device) if edge_mask is not None: edge_dict[k]['edge_mask'] = edge_mask return HeteroExplanation({**node_dict, **edge_dict}) def supports(self) -> bool: return True ================================================ FILE: torch_geometric/explain/algorithm/gnn_explainer.py ================================================ from math import sqrt from typing import Dict, Optional, Tuple, Union, overload import torch from torch import Tensor from torch.nn.parameter import Parameter from torch_geometric.explain import ( ExplainerConfig, Explanation, HeteroExplanation, ModelConfig, ) from torch_geometric.explain.algorithm import ExplainerAlgorithm from torch_geometric.explain.algorithm.utils import ( clear_masks, set_hetero_masks, set_masks, ) from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel from torch_geometric.typing import EdgeType, NodeType class GNNExplainer(ExplainerAlgorithm): r"""The GNN-Explainer model from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" `_ paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. .. note:: For an example of using :class:`GNNExplainer`, see `examples/explain/gnn_explainer.py `_, `examples/explain/gnn_explainer_ba_shapes.py `_, and `examples/explain/ gnn_explainer_link_pred.py `_. .. note:: The :obj:`edge_size` coefficient is multiplied by the number of nodes in the explanation at every iteration, and the resulting value is added to the loss as a regularization term, with the goal of producing compact explanations. A higher value will push the algorithm towards explanations with less elements. Consider adjusting the :obj:`edge_size` coefficient according to the average node degree in the dataset, especially if this value is bigger than in the datasets used in the original paper. Args: epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`. """ default_coeffs = { 'edge_size': 0.005, 'edge_reduction': 'sum', 'node_feat_size': 1.0, 'node_feat_reduction': 'mean', 'edge_ent': 1.0, 'node_feat_ent': 0.1, 'EPS': 1e-15, } def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs): super().__init__() self.epochs = epochs self.lr = lr self.coeffs = dict(self.default_coeffs) self.coeffs.update(kwargs) self.node_mask = self.hard_node_mask = None self.edge_mask = self.hard_edge_mask = None self.is_hetero = False @overload def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: ... @overload def forward( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> HeteroExplanation: ... def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: self.is_hetero = isinstance(x, dict) self._train(model, x, edge_index, target=target, index=index, **kwargs) explanation = self._create_explanation() self._clean_model(model) return explanation def _create_explanation(self) -> Union[Explanation, HeteroExplanation]: """Create an explanation object from the current masks.""" if self.is_hetero: # For heterogeneous graphs, process each type separately node_mask_dict = {} edge_mask_dict = {} for node_type, mask in self.node_mask.items(): if mask is not None: node_mask_dict[node_type] = self._post_process_mask( mask, self.hard_node_mask[node_type], apply_sigmoid=True, ) for edge_type, mask in self.edge_mask.items(): if mask is not None: edge_mask_dict[edge_type] = self._post_process_mask( mask, self.hard_edge_mask[edge_type], apply_sigmoid=True, ) # Create heterogeneous explanation explanation = HeteroExplanation() explanation.set_value_dict('node_mask', node_mask_dict) explanation.set_value_dict('edge_mask', edge_mask_dict) else: # For homogeneous graphs, process single masks node_mask = self._post_process_mask( self.node_mask, self.hard_node_mask, apply_sigmoid=True, ) edge_mask = self._post_process_mask( self.edge_mask, self.hard_edge_mask, apply_sigmoid=True, ) # Create homogeneous explanation explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask) return explanation def supports(self) -> bool: return True @overload def _train( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> None: ... @overload def _train( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> None: ... def _train( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> None: # Initialize masks based on input type self._initialize_masks(x, edge_index) # Collect parameters for optimization parameters = self._collect_parameters(model, edge_index) # Create optimizer optimizer = torch.optim.Adam(parameters, lr=self.lr) # Training loop for i in range(self.epochs): optimizer.zero_grad() # Forward pass with masked inputs y_hat = self._forward_with_masks(model, x, edge_index, **kwargs) y = target # Handle index if provided if index is not None: y_hat, y = y_hat[index], y[index] # Calculate loss loss = self._loss(y_hat, y) # Backward pass loss.backward() optimizer.step() # In the first iteration, collect gradients to identify important # nodes/edges if i == 0: self._collect_gradients() def _collect_parameters(self, model, edge_index): """Collect parameters for optimization.""" parameters = [] if self.is_hetero: # For heterogeneous graphs, collect parameters from all types for mask in self.node_mask.values(): if mask is not None: parameters.append(mask) if any(v is not None for v in self.edge_mask.values()): set_hetero_masks(model, self.edge_mask, edge_index) for mask in self.edge_mask.values(): if mask is not None: parameters.append(mask) else: # For homogeneous graphs, collect single parameters if self.node_mask is not None: parameters.append(self.node_mask) if self.edge_mask is not None: set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True) parameters.append(self.edge_mask) return parameters @overload def _forward_with_masks( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, **kwargs, ) -> Tensor: ... @overload def _forward_with_masks( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], **kwargs, ) -> Tensor: ... def _forward_with_masks( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], **kwargs, ) -> Tensor: """Forward pass with masked inputs.""" if self.is_hetero: # Apply masks to heterogeneous inputs h_dict = {} for node_type, features in x.items(): if node_type in self.node_mask and self.node_mask[ node_type] is not None: h_dict[node_type] = features * self.node_mask[ node_type].sigmoid() else: h_dict[node_type] = features # Forward pass with masked features return model(h_dict, edge_index, **kwargs) else: # Apply mask to homogeneous input h = x if self.node_mask is None else x * self.node_mask.sigmoid() # Forward pass with masked features return model(h, edge_index, **kwargs) def _initialize_masks( self, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], ) -> None: node_mask_type = self.explainer_config.node_mask_type edge_mask_type = self.explainer_config.edge_mask_type if self.is_hetero: # Initialize dictionaries for heterogeneous masks self.node_mask = {} self.hard_node_mask = {} self.edge_mask = {} self.hard_edge_mask = {} # Initialize node masks for each node type for node_type, features in x.items(): device = features.device N, F = features.size() self._initialize_node_mask(node_mask_type, node_type, N, F, device) # Initialize edge masks for each edge type for edge_type, indices in edge_index.items(): device = indices.device E = indices.size(1) N = max(indices.max().item() + 1, max(feat.size(0) for feat in x.values())) self._initialize_edge_mask(edge_mask_type, edge_type, E, N, device) else: # Initialize masks for homogeneous graph device = x.device (N, F), E = x.size(), edge_index.size(1) # Initialize homogeneous node and edge masks self._initialize_homogeneous_masks(node_mask_type, edge_mask_type, N, F, E, device) def _initialize_node_mask( self, node_mask_type, node_type, N, F, device, ) -> None: """Initialize node mask for a specific node type.""" std = 0.1 if node_mask_type is None: self.node_mask[node_type] = None self.hard_node_mask[node_type] = None elif node_mask_type == MaskType.object: self.node_mask[node_type] = Parameter( torch.randn(N, 1, device=device) * std) self.hard_node_mask[node_type] = None elif node_mask_type == MaskType.attributes: self.node_mask[node_type] = Parameter( torch.randn(N, F, device=device) * std) self.hard_node_mask[node_type] = None elif node_mask_type == MaskType.common_attributes: self.node_mask[node_type] = Parameter( torch.randn(1, F, device=device) * std) self.hard_node_mask[node_type] = None else: raise ValueError(f"Invalid node mask type: {node_mask_type}") def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device): """Initialize edge mask for a specific edge type.""" if edge_mask_type is None: self.edge_mask[edge_type] = None self.hard_edge_mask[edge_type] = None elif edge_mask_type == MaskType.object: std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) self.edge_mask[edge_type] = Parameter( torch.randn(E, device=device) * std) self.hard_edge_mask[edge_type] = None else: raise ValueError(f"Invalid edge mask type: {edge_mask_type}") def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N, F, E, device): """Initialize masks for homogeneous graph.""" # Initialize node mask std = 0.1 if node_mask_type is None: self.node_mask = None elif node_mask_type == MaskType.object: self.node_mask = Parameter(torch.randn(N, 1, device=device) * std) elif node_mask_type == MaskType.attributes: self.node_mask = Parameter(torch.randn(N, F, device=device) * std) elif node_mask_type == MaskType.common_attributes: self.node_mask = Parameter(torch.randn(1, F, device=device) * std) else: raise ValueError(f"Invalid node mask type: {node_mask_type}") # Initialize edge mask if edge_mask_type is None: self.edge_mask = None elif edge_mask_type == MaskType.object: std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) self.edge_mask = Parameter(torch.randn(E, device=device) * std) else: raise ValueError(f"Invalid edge mask type: {edge_mask_type}") def _collect_gradients(self) -> None: if self.is_hetero: self._collect_hetero_gradients() else: self._collect_homo_gradients() def _collect_hetero_gradients(self): """Collect gradients for heterogeneous graph.""" for node_type, mask in self.node_mask.items(): if mask is not None: if mask.grad is None: raise ValueError( f"Could not compute gradients for node masks of type " f"'{node_type}'. Please make sure that node masks are " f"used inside the model or disable it via " f"`node_mask_type=None`.") self.hard_node_mask[node_type] = mask.grad != 0.0 for edge_type, mask in self.edge_mask.items(): if mask is not None: if mask.grad is None: raise ValueError( f"Could not compute gradients for edge masks of type " f"'{edge_type}'. Please make sure that edge masks are " f"used inside the model or disable it via " f"`edge_mask_type=None`.") self.hard_edge_mask[edge_type] = mask.grad != 0.0 def _collect_homo_gradients(self): """Collect gradients for homogeneous graph.""" if self.node_mask is not None: if self.node_mask.grad is None: raise ValueError("Could not compute gradients for node " "features. Please make sure that node " "features are used inside the model or " "disable it via `node_mask_type=None`.") self.hard_node_mask = self.node_mask.grad != 0.0 if self.edge_mask is not None: if self.edge_mask.grad is None: raise ValueError("Could not compute gradients for edges. " "Please make sure that edges are used " "via message passing inside the model or " "disable it via `edge_mask_type=None`.") self.hard_edge_mask = self.edge_mask.grad != 0.0 def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor: # Calculate base loss based on model configuration loss = self._calculate_base_loss(y_hat, y) # Apply regularization based on graph type if self.is_hetero: # Apply regularization for heterogeneous graph loss = self._apply_hetero_regularization(loss) else: # Apply regularization for homogeneous graph loss = self._apply_homo_regularization(loss) return loss def _calculate_base_loss(self, y_hat, y): """Calculate base loss based on model configuration.""" if self.model_config.mode == ModelMode.binary_classification: return self._loss_binary_classification(y_hat, y) elif self.model_config.mode == ModelMode.multiclass_classification: return self._loss_multiclass_classification(y_hat, y) elif self.model_config.mode == ModelMode.regression: return self._loss_regression(y_hat, y) else: raise ValueError(f"Invalid model mode: {self.model_config.mode}") def _apply_hetero_regularization(self, loss): """Apply regularization for heterogeneous graph.""" # Apply regularization for each edge type for edge_type, mask in self.edge_mask.items(): if (mask is not None and self.hard_edge_mask[edge_type] is not None): loss = self._add_mask_regularization( loss, mask, self.hard_edge_mask[edge_type], self.coeffs['edge_size'], self.coeffs['edge_reduction'], self.coeffs['edge_ent']) # Apply regularization for each node type for node_type, mask in self.node_mask.items(): if (mask is not None and self.hard_node_mask[node_type] is not None): loss = self._add_mask_regularization( loss, mask, self.hard_node_mask[node_type], self.coeffs['node_feat_size'], self.coeffs['node_feat_reduction'], self.coeffs['node_feat_ent']) return loss def _apply_homo_regularization(self, loss): """Apply regularization for homogeneous graph.""" # Apply regularization for edge mask if self.hard_edge_mask is not None: assert self.edge_mask is not None loss = self._add_mask_regularization(loss, self.edge_mask, self.hard_edge_mask, self.coeffs['edge_size'], self.coeffs['edge_reduction'], self.coeffs['edge_ent']) # Apply regularization for node mask if self.hard_node_mask is not None: assert self.node_mask is not None loss = self._add_mask_regularization( loss, self.node_mask, self.hard_node_mask, self.coeffs['node_feat_size'], self.coeffs['node_feat_reduction'], self.coeffs['node_feat_ent']) return loss def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff, reduction_name, ent_coeff): """Add size and entropy regularization for a mask.""" m = mask[hard_mask].sigmoid() reduce_fn = getattr(torch, reduction_name) # Add size regularization loss = loss + size_coeff * reduce_fn(m) # Add entropy regularization ent = -m * torch.log(m + self.coeffs['EPS']) - ( 1 - m) * torch.log(1 - m + self.coeffs['EPS']) loss = loss + ent_coeff * ent.mean() return loss def _clean_model(self, model): clear_masks(model) self.node_mask = self.hard_node_mask = None self.edge_mask = self.hard_edge_mask = None class GNNExplainer_: r"""Deprecated version for :class:`GNNExplainer`.""" coeffs = GNNExplainer.default_coeffs conversion_node_mask_type = { 'feature': 'common_attributes', 'individual_feature': 'attributes', 'scalar': 'object', } conversion_return_type = { 'log_prob': 'log_probs', 'prob': 'probs', 'raw': 'raw', 'regression': 'raw', } def __init__( self, model: torch.nn.Module, epochs: int = 100, lr: float = 0.01, return_type: str = 'log_prob', feat_mask_type: str = 'feature', allow_edge_mask: bool = True, **kwargs, ): assert feat_mask_type in ['feature', 'individual_feature', 'scalar'] explainer_config = ExplainerConfig( explanation_type='model', node_mask_type=self.conversion_node_mask_type[feat_mask_type], edge_mask_type=MaskType.object if allow_edge_mask else None, ) model_config = ModelConfig( mode='regression' if return_type == 'regression' else 'multiclass_classification', task_level=ModelTaskLevel.node, return_type=self.conversion_return_type[return_type], ) self.model = model self._explainer = GNNExplainer(epochs=epochs, lr=lr, **kwargs) self._explainer.connect(explainer_config, model_config) @torch.no_grad() def get_initial_prediction(self, *args, **kwargs) -> Tensor: training = self.model.training self.model.eval() out = self.model(*args, **kwargs) if (self._explainer.model_config.mode == ModelMode.multiclass_classification): out = out.argmax(dim=-1) self.model.train(training) return out def explain_graph( self, x: Tensor, edge_index: Tensor, **kwargs, ) -> Tuple[Tensor, Tensor]: self._explainer.model_config.task_level = ModelTaskLevel.graph explanation = self._explainer( self.model, x, edge_index, target=self.get_initial_prediction(x, edge_index, **kwargs), **kwargs, ) return self._convert_output(explanation, edge_index) def explain_node( self, node_idx: int, x: Tensor, edge_index: Tensor, **kwargs, ) -> Tuple[Tensor, Tensor]: self._explainer.model_config.task_level = ModelTaskLevel.node explanation = self._explainer( self.model, x, edge_index, target=self.get_initial_prediction(x, edge_index, **kwargs), index=node_idx, **kwargs, ) return self._convert_output(explanation, edge_index, index=node_idx, x=x) def _convert_output(self, explanation, edge_index, index=None, x=None): node_mask = explanation.get('node_mask') edge_mask = explanation.get('edge_mask') if node_mask is not None: node_mask_type = self._explainer.explainer_config.node_mask_type if node_mask_type in {MaskType.object, MaskType.common_attributes}: node_mask = node_mask.view(-1) if edge_mask is None: if index is not None: _, edge_mask = self._explainer._get_hard_masks( self.model, index, edge_index, num_nodes=x.size(0)) edge_mask = edge_mask.to(x.dtype) else: edge_mask = torch.ones(edge_index.size(1), device=edge_index.device) return node_mask, edge_mask ================================================ FILE: torch_geometric/explain/algorithm/graphmask_explainer.py ================================================ import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import LayerNorm, Linear, Parameter, ReLU from tqdm import tqdm from torch_geometric.explain import Explanation from torch_geometric.explain.algorithm import ExplainerAlgorithm from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel from torch_geometric.nn import MessagePassing def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor: basis_messages = F.layer_norm(out, (out.size(-1), )).relu() if getattr(self, 'message_scale', None) is not None: basis_messages = basis_messages * self.message_scale.unsqueeze(-1) if self.message_replacement is not None: if basis_messages.shape == self.message_replacement.shape: basis_messages = (basis_messages + (1 - self.message_scale).unsqueeze(-1) * self.message_replacement) else: basis_messages = (basis_messages + ((1 - self.message_scale).unsqueeze(-1) * self.message_replacement.unsqueeze(0))) self.latest_messages = basis_messages self.latest_source_embeddings = x_j self.latest_target_embeddings = x_i return basis_messages class GraphMaskExplainer(ExplainerAlgorithm): r"""The GraphMask-Explainer model from the `"Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking" `_ paper for identifying layer-wise compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. .. note:: For an example of using :class:`GraphMaskExplainer`, see `examples/explain/graphmask_explainer.py `_. A working real-time example of :class:`GraphMaskExplainer` in the form of a deployed app can be accessed `here `_. Args: num_layers (int): The number of layers to use. epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) penalty_scaling (int, optional): Scaling value of penalty term. Value must lie between 0 and 10. (default: :obj:`5`) lambda_optimizer_lr (float, optional): The learning rate to optimize the Lagrange multiplier. (default: :obj:`1e-2`) init_lambda (float, optional): The Lagrange multiplier. Value must lie between :obj:`0` and `1`. (default: :obj:`0.55`) allowance (float, optional): A float value between :obj:`0` and :obj:`1` denotes tolerance level. (default: :obj:`0.03`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.nn.models.GraphMaskExplainer.coeffs`. """ coeffs = { 'node_feat_size': 1.0, 'node_feat_reduction': 'mean', 'node_feat_ent': 0.1, 'EPS': 1e-15, } def __init__( self, num_layers: int, epochs: int = 100, lr: float = 0.01, penalty_scaling: int = 5, lambda_optimizer_lr: int = 1e-2, init_lambda: int = 0.55, allowance: int = 0.03, allow_multiple_explanations: bool = False, log: bool = True, **kwargs, ): super().__init__() assert 0 <= penalty_scaling <= 10 assert 0 <= init_lambda <= 1 assert 0 <= allowance <= 1 self.num_layers = num_layers self.init_lambda = init_lambda self.lambda_optimizer_lr = lambda_optimizer_lr self.penalty_scaling = penalty_scaling self.allowance = allowance self.allow_multiple_explanations = allow_multiple_explanations self.epochs = epochs self.lr = lr self.log = log self.coeffs.update(kwargs) def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: hard_node_mask = None if self.model_config.task_level == ModelTaskLevel.node: hard_node_mask, hard_edge_mask = self._get_hard_masks( model, index, edge_index, num_nodes=x.size(0)) self._train_explainer(model, x, edge_index, target=target, index=index, **kwargs) node_mask = self._post_process_mask(self.node_feat_mask, hard_node_mask, apply_sigmoid=True) edge_mask = self._explain(model, index=index) edge_mask = edge_mask[:edge_index.size(1)] return Explanation(node_mask=node_mask, edge_mask=edge_mask) def supports(self) -> bool: return True def _hard_concrete( self, input_element: Tensor, summarize_penalty: bool = True, beta: float = 1 / 3, gamma: float = -0.2, zeta: float = 1.2, loc_bias: int = 2, min_val: int = 0, max_val: int = 1, training: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Helps to set the edge mask while sampling its values from the hard-concrete distribution. """ input_element = input_element + loc_bias if training: u = torch.empty_like(input_element).uniform_(1e-6, 1.0 - 1e-6) s = torch.sigmoid( (torch.log(u) - torch.log(1 - u) + input_element) / beta) penalty = torch.sigmoid(input_element - beta * math.log(-gamma / zeta)) else: s = torch.sigmoid(input_element) penalty = torch.zeros_like(input_element) if summarize_penalty: penalty = penalty.mean() s = s * (zeta - gamma) + gamma clipped_s = s.clamp(min_val, max_val) clip_value = (torch.min(clipped_s) + torch.max(clipped_s)) / 2 hard_concrete = (clipped_s > clip_value).float() clipped_s = clipped_s + (hard_concrete - clipped_s).detach() return clipped_s, penalty def _set_masks( self, i_dim: List[int], j_dim: List[int], h_dim: List[int], x: Tensor, ): r"""Sets the node masks and edge masks.""" (num_nodes, num_feat), std, device = x.size(), 0.1, x.device self.feat_mask_type = self.explainer_config.node_mask_type if self.feat_mask_type == MaskType.attributes: self.node_feat_mask = torch.nn.Parameter( torch.randn(num_nodes, num_feat, device=device) * std) elif self.feat_mask_type == MaskType.object: self.node_feat_mask = torch.nn.Parameter( torch.randn(num_nodes, 1, device=device) * std) else: self.node_feat_mask = torch.nn.Parameter( torch.randn(1, num_feat, device=device) * std) baselines, self.gates, full_biases = [], torch.nn.ModuleList(), [] for v_dim, m_dim, o_dim in zip(i_dim, j_dim, h_dim): self.transform, self.layer_norm = [], [] input_dims = [v_dim, m_dim, v_dim] for _, input_dim in enumerate(input_dims): self.transform.append( Linear(input_dim, o_dim, bias=False).to(device)) self.layer_norm.append(LayerNorm(o_dim).to(device)) self.transforms = torch.nn.ModuleList(self.transform) self.layer_norms = torch.nn.ModuleList(self.layer_norm) self.full_bias = Parameter( torch.tensor(o_dim, dtype=torch.float, device=device)) full_biases.append(self.full_bias) self.reset_parameters(input_dims, o_dim) self.non_linear = ReLU() self.output_layer = Linear(o_dim, 1).to(device) gate = [ self.transforms, self.layer_norms, self.non_linear, self.output_layer ] self.gates.extend(gate) baseline = torch.tensor(m_dim, dtype=torch.float, device=device) stdv = 1. / math.sqrt(m_dim) baseline.uniform_(-stdv, stdv) baseline = torch.nn.Parameter(baseline) baselines.append(baseline) full_biases = torch.nn.ParameterList(full_biases) self.full_biases = full_biases baselines = torch.nn.ParameterList(baselines) self.baselines = baselines for parameter in self.parameters(): parameter.requires_grad = False def _enable_layer(self, layer: int): r"""Enables the input layer's edge mask.""" for d in range(layer * 4, (layer * 4) + 4): for parameter in self.gates[d].parameters(): parameter.requires_grad = True self.full_biases[layer].requires_grad = True self.baselines[layer].requires_grad = True def reset_parameters(self, input_dims: List[int], h_dim: List[int]): r"""Resets all learnable parameters of the module.""" fan_in = sum(input_dims) std = math.sqrt(2.0 / float(fan_in + h_dim)) a = math.sqrt(3.0) * std for transform in self.transforms: torch.nn.init._no_grad_uniform_(transform.weight, -a, a) torch.nn.init.zeros_(self.full_bias) for layer_norm in self.layer_norms: layer_norm.reset_parameters() def _loss(self, y_hat: Tensor, y: Tensor, penalty: float) -> Tensor: if self.model_config.mode == ModelMode.binary_classification: loss = self._loss_binary_classification(y_hat, y) elif self.model_config.mode == ModelMode.multiclass_classification: loss = self._loss_multiclass_classification(y_hat, y) elif self.model_config.mode == ModelMode.regression: loss = self._loss_regression(y_hat, y) else: raise AssertionError() g = torch.relu(loss - self.allowance).mean() f = penalty * self.penalty_scaling loss = f + F.softplus(self.lambda_op) * g m = self.node_feat_mask.sigmoid() node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction']) loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m) ent = -m * torch.log(m + self.coeffs['EPS']) - ( 1 - m) * torch.log(1 - m + self.coeffs['EPS']) loss = loss + self.coeffs['node_feat_ent'] * ent.mean() return loss def _freeze_model(self, module: torch.nn.Module): r"""Freezes the parameters of the original GNN model by disabling their gradients. """ for param in module.parameters(): param.requires_grad = False def _set_flags(self, model: torch.nn.Module): r"""Initializes the underlying explainer model's parameters for each layer of the original GNN model. """ for module in model.modules(): if isinstance(module, MessagePassing): module.explain_message = explain_message.__get__( module, MessagePassing) module.explain = True def _inject_messages( self, model: torch.nn.Module, message_scale: List[Tensor], message_replacement: torch.nn.ParameterList, set: bool = False, ): r"""Injects the computed messages into each layer of the original GNN model. """ i = 0 for module in model.modules(): if isinstance(module, MessagePassing): if not set: module.message_scale = message_scale[i] module.message_replacement = message_replacement[i] i = i + 1 else: module.message_scale = None module.message_replacement = None def _train_explainer( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ): r"""Trains the underlying explainer model. Args: model (torch.nn.Module): The model to explain. x (torch.Tensor): The input node features. edge_index (torch.Tensor): The input edge indices. target (torch.Tensor): The target of the model. index (int or torch.Tensor, optional): The index of the model output to explain. Needs to be a single index. (default: :obj:`None`) **kwargs (optional): Additional keyword arguments passed to :obj:`model`. """ if (not isinstance(index, Tensor) and not isinstance(index, int) and index is not None): raise ValueError("'index' parameter can only be a 'Tensor', " "'integer' or set to 'None' instead.") self._freeze_model(model) self._set_flags(model) input_dims, output_dims = [], [] for module in model.modules(): if isinstance(module, MessagePassing): input_dims.append(module.in_channels) output_dims.append(module.out_channels) self._set_masks(input_dims, output_dims, output_dims, x) optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) for layer in reversed(list(range(self.num_layers))): if self.log: pbar = tqdm(total=self.epochs) if self.model_config.task_level == ModelTaskLevel.node: pbar.set_description( f'Train explainer for node(s) {index} with layer ' f'{layer}') elif self.model_config.task_level == ModelTaskLevel.edge: pbar.set_description( f"Train explainer for edge-level task with layer " f"{layer}") else: pbar.set_description( f'Train explainer for graph {index} with layer ' f'{layer}') self._enable_layer(layer) for _ in range(self.epochs): with torch.no_grad(): model(x, edge_index, **kwargs) gates, total_penalty = [], 0 latest_source_embeddings, latest_messages = [], [] latest_target_embeddings = [] for module in model.modules(): if isinstance(module, MessagePassing): latest_source_embeddings.append( module.latest_source_embeddings) latest_messages.append(module.latest_messages) latest_target_embeddings.append( module.latest_target_embeddings) gate_input = [ latest_source_embeddings, latest_messages, latest_target_embeddings ] for i in range(self.num_layers): output = self.full_biases[i] for j in range(len(gate_input)): try: partial = self.gates[i * 4][j](gate_input[j][i]) except Exception: try: self._set_masks(output_dims, output_dims, output_dims, x) partial = self.gates[i * 4][j]( gate_input[j][i]) except Exception: self._set_masks(input_dims, input_dims, output_dims, x) partial = self.gates[i * 4][j]( gate_input[j][i]) result = self.gates[(i * 4) + 1][j](partial) output = output + result relu_output = self.gates[(i * 4) + 2](output / len(gate_input)) sampling_weights = self.gates[(i * 4) + 3](relu_output).squeeze( dim=-1) sampling_weights, penalty = self._hard_concrete( sampling_weights) gates.append(sampling_weights) total_penalty += penalty self._inject_messages(model, gates, self.baselines) self.lambda_op = torch.tensor(self.init_lambda, requires_grad=True) optimizer_lambda = torch.optim.RMSprop( [self.lambda_op], lr=self.lambda_optimizer_lr, centered=True) optimizer.zero_grad() optimizer_lambda.zero_grad() h = x * self.node_feat_mask.sigmoid() y_hat, y = model(x=h, edge_index=edge_index, **kwargs), target if (self.model_config.task_level == ModelTaskLevel.node or self.model_config.task_level == ModelTaskLevel.edge): if index is not None: y_hat, y = y_hat[index], y[index] self._inject_messages(model, gates, self.baselines, True) loss = self._loss(y_hat, y, total_penalty) loss.backward() optimizer.step() self.lambda_op.grad *= -1 optimizer_lambda.step() if self.lambda_op.item() < -2: self.lambda_op.data = torch.full_like( self.lambda_op.data, -2) elif self.lambda_op.item() > 30: self.lambda_op.data = torch.full_like( self.lambda_op.data, 30) if self.log: pbar.update(1) if self.log: pbar.close() def _explain( self, model: torch.nn.Module, *, index: Optional[Union[int, Tensor]] = None, ) -> Tensor: r"""Generates explanations for the original GNN model. Args: model (torch.nn.Module): The model to explain. index (int or torch.Tensor, optional): The index of the model output to explain. Needs to be a single index. (default: :obj:`None`). """ if (not isinstance(index, Tensor) and not isinstance(index, int) and index is not None): raise ValueError("'index' parameter can only be a 'Tensor', " "'integer' or set to 'None' instead.") self._freeze_model(model) self._set_flags(model) with torch.no_grad(): latest_source_embeddings, latest_messages = [], [] latest_target_embeddings = [] for module in model.modules(): if isinstance(module, MessagePassing): latest_source_embeddings.append( module.latest_source_embeddings) latest_messages.append(module.latest_messages) latest_target_embeddings.append( module.latest_target_embeddings) gate_input = [ latest_source_embeddings, latest_messages, latest_target_embeddings ] if self.log: pbar = tqdm(total=self.num_layers) for i in range(self.num_layers): if self.log: pbar.set_description("Explain") output = self.full_biases[i] for j in range(len(gate_input)): partial = self.gates[i * 4][j](gate_input[j][i]) result = self.gates[(i * 4) + 1][j](partial) output = output + result relu_output = self.gates[(i * 4) + 2](output / len(gate_input)) sampling_weights = self.gates[(i * 4) + 3](relu_output).squeeze(dim=-1) sampling_weights, _ = self._hard_concrete( sampling_weights, training=False) if i == 0: edge_weight = sampling_weights else: edge_weight = torch.cat((edge_weight, sampling_weights), 0) if self.log: pbar.update(1) if self.log: pbar.close() edge_mask = edge_weight.view(-1, edge_weight.size(0) // self.num_layers) edge_mask = torch.mean(edge_mask, 0) return edge_mask ================================================ FILE: torch_geometric/explain/algorithm/pg_explainer.py ================================================ import logging from typing import Dict, Optional, Tuple, Union, overload import torch from torch import Tensor from torch.nn import ReLU, Sequential from torch_geometric.explain import Explanation, HeteroExplanation from torch_geometric.explain.algorithm import ExplainerAlgorithm from torch_geometric.explain.algorithm.utils import ( clear_masks, set_hetero_masks, set_masks, ) from torch_geometric.explain.config import ( ExplanationType, ModelMode, ModelTaskLevel, ) from torch_geometric.nn import HANConv, HeteroConv, HGTConv, Linear from torch_geometric.nn.inits import reset from torch_geometric.typing import EdgeType, NodeType from torch_geometric.utils import get_embeddings, get_embeddings_hetero class PGExplainer(ExplainerAlgorithm): r"""The PGExplainer model from the `"Parameterized Explainer for Graph Neural Network" `_ paper. Internally, it utilizes a neural network to identify subgraph structures that play a crucial role in the predictions made by a GNN. Importantly, the :class:`PGExplainer` needs to be trained via :meth:`~PGExplainer.train` before being able to generate explanations: .. code-block:: python explainer = Explainer( model=model, algorithm=PGExplainer(epochs=30, lr=0.003), explanation_type='phenomenon', edge_mask_type='object', model_config=ModelConfig(...), ) # Train against a variety of node-level or graph-level predictions: for epoch in range(30): for index in [...]: # Indices to train against. loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, index=index) # Get the final explanations: explanation = explainer(x, edge_index, target=target, index=0) Args: epochs (int): The number of epochs to train. lr (float, optional): The learning rate to apply. (default: :obj:`0.003`). **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`. """ coeffs = { 'edge_size': 0.05, 'edge_ent': 1.0, 'temp': [5.0, 2.0], 'bias': 0.01, } # NOTE: Add more in the future as needed. SUPPORTED_HETERO_MODELS = [ HGTConv, HANConv, HeteroConv, ] def __init__(self, epochs: int, lr: float = 0.003, **kwargs): super().__init__() self.epochs = epochs self.lr = lr self.coeffs.update(kwargs) self.mlp = Sequential( Linear(-1, 64), ReLU(), Linear(64, 1), ) self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr) self._curr_epoch = -1 self.is_hetero = False def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.mlp) @overload def train( self, epoch: int, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> float: ... @overload def train( self, epoch: int, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> float: ... def train( self, epoch: int, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> float: r"""Trains the underlying explainer model. Needs to be called before being able to make predictions. Args: epoch (int): The current epoch of the training phase. model (torch.nn.Module): The model to explain. x (torch.Tensor or Dict[str, torch.Tensor]): The input node features. Can be either homogeneous or heterogeneous. edge_index (torch.Tensor or Dict[Tuple[str, str, str]): The input edge indices. Can be either homogeneous or heterogeneous. target (torch.Tensor): The target of the model. index (int or torch.Tensor, optional): The index of the model output to explain. Needs to be a single index. (default: :obj:`None`) **kwargs (optional): Additional keyword arguments passed to :obj:`model`. """ self.is_hetero = isinstance(x, dict) if self.is_hetero: assert isinstance(edge_index, dict) if self.model_config.task_level == ModelTaskLevel.node: if index is None: raise ValueError(f"The 'index' argument needs to be provided " f"in '{self.__class__.__name__}' for " f"node-level explanations") if isinstance(index, Tensor) and index.numel() > 1: raise ValueError(f"Only scalars are supported for the 'index' " f"argument in '{self.__class__.__name__}'") # Get embeddings based on whether the graph is homogeneous or # heterogeneous node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs) # Train the model self.optimizer.zero_grad() temperature = self._get_temperature(epoch) # Process embeddings and generate edge masks edge_mask = self._generate_edge_masks(node_embeddings, edge_index, index, temperature) # Apply masks to the model if self.is_hetero: set_hetero_masks(model, edge_mask, edge_index, apply_sigmoid=True) # For node-level tasks, we can compute hard masks if self.model_config.task_level == ModelTaskLevel.node: # Process each edge type separately for edge_type, mask in edge_mask.items(): # Get the edge indices for this edge type edges = edge_index[edge_type] src_type, _, dst_type = edge_type # Get hard masks for this specific edge type _, hard_mask = self._get_hard_masks( model, index, edges, num_nodes=max(x[src_type].size(0), x[dst_type].size(0))) edge_mask[edge_type] = mask[hard_mask] else: # Apply masks for homogeneous graphs set_masks(model, edge_mask, edge_index, apply_sigmoid=True) # For node-level tasks, we may need to apply hard masks hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: _, hard_edge_mask = self._get_hard_masks( model, index, edge_index, num_nodes=x.size(0)) edge_mask = edge_mask[hard_edge_mask] # Forward pass with masks applied y_hat, y = model(x, edge_index, **kwargs), target if index is not None: y_hat, y = y_hat[index], y[index] # Calculate loss loss = self._loss(y_hat, y, edge_mask) # Backward pass and optimization loss.backward() self.optimizer.step() # Clean up clear_masks(model) self._curr_epoch = epoch return float(loss) @overload def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: ... @overload def forward( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> HeteroExplanation: ... def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: self.is_hetero = isinstance(x, dict) if self._curr_epoch < self.epochs - 1: # Safety check: raise ValueError(f"'{self.__class__.__name__}' is not yet fully " f"trained (got {self._curr_epoch + 1} epochs " f"from {self.epochs} epochs). Please first train " f"the underlying explainer model by running " f"`explainer.algorithm.train(...)`.") if self.model_config.task_level == ModelTaskLevel.node: if index is None: raise ValueError(f"The 'index' argument needs to be provided " f"in '{self.__class__.__name__}' for " f"node-level explanations") if isinstance(index, Tensor) and index.numel() > 1: raise ValueError(f"Only scalars are supported for the 'index' " f"argument in '{self.__class__.__name__}'") # Get embeddings node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs) # Generate explanations if self.is_hetero: # Generate edge masks for each edge type edge_masks = {} # Generate masks for each edge type for edge_type, edge_idx in edge_index.items(): src_node_type, _, dst_node_type = edge_type assert src_node_type in node_embeddings assert dst_node_type in node_embeddings inputs = self._get_inputs_hetero(node_embeddings, edge_type, edge_idx, index) logits = self.mlp(inputs).view(-1) # For node-level explanations, get hard masks for this # specific edge type hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: _, hard_edge_mask = self._get_hard_masks( model, index, edge_idx, num_nodes=max(x[src_node_type].size(0), x[dst_node_type].size(0))) # Apply hard mask if available and it has any True values edge_masks[edge_type] = self._post_process_mask( logits, hard_edge_mask, apply_sigmoid=True) explanation = HeteroExplanation() explanation.set_value_dict('edge_mask', edge_masks) return explanation else: hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: # We need to compute hard masks to properly clean up edges _, hard_edge_mask = self._get_hard_masks( model, index, edge_index, num_nodes=x.size(0)) inputs = self._get_inputs(node_embeddings, edge_index, index) logits = self.mlp(inputs).view(-1) edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True) return Explanation(edge_mask=edge_mask) def supports(self) -> bool: explanation_type = self.explainer_config.explanation_type if explanation_type != ExplanationType.phenomenon: logging.error(f"'{self.__class__.__name__}' only supports " f"phenomenon explanations " f"got (`explanation_type={explanation_type.value}`)") return False task_level = self.model_config.task_level if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}: logging.error(f"'{self.__class__.__name__}' only supports " f"node-level or graph-level explanations " f"got (`task_level={task_level.value}`)") return False node_mask_type = self.explainer_config.node_mask_type if node_mask_type is not None: logging.error(f"'{self.__class__.__name__}' does not support " f"explaining input node features " f"got (`node_mask_type={node_mask_type.value}`)") return False return True ########################################################################### def _get_embeddings(self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], **kwargs) -> Union[Tensor, Dict[NodeType, Tensor]]: """Get embeddings from the model based on input type.""" if self.is_hetero: # For heterogeneous graphs, get embeddings for each node type embeddings_dict = get_embeddings_hetero( model, self.SUPPORTED_HETERO_MODELS, x, edge_index, **kwargs, ) # Use the last layer's embeddings for each node type last_embedding_dict = { node_type: embs[-1] if embs and len(embs) > 0 else None for node_type, embs in embeddings_dict.items() } # Skip if no embeddings were captured if not any(emb is not None for emb in last_embedding_dict.values()): raise ValueError( "No embeddings were captured from the model. " "Please check if the model architecture is supported.") return last_embedding_dict else: # For homogeneous graphs, get embeddings directly return get_embeddings(model, x, edge_index, **kwargs)[-1] def _generate_edge_masks( self, emb: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], index: Optional[Union[int, Tensor]], temperature: float) -> Union[Tensor, Dict[EdgeType, Tensor]]: """Generate edge masks based on embeddings.""" if self.is_hetero: # For heterogeneous graphs, generate masks for each edge type edge_masks = {} for edge_type, edge_idx in edge_index.items(): src, _, dst = edge_type assert src in emb and dst in emb # Generate inputs for this edge type inputs = self._get_inputs_hetero(emb, edge_type, edge_idx, index) logits = self.mlp(inputs).view(-1) edge_masks[edge_type] = self._concrete_sample( logits, temperature) # Ensure we have at least one valid edge mask if not edge_masks: raise ValueError( "Could not generate edge masks for any edge type. " "Please ensure the model architecture is supported.") return edge_masks else: # For homogeneous graphs, generate a single mask inputs = self._get_inputs(emb, edge_index, index) logits = self.mlp(inputs).view(-1) return self._concrete_sample(logits, temperature) def _get_inputs(self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None) -> Tensor: zs = [embedding[edge_index[0]], embedding[edge_index[1]]] if self.model_config.task_level == ModelTaskLevel.node: assert index is not None zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1)) return torch.cat(zs, dim=-1) def _get_inputs_hetero(self, embedding_dict: Dict[NodeType, Tensor], edge_type: Tuple[str, str, str], edge_index: Tensor, index: Optional[int] = None) -> Tensor: src, _, dst = edge_type # Get embeddings for source and destination nodes src_emb = embedding_dict[src] dst_emb = embedding_dict[dst] # Source and destination node embeddings zs = [src_emb[edge_index[0]], dst_emb[edge_index[1]]] # For node-level explanations, add the target node embedding if self.model_config.task_level == ModelTaskLevel.node: assert index is not None # Assuming index refers to a node of type 'src' target_emb = src_emb[index].view(1, -1).repeat(zs[0].size(0), 1) zs.append(target_emb) return torch.cat(zs, dim=-1) def _get_temperature(self, epoch: int) -> float: temp = self.coeffs['temp'] return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs) def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor: bias = self.coeffs['bias'] eps = (1 - 2 * bias) * torch.rand_like(logits) + bias return (eps.log() - (1 - eps).log() + logits) / temperature def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Union[Tensor, Dict[EdgeType, Tensor]]) -> Tensor: # Calculate base loss based on model configuration loss = self._calculate_base_loss(y_hat, y) # Apply regularization based on graph type if self.is_hetero: loss = self._apply_hetero_regularization(loss, edge_mask) else: loss = self._apply_homo_regularization(loss, edge_mask) return loss def _calculate_base_loss(self, y_hat: Tensor, y: Tensor) -> Tensor: """Calculate base loss based on model configuration.""" if self.model_config.mode == ModelMode.binary_classification: return self._loss_binary_classification(y_hat, y) elif self.model_config.mode == ModelMode.multiclass_classification: return self._loss_multiclass_classification(y_hat, y) elif self.model_config.mode == ModelMode.regression: return self._loss_regression(y_hat, y) else: raise ValueError( f"Unsupported model mode: {self.model_config.mode}") def _apply_hetero_regularization( self, loss: Tensor, edge_mask: Dict[EdgeType, Tensor]) -> Tensor: """Apply regularization for heterogeneous graph.""" for _, mask in edge_mask.items(): loss = self._add_mask_regularization(loss, mask) return loss def _apply_homo_regularization(self, loss: Tensor, edge_mask: Tensor) -> Tensor: """Apply regularization for homogeneous graph.""" return self._add_mask_regularization(loss, edge_mask) def _add_mask_regularization(self, loss: Tensor, mask: Tensor) -> Tensor: """Add size and entropy regularization for a mask.""" # Apply sigmoid for mask values mask = mask.sigmoid() # Size regularization size_loss = mask.sum() * self.coeffs['edge_size'] # Entropy regularization masked = 0.99 * mask + 0.005 mask_ent = -masked * masked.log() - (1 - masked) * (1 - masked).log() mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent'] return loss + size_loss + mask_ent_loss ================================================ FILE: torch_geometric/explain/algorithm/utils.py ================================================ from typing import Dict, Union import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn import MessagePassing from torch_geometric.typing import EdgeType def set_masks( model: torch.nn.Module, mask: Union[Tensor, Parameter], edge_index: Tensor, apply_sigmoid: bool = True, ): r"""Apply mask to every graph layer in the :obj:`model`.""" loop_mask = edge_index[0] != edge_index[1] # Loop over layers and set masks on MessagePassing layers: for module in model.modules(): if isinstance(module, MessagePassing): # Skip layers that have been explicitly set to `False`: if module.explain is False: continue # Convert mask to a param if it was previously registered as one. # This is a workaround for the fact that PyTorch does not allow # assignments of pure tensors to parameter attributes: if (not isinstance(mask, Parameter) and '_edge_mask' in module._parameters): mask = Parameter(mask) module.explain = True module._edge_mask = mask module._loop_mask = loop_mask module._apply_sigmoid = apply_sigmoid def set_hetero_masks( model: torch.nn.Module, mask_dict: Dict[EdgeType, Union[Tensor, Parameter]], edge_index_dict: Dict[EdgeType, Tensor], apply_sigmoid: bool = True, ): r"""Apply masks to every heterogeneous graph layer in the :obj:`model` according to edge types. """ for module in model.modules(): if isinstance(module, torch.nn.ModuleDict): for edge_type in mask_dict.keys(): if edge_type in module: edge_level_module = module[edge_type] elif '__'.join(edge_type) in module: edge_level_module = module['__'.join(edge_type)] else: continue set_masks( edge_level_module, mask_dict[edge_type], edge_index_dict[edge_type], apply_sigmoid=apply_sigmoid, ) def clear_masks(model: torch.nn.Module): r"""Clear all masks from the model.""" for module in model.modules(): if isinstance(module, MessagePassing): if module.explain is True: module.explain = None module._edge_mask = None module._loop_mask = None module._apply_sigmoid = True return module ================================================ FILE: torch_geometric/explain/config.py ================================================ from dataclasses import dataclass from enum import Enum from typing import Optional, Union from torch_geometric.utils.mixin import CastMixin class ExplanationType(Enum): """Enum class for the explanation type.""" model = 'model' phenomenon = 'phenomenon' class MaskType(Enum): """Enum class for the mask type.""" object = 'object' common_attributes = 'common_attributes' attributes = 'attributes' class ModelMode(Enum): """Enum class for the model return type.""" binary_classification = 'binary_classification' multiclass_classification = 'multiclass_classification' regression = 'regression' class ModelTaskLevel(Enum): """Enum class for the model task level.""" node = 'node' edge = 'edge' graph = 'graph' class ModelReturnType(Enum): """Enum class for the model return type.""" raw = 'raw' probs = 'probs' log_probs = 'log_probs' class ThresholdType(Enum): """Enum class for the threshold type.""" hard = 'hard' topk = 'topk' topk_hard = 'topk_hard' # connected = 'connected' # TODO @dataclass class ExplainerConfig(CastMixin): r"""Configuration class to store and validate high level explanation parameters. Args: explanation_type (ExplanationType or str): The type of explanation to compute. The possible values are: - :obj:`"model"`: Explains the model prediction. - :obj:`"phenomenon"`: Explains the phenomenon that the model is trying to predict. In practice, this means that the explanation algorithm will either compute their losses with respect to the model output (:obj:`"model"`) or the target output (:obj:`"phenomenon"`). node_mask_type (MaskType or str, optional): The type of mask to apply on nodes. The possible values are (default: :obj:`None`): - :obj:`None`: Will not apply any mask on nodes. - :obj:`"object"`: Will mask each node. - :obj:`"common_attributes"`: Will mask each feature. - :obj:`"attributes"`: Will mask each feature across all nodes. edge_mask_type (MaskType or str, optional): The type of mask to apply on edges. Has the sample possible values as :obj:`node_mask_type`. (default: :obj:`None`) """ explanation_type: ExplanationType node_mask_type: Optional[MaskType] edge_mask_type: Optional[MaskType] def __init__( self, explanation_type: Union[ExplanationType, str], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, ): if node_mask_type is not None: node_mask_type = MaskType(node_mask_type) if edge_mask_type is not None: edge_mask_type = MaskType(edge_mask_type) if edge_mask_type is not None and edge_mask_type != MaskType.object: raise ValueError(f"'edge_mask_type' needs be None or of type " f"'object' (got '{edge_mask_type.value}')") if node_mask_type is None and edge_mask_type is None: raise ValueError("Either 'node_mask_type' or 'edge_mask_type' " "must be provided") self.explanation_type = ExplanationType(explanation_type) self.node_mask_type = node_mask_type self.edge_mask_type = edge_mask_type @dataclass class ModelConfig(CastMixin): r"""Configuration class to store model parameters. Args: mode (ModelMode or str): The mode of the model. The possible values are: - :obj:`"binary_classification"`: A binary classification model. - :obj:`"multiclass_classification"`: A multiclass classification model. - :obj:`"regression"`: A regression model. task_level (ModelTaskLevel or str): The task-level of the model. The possible values are: - :obj:`"node"`: A node-level prediction model. - :obj:`"edge"`: An edge-level prediction model. - :obj:`"graph"`: A graph-level prediction model. return_type (ModelReturnType or str, optional): The return type of the model. The possible values are (default: :obj:`None`): - :obj:`"raw"`: The model returns raw values. - :obj:`"probs"`: The model returns probabilities. - :obj:`"log_probs"`: The model returns log-probabilities. """ mode: ModelMode task_level: ModelTaskLevel return_type: ModelReturnType def __init__( self, mode: Union[ModelMode, str], task_level: Union[ModelTaskLevel, str], return_type: Optional[Union[ModelReturnType, str]] = None, ): self.mode = ModelMode(mode) self.task_level = ModelTaskLevel(task_level) if return_type is None and self.mode == ModelMode.regression: return_type = ModelReturnType.raw self.return_type = ModelReturnType(return_type) if (self.mode == ModelMode.regression and self.return_type != ModelReturnType.raw): raise ValueError(f"A model for regression needs to return raw " f"outputs (got {self.return_type.value})") if (self.mode == ModelMode.binary_classification and self.return_type not in [ModelReturnType.raw, ModelReturnType.probs]): raise ValueError( f"A model for binary classification needs to return raw " f"outputs or probabilities (got {self.return_type.value})") @dataclass class ThresholdConfig(CastMixin): r"""Configuration class to store and validate threshold parameters. Args: threshold_type (ThresholdType or str): The type of threshold to apply. The possible values are: - :obj:`None`: No threshold is applied. - :obj:`"hard"`: A hard threshold is applied to each mask. The elements of the mask with a value below the :obj:`value` are set to :obj:`0`, the others are set to :obj:`1`. - :obj:`"topk"`: A soft threshold is applied to each mask. The top obj:`value` elements of each mask are kept, the others are set to :obj:`0`. - :obj:`"topk_hard"`: Same as :obj:`"topk"` but values are set to :obj:`1` for all elements which are kept. value (int or float, optional): The value to use when thresholding. (default: :obj:`None`) """ type: ThresholdType value: Union[float, int] def __init__( self, threshold_type: Union[ThresholdType, str], value: Union[float, int], ): self.type = ThresholdType(threshold_type) self.value = value if not isinstance(self.value, (int, float)): raise ValueError(f"Threshold value must be a float or int " f"(got {type(self.value)}).") if (self.type == ThresholdType.hard and (self.value < 0 or self.value > 1)): raise ValueError(f"Threshold value must be between 0 and 1 " f"(got {self.value})") if self.type in [ThresholdType.topk, ThresholdType.topk_hard]: if not isinstance(self.value, int): raise ValueError(f"Threshold value needs to be an integer " f"(got {type(self.value)}).") if self.value <= 0: raise ValueError(f"Threshold value needs to be positive " f"(got {self.value}).") ================================================ FILE: torch_geometric/explain/explainer.py ================================================ import warnings from typing import Any, Dict, Optional, Union import torch from torch import Tensor from torch_geometric.explain import ( ExplainerAlgorithm, Explanation, HeteroExplanation, ) from torch_geometric.explain.algorithm.utils import ( clear_masks, set_hetero_masks, set_masks, ) from torch_geometric.explain.config import ( ExplainerConfig, ExplanationType, MaskType, ModelConfig, ModelMode, ModelReturnType, ThresholdConfig, ) from torch_geometric.typing import EdgeType, NodeType class Explainer: r"""An explainer class for instance-level explanations of Graph Neural Networks. Args: model (torch.nn.Module): The model to explain. algorithm (ExplainerAlgorithm): The explanation algorithm. explanation_type (ExplanationType or str): The type of explanation to compute. The possible values are: - :obj:`"model"`: Explains the model prediction. - :obj:`"phenomenon"`: Explains the phenomenon that the model is trying to predict. In practice, this means that the explanation algorithm will either compute their losses with respect to the model output (:obj:`"model"`) or the target output (:obj:`"phenomenon"`). model_config (ModelConfig): The model configuration. See :class:`~torch_geometric.explain.config.ModelConfig` for available options. (default: :obj:`None`) node_mask_type (MaskType or str, optional): The type of mask to apply on nodes. The possible values are (default: :obj:`None`): - :obj:`None`: Will not apply any mask on nodes. - :obj:`"object"`: Will mask each node. - :obj:`"common_attributes"`: Will mask each feature. - :obj:`"attributes"`: Will mask each feature across all nodes. edge_mask_type (MaskType or str, optional): The type of mask to apply on edges. Has the sample possible values as :obj:`node_mask_type`. (default: :obj:`None`) threshold_config (ThresholdConfig, optional): The threshold configuration. See :class:`~torch_geometric.explain.config.ThresholdConfig` for available options. (default: :obj:`None`) """ def __init__( self, model: torch.nn.Module, algorithm: ExplainerAlgorithm, explanation_type: Union[ExplanationType, str], model_config: Union[ModelConfig, Dict[str, Any]], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, threshold_config: Optional[ThresholdConfig] = None, ): explainer_config = ExplainerConfig( explanation_type=explanation_type, node_mask_type=node_mask_type, edge_mask_type=edge_mask_type, ) self.model = model self.algorithm = algorithm self.explanation_type = explainer_config.explanation_type self.model_config = ModelConfig.cast(model_config) self.node_mask_type = explainer_config.node_mask_type self.edge_mask_type = explainer_config.edge_mask_type self.threshold_config = ThresholdConfig.cast(threshold_config) self.algorithm.connect(explainer_config, self.model_config) @torch.no_grad() def get_prediction(self, *args, **kwargs) -> Tensor: r"""Returns the prediction of the model on the input graph. If the model mode is :obj:`"regression"`, the prediction is returned as a scalar value. If the model mode is :obj:`"multiclass_classification"` or :obj:`"binary_classification"`, the prediction is returned as the predicted class label. Args: *args: Arguments passed to the model. **kwargs (optional): Additional keyword arguments passed to the model. """ training = self.model.training self.model.eval() with torch.no_grad(): out = self.model(*args, **kwargs) self.model.train(training) return out def get_masked_prediction( self, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], node_mask: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None, **kwargs, ) -> Tensor: r"""Returns the prediction of the model on the input graph with node and edge masks applied. """ if isinstance(x, Tensor) and node_mask is not None: x = node_mask * x elif isinstance(x, dict) and node_mask is not None: x = {key: value * node_mask[key] for key, value in x.items()} if isinstance(edge_mask, Tensor): set_masks(self.model, edge_mask, edge_index, apply_sigmoid=False) elif isinstance(edge_mask, dict): set_hetero_masks(self.model, edge_mask, edge_index, apply_sigmoid=False) out = self.get_prediction(x, edge_index, **kwargs) clear_masks(self.model) return out def __call__( self, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: r"""Computes the explanation of the GNN for the given inputs and target. .. note:: If you get an error message like "Trying to backward through the graph a second time", make sure that the target you provided was computed with :meth:`torch.no_grad`. Args: x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input node features of a homogeneous or heterogeneous graph. edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input edge indices of a homogeneous or heterogeneous graph. target (torch.Tensor): The target of the model. If the explanation type is :obj:`"phenomenon"`, the target has to be provided. If the explanation type is :obj:`"model"`, the target should be set to :obj:`None` and will get automatically inferred. For classification tasks, the target needs to contain the class labels. (default: :obj:`None`) index (Union[int, Tensor], optional): The indices in the first-dimension of the model output to explain. Can be a single index or a tensor of indices. If set to :obj:`None`, all model outputs will be explained. (default: :obj:`None`) **kwargs: additional arguments to pass to the GNN. """ # Choose the `target` depending on the explanation type: prediction: Optional[Tensor] = None if self.explanation_type == ExplanationType.phenomenon: if target is None: raise ValueError( f"The 'target' has to be provided for the explanation " f"type '{self.explanation_type.value}'") elif self.explanation_type == ExplanationType.model: if target is not None: warnings.warn( f"The 'target' should not be provided for the explanation " f"type '{self.explanation_type.value}'", stacklevel=2) prediction = self.get_prediction(x, edge_index, **kwargs) target = self.get_target(prediction) if isinstance(index, int): index = torch.tensor([index]) training = self.model.training self.model.eval() explanation = self.algorithm( self.model, x, edge_index, target=target, index=index, **kwargs, ) self.model.train(training) # Add explainer objectives to the `Explanation` object: explanation._model_config = self.model_config explanation.prediction = prediction explanation.target = target explanation.index = index # Add model inputs to the `Explanation` object: if isinstance(explanation, Explanation): explanation._model_args = list(kwargs.keys()) explanation.x = x explanation.edge_index = edge_index for key, arg in kwargs.items(): # Add remaining `kwargs`: explanation[key] = arg elif isinstance(explanation, HeteroExplanation): # TODO Add `explanation._model_args` assert isinstance(x, dict) explanation.set_value_dict('x', x) assert isinstance(edge_index, dict) explanation.set_value_dict('edge_index', edge_index) for key, arg in kwargs.items(): # Add remaining `kwargs`: if isinstance(arg, dict): # Keyword arguments are likely named `{attr_name}_dict` # while we only want to assign the `{attr_name}` to the # `HeteroExplanation` object: key = key[:-5] if key.endswith('_dict') else key explanation.set_value_dict(key, arg) else: explanation[key] = arg explanation.validate_masks() return explanation.threshold(self.threshold_config) def get_target(self, prediction: Tensor) -> Tensor: r"""Returns the target of the model from a given prediction. If the model mode is of type :obj:`"regression"`, the prediction is returned as it is. If the model mode is of type :obj:`"multiclass_classification"` or :obj:`"binary_classification"`, the prediction is returned as the predicted class label. """ if self.model_config.mode == ModelMode.binary_classification: # TODO: Allow customization of the thresholds used below. if self.model_config.return_type == ModelReturnType.raw: return (prediction > 0).long().view(-1) if self.model_config.return_type == ModelReturnType.probs: return (prediction > 0.5).long().view(-1) raise AssertionError() if self.model_config.mode == ModelMode.multiclass_classification: return prediction.argmax(dim=-1) return prediction ================================================ FILE: torch_geometric/explain/explanation.py ================================================ import copy from typing import Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data.data import Data, warn_or_raise from torch_geometric.data.hetero_data import HeteroData from torch_geometric.explain.config import ThresholdConfig, ThresholdType from torch_geometric.typing import EdgeType, NodeType from torch_geometric.visualization import ( visualize_graph, visualize_hetero_graph, ) class ExplanationMixin: @property def available_explanations(self) -> List[str]: """Returns the available explanation masks.""" return [key for key in self.keys() if key.endswith('_mask')] def validate_masks(self, raise_on_error: bool = True) -> bool: r"""Validates the correctness of the :class:`Explanation` masks.""" status = True for store in self.node_stores: if 'node_mask' not in store: continue if store.node_mask.dim() != 2: status = False warn_or_raise( f"Expected a 'node_mask' with two dimensions (got " f"{store.node_mask.dim()} dimensions)", raise_on_error) if store.node_mask.size(0) not in {1, store.num_nodes}: status = False warn_or_raise( f"Expected a 'node_mask' with {store.num_nodes} nodes " f"(got {store.node_mask.size(0)} nodes)", raise_on_error) if 'x' in store: num_features = store.x.size(-1) else: num_features = store.node_mask.size(-1) if store.node_mask.size(1) not in {1, num_features}: status = False warn_or_raise( f"Expected a 'node_mask' with {num_features} features (" f"got {store.node_mask.size(1)} features)", raise_on_error) for store in self.edge_stores: if 'edge_mask' not in store: continue if store.edge_mask.dim() != 1: status = False warn_or_raise( f"Expected an 'edge_mask' with one dimension (got " f"{store.edge_mask.dim()} dimensions)", raise_on_error) if store.edge_mask.size(0) != store.num_edges: status = False warn_or_raise( f"Expected an 'edge_mask' with {store.num_edges} edges " f"(got {store.edge_mask.size(0)} edges)", raise_on_error) return status def _threshold_mask( self, mask: Optional[Tensor], threshold_config: ThresholdConfig, ) -> Optional[Tensor]: if mask is None: return None if threshold_config.type == ThresholdType.hard: return (mask > threshold_config.value).float() if threshold_config.type in [ ThresholdType.topk, ThresholdType.topk_hard, ]: if threshold_config.value >= mask.numel(): if threshold_config.type == ThresholdType.topk: return mask else: return torch.ones_like(mask) value, index = torch.topk( mask.flatten(), k=threshold_config.value, ) out = torch.zeros_like(mask.flatten()) if threshold_config.type == ThresholdType.topk: out[index] = value else: out[index] = 1.0 return out.view(mask.size()) raise AssertionError() def threshold( self, *args, **kwargs, ) -> Union['Explanation', 'HeteroExplanation']: """Thresholds the explanation masks according to the thresholding method. Args: *args: Arguments passed to :class:`ThresholdConfig`. **kwargs: Keyword arguments passed to :class:`ThresholdConfig`. """ threshold_config = ThresholdConfig.cast(*args, **kwargs) if threshold_config is None: return self # Avoid modification of the original explanation: out = copy.copy(self) for store in out.node_stores: store.node_mask = self._threshold_mask(store.get('node_mask'), threshold_config) for store in out.edge_stores: store.edge_mask = self._threshold_mask(store.get('edge_mask'), threshold_config) return out class Explanation(Data, ExplanationMixin): r"""Holds all the obtained explanations of a homogeneous graph. The explanation object is a :obj:`~torch_geometric.data.Data` object and can hold node attributions and edge attributions. It can also hold the original graph if needed. Args: node_mask (Tensor, optional): Node-level mask with shape :obj:`[num_nodes, 1]`, :obj:`[1, num_features]` or :obj:`[num_nodes, num_features]`. (default: :obj:`None`) edge_mask (Tensor, optional): Edge-level mask with shape :obj:`[num_edges]`. (default: :obj:`None`) **kwargs (optional): Additional attributes. """ def validate(self, raise_on_error: bool = True) -> bool: r"""Validates the correctness of the :class:`Explanation` object.""" status = super().validate(raise_on_error) status &= self.validate_masks(raise_on_error) return status def get_explanation_subgraph(self) -> 'Explanation': r"""Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out. """ node_mask = self.get('node_mask') if node_mask is not None: node_mask = node_mask.sum(dim=-1) > 0 edge_mask = self.get('edge_mask') if edge_mask is not None: edge_mask = edge_mask > 0 return self._apply_masks(node_mask, edge_mask) def get_complement_subgraph(self) -> 'Explanation': r"""Returns the induced subgraph, in which all nodes and edges with any attribution are masked out. """ node_mask = self.get('node_mask') if node_mask is not None: node_mask = node_mask.sum(dim=-1) == 0 edge_mask = self.get('edge_mask') if edge_mask is not None: edge_mask = edge_mask == 0 return self._apply_masks(node_mask, edge_mask) def _apply_masks( self, node_mask: Optional[Tensor] = None, edge_mask: Optional[Tensor] = None, ) -> 'Explanation': out = copy.copy(self) if edge_mask is not None: for key, value in self.items(): if key == 'edge_index': out.edge_index = value[:, edge_mask] elif self.is_edge_attr(key): out[key] = value[edge_mask] if node_mask is not None: out = out.subgraph(node_mask) return out def visualize_feature_importance( self, path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None, ): r"""Creates a bar plot of the node feature importances by summing up the node mask across all nodes. Args: path (str, optional): The path to where the plot is saved. If set to :obj:`None`, will visualize the plot on-the-fly. (default: :obj:`None`) feat_labels (List[str], optional): The labels of features. (default :obj:`None`) top_k (int, optional): Top k features to plot. If :obj:`None` plots all features. (default: :obj:`None`) """ node_mask = self.get('node_mask') if node_mask is None: raise ValueError(f"The attribute 'node_mask' is not available " f"in '{self.__class__.__name__}' " f"(got {self.available_explanations})") if node_mask.dim() != 2 or node_mask.size(1) <= 1: raise ValueError(f"Cannot compute feature importance for " f"object-level 'node_mask' " f"(got shape {node_mask.size()})") if feat_labels is None: feat_labels = range(node_mask.size(1)) score = node_mask.sum(dim=0) return _visualize_score(score, feat_labels, path, top_k) def visualize_graph( self, path: Optional[str] = None, backend: Optional[str] = None, node_labels: Optional[List[str]] = None, ) -> None: r"""Visualizes the explanation graph with edge opacity corresponding to edge importance. Args: path (str, optional): The path to where the plot is saved. If set to :obj:`None`, will visualize the plot on-the-fly. (default: :obj:`None`) backend (str, optional): The graph drawing backend to use for visualization (:obj:`"graphviz"`, :obj:`"networkx"`). If set to :obj:`None`, will use the most appropriate visualization backend based on available system packages. (default: :obj:`None`) node_labels (list[str], optional): The labels/IDs of nodes. (default: :obj:`None`) """ edge_mask = self.get('edge_mask') if edge_mask is None: raise ValueError(f"The attribute 'edge_mask' is not available " f"in '{self.__class__.__name__}' " f"(got {self.available_explanations})") visualize_graph(self.edge_index, edge_mask, path, backend, node_labels) class HeteroExplanation(HeteroData, ExplanationMixin): r"""Holds all the obtained explanations of a heterogeneous graph. The explanation object is a :obj:`~torch_geometric.data.HeteroData` object and can hold node attributions and edge attributions. It can also hold the original graph if needed. """ def validate(self, raise_on_error: bool = True) -> bool: r"""Validates the correctness of the :class:`Explanation` object.""" status = super().validate(raise_on_error) status &= self.validate_masks(raise_on_error) return status def get_explanation_subgraph(self) -> 'HeteroExplanation': r"""Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out. """ return self._apply_masks( node_mask_dict={ key: mask.sum(dim=-1) > 0 for key, mask in self.collect('node_mask', True).items() }, edge_mask_dict={ key: mask > 0 for key, mask in self.collect('edge_mask', True).items() }, ) def get_complement_subgraph(self) -> 'HeteroExplanation': r"""Returns the induced subgraph, in which all nodes and edges with any attribution are masked out. """ return self._apply_masks( node_mask_dict={ key: mask.sum(dim=-1) == 0 for key, mask in self.collect('node_mask', True).items() }, edge_mask_dict={ key: mask == 0 for key, mask in self.collect('edge_mask', True).items() }, ) def _apply_masks( self, node_mask_dict: Dict[NodeType, Tensor], edge_mask_dict: Dict[EdgeType, Tensor], ) -> 'HeteroExplanation': out = copy.copy(self) for edge_type, edge_mask in edge_mask_dict.items(): for key, value in self[edge_type].items(): if key == 'edge_index': out[edge_type].edge_index = value[:, edge_mask] elif self[edge_type].is_edge_attr(key): out[edge_type][key] = value[edge_mask] return out.subgraph(node_mask_dict) def visualize_feature_importance( self, path: Optional[str] = None, feat_labels: Optional[Dict[NodeType, List[str]]] = None, top_k: Optional[int] = None, ): r"""Creates a bar plot of the node feature importances by summing up node masks across all nodes for each node type. Args: path (str, optional): The path to where the plot is saved. If set to :obj:`None`, will visualize the plot on-the-fly. (default: :obj:`None`) feat_labels (Dict[NodeType, List[str]], optional): The labels of features for each node type. (default :obj:`None`) top_k (int, optional): Top k features to plot. If :obj:`None` plots all features. (default: :obj:`None`) """ node_mask_dict = self.node_mask_dict for node_mask in node_mask_dict.values(): if node_mask.dim() != 2: raise ValueError(f"Cannot compute feature importance for " f"object-level 'node_mask' " f"(got shape {node_mask.size()})") if feat_labels is None: feat_labels = {} for node_type, node_mask in node_mask_dict.items(): feat_labels[node_type] = range(node_mask.size(1)) score = torch.cat( [node_mask.sum(dim=0) for node_mask in node_mask_dict.values()], dim=0) all_feat_labels = [] for node_type in node_mask_dict.keys(): all_feat_labels += [ f'{node_type}#{label}' for label in feat_labels[node_type] ] return _visualize_score(score, all_feat_labels, path, top_k) def visualize_graph( self, path: Optional[str] = None, node_labels: Optional[Dict[NodeType, List[str]]] = None, node_size_range: Tuple[float, float] = (50, 500), node_opacity_range: Tuple[float, float] = (0.2, 1.0), edge_width_range: Tuple[float, float] = (0.1, 2.0), edge_opacity_range: Tuple[float, float] = (0.2, 1.0), ) -> None: r"""Visualizes the explanation subgraph using networkx, with edge opacity corresponding to edge importance and node colors corresponding to node types. Args: path (str, optional): The path to where the plot is saved. If set to :obj:`None`, will visualize the plot on-the-fly. (default: :obj:`None`) node_labels (Dict[NodeType, List[str]], optional): The display names of nodes for each node type that will be shown in the visualization. (default: :obj:`None`) node_size_range (Tuple[float, float], optional): The minimum and maximum node size in the visualization. (default: :obj:`(50, 500)`) node_opacity_range (Tuple[float, float], optional): The minimum and maximum node opacity in the visualization. (default: :obj:`(0.2, 1.0)`) edge_width_range (Tuple[float, float], optional): The minimum and maximum edge width in the visualization. (default: :obj:`(0.1, 2.0)`) edge_opacity_range (Tuple[float, float], optional): The minimum and maximum edge opacity in the visualization. (default: :obj:`(0.2, 1.0)`) """ # Validate node labels if provided if node_labels is not None: for node_type, labels in node_labels.items(): if node_type not in self.node_types: raise ValueError( f"Node type '{node_type}' in node_labels " f"does not exist in the explanation graph") if len(labels) != self[node_type].num_nodes: raise ValueError(f"Number of labels for node type " f"'{node_type}' (got {len(labels)}) does " f"not match the number of nodes " f"(got {self[node_type].num_nodes})") # Get the explanation subgraph subgraph = self.get_explanation_subgraph() # Prepare edge indices and weights for each edge type edge_index_dict = {} edge_weight_dict = {} for edge_type in subgraph.edge_types: if edge_type[0] == 'x' or edge_type[-1] == 'x': # Skip edges continue edge_index_dict[edge_type] = subgraph[edge_type].edge_index edge_weight_dict[edge_type] = subgraph[edge_type].get( 'edge_mask', torch.ones(subgraph[edge_type].edge_index.size(1))) # Prepare node weights for each node type node_weight_dict = {} for node_type in subgraph.node_types: if node_type == 'x': # Skip the global store continue node_weight_dict[node_type] = subgraph[node_type] \ .get('node_mask', torch.ones(subgraph[node_type].num_nodes)).squeeze(-1) # Call the visualization function visualize_hetero_graph( edge_index_dict=edge_index_dict, edge_weight_dict=edge_weight_dict, path=path, node_labels_dict=node_labels, node_weight_dict=node_weight_dict, node_size_range=node_size_range, node_opacity_range=node_opacity_range, edge_width_range=edge_width_range, edge_opacity_range=edge_opacity_range, ) def _visualize_score( score: torch.Tensor, labels: List[str], path: Optional[str] = None, top_k: Optional[int] = None, ): import matplotlib.pyplot as plt import pandas as pd if len(labels) != score.numel(): raise ValueError(f"The number of labels (got {len(labels)}) must " f"match the number of scores (got {score.numel()})") score = score.cpu().numpy() df = pd.DataFrame({'score': score}, index=labels) df = df.sort_values('score', ascending=False) df = df.round(decimals=3) if top_k is not None: df = df.head(top_k) title = f"Feature importance for top {len(df)} features" else: title = f"Feature importance for {len(df)} features" ax = df.plot( kind='barh', figsize=(10, 7), title=title, ylabel='Feature label', xlim=[0, float(df['score'].max()) + 0.3], legend=False, ) plt.gca().invert_yaxis() ax.bar_label(container=ax.containers[0], label_type='edge') if path is not None: plt.savefig(path) else: plt.show() plt.close() ================================================ FILE: torch_geometric/explain/metric/__init__.py ================================================ from .basic import groundtruth_metrics from .fidelity import fidelity, characterization_score, fidelity_curve_auc from .faithfulness import unfaithfulness __all__ = classes = [ 'groundtruth_metrics', 'fidelity', 'characterization_score', 'fidelity_curve_auc', 'unfaithfulness', ] ================================================ FILE: torch_geometric/explain/metric/basic.py ================================================ from typing import List, Optional, Tuple, Union from torch import Tensor METRICS = ['accuracy', 'recall', 'precision', 'f1_score', 'auroc'] def groundtruth_metrics( pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5, ) -> Union[float, Tuple[float, ...]]: r"""Compares and evaluates an explanation mask with the ground-truth explanation mask. Args: pred_mask (torch.Tensor): The prediction mask to evaluate. target_mask (torch.Tensor): The ground-truth target mask. metrics (str or List[str], optional): The metrics to return (:obj:`"accuracy"`, :obj:`"recall"`, :obj:`"precision"`, :obj:`"f1_score"`, :obj:`"auroc"`). (default: :obj:`["accuracy", "recall", "precision", "f1_score", "auroc"]`) threshold (float, optional): The threshold value to perform hard thresholding of :obj:`mask` and :obj:`groundtruth`. (default: :obj:`0.5`) """ import torchmetrics if metrics is None: metrics = METRICS if isinstance(metrics, str): metrics = [metrics] if not isinstance(metrics, (tuple, list)): raise ValueError(f"Expected metrics to be a string or a list of " f"strings (got {type(metrics)})") pred_mask = pred_mask.view(-1) target_mask = (target_mask >= threshold).view(-1) outs = [] for metric in metrics: if metric not in METRICS: raise ValueError(f"Encountered invalid metric {metric}") fn = getattr(torchmetrics.functional, metric) if metric in {'auroc'}: out = fn(pred_mask, target_mask, 'binary') else: out = fn(pred_mask, target_mask, 'binary', threshold) outs.append(float(out)) return tuple(outs) if len(outs) > 1 else outs[0] ================================================ FILE: torch_geometric/explain/metric/faithfulness.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch_geometric.explain import Explainer, Explanation from torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType def unfaithfulness( explainer: Explainer, explanation: Explanation, top_k: Optional[int] = None, ) -> float: r"""Evaluates how faithful an :class:`~torch_geometric.explain.Explanation` is to an underlying GNN predictor, as described in the `"Evaluating Explainability for Graph Neural Networks" `_ paper. In particular, the graph explanation unfaithfulness metric is defined as .. math:: \textrm{GEF}(y, \hat{y}) = 1 - \exp(- \textrm{KL}(y || \hat{y})) where :math:`y` refers to the prediction probability vector obtained from the original graph, and :math:`\hat{y}` refers to the prediction probability vector obtained from the masked subgraph. Finally, the Kullback-Leibler (KL) divergence score quantifies the distance between the two probability distributions. Args: explainer (Explainer): The explainer to evaluate. explanation (Explanation): The explanation to evaluate. top_k (int, optional): If set, will only keep the original values of the top-:math:`k` node features identified by an explanation. If set to :obj:`None`, will use :obj:`explanation.node_mask` as it is for masking node features. (default: :obj:`None`) """ if explainer.model_config.mode == ModelMode.regression: raise ValueError("Fidelity not defined for 'regression' models") if top_k is not None and explainer.node_mask_type == MaskType.object: raise ValueError("Cannot apply top-k feature selection based on a " "node mask of type 'object'") node_mask = explanation.get('node_mask') edge_mask = explanation.get('edge_mask') x, edge_index = explanation.x, explanation.edge_index kwargs = {key: explanation[key] for key in explanation._model_args} y = explanation.get('prediction') if y is None: # == ExplanationType.phenomenon y = explainer.get_prediction(x, edge_index, **kwargs) if node_mask is not None and top_k is not None: feat_importance = node_mask.sum(dim=0) _, top_k_index = feat_importance.topk(top_k) node_mask = torch.zeros_like(node_mask) node_mask[:, top_k_index] = 1.0 y_hat = explainer.get_masked_prediction(x, edge_index, node_mask, edge_mask, **kwargs) if explanation.get('index') is not None: y, y_hat = y[explanation.index], y_hat[explanation.index] if explainer.model_config.return_type == ModelReturnType.raw: y, y_hat = y.softmax(dim=-1), y_hat.softmax(dim=-1) elif explainer.model_config.return_type == ModelReturnType.log_probs: y, y_hat = y.exp(), y_hat.exp() kl_div = F.kl_div(y.log(), y_hat, reduction='batchmean') return 1 - float(torch.exp(-kl_div)) ================================================ FILE: torch_geometric/explain/metric/fidelity.py ================================================ from typing import Tuple import torch from torch import Tensor from torch_geometric.explain import Explainer, Explanation from torch_geometric.explain.config import ExplanationType, ModelMode def fidelity( explainer: Explainer, explanation: Explanation, ) -> Tuple[float, float]: r"""Evaluates the fidelity of an :class:`~torch_geometric.explain.Explainer` given an :class:`~torch_geometric.explain.Explanation`, as described in the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" `_ paper. Fidelity evaluates the contribution of the produced explanatory subgraph to the initial prediction, either by giving only the subgraph to the model (fidelity-) or by removing it from the entire graph (fidelity+). The fidelity scores capture how good an explainable model reproduces the natural phenomenon or the GNN model logic. For **phenomenon** explanations, the fidelity scores are given by: .. math:: \textrm{fid}_{+} &= \frac{1}{N} \sum_{i = 1}^N \| \mathbb{1}(\hat{y}_i = y_i) - \mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = y_i) \| \textrm{fid}_{-} &= \frac{1}{N} \sum_{i = 1}^N \| \mathbb{1}(\hat{y}_i = y_i) - \mathbb{1}( \hat{y}_i^{G_S} = y_i) \| For **model** explanations, the fidelity scores are given by: .. math:: \textrm{fid}_{+} &= 1 - \frac{1}{N} \sum_{i = 1}^N \mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = \hat{y}_i) \textrm{fid}_{-} &= 1 - \frac{1}{N} \sum_{i = 1}^N \mathbb{1}( \hat{y}_i^{G_S} = \hat{y}_i) Args: explainer (Explainer): The explainer to evaluate. explanation (Explanation): The explanation to evaluate. """ if explainer.model_config.mode == ModelMode.regression: raise ValueError("Fidelity not defined for 'regression' models") node_mask = explanation.get('node_mask') edge_mask = explanation.get('edge_mask') kwargs = {key: explanation[key] for key in explanation._model_args} y = explanation.target if explainer.explanation_type == ExplanationType.phenomenon: y_hat = explainer.get_prediction( explanation.x, explanation.edge_index, **kwargs, ) y_hat = explainer.get_target(y_hat) explain_y_hat = explainer.get_masked_prediction( explanation.x, explanation.edge_index, node_mask, edge_mask, **kwargs, ) explain_y_hat = explainer.get_target(explain_y_hat) complement_y_hat = explainer.get_masked_prediction( explanation.x, explanation.edge_index, 1. - node_mask if node_mask is not None else None, 1. - edge_mask if edge_mask is not None else None, **kwargs, ) complement_y_hat = explainer.get_target(complement_y_hat) if explanation.get('index') is not None: y = y[explanation.index] if explainer.explanation_type == ExplanationType.phenomenon: y_hat = y_hat[explanation.index] explain_y_hat = explain_y_hat[explanation.index] complement_y_hat = complement_y_hat[explanation.index] if explainer.explanation_type == ExplanationType.model: pos_fidelity = 1. - (complement_y_hat == y).float().mean() neg_fidelity = 1. - (explain_y_hat == y).float().mean() else: pos_fidelity = ((y_hat == y).float() - (complement_y_hat == y).float()).abs().mean() neg_fidelity = ((y_hat == y).float() - (explain_y_hat == y).float()).abs().mean() return float(pos_fidelity), float(neg_fidelity) def characterization_score( pos_fidelity: Tensor, neg_fidelity: Tensor, pos_weight: float = 0.5, neg_weight: float = 0.5, ) -> Tensor: r"""Returns the componentwise characterization score as described in the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" `_ paper. .. math:: \textrm{charact} = \frac{w_{+} + w_{-}}{\frac{w_{+}}{\textrm{fid}_{+}} + \frac{w_{-}}{1 - \textrm{fid}_{-}}} Args: pos_fidelity (torch.Tensor): The positive fidelity :math:`\textrm{fid}_{+}`. neg_fidelity (torch.Tensor): The negative fidelity :math:`\textrm{fid}_{-}`. pos_weight (float, optional): The weight :math:`w_{+}` for :math:`\textrm{fid}_{+}`. (default: :obj:`0.5`) neg_weight (float, optional): The weight :math:`w_{-}` for :math:`\textrm{fid}_{-}`. (default: :obj:`0.5`) """ if (pos_weight + neg_weight) != 1.0: raise ValueError(f"The weights need to sum up to 1 " f"(got {pos_weight} and {neg_weight})") denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity)) return 1. / denom def fidelity_curve_auc( pos_fidelity: Tensor, neg_fidelity: Tensor, x: Tensor, ) -> Tensor: r"""Returns the AUC for the fidelity curve as described in the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" `_ paper. More precisely, returns the AUC of .. math:: f(x) = \frac{\textrm{fid}_{+}}{1 - \textrm{fid}_{-}} Args: pos_fidelity (torch.Tensor): The positive fidelity :math:`\textrm{fid}_{+}`. neg_fidelity (torch.Tensor): The negative fidelity :math:`\textrm{fid}_{-}`. x (torch.Tensor): Tensor containing the points on the :math:`x`-axis. Needs to be sorted in ascending order. """ if torch.any(neg_fidelity == 1): raise ValueError("There exists negative fidelity values containing 1, " "leading to a division by zero") y = pos_fidelity / (1. - neg_fidelity) return auc(x, y) def auc(x: Tensor, y: Tensor) -> Tensor: if torch.any(x.diff() < 0): raise ValueError("'x' must be given in ascending order") return torch.trapezoid(y, x) ================================================ FILE: torch_geometric/graphgym/__init__.py ================================================ from .contrib import * # noqa from .models import * # noqa from .utils import * # noqa from .checkpoint import load_ckpt, save_ckpt, remove_ckpt, clean_ckpt from .cmd_args import parse_args from .config import (cfg, set_cfg, load_cfg, dump_cfg, set_run_dir, set_out_dir, get_fname) from .init import init_weights from .loader import create_loader from .logger import set_printing, create_logger from .loss import compute_loss from .model_builder import create_model from .optim import create_optimizer, create_scheduler from .train import train from .register import (register_base, register_act, register_node_encoder, register_edge_encoder, register_stage, register_head, register_layer, register_pooling, register_network, register_config, register_dataset, register_loader, register_optimizer, register_scheduler, register_loss, register_train, register_metric) __all__ = classes = [ 'load_ckpt', 'save_ckpt', 'remove_ckpt', 'clean_ckpt', 'parse_args', 'cfg', 'set_cfg', 'load_cfg', 'dump_cfg', 'set_run_dir', 'set_out_dir', 'get_fname', 'init_weights', 'create_loader', 'set_printing', 'create_logger', 'compute_loss', 'create_model', 'create_optimizer', 'create_scheduler', 'train', 'register_base', 'register_act', 'register_node_encoder', 'register_edge_encoder', 'register_stage', 'register_head', 'register_layer', 'register_pooling', 'register_network', 'register_config', 'register_dataset', 'register_loader', 'register_optimizer', 'register_scheduler', 'register_loss', 'register_train', 'register_metric', ] ================================================ FILE: torch_geometric/graphgym/benchmark.py ================================================ # Do not change; required for benchmarking import torch_geometric_benchmark.torchprof_local as torchprof # noqa from pytorch_memlab import LineProfiler # noqa from torch_geometric_benchmark.utils import count_parameters # noqa from torch_geometric_benchmark.utils import get_gpu_memory_nvdia # noqa from torch_geometric_benchmark.utils import get_memory_status # noqa from torch_geometric_benchmark.utils import get_model_size # noqa global_line_profiler = LineProfiler() global_line_profiler.enable() ================================================ FILE: torch_geometric/graphgym/checkpoint.py ================================================ import glob import os import os.path as osp from typing import Any, Dict, List, Optional, Union import torch from torch_geometric.graphgym.config import cfg from torch_geometric.io import fs MODEL_STATE = 'model_state' OPTIMIZER_STATE = 'optimizer_state' SCHEDULER_STATE = 'scheduler_state' def load_ckpt( model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, epoch: int = -1, ) -> int: r"""Loads the model checkpoint at a given epoch.""" epoch = get_ckpt_epoch(epoch) path = get_ckpt_path(epoch) if not osp.exists(path): return 0 ckpt = fs.torch_load(path) model.load_state_dict(ckpt[MODEL_STATE]) if optimizer is not None and OPTIMIZER_STATE in ckpt: optimizer.load_state_dict(ckpt[OPTIMIZER_STATE]) if scheduler is not None and SCHEDULER_STATE in ckpt: scheduler.load_state_dict(ckpt[SCHEDULER_STATE]) return epoch + 1 def save_ckpt( model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, epoch: int = 0, ): r"""Saves the model checkpoint at a given epoch.""" ckpt: Dict[str, Any] = {} ckpt[MODEL_STATE] = model.state_dict() if optimizer is not None: ckpt[OPTIMIZER_STATE] = optimizer.state_dict() if scheduler is not None: ckpt[SCHEDULER_STATE] = scheduler.state_dict() os.makedirs(get_ckpt_dir(), exist_ok=True) torch.save(ckpt, get_ckpt_path(get_ckpt_epoch(epoch))) def remove_ckpt(epoch: int = -1): r"""Removes the model checkpoint at a given epoch.""" os.remove(get_ckpt_path(get_ckpt_epoch(epoch))) def clean_ckpt(): r"""Removes all but the last model checkpoint.""" for epoch in get_ckpt_epochs()[:-1]: os.remove(get_ckpt_path(epoch)) ############################################################################### def get_ckpt_dir() -> str: return osp.join(cfg.run_dir, 'ckpt') def get_ckpt_path(epoch: Union[int, str]) -> str: return osp.join(get_ckpt_dir(), f'{epoch}.ckpt') def get_ckpt_epochs() -> List[int]: paths = glob.glob(get_ckpt_path('*')) return sorted([int(osp.basename(path).split('.')[0]) for path in paths]) def get_ckpt_epoch(epoch: int) -> int: if epoch < 0: epochs = get_ckpt_epochs() epoch = epochs[epoch] if len(epochs) > 0 else 0 return epoch ================================================ FILE: torch_geometric/graphgym/cmd_args.py ================================================ import argparse def parse_args() -> argparse.Namespace: r"""Parses the command line arguments.""" parser = argparse.ArgumentParser(description='GraphGym') parser.add_argument('--cfg', dest='cfg_file', type=str, required=True, help='The configuration file path.') parser.add_argument('--repeat', type=int, default=1, help='The number of repeated jobs.') parser.add_argument('--mark_done', action='store_true', help='Mark yaml as done after a job has finished.') parser.add_argument('opts', default=None, nargs=argparse.REMAINDER, help='See graphgym/config.py for remaining options.') return parser.parse_args() ================================================ FILE: torch_geometric/graphgym/config.py ================================================ import functools import inspect import logging import os import os.path as osp import warnings from collections.abc import Iterable from dataclasses import asdict from typing import Any import torch_geometric.graphgym.register as register from torch_geometric.io import fs try: # Define global config object from yacs.config import CfgNode as CN cfg = CN() except ImportError: cfg = None warnings.warn( "Could not define global config object. Please install " "'yacs' via 'pip install yacs' in order to use GraphGym", stacklevel=2) def set_cfg(cfg): r"""This function sets the default config value. 1) Note that for an experiment, only part of the arguments will be used The remaining unused arguments won't affect anything. So feel free to register any argument in graphgym.contrib.config 2) We support *at most* two levels of configs, *e.g.*, :obj:`cfg.dataset.name`. :return: Configuration use by the experiment. """ if cfg is None: return cfg # ----------------------------------------------------------------------- # # Basic options # ----------------------------------------------------------------------- # # Set print destination: stdout / file / both cfg.print = 'both' # Select device: 'cpu', 'cuda', 'auto' cfg.accelerator = 'auto' # number of devices: eg. for 2 GPU set cfg.devices=2 cfg.devices = 1 # Output directory cfg.out_dir = 'results' # Config name (in out_dir) cfg.cfg_dest = 'config.yaml' # Names of registered custom metric funcs to be used (use defaults if none) cfg.custom_metrics = [] # Random seed cfg.seed = 0 # Print rounding cfg.round = 4 # Tensorboard support for each run cfg.tensorboard_each_run = False # Tensorboard support for aggregated results cfg.tensorboard_agg = True # Additional num of worker for data loading cfg.num_workers = 0 # Max threads used by PyTorch cfg.num_threads = 6 # The metric for selecting the best epoch for each run cfg.metric_best = 'auto' # argmax or argmin in aggregating results cfg.metric_agg = 'argmax' # If visualize embedding. cfg.view_emb = False # If get GPU usage cfg.gpu_mem = False # If do benchmark analysis cfg.benchmark = False # ----------------------------------------------------------------------- # # Globally shared variables: # These variables will be set dynamically based on the input dataset # Do not directly set them here or in .yaml files # ----------------------------------------------------------------------- # cfg.share = CN() # Size of input dimension cfg.share.dim_in = 1 # Size of out dimension, i.e., number of labels to be predicted cfg.share.dim_out = 1 # Number of dataset splits: train/val/test cfg.share.num_splits = 1 # ----------------------------------------------------------------------- # # Dataset options # ----------------------------------------------------------------------- # cfg.dataset = CN() # Name of the dataset cfg.dataset.name = 'Cora' # if PyG: look for it in Pytorch Geometric dataset # if NetworkX/nx: load data in NetworkX format cfg.dataset.format = 'PyG' # Dir to load the dataset. If the dataset is downloaded, this is the # cache dir cfg.dataset.dir = './datasets' # Task: node, edge, graph, link_pred cfg.dataset.task = 'node' # Type of task: classification, regression, classification_binary # classification_multi cfg.dataset.task_type = 'classification' # Transductive / Inductive # Graph classification is always inductive cfg.dataset.transductive = True # Split ratio of dataset. Len=2: Train, Val. Len=3: Train, Val, Test cfg.dataset.split = [0.8, 0.1, 0.1] # Whether to shuffle the graphs for splitting cfg.dataset.shuffle_split = True # Whether random split or use custom split: random / custom cfg.dataset.split_mode = 'random' # Whether to use an encoder for general attribute features cfg.dataset.encoder = True # Name of general encoder cfg.dataset.encoder_name = 'db' # If add batchnorm after general encoder cfg.dataset.encoder_bn = True # Whether to use an encoder for the node features cfg.dataset.node_encoder = False # Name of node encoder cfg.dataset.node_encoder_name = 'Atom' # If add batchnorm after node encoder cfg.dataset.node_encoder_bn = True # Whether to use an encoder for the edge features cfg.dataset.edge_encoder = False # Name of edge encoder cfg.dataset.edge_encoder_name = 'Bond' # If add batchnorm after edge encoder cfg.dataset.edge_encoder_bn = True # Dimension of the encoded features. # For now the node and edge encoding dimensions # are the same. cfg.dataset.encoder_dim = 128 # Dimension for edge feature. Updated by the real dim of the dataset cfg.dataset.edge_dim = 128 # ============== Link/edge tasks only # all or disjoint cfg.dataset.edge_train_mode = 'all' # Used in disjoint edge_train_mode. The proportion of edges used for # message-passing cfg.dataset.edge_message_ratio = 0.8 # The ratio of negative samples to positive samples cfg.dataset.edge_negative_sampling_ratio = 1.0 # Whether resample disjoint when dataset.edge_train_mode is 'disjoint' cfg.dataset.resample_disjoint = False # Whether resample negative edges at training time (link prediction only) cfg.dataset.resample_negative = False # What transformation function is applied to the dataset cfg.dataset.transform = 'none' # Whether cache the splitted dataset # NOTE: it should be cautiouslly used, as cached dataset may not have # exactly the same setting as the config file cfg.dataset.cache_save = False cfg.dataset.cache_load = False # Whether remove the original node features in the dataset cfg.dataset.remove_feature = False # Simplify TU dataset for synthetic tasks cfg.dataset.tu_simple = True # Convert to undirected graph (save 2*E edges) cfg.dataset.to_undirected = False # dataset location: local, snowflake cfg.dataset.location = 'local' # Define label: Table name cfg.dataset.label_table = 'none' # Define label: Column name cfg.dataset.label_column = 'none' # ----------------------------------------------------------------------- # # Training options # ----------------------------------------------------------------------- # cfg.train = CN() # Total graph mini-batch size cfg.train.batch_size = 16 # Sampling strategy for a train loader cfg.train.sampler = 'full_batch' # Minibatch node cfg.train.sample_node = False # Num of sampled node per graph cfg.train.node_per_graph = 32 # Radius: same, extend. same: same as cfg.gnn.layers_mp, extend: layers+1 cfg.train.radius = 'extend' # Evaluate model on test data every eval period epochs cfg.train.eval_period = 10 # Option to skip training epoch evaluation cfg.train.skip_train_eval = False # Save model checkpoint every checkpoint period epochs cfg.train.ckpt_period = 100 # Enabling checkpoint, set False to disable and save I/O cfg.train.enable_ckpt = True # Resume training from the latest checkpoint in the output directory cfg.train.auto_resume = False # The epoch to resume. -1 means resume the latest epoch. cfg.train.epoch_resume = -1 # Clean checkpoint: only keep the last ckpt cfg.train.ckpt_clean = True # Number of iterations per epoch (for sampling based loaders only) cfg.train.iter_per_epoch = 32 # GraphSAINTRandomWalkSampler: random walk length cfg.train.walk_length = 4 # NeighborSampler: number of sampled nodes per layer cfg.train.neighbor_sizes = [20, 15, 10, 5] # ----------------------------------------------------------------------- # # Validation options # ----------------------------------------------------------------------- # cfg.val = CN() # Minibatch node cfg.val.sample_node = False # Sampling strategy for a val/test loader cfg.val.sampler = 'full_batch' # Num of sampled node per graph cfg.val.node_per_graph = 32 # Radius: same, extend. same: same as cfg.gnn.layers_mp, extend: layers+1 cfg.val.radius = 'extend' # ----------------------------------------------------------------------- # # Model options # ----------------------------------------------------------------------- # cfg.model = CN() # Model type to use cfg.model.type = 'gnn' # Auto match computational budget, match upper bound / lower bound cfg.model.match_upper = True # Loss function: cross_entropy, mse cfg.model.loss_fun = 'cross_entropy' # size average for loss function. 'mean' or 'sum' cfg.model.size_average = 'mean' # Threshold for binary classification cfg.model.thresh = 0.5 # ============== Link/edge tasks only # Edge decoding methods. # - dot: compute dot(u, v) to predict link (binary) # - cosine_similarity: use cosine similarity (u, v) to predict link ( # binary) # - concat: use u||v followed by an nn.Linear to obtain edge embedding # (multi-class) cfg.model.edge_decoding = 'dot' # =================================== # ================== Graph tasks only # Pooling methods. # - add: global add pool # - mean: global mean pool # - max: global max pool cfg.model.graph_pooling = 'add' # =================================== # ----------------------------------------------------------------------- # # GNN options # ----------------------------------------------------------------------- # cfg.gnn = CN() # Prediction head. Use cfg.dataset.task by default cfg.gnn.head = 'default' # Number of layers before message passing cfg.gnn.layers_pre_mp = 0 # Number of layers for message passing cfg.gnn.layers_mp = 2 # Number of layers after message passing cfg.gnn.layers_post_mp = 0 # Hidden layer dim. Automatically set if train.auto_match = True cfg.gnn.dim_inner = 16 # Type of graph conv: generalconv, gcnconv, sageconv, gatconv, ... cfg.gnn.layer_type = 'generalconv' # Stage type: 'stack', 'skipsum', 'skipconcat' cfg.gnn.stage_type = 'stack' # How many layers to skip each time cfg.gnn.skip_every = 1 # Whether use batch norm cfg.gnn.batchnorm = True # Activation cfg.gnn.act = 'relu' # Dropout cfg.gnn.dropout = 0.0 # Aggregation type: add, mean, max # Note: only for certain layers that explicitly set aggregation type # e.g., when cfg.gnn.layer_type = 'generalconv' cfg.gnn.agg = 'add' # Normalize adj cfg.gnn.normalize_adj = False # Message direction: single, both cfg.gnn.msg_direction = 'single' # Whether add message from node itself: none, add, cat cfg.gnn.self_msg = 'concat' # Number of attention heads cfg.gnn.att_heads = 1 # After concat attention heads, add a linear layer cfg.gnn.att_final_linear = False # After concat attention heads, add a linear layer cfg.gnn.att_final_linear_bn = False # Normalize after message passing cfg.gnn.l2norm = True # randomly use fewer edges for message passing cfg.gnn.keep_edge = 0.5 # clear cached feature_new cfg.gnn.clear_feature = True # ----------------------------------------------------------------------- # # Optimizer options # ----------------------------------------------------------------------- # cfg.optim = CN() # optimizer: sgd, adam cfg.optim.optimizer = 'adam' # Base learning rate cfg.optim.base_lr = 0.01 # L2 regularization cfg.optim.weight_decay = 5e-4 # SGD momentum cfg.optim.momentum = 0.9 # scheduler: none, steps, cos cfg.optim.scheduler = 'cos' # Steps for 'steps' policy (in epochs) cfg.optim.steps = [30, 60, 90] # Learning rate multiplier for 'steps' policy cfg.optim.lr_decay = 0.1 # Maximal number of epochs cfg.optim.max_epoch = 200 # ----------------------------------------------------------------------- # # Batch norm options # ----------------------------------------------------------------------- # cfg.bn = CN() # BN epsilon cfg.bn.eps = 1e-5 # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) cfg.bn.mom = 0.1 # ----------------------------------------------------------------------- # # Memory options # ----------------------------------------------------------------------- # cfg.mem = CN() # Perform ReLU inplace cfg.mem.inplace = False # Set user customized cfgs for func in register.config_dict.values(): func(cfg) def assert_cfg(cfg): r"""Checks config values, do necessary post processing to the configs.""" if cfg.dataset.task not in ['node', 'edge', 'graph', 'link_pred']: raise ValueError(f"Task '{cfg.dataset.task}' not supported. Must be " f"one of node, edge, graph, link_pred") if 'classification' in cfg.dataset.task_type and cfg.model.loss_fun == \ 'mse': cfg.model.loss_fun = 'cross_entropy' logging.warning( 'model.loss_fun changed to cross_entropy for classification.') if cfg.dataset.task_type == 'regression' and cfg.model.loss_fun == \ 'cross_entropy': cfg.model.loss_fun = 'mse' logging.warning('model.loss_fun changed to mse for regression.') if cfg.dataset.task == 'graph' and cfg.dataset.transductive: cfg.dataset.transductive = False logging.warning('dataset.transductive changed ' 'to False for graph task.') if cfg.gnn.layers_post_mp < 1: cfg.gnn.layers_post_mp = 1 logging.warning('Layers after message passing should be >=1') if cfg.gnn.head == 'default': cfg.gnn.head = cfg.dataset.task cfg.run_dir = cfg.out_dir def dump_cfg(cfg): r"""Dumps the config to the output directory specified in :obj:`cfg.out_dir`. Args: cfg (CfgNode): Configuration node """ os.makedirs(cfg.out_dir, exist_ok=True) cfg_file = osp.join(cfg.out_dir, cfg.cfg_dest) with open(cfg_file, 'w') as f: cfg.dump(stream=f) def load_cfg(cfg, args): r"""Load configurations from file system and command line. Args: cfg (CfgNode): Configuration node args (ArgumentParser): Command argument parser """ cfg.merge_from_file(args.cfg_file) cfg.merge_from_list(args.opts) assert_cfg(cfg) def makedirs_rm_exist(dir): if osp.isdir(dir): fs.rm(dir) os.makedirs(dir, exist_ok=True) def get_fname(fname): r"""Extract filename from file name path. Args: fname (str): Filename for the yaml format configuration file """ fname = osp.basename(fname) if fname.endswith('.yaml'): fname = fname[:-5] elif fname.endswith('.yml'): fname = fname[:-4] return fname def set_out_dir(out_dir, fname): r"""Create the directory for full experiment run. Args: out_dir (str): Directory for output, specified in :obj:`cfg.out_dir` fname (str): Filename for the yaml format configuration file """ fname = get_fname(fname) cfg.out_dir = osp.join(out_dir, fname) # Make output directory if cfg.train.auto_resume: os.makedirs(cfg.out_dir, exist_ok=True) else: makedirs_rm_exist(cfg.out_dir) def set_run_dir(out_dir): r"""Create the directory for each random seed experiment run. Args: out_dir (str): Directory for output, specified in :obj:`cfg.out_dir` """ cfg.run_dir = osp.join(out_dir, str(cfg.seed)) # Make output directory if cfg.train.auto_resume: os.makedirs(cfg.run_dir, exist_ok=True) else: makedirs_rm_exist(cfg.run_dir) set_cfg(cfg) def from_config(func): if inspect.isclass(func): params = list(inspect.signature(func.__init__).parameters.values())[1:] else: params = list(inspect.signature(func).parameters.values()) arg_names = [p.name for p in params] has_defaults = [p.default != inspect.Parameter.empty for p in params] @functools.wraps(func) def wrapper(*args, cfg: Any = None, **kwargs): if cfg is not None: cfg = dict(cfg) if isinstance(cfg, Iterable) else asdict(cfg) iterator = zip(arg_names[len(args):], has_defaults[len(args):]) for arg_name, has_default in iterator: if arg_name in kwargs: continue elif arg_name in cfg: kwargs[arg_name] = cfg[arg_name] elif not has_default: raise ValueError(f"'cfg.{arg_name}' undefined") return func(*args, **kwargs) return wrapper ================================================ FILE: torch_geometric/graphgym/contrib/__init__.py ================================================ from .act import * # noqa from .config import * # noqa from .encoder import * # noqa from .head import * # noqa from .layer import * # noqa from .loader import * # noqa from .loss import * # noqa from .network import * # noqa from .optimizer import * # noqa from .pooling import * # noqa from .stage import * # noqa from .train import * # noqa from .transform import * # noqa ================================================ FILE: torch_geometric/graphgym/contrib/act/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/config/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/encoder/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/head/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/layer/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/layer/generalconv.py ================================================ import torch from torch.nn import Parameter from torch_geometric.graphgym.config import cfg from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import glorot, zeros from torch_geometric.utils import add_remaining_self_loops, scatter class GeneralConvLayer(MessagePassing): r"""A general GNN layer.""" def __init__(self, in_channels, out_channels, improved=False, cached=False, bias=True, **kwargs): super().__init__(aggr=cfg.gnn.agg, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.cached = cached self.normalize = cfg.gnn.normalize_adj self.weight = Parameter(torch.empty(in_channels, out_channels)) if cfg.gnn.self_msg == 'concat': self.weight_self = Parameter(torch.empty(in_channels, out_channels)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): glorot(self.weight) if cfg.gnn.self_msg == 'concat': glorot(self.weight_self) zeros(self.bias) self.cached_result = None self.cached_num_edges = None @staticmethod def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) fill_value = 1 if not improved else 2 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum') deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] def forward(self, x, edge_index, edge_weight=None, edge_feature=None): if cfg.gnn.self_msg == 'concat': x_self = torch.matmul(x, self.weight_self) x = torch.matmul(x, self.weight) if self.cached and self.cached_result is not None: if edge_index.size(1) != self.cached_num_edges: raise RuntimeError( 'Cached {} number of edges, but found {}. Please ' 'disable the caching behavior of this layer by removing ' 'the `cached=True` argument in its constructor.'.format( self.cached_num_edges, edge_index.size(1))) if not self.cached or self.cached_result is None: self.cached_num_edges = edge_index.size(1) if self.normalize: edge_index, norm = self.norm(edge_index, x.size(self.node_dim), edge_weight, self.improved, x.dtype) else: norm = edge_weight self.cached_result = edge_index, norm edge_index, norm = self.cached_result x_msg = self.propagate(edge_index, x=x, norm=norm, edge_feature=edge_feature) if cfg.gnn.self_msg == 'none': return x_msg elif cfg.gnn.self_msg == 'add': return x_msg + x elif cfg.gnn.self_msg == 'concat': return x_msg + x_self else: raise ValueError('self_msg {} not defined'.format( cfg.gnn.self_msg)) def message(self, x_j, norm, edge_feature): if edge_feature is None: return norm.view(-1, 1) * x_j if norm is not None else x_j else: return norm.view(-1, 1) * ( x_j + edge_feature) if norm is not None else (x_j + edge_feature) def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) class GeneralEdgeConvLayer(MessagePassing): r"""General GNN layer, with edge features.""" def __init__(self, in_channels, out_channels, edge_dim, improved=False, cached=False, bias=True, **kwargs): super().__init__(aggr=cfg.gnn.agg, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.cached = cached self.normalize = cfg.gnn.normalize_adj self.msg_direction = cfg.gnn.msg_direction if self.msg_direction == 'single': self.linear_msg = torch.nn.Linear( in_channels + edge_dim, out_channels, bias=False, ) else: self.linear_msg = torch.nn.Linear( in_channels * 2 + edge_dim, out_channels, bias=False, ) if cfg.gnn.self_msg == 'concat': self.linear_self = torch.nn.Linear( in_channels, out_channels, bias=False, ) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): zeros(self.bias) self.cached_result = None self.cached_num_edges = None @staticmethod def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) fill_value = 1 if not improved else 2 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum') deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] def forward(self, x, edge_index, edge_weight=None, edge_feature=None): if self.cached and self.cached_result is not None: if edge_index.size(1) != self.cached_num_edges: raise RuntimeError( 'Cached {} number of edges, but found {}. Please ' 'disable the caching behavior of this layer by removing ' 'the `cached=True` argument in its constructor.'.format( self.cached_num_edges, edge_index.size(1))) if not self.cached or self.cached_result is None: self.cached_num_edges = edge_index.size(1) if self.normalize: edge_index, norm = self.norm(edge_index, x.size(self.node_dim), edge_weight, self.improved, x.dtype) else: norm = edge_weight self.cached_result = edge_index, norm edge_index, norm = self.cached_result x_msg = self.propagate(edge_index, x=x, norm=norm, edge_feature=edge_feature) if cfg.gnn.self_msg == 'concat': x_self = self.linear_self(x) return x_self + x_msg elif cfg.gnn.self_msg == 'add': return x + x_msg else: return x_msg def message(self, x_i, x_j, norm, edge_feature): if self.msg_direction == 'both': x_j = torch.cat((x_i, x_j, edge_feature), dim=-1) else: x_j = torch.cat((x_j, edge_feature), dim=-1) x_j = self.linear_msg(x_j) return norm.view(-1, 1) * x_j if norm is not None else x_j def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) ================================================ FILE: torch_geometric/graphgym/contrib/loader/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/loss/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/network/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/optimizer/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/pooling/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/stage/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/train/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/contrib/transform/__init__.py ================================================ from os.path import dirname, basename, isfile, join import glob modules = glob.glob(join(dirname(__file__), "*.py")) __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] ================================================ FILE: torch_geometric/graphgym/imports.py ================================================ import warnings import torch try: import lightning.pytorch as pl _pl_is_available = True except ImportError: try: import pytorch_lightning as pl _pl_is_available = True except ImportError: _pl_is_available = False if _pl_is_available: LightningModule = pl.LightningModule Callback = pl.Callback else: pl = object LightningModule = torch.nn.Module Callback = object warnings.warn( "To use GraphGym, install 'pytorch_lightning' or 'lightning' via " "'pip install pytorch_lightning' or 'pip install lightning'", stacklevel=2) ================================================ FILE: torch_geometric/graphgym/init.py ================================================ import torch def init_weights(m): r"""Performs weight initialization. Args: m (nn.Module): PyTorch module """ if (isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d)): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, torch.nn.Linear): m.weight.data = torch.nn.init.xavier_uniform_( m.weight.data, gain=torch.nn.init.calculate_gain('relu')) if m.bias is not None: m.bias.data.zero_() ================================================ FILE: torch_geometric/graphgym/loader.py ================================================ import os.path as osp from typing import Callable import torch import torch_geometric.graphgym.register as register import torch_geometric.transforms as T from torch_geometric.datasets import ( PPI, Amazon, Coauthor, KarateClub, MNISTSuperpixels, Planetoid, QM7b, TUDataset, ) from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.models.transform import ( create_link_label, neg_sampling_transform, ) from torch_geometric.loader import ( ClusterLoader, DataLoader, GraphSAINTEdgeSampler, GraphSAINTNodeSampler, GraphSAINTRandomWalkSampler, NeighborSampler, RandomNodeLoader, ) from torch_geometric.utils import ( index_to_mask, negative_sampling, to_undirected, ) index2mask = index_to_mask # TODO Backward compatibility def planetoid_dataset(name: str) -> Callable: return lambda root: Planetoid(root, name) register.register_dataset('Cora', planetoid_dataset('Cora')) register.register_dataset('CiteSeer', planetoid_dataset('CiteSeer')) register.register_dataset('PubMed', planetoid_dataset('PubMed')) register.register_dataset('PPI', PPI) def load_pyg(name, dataset_dir): """Load PyG dataset objects. (More PyG datasets will be supported). Args: name (str): dataset name dataset_dir (str): data directory Returns: PyG dataset object """ dataset_dir = osp.join(dataset_dir, name) if name in ['Cora', 'CiteSeer', 'PubMed']: dataset = Planetoid(dataset_dir, name) elif name[:3] == 'TU_': # TU_IMDB doesn't have node features if name[3:] == 'IMDB': name = 'IMDB-MULTI' dataset = TUDataset(dataset_dir, name, transform=T.Constant()) else: dataset = TUDataset(dataset_dir, name[3:]) elif name == 'Karate': dataset = KarateClub() elif 'Coauthor' in name: if 'CS' in name: dataset = Coauthor(dataset_dir, name='CS') else: dataset = Coauthor(dataset_dir, name='Physics') elif 'Amazon' in name: if 'Computers' in name: dataset = Amazon(dataset_dir, name='Computers') else: dataset = Amazon(dataset_dir, name='Photo') elif name == 'MNIST': dataset = MNISTSuperpixels(dataset_dir) elif name == 'PPI': dataset = PPI(dataset_dir) elif name == 'QM7b': dataset = QM7b(dataset_dir) else: raise ValueError(f"'{name}' not support") return dataset def set_dataset_attr(dataset, name, value, size): dataset._data_list = None dataset.data[name] = value if dataset.slices is not None: dataset.slices[name] = torch.tensor([0, size], dtype=torch.long) def load_ogb(name, dataset_dir): r"""Load OGB dataset objects. Args: name (str): dataset name dataset_dir (str): data directory Returns: PyG dataset object """ from ogb.graphproppred import PygGraphPropPredDataset from ogb.linkproppred import PygLinkPropPredDataset from ogb.nodeproppred import PygNodePropPredDataset if name[:4] == 'ogbn': dataset = PygNodePropPredDataset(name=name, root=dataset_dir) splits = dataset.get_idx_split() split_names = ['train_mask', 'val_mask', 'test_mask'] for i, key in enumerate(splits.keys()): mask = index_to_mask(splits[key], size=dataset._data.y.shape[0]) set_dataset_attr(dataset, split_names[i], mask, len(mask)) edge_index = to_undirected(dataset._data.edge_index) set_dataset_attr(dataset, 'edge_index', edge_index, edge_index.shape[1]) elif name[:4] == 'ogbg': dataset = PygGraphPropPredDataset(name=name, root=dataset_dir) splits = dataset.get_idx_split() split_names = [ 'train_graph_index', 'val_graph_index', 'test_graph_index' ] for i, key in enumerate(splits.keys()): id = splits[key] set_dataset_attr(dataset, split_names[i], id, len(id)) elif name[:4] == "ogbl": dataset = PygLinkPropPredDataset(name=name, root=dataset_dir) splits = dataset.get_edge_split() id = splits['train']['edge'].T if cfg.dataset.resample_negative: set_dataset_attr(dataset, 'train_pos_edge_index', id, id.shape[1]) dataset.transform = neg_sampling_transform else: id_neg = negative_sampling(edge_index=id, num_nodes=dataset._data.num_nodes, num_neg_samples=id.shape[1]) id_all = torch.cat([id, id_neg], dim=-1) label = create_link_label(id, id_neg) set_dataset_attr(dataset, 'train_edge_index', id_all, id_all.shape[1]) set_dataset_attr(dataset, 'train_edge_label', label, len(label)) id, id_neg = splits['valid']['edge'].T, splits['valid']['edge_neg'].T id_all = torch.cat([id, id_neg], dim=-1) label = create_link_label(id, id_neg) set_dataset_attr(dataset, 'val_edge_index', id_all, id_all.shape[1]) set_dataset_attr(dataset, 'val_edge_label', label, len(label)) id, id_neg = splits['test']['edge'].T, splits['test']['edge_neg'].T id_all = torch.cat([id, id_neg], dim=-1) label = create_link_label(id, id_neg) set_dataset_attr(dataset, 'test_edge_index', id_all, id_all.shape[1]) set_dataset_attr(dataset, 'test_edge_label', label, len(label)) else: raise ValueError('OGB dataset: {} non-exist') return dataset def load_dataset(): r"""Load dataset objects. Returns: PyG dataset object """ format = cfg.dataset.format name = cfg.dataset.name dataset_dir = cfg.dataset.dir # Try to load customized data format for func in register.loader_dict.values(): dataset = func(format, name, dataset_dir) if dataset is not None: return dataset # Load from Pytorch Geometric dataset if format == 'PyG': dataset = load_pyg(name, dataset_dir) # Load from OGB formatted data elif format == 'OGB': dataset = load_ogb(name.replace('_', '-'), dataset_dir) else: raise ValueError(f"Unknown data format '{format}'") return dataset def set_dataset_info(dataset): r"""Set global dataset information. Args: dataset: PyG dataset object """ # get dim_in and dim_out try: cfg.share.dim_in = dataset._data.x.shape[1] except Exception: cfg.share.dim_in = 1 try: if cfg.dataset.task_type == 'classification': cfg.share.dim_out = torch.unique(dataset._data.y).shape[0] else: cfg.share.dim_out = dataset._data.y.shape[1] except Exception: cfg.share.dim_out = 1 # count number of dataset splits cfg.share.num_splits = 1 for key in dataset._data.keys(): if 'val' in key: cfg.share.num_splits += 1 break for key in dataset._data.keys(): if 'test' in key: cfg.share.num_splits += 1 break def create_dataset(): r"""Create dataset object. Returns: PyG dataset object """ dataset = load_dataset() set_dataset_info(dataset) return dataset def get_loader(dataset, sampler, batch_size, shuffle=True): pw = cfg.num_workers > 0 if sampler == "full_batch" or len(dataset) > 1: loader_train = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, persistent_workers=pw) elif sampler == "neighbor": loader_train = NeighborSampler( dataset[0], sizes=cfg.train.neighbor_sizes[:cfg.gnn.layers_mp], batch_size=batch_size, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True) elif sampler == "random_node": loader_train = RandomNodeLoader(dataset[0], num_parts=cfg.train.train_parts, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, persistent_workers=pw) elif sampler == "saint_rw": loader_train = \ GraphSAINTRandomWalkSampler(dataset[0], batch_size=batch_size, walk_length=cfg.train.walk_length, num_steps=cfg.train.iter_per_epoch, sample_coverage=0, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, persistent_workers=pw) elif sampler == "saint_node": loader_train = \ GraphSAINTNodeSampler(dataset[0], batch_size=batch_size, num_steps=cfg.train.iter_per_epoch, sample_coverage=0, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, persistent_workers=pw) elif sampler == "saint_edge": loader_train = \ GraphSAINTEdgeSampler(dataset[0], batch_size=batch_size, num_steps=cfg.train.iter_per_epoch, sample_coverage=0, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, persistent_workers=pw) elif sampler == "cluster": loader_train = ClusterLoader( dataset[0], num_parts=cfg.train.train_parts, save_dir=osp.join( cfg.dataset.dir, cfg.dataset.name.replace("-", "_"), ), batch_size=batch_size, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, persistent_workers=pw, ) else: raise NotImplementedError(f"'{sampler}' is not implemented") return loader_train def create_loader(): """Create data loader object. Returns: List of PyTorch data loaders """ dataset = create_dataset() # train loader if cfg.dataset.task == 'graph': id = dataset.data['train_graph_index'] loaders = [ get_loader(dataset[id], cfg.train.sampler, cfg.train.batch_size, shuffle=True) ] delattr(dataset.data, 'train_graph_index') else: loaders = [ get_loader(dataset, cfg.train.sampler, cfg.train.batch_size, shuffle=True) ] # val and test loaders for i in range(cfg.share.num_splits - 1): if cfg.dataset.task == 'graph': split_names = ['val_graph_index', 'test_graph_index'] id = dataset.data[split_names[i]] loaders.append( get_loader(dataset[id], cfg.val.sampler, cfg.train.batch_size, shuffle=False)) delattr(dataset.data, split_names[i]) else: loaders.append( get_loader(dataset, cfg.val.sampler, cfg.train.batch_size, shuffle=False)) return loaders ================================================ FILE: torch_geometric/graphgym/logger.py ================================================ import logging import math import os import sys import time from typing import Any, Dict, Optional import torch from torch_geometric.graphgym import register from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.imports import Callback, pl from torch_geometric.graphgym.utils.device import get_current_gpu_usage from torch_geometric.graphgym.utils.io import dict_to_json, dict_to_tb def set_printing(): """Set up printing options.""" logging.root.handlers = [] logging_cfg = {'level': logging.INFO, 'format': '%(message)s'} os.makedirs(cfg.run_dir, exist_ok=True) h_file = logging.FileHandler(f'{cfg.run_dir}/logging.log') h_stdout = logging.StreamHandler(sys.stdout) if cfg.print == 'file': logging_cfg['handlers'] = [h_file] elif cfg.print == 'stdout': logging_cfg['handlers'] = [h_stdout] elif cfg.print == 'both': logging_cfg['handlers'] = [h_file, h_stdout] else: raise ValueError('Print option not supported') logging.basicConfig(**logging_cfg) class Logger: def __init__(self, name='train', task_type=None): self.name = name self.task_type = task_type self._epoch_total = cfg.optim.max_epoch self._time_total = 0 # won't be reset self.out_dir = f'{cfg.run_dir}/{name}' os.makedirs(self.out_dir, exist_ok=True) if cfg.tensorboard_each_run: from tensorboardX import SummaryWriter self.tb_writer = SummaryWriter(self.out_dir) self.reset() def __getitem__(self, key): return getattr(self, key, None) def __setitem__(self, key, value): setattr(self, key, value) def reset(self): self._iter = 0 self._size_current = 0 self._loss = 0 self._lr = 0 self._params = 0 self._time_used = 0 self._true = [] self._pred = [] self._custom_stats = {} # basic properties def basic(self): stats = { 'loss': round(self._loss / self._size_current, cfg.round), 'lr': round(self._lr, cfg.round), 'params': self._params, 'time_iter': round(self.time_iter(), cfg.round), } gpu_memory = get_current_gpu_usage() if gpu_memory > 0: stats['gpu_memory'] = gpu_memory return stats # customized input properties def custom(self): if len(self._custom_stats) == 0: return {} out = {} for key, val in self._custom_stats.items(): out[key] = val / self._size_current return out def _get_pred_int(self, pred_score): if len(pred_score.shape) == 1 or pred_score.shape[1] == 1: return (pred_score > cfg.model.thresh).long() else: return pred_score.max(dim=1)[1] # task properties def classification_binary(self): from sklearn.metrics import ( accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, ) true, pred_score = torch.cat(self._true), torch.cat(self._pred) pred_int = self._get_pred_int(pred_score) try: r_a_score = roc_auc_score(true, pred_score) except ValueError: r_a_score = 0.0 return { 'accuracy': round(accuracy_score(true, pred_int), cfg.round), 'precision': round(precision_score(true, pred_int), cfg.round), 'recall': round(recall_score(true, pred_int), cfg.round), 'f1': round(f1_score(true, pred_int), cfg.round), 'auc': round(r_a_score, cfg.round), } def classification_multi(self): from sklearn.metrics import accuracy_score true, pred_score = torch.cat(self._true), torch.cat(self._pred) pred_int = self._get_pred_int(pred_score) return {'accuracy': round(accuracy_score(true, pred_int), cfg.round)} def regression(self): from sklearn.metrics import mean_absolute_error, mean_squared_error true, pred = torch.cat(self._true), torch.cat(self._pred) return { 'mae': float(round(mean_absolute_error(true, pred), cfg.round)), 'mse': float(round(mean_squared_error(true, pred), cfg.round)), 'rmse': float(round(math.sqrt(mean_squared_error(true, pred)), cfg.round)) } def time_iter(self): return self._time_used / self._iter def eta(self, epoch_current): epoch_current += 1 # since counter starts from 0 time_per_epoch = self._time_total / epoch_current return time_per_epoch * (self._epoch_total - epoch_current) def update_stats(self, true, pred, loss, lr, time_used, params, **kwargs): assert true.shape[0] == pred.shape[0] self._iter += 1 self._true.append(true) self._pred.append(pred) batch_size = true.shape[0] self._size_current += batch_size self._loss += loss * batch_size self._lr = lr self._params = params self._time_used += time_used self._time_total += time_used for key, val in kwargs.items(): if key not in self._custom_stats: self._custom_stats[key] = val * batch_size else: self._custom_stats[key] += val * batch_size def write_iter(self): raise NotImplementedError def write_epoch(self, cur_epoch): basic_stats = self.basic() # Try to load customized metrics task_stats = {} for custom_metric in cfg.custom_metrics: func = register.metric_dict.get(custom_metric) if not func: raise ValueError( f'Unknown custom metric function name: {custom_metric}') custom_metric_score = func(self._true, self._pred, self.task_type) task_stats[custom_metric] = custom_metric_score if not task_stats: # use default metrics if no matching custom metric if self.task_type == 'regression': task_stats = self.regression() elif self.task_type == 'classification_binary': task_stats = self.classification_binary() elif self.task_type == 'classification_multi': task_stats = self.classification_multi() else: raise ValueError('Task has to be regression or classification') epoch_stats = {'epoch': cur_epoch} eta_stats = {'eta': round(self.eta(cur_epoch), cfg.round)} custom_stats = self.custom() if self.name == 'train': stats = { **epoch_stats, **eta_stats, **basic_stats, **task_stats, **custom_stats } else: stats = { **epoch_stats, **basic_stats, **task_stats, **custom_stats } # print logging.info(f'{self.name}: {stats}') # json dict_to_json(stats, f'{self.out_dir}/stats.json') # tensorboard if cfg.tensorboard_each_run: dict_to_tb(stats, self.tb_writer, cur_epoch) self.reset() def close(self): if cfg.tensorboard_each_run: self.tb_writer.close() def infer_task(): num_label = cfg.share.dim_out if cfg.dataset.task_type == 'classification': if num_label <= 2: task_type = 'classification_binary' else: task_type = 'classification_multi' else: task_type = cfg.dataset.task_type return task_type def create_logger(): r"""Create logger for the experiment.""" loggers = [] names = ['train', 'val', 'test'] for i, _ in enumerate(range(cfg.share.num_splits)): loggers.append(Logger(name=names[i], task_type=infer_task())) return loggers class LoggerCallback(Callback): def __init__(self): self._logger = create_logger() self._train_epoch_start_time = None self._val_epoch_start_time = None self._test_epoch_start_time = None @property def train_logger(self) -> Any: return self._logger[0] @property def val_logger(self) -> Any: return self._logger[1] @property def test_logger(self) -> Any: return self._logger[2] def close(self): for logger in self._logger: logger.close() def _get_stats( self, epoch_start_time: int, outputs: Dict[str, Any], trainer: 'pl.Trainer', ) -> Dict: return dict( true=outputs['true'].detach().cpu(), pred=outputs['pred_score'].detach().cpu(), loss=float(outputs['loss']), lr=trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0], time_used=time.time() - epoch_start_time, params=cfg.params, ) def on_train_epoch_start( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self._train_epoch_start_time = time.time() def on_validation_epoch_start( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self._val_epoch_start_time = time.time() def on_test_epoch_start( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self._test_epoch_start_time = time.time() def on_train_batch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Dict[str, Any], batch: Any, batch_idx: int, unused: int = 0, ): stats = self._get_stats(self._train_epoch_start_time, outputs, trainer) self.train_logger.update_stats(**stats) def on_validation_batch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Optional[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int = 0, ): stats = self._get_stats(self._val_epoch_start_time, outputs, trainer) self.val_logger.update_stats(**stats) def on_test_batch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Optional[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int = 0, ): stats = self._get_stats(self._test_epoch_start_time, outputs, trainer) self.test_logger.update_stats(**stats) def on_train_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self.train_logger.write_epoch(trainer.current_epoch) def on_validation_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self.val_logger.write_epoch(trainer.current_epoch) def on_test_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self.test_logger.write_epoch(trainer.current_epoch) def on_fit_end(self, trainer, pl_module): self.close() ================================================ FILE: torch_geometric/graphgym/loss.py ================================================ import torch import torch.nn.functional as F import torch_geometric.graphgym.register as register from torch_geometric.graphgym.config import cfg def compute_loss(pred, true): """Compute loss and prediction score. Args: pred (torch.tensor): Unnormalized prediction true (torch.tensor): Ground truth labels Returns: Loss, normalized prediction score """ bce_loss = torch.nn.BCEWithLogitsLoss(reduction=cfg.model.size_average) mse_loss = torch.nn.MSELoss(reduction=cfg.model.size_average) # default manipulation for pred and true # can be skipped if special loss computation is needed pred = pred.squeeze(-1) if pred.ndim > 1 else pred true = true.squeeze(-1) if true.ndim > 1 else true # Try to load customized loss for func in register.loss_dict.values(): value = func(pred, true) if value is not None: return value if cfg.model.loss_fun == 'cross_entropy': # multiclass if pred.ndim > 1 and true.ndim == 1: pred = F.log_softmax(pred, dim=-1) return F.nll_loss(pred, true), pred # binary or multilabel else: true = true.float() return bce_loss(pred, true), torch.sigmoid(pred) elif cfg.model.loss_fun == 'mse': true = true.float() return mse_loss(pred, true), pred else: raise ValueError(f"Loss function '{cfg.model.loss_fun}' not supported") ================================================ FILE: torch_geometric/graphgym/model_builder.py ================================================ import time from typing import Any, Dict, Tuple import torch from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.imports import LightningModule from torch_geometric.graphgym.loss import compute_loss from torch_geometric.graphgym.models.gnn import GNN from torch_geometric.graphgym.optim import create_optimizer, create_scheduler from torch_geometric.graphgym.register import network_dict, register_network register_network('gnn', GNN) class GraphGymModule(LightningModule): def __init__(self, dim_in, dim_out, cfg): super().__init__() self.cfg = cfg self.model = network_dict[cfg.model.type](dim_in=dim_in, dim_out=dim_out) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def configure_optimizers(self) -> Tuple[Any, Any]: optimizer = create_optimizer(self.model.parameters(), self.cfg.optim) scheduler = create_scheduler(optimizer, self.cfg.optim) return [optimizer], [scheduler] def _shared_step(self, batch, split: str) -> Dict: batch.split = split pred, true = self(batch) loss, pred_score = compute_loss(pred, true) step_end_time = time.time() return dict(loss=loss, true=true, pred_score=pred_score.detach(), step_end_time=step_end_time) def training_step(self, batch, *args, **kwargs): return self._shared_step(batch, split="train") def validation_step(self, batch, *args, **kwargs): return self._shared_step(batch, split="val") def test_step(self, batch, *args, **kwargs): return self._shared_step(batch, split="test") @property def encoder(self) -> torch.nn.Module: return self.model.encoder @property def mp(self) -> torch.nn.Module: return self.model.mp @property def post_mp(self) -> torch.nn.Module: return self.model.post_mp @property def pre_mp(self) -> torch.nn.Module: return self.model.pre_mp def lr_scheduler_step(self, *args, **kwargs): # Needed for PyTorch 2.0 since the base class of LR schedulers changed. # TODO Remove once we only want to support PyTorch Lightning >= 2.0. return super().lr_scheduler_step(*args, **kwargs) def create_model(to_device=True, dim_in=None, dim_out=None) -> GraphGymModule: r"""Create model for graph machine learning. Args: to_device (bool, optional): Whether to transfer the model to the specified device. (default: :obj:`True`) dim_in (int, optional): Input dimension to the model dim_out (int, optional): Output dimension to the model """ dim_in = cfg.share.dim_in if dim_in is None else dim_in dim_out = cfg.share.dim_out if dim_out is None else dim_out # binary classification, output dim = 1 if 'classification' == cfg.dataset.task_type and dim_out == 2: dim_out = 1 model = GraphGymModule(dim_in, dim_out, cfg) if to_device: model.to(torch.device(cfg.accelerator)) return model ================================================ FILE: torch_geometric/graphgym/models/__init__.py ================================================ from .encoder import (IntegerFeatureEncoder, AtomEncoder, BondEncoder) from .gnn import (GNNLayer, GNNPreMP, GNNStackStage, FeatureEncoder, GNN) from .head import (GNNNodeHead, GNNEdgeHead, GNNGraphHead) from .layer import (GeneralLayer, GeneralMultiLayer, Linear, BatchNorm1dNode, BatchNorm1dEdge, MLP, GCNConv, SAGEConv, GATConv, GINConv, SplineConv, GeneralConv, GeneralEdgeConv, GeneralSampleEdgeConv) from .pooling import (global_add_pool, global_mean_pool, global_max_pool) __all__ = [ 'IntegerFeatureEncoder', 'AtomEncoder', 'BondEncoder', 'GNNLayer', 'GNNPreMP', 'GNNStackStage', 'FeatureEncoder', 'GNN', 'GNNNodeHead', 'GNNEdgeHead', 'GNNGraphHead', 'GeneralLayer', 'GeneralMultiLayer', 'Linear', 'BatchNorm1dNode', 'BatchNorm1dEdge', 'MLP', 'GCNConv', 'SAGEConv', 'GATConv', 'GINConv', 'SplineConv', 'GeneralConv', 'GeneralEdgeConv', 'GeneralSampleEdgeConv', 'global_add_pool', 'global_mean_pool', 'global_max_pool', ] classes = __all__ ================================================ FILE: torch_geometric/graphgym/models/act.py ================================================ import torch from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_act def relu(): return torch.nn.ReLU(inplace=cfg.mem.inplace) def selu(): return torch.nn.SELU(inplace=cfg.mem.inplace) def prelu(): return torch.nn.PReLU() def elu(): return torch.nn.ELU(inplace=cfg.mem.inplace) def lrelu_01(): return torch.nn.LeakyReLU(0.1, inplace=cfg.mem.inplace) def lrelu_025(): return torch.nn.LeakyReLU(0.25, inplace=cfg.mem.inplace) def lrelu_05(): return torch.nn.LeakyReLU(0.5, inplace=cfg.mem.inplace) if cfg is not None: register_act('relu', relu) register_act('selu', selu) register_act('prelu', prelu) register_act('elu', elu) register_act('lrelu_01', lrelu_01) register_act('lrelu_025', lrelu_025) register_act('lrelu_05', lrelu_05) ================================================ FILE: torch_geometric/graphgym/models/encoder.py ================================================ import torch from torch_geometric.graphgym.register import ( register_edge_encoder, register_node_encoder, ) @register_node_encoder('Integer') class IntegerFeatureEncoder(torch.nn.Module): r"""Provides an encoder for integer node features. Args: emb_dim (int): The output embedding dimension. num_classes (int): The number of classes/integers. Example: >>> encoder = IntegerFeatureEncoder(emb_dim=16, num_classes=10) >>> batch = torch.randint(0, 10, (10, 2)) >>> encoder(batch).size() torch.Size([10, 16]) """ def __init__(self, emb_dim: int, num_classes: int): super().__init__() self.encoder = torch.nn.Embedding(num_classes, emb_dim) torch.nn.init.xavier_uniform_(self.encoder.weight.data) def forward(self, batch): # Encode just the first dimension if more exist batch.x = self.encoder(batch.x[:, 0]) return batch @register_node_encoder('Atom') class AtomEncoder(torch.nn.Module): r"""The atom encoder used in OGB molecule dataset. Args: emb_dim (int): The output embedding dimension. Example: >>> encoder = AtomEncoder(emb_dim=16) >>> batch = torch.randint(0, 10, (10, 3)) >>> encoder(batch).size() torch.Size([10, 16]) """ def __init__(self, emb_dim, *args, **kwargs): super().__init__() from ogb.utils.features import get_atom_feature_dims self.atom_embedding_list = torch.nn.ModuleList() for dim in get_atom_feature_dims(): emb = torch.nn.Embedding(dim, emb_dim) torch.nn.init.xavier_uniform_(emb.weight.data) self.atom_embedding_list.append(emb) def forward(self, batch): encoded_features = 0 for i in range(batch.x.shape[1]): encoded_features += self.atom_embedding_list[i](batch.x[:, i]) batch.x = encoded_features return batch @register_edge_encoder('Bond') class BondEncoder(torch.nn.Module): r"""The bond encoder used in OGB molecule dataset. Args: emb_dim (int): The output embedding dimension. Example: >>> encoder = BondEncoder(emb_dim=16) >>> batch = torch.randint(0, 10, (10, 3)) >>> encoder(batch).size() torch.Size([10, 16]) """ def __init__(self, emb_dim: int): super().__init__() from ogb.utils.features import get_bond_feature_dims self.bond_embedding_list = torch.nn.ModuleList() for dim in get_bond_feature_dims(): emb = torch.nn.Embedding(dim, emb_dim) torch.nn.init.xavier_uniform_(emb.weight.data) self.bond_embedding_list.append(emb) def forward(self, batch): bond_embedding = 0 for i in range(batch.edge_attr.shape[1]): edge_attr = batch.edge_attr bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) batch.edge_attr = bond_embedding return batch ================================================ FILE: torch_geometric/graphgym/models/gnn.py ================================================ import torch import torch.nn.functional as F import torch_geometric.graphgym.register as register from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.init import init_weights from torch_geometric.graphgym.models.layer import ( BatchNorm1dNode, GeneralLayer, GeneralMultiLayer, new_layer_config, ) from torch_geometric.graphgym.register import register_stage def GNNLayer(dim_in: int, dim_out: int, has_act: bool = True) -> GeneralLayer: r"""Creates a GNN layer, given the specified input and output dimensions and the underlying configuration in :obj:`cfg`. Args: dim_in (int): The input dimension dim_out (int): The output dimension. has_act (bool, optional): Whether to apply an activation function after the layer. (default: :obj:`True`) """ return GeneralLayer( cfg.gnn.layer_type, layer_config=new_layer_config( dim_in, dim_out, 1, has_act=has_act, has_bias=False, cfg=cfg, ), ) def GNNPreMP(dim_in: int, dim_out: int, num_layers: int) -> GeneralMultiLayer: r"""Creates a NN layer used before message passing, given the specified input and output dimensions and the underlying configuration in :obj:`cfg`. Args: dim_in (int): The input dimension dim_out (int): The output dimension. num_layers (int): The number of layers. """ return GeneralMultiLayer( 'linear', layer_config=new_layer_config( dim_in, dim_out, num_layers, has_act=False, has_bias=False, cfg=cfg, ), ) @register_stage('stack') @register_stage('skipsum') @register_stage('skipconcat') class GNNStackStage(torch.nn.Module): r"""Stacks a number of GNN layers. Args: dim_in (int): The input dimension dim_out (int): The output dimension. num_layers (int): The number of layers. """ def __init__(self, dim_in, dim_out, num_layers): super().__init__() self.num_layers = num_layers for i in range(num_layers): if cfg.gnn.stage_type == 'skipconcat': d_in = dim_in if i == 0 else dim_in + i * dim_out else: d_in = dim_in if i == 0 else dim_out layer = GNNLayer(d_in, dim_out) self.add_module(f'layer{i}', layer) def forward(self, batch): for i, layer in enumerate(self.children()): x = batch.x batch = layer(batch) if cfg.gnn.stage_type == 'skipsum': batch.x = x + batch.x elif (cfg.gnn.stage_type == 'skipconcat' and i < self.num_layers - 1): batch.x = torch.cat([x, batch.x], dim=1) if cfg.gnn.l2norm: batch.x = F.normalize(batch.x, p=2, dim=-1) return batch class FeatureEncoder(torch.nn.Module): r"""Encodes node and edge features, given the specified input dimension and the underlying configuration in :obj:`cfg`. Args: dim_in (int): The input feature dimension. """ def __init__(self, dim_in: int): super().__init__() self.dim_in = dim_in if cfg.dataset.node_encoder: # Encode integer node features via `torch.nn.Embedding`: NodeEncoder = register.node_encoder_dict[ cfg.dataset.node_encoder_name] self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) if cfg.dataset.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( new_layer_config( cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg, )) # Update `dim_in` to reflect the new dimension fo the node features self.dim_in = cfg.gnn.dim_inner if cfg.dataset.edge_encoder: # Encode integer edge features via `torch.nn.Embedding`: EdgeEncoder = register.edge_encoder_dict[ cfg.dataset.edge_encoder_name] self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( new_layer_config( cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg, )) def forward(self, batch): for module in self.children(): batch = module(batch) return batch class GNN(torch.nn.Module): r"""A general Graph Neural Network (GNN) model. The GNN model consists of three main components: 1. An encoder to transform input features into a fixed-size embedding space. 2. A processing or message passing stage for information exchange between nodes. 3. A head to produce the final output features/predictions. The configuration of each component is determined by the underlying configuration in :obj:`cfg`. Args: dim_in (int): The input feature dimension. dim_out (int): The output feature dimension. **kwargs (optional): Additional keyword arguments. """ def __init__(self, dim_in: int, dim_out: int, **kwargs): super().__init__() GNNStage = register.stage_dict[cfg.gnn.stage_type] GNNHead = register.head_dict[cfg.gnn.head] self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) dim_in = cfg.gnn.dim_inner if cfg.gnn.layers_mp > 0: self.mp = GNNStage(dim_in=dim_in, dim_out=cfg.gnn.dim_inner, num_layers=cfg.gnn.layers_mp) self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) self.apply(init_weights) def forward(self, batch): for module in self.children(): batch = module(batch) return batch ================================================ FILE: torch_geometric/graphgym/models/head.py ================================================ import torch import torch_geometric.graphgym.register as register from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.models.layer import MLP, new_layer_config from torch_geometric.graphgym.register import register_head @register_head('node') class GNNNodeHead(torch.nn.Module): r"""A GNN prediction head for node-level prediction tasks. Args: dim_in (int): The input feature dimension. dim_out (int): The output feature dimension. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.layer_post_mp = MLP( new_layer_config( dim_in, dim_out, cfg.gnn.layers_post_mp, has_act=False, has_bias=True, cfg=cfg, )) def _apply_index(self, batch): x = batch.x y = batch.y if 'y' in batch else None if 'split' not in batch: return x, y mask = batch[f'{batch.split}_mask'] return x[mask], y[mask] if y is not None else None def forward(self, batch): batch = self.layer_post_mp(batch) pred, label = self._apply_index(batch) return pred, label @register_head('edge') @register_head('link_pred') class GNNEdgeHead(torch.nn.Module): r"""A GNN prediction head for edge-level/link-level prediction tasks. Args: dim_in (int): The input feature dimension. dim_out (int): The output feature dimension. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() # Module to decode edges from node embeddings: if cfg.model.edge_decoding == 'concat': self.layer_post_mp = MLP( new_layer_config( dim_in * 2, dim_out, cfg.gnn.layers_post_mp, has_act=False, has_bias=True, cfg=cfg, )) self.decode_module = lambda v1, v2: \ self.layer_post_mp(torch.cat((v1, v2), dim=-1)) else: if dim_out > 1: raise ValueError(f"Binary edge decoding " f"'{cfg.model.edge_decoding}' is used for " f"multi-class classification") self.layer_post_mp = MLP( new_layer_config( dim_in, dim_in, cfg.gnn.layers_post_mp, has_act=False, has_bias=True, cfg=cfg, )) if cfg.model.edge_decoding == 'dot': self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1) elif cfg.model.edge_decoding == 'cosine_similarity': self.decode_module = torch.nn.CosineSimilarity(dim=-1) else: raise ValueError(f"Unknown edge decoding " f"'{cfg.model.edge_decoding}'") def _apply_index(self, batch): index = f'{batch.split}_edge_index' label = f'{batch.split}_edge_label' return batch.x[batch[index]], batch[label] def forward(self, batch): if cfg.model.edge_decoding != 'concat': batch = self.layer_post_mp(batch) pred, label = self._apply_index(batch) nodes_first = pred[0] nodes_second = pred[1] pred = self.decode_module(nodes_first, nodes_second) return pred, label @register_head('graph') class GNNGraphHead(torch.nn.Module): r"""A GNN prediction head for graph-level prediction tasks. A post message passing layer (as specified by :obj:`cfg.gnn.post_mp`) is used to transform the pooled graph-level embeddings using an MLP. Args: dim_in (int): The input feature dimension. dim_out (int): The output feature dimension. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.layer_post_mp = MLP( new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, has_act=False, has_bias=True, cfg=cfg)) self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling] def _apply_index(self, batch): return batch.graph_feature, batch.y def forward(self, batch): graph_emb = self.pooling_fun(batch.x, batch.batch) graph_emb = self.layer_post_mp(graph_emb) batch.graph_feature = graph_emb pred, label = self._apply_index(batch) return pred, label ================================================ FILE: torch_geometric/graphgym/models/layer.py ================================================ import copy from dataclasses import dataclass, replace import torch import torch.nn.functional as F import torch_geometric as pyg import torch_geometric.graphgym.models.act import torch_geometric.graphgym.register as register from torch_geometric.graphgym.contrib.layer.generalconv import ( GeneralConvLayer, GeneralEdgeConvLayer, ) from torch_geometric.graphgym.register import register_layer from torch_geometric.nn import Linear as Linear_pyg @dataclass class LayerConfig: # batchnorm parameters. has_batchnorm: bool = False bn_eps: float = 1e-5 bn_mom: float = 0.1 # mem parameters. mem_inplace: bool = False # gnn parameters. dim_in: int = -1 dim_out: int = -1 edge_dim: int = -1 dim_inner: int = None num_layers: int = 2 has_bias: bool = True # regularizer parameters. has_l2norm: bool = True dropout: float = 0.0 # activation parameters. has_act: bool = True final_act: bool = True act: str = 'relu' # other parameters. keep_edge: float = 0.5 def new_layer_config( dim_in: int, dim_out: int, num_layers: int, has_act: bool, has_bias: bool, cfg, ) -> LayerConfig: r"""Create a layer configuration for a GNN layer. Args: dim_in (int): The input feature dimension. dim_out (int): The output feature dimension. num_layers (int): The number of hidden layers has_act (bool): Whether to apply an activation function after the layer. has_bias (bool): Whether to apply a bias term in the layer. cfg (ConfigNode): The underlying configuration. """ return LayerConfig( has_batchnorm=cfg.gnn.batchnorm, bn_eps=cfg.bn.eps, bn_mom=cfg.bn.mom, mem_inplace=cfg.mem.inplace, dim_in=dim_in, dim_out=dim_out, edge_dim=cfg.dataset.edge_dim, has_l2norm=cfg.gnn.l2norm, dropout=cfg.gnn.dropout, has_act=has_act, final_act=True, act=cfg.gnn.act, has_bias=has_bias, keep_edge=cfg.gnn.keep_edge, dim_inner=cfg.gnn.dim_inner, num_layers=num_layers, ) class GeneralLayer(torch.nn.Module): r"""A general wrapper for layers. Args: name (str): The registered name of the layer. layer_config (LayerConfig): The configuration of the layer. **kwargs (optional): Additional keyword arguments. """ def __init__(self, name, layer_config: LayerConfig, **kwargs): super().__init__() self.has_l2norm = layer_config.has_l2norm has_bn = layer_config.has_batchnorm layer_config.has_bias = not has_bn self.layer = register.layer_dict[name](layer_config, **kwargs) layer_wrapper = [] if has_bn: layer_wrapper.append( torch.nn.BatchNorm1d( layer_config.dim_out, eps=layer_config.bn_eps, momentum=layer_config.bn_mom, )) if layer_config.dropout > 0: layer_wrapper.append( torch.nn.Dropout( p=layer_config.dropout, inplace=layer_config.mem_inplace, )) if layer_config.has_act: layer_wrapper.append(register.act_dict[layer_config.act]()) self.post_layer = torch.nn.Sequential(*layer_wrapper) def forward(self, batch): batch = self.layer(batch) if isinstance(batch, torch.Tensor): batch = self.post_layer(batch) if self.has_l2norm: batch = F.normalize(batch, p=2, dim=1) else: batch.x = self.post_layer(batch.x) if self.has_l2norm: batch.x = F.normalize(batch.x, p=2, dim=1) return batch class GeneralMultiLayer(torch.nn.Module): r"""A general wrapper class for a stacking multiple NN layers. Args: name (str): The registered name of the layer. layer_config (LayerConfig): The configuration of the layer. **kwargs (optional): Additional keyword arguments. """ def __init__(self, name, layer_config: LayerConfig, **kwargs): super().__init__() if layer_config.dim_inner: dim_inner = layer_config.dim_out else: dim_inner = layer_config.dim_inner for i in range(layer_config.num_layers): d_in = layer_config.dim_in if i == 0 else dim_inner d_out = layer_config.dim_out \ if i == layer_config.num_layers - 1 else dim_inner has_act = layer_config.final_act \ if i == layer_config.num_layers - 1 else True inter_layer_config = copy.deepcopy(layer_config) inter_layer_config.dim_in = d_in inter_layer_config.dim_out = d_out inter_layer_config.has_act = has_act layer = GeneralLayer(name, inter_layer_config, **kwargs) self.add_module(f'Layer_{i}', layer) def forward(self, batch): for layer in self.children(): batch = layer(batch) return batch # ---------- Core basic layers. Input: batch; Output: batch ----------------- # @register_layer('linear') class Linear(torch.nn.Module): r"""A basic Linear layer. Args: layer_config (LayerConfig): The configuration of the layer. **kwargs (optional): Additional keyword arguments. """ def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = Linear_pyg( layer_config.dim_in, layer_config.dim_out, bias=layer_config.has_bias, ) def forward(self, batch): if isinstance(batch, torch.Tensor): batch = self.model(batch) else: batch.x = self.model(batch.x) return batch class BatchNorm1dNode(torch.nn.Module): r"""A batch normalization layer for node-level features. Args: layer_config (LayerConfig): The configuration of the layer. """ def __init__(self, layer_config: LayerConfig): super().__init__() self.bn = torch.nn.BatchNorm1d( layer_config.dim_in, eps=layer_config.bn_eps, momentum=layer_config.bn_mom, ) def forward(self, batch): batch.x = self.bn(batch.x) return batch class BatchNorm1dEdge(torch.nn.Module): r"""A batch normalization layer for edge-level features. Args: layer_config (LayerConfig): The configuration of the layer. """ def __init__(self, layer_config: LayerConfig): super().__init__() self.bn = torch.nn.BatchNorm1d( layer_config.dim_in, eps=layer_config.bn_eps, momentum=layer_config.bn_mom, ) def forward(self, batch): batch.edge_attr = self.bn(batch.edge_attr) return batch @register_layer('mlp') class MLP(torch.nn.Module): """A basic MLP model. Args: layer_config (LayerConfig): The configuration of the layer. **kwargs (optional): Additional keyword arguments. """ def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() if layer_config.dim_inner is None: dim_inner = layer_config.dim_in else: dim_inner = layer_config.dim_inner layer_config.has_bias = True layers = [] if layer_config.num_layers > 1: sub_layer_config = LayerConfig( num_layers=layer_config.num_layers - 1, dim_in=layer_config.dim_in, dim_out=dim_inner, dim_inner=dim_inner, final_act=True) layers.append(GeneralMultiLayer('linear', sub_layer_config)) layer_config = replace(layer_config, dim_in=dim_inner) layers.append(Linear(layer_config)) else: layers.append(Linear(layer_config)) self.model = torch.nn.Sequential(*layers) def forward(self, batch): if isinstance(batch, torch.Tensor): batch = self.model(batch) else: batch.x = self.model(batch.x) return batch @register_layer('gcnconv') class GCNConv(torch.nn.Module): r"""A Graph Convolutional Network (GCN) layer.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = pyg.nn.GCNConv( layer_config.dim_in, layer_config.dim_out, bias=layer_config.has_bias, ) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch @register_layer('sageconv') class SAGEConv(torch.nn.Module): r"""A GraphSAGE layer.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = pyg.nn.SAGEConv( layer_config.dim_in, layer_config.dim_out, bias=layer_config.has_bias, ) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch @register_layer('gatconv') class GATConv(torch.nn.Module): r"""A Graph Attention Network (GAT) layer.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = pyg.nn.GATConv( layer_config.dim_in, layer_config.dim_out, bias=layer_config.has_bias, ) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch @register_layer('ginconv') class GINConv(torch.nn.Module): r"""A Graph Isomorphism Network (GIN) layer.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() gin_nn = torch.nn.Sequential( Linear_pyg(layer_config.dim_in, layer_config.dim_out), torch.nn.ReLU(), Linear_pyg(layer_config.dim_out, layer_config.dim_out), ) self.model = pyg.nn.GINConv(gin_nn) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch @register_layer('splineconv') class SplineConv(torch.nn.Module): r"""A SplineCNN layer.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = pyg.nn.SplineConv( layer_config.dim_in, layer_config.dim_out, dim=1, kernel_size=2, bias=layer_config.has_bias, ) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr) return batch @register_layer('generalconv') class GeneralConv(torch.nn.Module): r"""A general GNN layer.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = GeneralConvLayer( layer_config.dim_in, layer_config.dim_out, bias=layer_config.has_bias, ) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch @register_layer('generaledgeconv') class GeneralEdgeConv(torch.nn.Module): r"""A general GNN layer with edge feature support.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = GeneralEdgeConvLayer( layer_config.dim_in, layer_config.dim_out, layer_config.edge_dim, bias=layer_config.has_bias, ) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index, edge_feature=batch.edge_attr) return batch @register_layer('generalsampleedgeconv') class GeneralSampleEdgeConv(torch.nn.Module): r"""A general GNN layer that supports edge features and edge sampling.""" def __init__(self, layer_config: LayerConfig, **kwargs): super().__init__() self.model = GeneralEdgeConvLayer( layer_config.dim_in, layer_config.dim_out, layer_config.edge_dim, bias=layer_config.has_bias, ) self.keep_edge = layer_config.keep_edge def forward(self, batch): edge_mask = torch.rand(batch.edge_index.shape[1]) < self.keep_edge edge_index = batch.edge_index[:, edge_mask] edge_feature = batch.edge_attr[edge_mask, :] batch.x = self.model(batch.x, edge_index, edge_feature=edge_feature) return batch ================================================ FILE: torch_geometric/graphgym/models/pooling.py ================================================ from torch_geometric.graphgym.register import register_pooling from torch_geometric.nn import ( global_add_pool, global_max_pool, global_mean_pool, ) register_pooling('add', global_add_pool) register_pooling('mean', global_mean_pool) register_pooling('max', global_max_pool) ================================================ FILE: torch_geometric/graphgym/models/transform.py ================================================ import torch from torch_geometric.utils import negative_sampling def create_link_label(pos_edge_index, neg_edge_index): """Create labels for link prediction, based on positive and negative edges. Args: pos_edge_index (torch.tensor): Positive edge index [2, num_edges] neg_edge_index (torch.tensor): Negative edge index [2, num_edges] Returns: Link label tensor, [num_positive_edges + num_negative_edges] """ num_links = pos_edge_index.size(1) + neg_edge_index.size(1) link_labels = torch.zeros(num_links, dtype=torch.float, device=pos_edge_index.device) link_labels[:pos_edge_index.size(1)] = 1. return link_labels def neg_sampling_transform(data): """Do negative sampling for link prediction tasks. Args: data (torch_geometric.data): Input data object Returns: Transformed data object with negative edges + link pred labels """ train_neg_edge_index = negative_sampling( edge_index=data.train_pos_edge_index, num_nodes=data.num_nodes, num_neg_samples=data.train_pos_edge_index.size(1)) data.train_edge_index = torch.cat( [data.train_pos_edge_index, train_neg_edge_index], dim=-1) data.train_edge_label = create_link_label(data.train_pos_edge_index, train_neg_edge_index) return data ================================================ FILE: torch_geometric/graphgym/optim.py ================================================ from dataclasses import dataclass, field from typing import Any, Iterator, List, Optional from torch.nn import Parameter from torch.optim import SGD, Adam, Optimizer from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR, StepLR import torch_geometric.graphgym.register as register from torch_geometric.graphgym.config import from_config @dataclass class OptimizerConfig: optimizer: str = 'adam' # ['sgd', 'adam'] base_lr: float = 0.01 weight_decay: float = 5e-4 momentum: float = 0.9 # 'sgd' policy @register.register_optimizer('adam') def adam_optimizer(params: Iterator[Parameter], base_lr: float, weight_decay: float) -> Adam: return Adam(params, lr=base_lr, weight_decay=weight_decay) @register.register_optimizer('sgd') def sgd_optimizer(params: Iterator[Parameter], base_lr: float, momentum: float, weight_decay: float) -> SGD: return SGD(params, lr=base_lr, momentum=momentum, weight_decay=weight_decay) def create_optimizer(params: Iterator[Parameter], cfg: Any) -> Any: r"""Creates a config-driven optimizer.""" params = filter(lambda p: p.requires_grad, params) func = register.optimizer_dict.get(cfg.optimizer, None) if func is not None: return from_config(func)(params, cfg=cfg) raise ValueError(f"Optimizer '{cfg.optimizer}' not supported") @dataclass class SchedulerConfig: scheduler: Optional[str] = 'cos' # [None, 'steps', 'cos'] steps: List[int] = field(default_factory=[30, 60, 90]) # 'steps' policy lr_decay: float = 0.1 # 'steps' policy max_epoch: int = 200 @register.register_scheduler(None) @register.register_scheduler('none') def none_scheduler(optimizer: Optimizer, max_epoch: int) -> StepLR: return StepLR(optimizer, step_size=max_epoch + 1) @register.register_scheduler('step') def step_scheduler(optimizer: Optimizer, steps: List[int], lr_decay: float) -> MultiStepLR: return MultiStepLR(optimizer, milestones=steps, gamma=lr_decay) @register.register_scheduler('cos') def cos_scheduler(optimizer: Optimizer, max_epoch: int) -> CosineAnnealingLR: return CosineAnnealingLR(optimizer, T_max=max_epoch) def create_scheduler(optimizer: Optimizer, cfg: Any) -> Any: r"""Creates a config-driven learning rate scheduler.""" func = register.scheduler_dict.get(cfg.scheduler, None) if func is not None: return from_config(func)(optimizer, cfg=cfg) raise ValueError(f"Scheduler '{cfg.scheduler}' not supported") ================================================ FILE: torch_geometric/graphgym/register.py ================================================ from typing import Any, Callable, Dict, Union act_dict: Dict[str, Any] = {} node_encoder_dict: Dict[str, Any] = {} edge_encoder_dict: Dict[str, Any] = {} stage_dict: Dict[str, Any] = {} head_dict: Dict[str, Any] = {} layer_dict: Dict[str, Any] = {} pooling_dict: Dict[str, Any] = {} network_dict: Dict[str, Any] = {} config_dict: Dict[str, Any] = {} dataset_dict: Dict[str, Any] = {} loader_dict: Dict[str, Any] = {} optimizer_dict: Dict[str, Any] = {} scheduler_dict: Dict[str, Any] = {} loss_dict: Dict[str, Any] = {} train_dict: Dict[str, Any] = {} metric_dict: Dict[str, Any] = {} def register_base(mapping: Dict[str, Any], key: str, module: Any = None) -> Union[None, Callable]: r"""Base function for registering a module in GraphGym. Args: mapping (dict): :python:`Python` dictionary to register the module. hosting all the registered modules key (str): The name of the module. module (any, optional): The module. If set to :obj:`None`, will return a decorator to register a module. """ if module is not None: if key in mapping: raise KeyError(f"Module with '{key}' already defined") mapping[key] = module return # Other-wise, use it as a decorator: def bounded_register(module): register_base(mapping, key, module) return module return bounded_register def register_act(key: str, module: Any = None): r"""Registers an activation function in GraphGym.""" return register_base(act_dict, key, module) def register_node_encoder(key: str, module: Any = None): r"""Registers a node feature encoder in GraphGym.""" return register_base(node_encoder_dict, key, module) def register_edge_encoder(key: str, module: Any = None): r"""Registers an edge feature encoder in GraphGym.""" return register_base(edge_encoder_dict, key, module) def register_stage(key: str, module: Any = None): r"""Registers a customized GNN stage in GraphGym.""" return register_base(stage_dict, key, module) def register_head(key: str, module: Any = None): r"""Registers a GNN prediction head in GraphGym.""" return register_base(head_dict, key, module) def register_layer(key: str, module: Any = None): r"""Registers a GNN layer in GraphGym.""" return register_base(layer_dict, key, module) def register_pooling(key: str, module: Any = None): r"""Registers a GNN global pooling/readout layer in GraphGym.""" return register_base(pooling_dict, key, module) def register_network(key: str, module: Any = None): r"""Registers a GNN model in GraphGym.""" return register_base(network_dict, key, module) def register_config(key: str, module: Any = None): r"""Registers a configuration group in GraphGym.""" return register_base(config_dict, key, module) def register_dataset(key: str, module: Any = None): r"""Registers a dataset in GraphGym.""" return register_base(dataset_dict, key, module) def register_loader(key: str, module: Any = None): r"""Registers a data loader in GraphGym.""" return register_base(loader_dict, key, module) def register_optimizer(key: str, module: Any = None): r"""Registers an optimizer in GraphGym.""" return register_base(optimizer_dict, key, module) def register_scheduler(key: str, module: Any = None): r"""Registers a learning rate scheduler in GraphGym.""" return register_base(scheduler_dict, key, module) def register_loss(key: str, module: Any = None): r"""Registers a loss function in GraphGym.""" return register_base(loss_dict, key, module) def register_train(key: str, module: Any = None): r"""Registers a training function in GraphGym.""" return register_base(train_dict, key, module) def register_metric(key: str, module: Any = None): r"""Register a metric function in GraphGym.""" return register_base(metric_dict, key, module) ================================================ FILE: torch_geometric/graphgym/train.py ================================================ import warnings from typing import Any, Dict, Optional import torch from torch.utils.data import DataLoader from torch_geometric.data.lightning.datamodule import LightningDataModule from torch_geometric.graphgym import create_loader from torch_geometric.graphgym.checkpoint import get_ckpt_dir from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.imports import pl from torch_geometric.graphgym.logger import LoggerCallback from torch_geometric.graphgym.model_builder import GraphGymModule class GraphGymDataModule(LightningDataModule): r"""A :class:`pytorch_lightning.LightningDataModule` for handling data loading routines in GraphGym. This class provides data loaders for training, validation, and testing, and can be accessed through the :meth:`train_dataloader`, :meth:`val_dataloader`, and :meth:`test_dataloader` methods, respectively. """ def __init__(self): self.loaders = create_loader() super().__init__(has_val=True, has_test=True) def train_dataloader(self) -> DataLoader: return self.loaders[0] def val_dataloader(self) -> DataLoader: # better way would be to test after fit. # First call trainer.fit(...) then trainer.test(...) return self.loaders[1] def test_dataloader(self) -> DataLoader: return self.loaders[2] def train( model: GraphGymModule, datamodule: GraphGymDataModule, logger: bool = True, trainer_config: Optional[Dict[str, Any]] = None, ): r"""Trains a GraphGym model using PyTorch Lightning. Args: model (GraphGymModule): The GraphGym model. datamodule (GraphGymDataModule): The GraphGym data module. logger (bool, optional): Whether to enable logging during training. (default: :obj:`True`) trainer_config (dict, optional): Additional trainer configuration. """ warnings.filterwarnings('ignore', '.*use `CSVLogger` as the default.*') callbacks = [] if logger: callbacks.append(LoggerCallback()) if cfg.train.enable_ckpt: ckpt_cbk = pl.callbacks.ModelCheckpoint(dirpath=get_ckpt_dir()) callbacks.append(ckpt_cbk) trainer_config = trainer_config or {} trainer = pl.Trainer( **trainer_config, enable_checkpointing=cfg.train.enable_ckpt, callbacks=callbacks, default_root_dir=cfg.out_dir, max_epochs=cfg.optim.max_epoch, accelerator=cfg.accelerator, devices='auto' if not torch.cuda.is_available() else cfg.devices, ) trainer.fit(model, datamodule=datamodule) trainer.test(model, datamodule=datamodule) ================================================ FILE: torch_geometric/graphgym/utils/LICENSE ================================================ ================================================ FILE: torch_geometric/graphgym/utils/__init__.py ================================================ from .agg_runs import agg_runs, agg_batch from .comp_budget import params_count, match_baseline_cfg from .device import get_current_gpu_usage, auto_select_device from .epoch import is_eval_epoch, is_ckpt_epoch from .io import dict_to_json, dict_list_to_json, dict_to_tb, makedirs_rm_exist from .tools import dummy_context __all__ = [ 'agg_runs', 'agg_batch', 'params_count', 'match_baseline_cfg', 'get_current_gpu_usage', 'auto_select_device', 'is_eval_epoch', 'is_ckpt_epoch', 'dict_to_json', 'dict_list_to_json', 'dict_to_tb', 'makedirs_rm_exist', 'dummy_context', ] classes = __all__ ================================================ FILE: torch_geometric/graphgym/utils/agg_runs.py ================================================ import logging import os import os.path as osp import numpy as np from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.utils.io import ( dict_list_to_json, dict_list_to_tb, dict_to_json, json_to_dict_list, makedirs_rm_exist, string_to_python, ) try: from tensorboardX import SummaryWriter except ImportError: SummaryWriter = None def is_seed(s): try: int(s) return True except Exception: return False def is_split(s): if s in ['train', 'val']: return True else: return False def join_list(l1, l2): assert len(l1) == len(l2), \ 'Results with different seeds must have the save format' for i in range(len(l1)): l1[i] += l2[i] return l1 def agg_dict_list(dict_list): """Aggregate a list of dictionaries: mean + std Args: dict_list: list of dictionaries. """ dict_agg = {'epoch': dict_list[0]['epoch']} for key in dict_list[0]: if key != 'epoch': value = np.array([dict[key] for dict in dict_list]) dict_agg[key] = np.mean(value).round(cfg.round) dict_agg[f'{key}_std'] = np.std(value).round(cfg.round) return dict_agg def name_to_dict(run): run = run.split('-', 1)[-1] cols = run.split('=') keys, vals = [], [] keys.append(cols[0]) for col in cols[1:-1]: try: val, key = col.rsplit('-', 1) except Exception: print(col) keys.append(key) vals.append(string_to_python(val)) vals.append(cols[-1]) return dict(zip(keys, vals)) def rm_keys(dict, keys): for key in keys: dict.pop(key, None) def agg_runs(dir, metric_best='auto'): r"""Aggregate over different random seeds of a single experiment. Args: dir (str): Directory of the results, containing 1 experiment metric_best (str, optional): The metric for selecting the best validation performance. Options: auto, accuracy, auc. """ results = {'train': None, 'val': None} results_best = {'train': None, 'val': None} for seed in os.listdir(dir): if is_seed(seed): dir_seed = osp.join(dir, seed) split = 'val' if split in os.listdir(dir_seed): dir_split = osp.join(dir_seed, split) fname_stats = osp.join(dir_split, 'stats.json') stats_list = json_to_dict_list(fname_stats) if metric_best == 'auto': metric = 'auc' if 'auc' in stats_list[0] else 'accuracy' else: metric = metric_best performance_np = np.array( # noqa [stats[metric] for stats in stats_list]) best_epoch = \ stats_list[ eval(f"performance_np.{cfg.metric_agg}()")][ 'epoch'] print(best_epoch) for split in os.listdir(dir_seed): if is_split(split): dir_split = osp.join(dir_seed, split) fname_stats = osp.join(dir_split, 'stats.json') stats_list = json_to_dict_list(fname_stats) stats_best = [ stats for stats in stats_list if stats['epoch'] == best_epoch ][0] print(stats_best) stats_list = [[stats] for stats in stats_list] if results[split] is None: results[split] = stats_list else: results[split] = join_list(results[split], stats_list) if results_best[split] is None: results_best[split] = [stats_best] else: results_best[split] += [stats_best] results = {k: v for k, v in results.items() if v is not None} results_best = {k: v for k, v in results_best.items() if v is not None} for key in results: for i in range(len(results[key])): results[key][i] = agg_dict_list(results[key][i]) for key in results_best: results_best[key] = agg_dict_list(results_best[key]) # save aggregated results for key, value in results.items(): dir_out = osp.join(dir, 'agg', key) makedirs_rm_exist(dir_out) fname = osp.join(dir_out, 'stats.json') dict_list_to_json(value, fname) if cfg.tensorboard_agg: if SummaryWriter is None: raise ImportError( 'Tensorboard support requires `tensorboardX`.') writer = SummaryWriter(dir_out) dict_list_to_tb(value, writer) writer.close() for key, value in results_best.items(): dir_out = osp.join(dir, 'agg', key) fname = osp.join(dir_out, 'best.json') dict_to_json(value, fname) logging.info('Results aggregated across runs saved in {}'.format( osp.join(dir, 'agg'))) def agg_batch(dir, metric_best='auto'): r"""Aggregate across results from multiple experiments via grid search. Args: dir (str): Directory of the results, containing multiple experiments metric_best (str, optional): The metric for selecting the best validation performance. Options: auto, accuracy, auc. """ import pandas as pd results = {'train': [], 'val': [], 'test': []} for run in os.listdir(dir): if run != 'agg': dict_name = name_to_dict(run) dir_run = osp.join(dir, run, 'agg') if osp.isdir(dir_run): for split in os.listdir(dir_run): dir_split = osp.join(dir_run, split) fname_stats = osp.join(dir_split, 'best.json') dict_stats = json_to_dict_list(fname_stats)[ -1] # get best val epoch rm_keys(dict_stats, ['lr', 'lr_std', 'eta', 'eta_std', 'params_std']) results[split].append({**dict_name, **dict_stats}) dir_out = osp.join(dir, 'agg') makedirs_rm_exist(dir_out) for key in results: if len(results[key]) > 0: results[key] = pd.DataFrame(results[key]) results[key] = results[key].sort_values( list(dict_name.keys()), ascending=[True] * len(dict_name)) fname = osp.join(dir_out, f'{key}_best.csv') results[key].to_csv(fname, index=False) results = {'train': [], 'val': [], 'test': []} for run in os.listdir(dir): if run != 'agg': dict_name = name_to_dict(run) dir_run = osp.join(dir, run, 'agg') if osp.isdir(dir_run): for split in os.listdir(dir_run): dir_split = osp.join(dir_run, split) fname_stats = osp.join(dir_split, 'stats.json') dict_stats = json_to_dict_list(fname_stats)[ -1] # get last epoch rm_keys(dict_stats, ['lr', 'lr_std', 'eta', 'eta_std', 'params_std']) results[split].append({**dict_name, **dict_stats}) dir_out = osp.join(dir, 'agg') for key in results: if len(results[key]) > 0: results[key] = pd.DataFrame(results[key]) results[key] = results[key].sort_values( list(dict_name.keys()), ascending=[True] * len(dict_name)) fname = osp.join(dir_out, f'{key}.csv') results[key].to_csv(fname, index=False) results = {'train': [], 'val': [], 'test': []} for run in os.listdir(dir): if run != 'agg': dict_name = name_to_dict(run) dir_run = osp.join(dir, run, 'agg') if osp.isdir(dir_run): for split in os.listdir(dir_run): dir_split = osp.join(dir_run, split) fname_stats = osp.join(dir_split, 'stats.json') dict_stats = json_to_dict_list( fname_stats) # get best epoch if metric_best == 'auto': metric = 'auc' if 'auc' in dict_stats[0] \ else 'accuracy' else: metric = metric_best performance_np = np.array( # noqa [stats[metric] for stats in dict_stats]) dict_stats = dict_stats[eval("performance_np.{}()".format( cfg.metric_agg))] rm_keys(dict_stats, ['lr', 'lr_std', 'eta', 'eta_std', 'params_std']) results[split].append({**dict_name, **dict_stats}) dir_out = osp.join(dir, 'agg') for key in results: if len(results[key]) > 0: results[key] = pd.DataFrame(results[key]) results[key] = results[key].sort_values( list(dict_name.keys()), ascending=[True] * len(dict_name)) fname = osp.join(dir_out, f'{key}_bestepoch.csv') results[key].to_csv(fname, index=False) print(f'Results aggregated across models saved in {dir_out}') ================================================ FILE: torch_geometric/graphgym/utils/comp_budget.py ================================================ import math from torch_geometric.graphgym.config import cfg, set_cfg from torch_geometric.graphgym.model_builder import create_model def params_count(model): """Computes the number of parameters. Args: model (nn.Module): PyTorch model """ return sum([p.numel() for p in model.parameters()]) def get_stats(): model = create_model(to_device=False, dim_in=1, dim_out=1) return params_count(model) def match_computation(stats_baseline, key=None, mode='sqrt'): """Match computation budget by modifying :obj:`cfg.gnn.dim_inner`.""" key = key or ['gnn', 'dim_inner'] stats = get_stats() if stats != stats_baseline: # Phase 1: fast approximation while True: if mode == 'sqrt': scale = math.sqrt(stats_baseline / stats) elif mode == 'linear': scale = stats_baseline / stats step = int(round(cfg[key[0]][key[1]] * scale)) \ - cfg[key[0]][key[1]] cfg[key[0]][key[1]] += step stats = get_stats() if abs(step) <= 1: break # Phase 2: fine tune flag_init = 1 if stats < stats_baseline else -1 step = 1 while True: cfg[key[0]][key[1]] += flag_init * step stats = get_stats() flag = 1 if stats < stats_baseline else -1 if stats == stats_baseline: return stats if flag != flag_init: if not cfg.model.match_upper: # stats is SMALLER if flag < 0: cfg[key[0]][key[1]] -= flag_init * step return get_stats() else: if flag > 0: cfg[key[0]][key[1]] -= flag_init * step return get_stats() return stats def dict_to_stats(cfg_dict): from yacs.config import CfgNode as CN set_cfg(cfg) cfg_new = CN(cfg_dict) cfg.merge_from_other_cfg(cfg_new) stats = get_stats() set_cfg(cfg) return stats def match_baseline_cfg(cfg_dict, cfg_dict_baseline, verbose=True): """Match the computational budget of a given baseline model. The current configuration dictionary will be modified and returned. Args: cfg_dict (dict): Current experiment's configuration cfg_dict_baseline (dict): Baseline configuration verbose (str, optional): If printing matched parameter conunts """ from yacs.config import CfgNode as CN stats_baseline = dict_to_stats(cfg_dict_baseline) set_cfg(cfg) cfg_new = CN(cfg_dict) cfg.merge_from_other_cfg(cfg_new) stats = match_computation(stats_baseline, key=['gnn', 'dim_inner']) if 'gnn' in cfg_dict: cfg_dict['gnn']['dim_inner'] = cfg.gnn.dim_inner else: cfg_dict['gnn'] = {'dim_inner', cfg.gnn.dim_inner} set_cfg(cfg) if verbose: print(f"Computational budget has matched - Baseline params: " f"{stats_baseline}, Current params: {stats}") return cfg_dict ================================================ FILE: torch_geometric/graphgym/utils/device.py ================================================ import os import subprocess import numpy as np import torch from torch_geometric.graphgym.config import cfg def get_gpu_memory_map(): """Get the current GPU usage.""" result = subprocess.check_output([ 'nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader' ], encoding='utf-8') gpu_memory = np.array([int(x) for x in result.strip().split('\n')]) return gpu_memory def get_current_gpu_usage(): """Get the current GPU memory usage.""" if cfg.gpu_mem and cfg.device != 'cpu' and torch.cuda.is_available(): result = subprocess.check_output([ 'nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv,nounits,noheader' ], encoding='utf-8') current_pid = os.getpid() used_memory = 0 for line in result.strip().split('\n'): line = line.split(', ') if current_pid == int(line[0]): used_memory += int(line[1]) return used_memory else: return -1 def auto_select_device(): r"""Auto select device for the current experiment.""" if cfg.accelerator == 'auto': if torch.cuda.is_available(): cfg.accelerator = 'cuda' cfg.devices = 1 else: cfg.accelerator = 'cpu' cfg.devices = None ================================================ FILE: torch_geometric/graphgym/utils/epoch.py ================================================ from torch_geometric.graphgym.config import cfg def is_train_eval_epoch(cur_epoch): """Determines if the model should be evaluated at the training epoch.""" return is_eval_epoch(cur_epoch) or not cfg.train.skip_train_eval def is_eval_epoch(cur_epoch): """Determines if the model should be evaluated at the current epoch.""" return ((cur_epoch + 1) % cfg.train.eval_period == 0 or cur_epoch == 0 or (cur_epoch + 1) == cfg.optim.max_epoch) def is_ckpt_epoch(cur_epoch): """Determines if the model should be evaluated at the current epoch.""" return ((cur_epoch + 1) % cfg.train.ckpt_period == 0 or (cur_epoch + 1) == cfg.optim.max_epoch) ================================================ FILE: torch_geometric/graphgym/utils/io.py ================================================ import ast import json import os import os.path as osp from torch_geometric.io import fs def string_to_python(string): try: return ast.literal_eval(string) except Exception: return string def dict_to_json(dict, fname): """Dump a :python:`Python` dictionary to a JSON file. Args: dict (dict): The :python:`Python` dictionary. fname (str): The output file name. """ with open(fname, 'a') as f: json.dump(dict, f) f.write('\n') def dict_list_to_json(dict_list, fname): """Dump a list of :python:`Python` dictionaries to a JSON file. Args: dict_list (list of dict): List of :python:`Python` dictionaries. fname (str): the output file name. """ with open(fname, 'a') as f: for dict in dict_list: json.dump(dict, f) f.write('\n') def json_to_dict_list(fname): dict_list = [] epoch_set = set() with open(fname) as f: lines = f.readlines() for line in lines: line = line.rstrip() dict = json.loads(line) if dict['epoch'] not in epoch_set: dict_list.append(dict) epoch_set.add(dict['epoch']) return dict_list def dict_to_tb(dict, writer, epoch): """Add a dictionary of statistics to a Tensorboard writer. Args: dict (dict): Statistics of experiments, the keys are attribute names, the values are the attribute values writer: Tensorboard writer object epoch (int): The current epoch """ for key in dict: writer.add_scalar(key, dict[key], epoch) def dict_list_to_tb(dict_list, writer): for dict in dict_list: assert 'epoch' in dict, 'Key epoch must exist in stats dict' dict_to_tb(dict, writer, dict['epoch']) def makedirs_rm_exist(dir): """Make a directory, remove any existing data. Args: dir (str): The directory to be created. """ if osp.isdir(dir): fs.rm(dir) os.makedirs(dir, exist_ok=True) ================================================ FILE: torch_geometric/graphgym/utils/plot.py ================================================ import os.path as osp def view_emb(emb, dir): """Visualize a embedding matrix. Args: emb (torch.tensor): Embedding matrix with shape (N, D). D is the feature dimension. dir (str): Output directory for the embedding figure. """ import matplotlib.pyplot as plt import seaborn as sns from sklearn.decomposition import PCA sns.set_context('poster') if emb.shape[1] > 2: pca = PCA(n_components=2) emb = pca.fit_transform(emb) plt.figure(figsize=(10, 10)) plt.scatter(emb[:, 0], emb[:, 1]) plt.savefig(osp.join(dir, 'emb_pca.png'), dpi=100) ================================================ FILE: torch_geometric/graphgym/utils/tools.py ================================================ class dummy_context(): """Default context manager that does nothing.""" def __enter__(self): return None def __exit__(self, exc_type, exc_value, traceback): return False ================================================ FILE: torch_geometric/hash_tensor.py ================================================ import functools import warnings from typing import ( Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, ) import numpy as np import torch import torch.utils._pytree as pytree import xxhash from torch import Tensor import torch_geometric.typing from torch_geometric.typing import CPUHashMap, CUDAHashMap aten = torch.ops.aten HANDLED_FUNCTIONS: Dict[Callable, Callable] = {} def implements(torch_function: Callable) -> Callable: r"""Registers a :pytorch:`PyTorch` function override.""" @functools.wraps(torch_function) def decorator(my_function: Callable) -> Callable: HANDLED_FUNCTIONS[torch_function] = my_function return my_function return decorator def as_key_tensor( key: Any, *, device: Optional[torch.device] = None, ) -> Tensor: try: key = torch.as_tensor(key, device=device) except Exception: device = device or torch.get_default_device() key = torch.tensor( [xxhash.xxh64(x).intdigest() & 0x7FFFFFFFFFFFFFFF for x in key], dtype=torch.int64, device=device) if key.element_size() == 1: key = key.view(torch.uint8) elif key.element_size() == 2: key = key.view(torch.int16) elif key.element_size() == 4: key = key.view(torch.int32) elif key.element_size() == 8: key = key.view(torch.int64) else: raise ValueError(f"Received invalid dtype '{key.dtype}' with " f"{key.element_size()} bytes") return key def get_hash_map(key: Tensor) -> Union[CPUHashMap, CUDAHashMap]: if torch_geometric.typing.WITH_CUDA_HASH_MAP and key.is_cuda: return CUDAHashMap(key, 0.5) if key.is_cuda: warnings.warn( "Fallback to CPU-based mapping algorithm which may " "cause slowdowns and device synchronization. Please " "install 'pyg-lib' for an accelerated 'HashTensor' " "implementation.", stacklevel=2) if torch_geometric.typing.WITH_CPU_HASH_MAP: return CPUHashMap(key.cpu(), -1) import pandas as pd return pd.CategoricalDtype( categories=key.cpu().numpy(), ordered=True, ) class HashTensor(Tensor): r"""A :pytorch:`null` :class:`torch.Tensor` that can be referenced by arbitrary keys rather than indices in the first dimension. :class:`HashTensor` sub-classes a general :pytorch:`null` :class:`torch.Tensor`, and extends it by CPU- and GPU-accelerated mapping routines. This allow for fast and efficient access to non-contiguous indices/keys while the underlying data is stored in a compact format. This representation is ideal for scenarios where one needs a fast mapping routine without relying on CPU-based external packages, and can be used, *e.g.*, to perform mapping of global indices to local indices during subgraph creation, or in data-processing pipelines to map non-contiguous input data into a contiguous space, such as * mapping of hashed node IDs to range :obj:`[0, num_nodes - 1]` * mapping of raw input data, *e.g.*, categorical data to range :obj:`[0, num_categories - 1]` Specifically, :class:`HashTensor` supports *any* keys of *any* type, *e.g.*, strings, timestamps, etc. .. code-block:: python from torch_geometric import HashTensor key = torch.tensor([1000, 100, 10000]) value = torch.randn(3, 4) tensor = HashTensor(key, value) assert tensor.size() == (3, 4) # Filtering: query = torch.tensor([10000, 1000]) out = tensor[query] assert out.equal(value[[2, 0]]) # Accessing non-existing keys: out = tensor[[10000, 0]] out.isnan() >>> tensor([[False, False, False, False], ... [True, True, True, True]) # If `value` is not given, indexing returns the position of `query` in # `key`, and `-1` otherwise: key = ['Animation', 'Comedy', 'Fantasy'] tensor = HashTensor(key) out = tensor[['Comedy', 'Romance']] >>> tensor([1, -1]) Args: key: The keys in the first dimension. value: The values to hold. dtype: The desired data type of the values of the returned tensor. device: The device of the returned tensor. """ _map: Union[Tensor, CPUHashMap, CUDAHashMap] _value: Optional[Tensor] _min_key: Tensor _max_key: Tensor @staticmethod def __new__( cls: Type, key: Any, value: Optional[Any] = None, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> 'HashTensor': if value is not None: value = torch.as_tensor(value, dtype=dtype, device=device) device = value.device key = as_key_tensor(key, device=device) if key.dim() != 1: raise ValueError(f"'key' data in '{cls.__name__}' needs to be " f"one-dimensional (got {key.dim()} dimensions)") if not key.is_contiguous(): raise ValueError(f"'key' data in '{cls.__name__}' needs to be " f"contiguous") if value is not None: if key.device != value.device: raise ValueError(f"'key' and 'value' data in '{cls.__name__}' " f"are expected to be on the same device (got " f"'{key.device}' and '{value.device}')") if key.numel() != value.size(0): raise ValueError(f"'key' and 'value' data in '{cls.__name__}' " f"are expected to have the same size in the " f"first dimension (got {key.size(0)} and " f"{value.size(0)})") min_key = key.min() if key.numel() > 0 else key.new_zeros(()) max_key = key.max() if key.numel() > 0 else key.new_zeros(()) _range = max_key - min_key # TODO Expose fixed threshold as argument. if (key.dtype in {torch.uint8, torch.int16} or _range <= 1_000_000 or _range <= 2 * key.numel()): _map = torch.full( size=(_range + 3, ), fill_value=-1, dtype=torch.int64, device=key.device, ) _map[key.long() - (min_key.long() - 1)] = torch.arange( key.numel(), dtype=_map.dtype, device=_map.device, ) else: _map = get_hash_map(key) return cls._from_data( _map, value, min_key, max_key, num_keys=key.numel(), dtype=dtype, ) # Private Methods ######################################################### @classmethod def _from_data( cls, _map: Union[Tensor, CPUHashMap, CUDAHashMap], value: Optional[Tensor], min_key: Tensor, max_key: Tensor, *, num_keys: int, dtype: Optional[torch.dtype], ) -> 'HashTensor': if value is not None: dtype = value.dtype size = value.size() stride = value.stride() layout = value.layout requires_grad = value.requires_grad else: dtype = dtype or torch.int64 size = torch.Size([num_keys]) stride = (1, ) layout = torch.strided requires_grad = False out = Tensor._make_wrapper_subclass( cls, size=size, strides=stride, dtype=dtype, device=min_key.device, layout=layout, requires_grad=requires_grad, ) assert isinstance(out, HashTensor) out._map = _map out._value = value out._min_key = min_key out._max_key = max_key return out @property def _key(self) -> Tensor: if isinstance(self._map, Tensor): mask = self._map >= 0 key = mask.nonzero().view(-1) - 1 key = key[self._map[mask]] elif (torch_geometric.typing.WITH_CUDA_HASH_MAP or torch_geometric.typing.WITH_CPU_HASH_MAP): key = self._map.keys().to(self.device) else: key = torch.from_numpy(self._map.categories.to_numpy()) return key.to(self.device) def _shallow_copy(self) -> 'HashTensor': return self._from_data( self._map, self._value, self._min_key, self._max_key, num_keys=self.size(0), dtype=self.dtype, ) def _get(self, query: Tensor) -> Tensor: if isinstance(self._map, Tensor): index = query.long() - (self._min_key.long() - 1) index = self._map[index.clamp_(min=0, max=self._map.numel() - 1)] elif torch_geometric.typing.WITH_CUDA_HASH_MAP and query.is_cuda: index = self._map.get(query) elif torch_geometric.typing.WITH_CPU_HASH_MAP: index = self._map.get(query.cpu()) else: import pandas as pd ser = pd.Series(query.cpu().numpy(), dtype=self._map) index = torch.from_numpy(ser.cat.codes.to_numpy().copy()).long() index = index.to(self.device) if self._value is None: return index.to(self.dtype) out = self._value[index] mask = index != -1 mask = mask.view([-1] + [1] * (out.dim() - 1)) fill_value = float('NaN') if out.is_floating_point() else -1 if torch_geometric.typing.WITH_PT20: other: Union[int, float, Tensor] = fill_value else: other = torch.full_like(out, fill_value) return out.where(mask, other) # Methods ################################################################# def as_tensor(self) -> Tensor: r"""Zero-copies the :class:`HashTensor` representation back to a :class:`torch.Tensor` representation. """ if self._value is not None: return self._value return torch.arange(self.size(0), dtype=self.dtype, device=self.device) # PyTorch/Python builtins ################################################# # Prevent auto-wrapping outputs back into the proper subclass type: __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore @classmethod def __torch_dispatch__( # type: ignore cls: Type, func: Callable[..., Any], types: Iterable[Type[Any]], args: Iterable[Tuple[Any, ...]] = (), kwargs: Optional[Dict[Any, Any]] = None, ) -> Any: # Hold a number of `HANDLED_FUNCTIONS` that implement specific # functions for valid `HashTensor` routines. if func in HANDLED_FUNCTIONS: return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) # For all other PyTorch functions, we treat them as vanilla tensors. args = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), args) if kwargs is not None: kwargs = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), kwargs) return func(*args, **(kwargs or {})) def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: attrs = ['_map', '_min_key', '_max_key'] if self._value is not None: attrs.append('_value') ctx = (self.size(0), self.dtype) return attrs, ctx @staticmethod def __tensor_unflatten__( inner_tensors: Dict[str, Any], ctx: Tuple[Any, ...], outer_size: Tuple[int, ...], outer_stride: Tuple[int, ...], ) -> 'HashTensor': return HashTensor._from_data( inner_tensors['_map'], inner_tensors.get('_value', None), inner_tensors['_min_key'], inner_tensors['_min_key'], num_keys=ctx[0], dtype=ctx[1], ) def __repr__(self) -> str: # type: ignore indent = len(f'{self.__class__.__name__}(') tensor_str = torch._tensor_str._tensor_str(self.as_tensor(), indent) return torch._tensor_str._str_intern(self, tensor_contents=tensor_str) def tolist(self) -> List[Any]: """""" # noqa: D419 return self.as_tensor().tolist() def numpy(self, *, force: bool = False) -> np.ndarray: """""" # noqa: D419 return self.as_tensor().numpy(force=force) def index_select( # type: ignore self, dim: int, index: Any, ) -> Union['HashTensor', Tensor]: """""" # noqa: D419 return torch.index_select(self, dim, index) def select( # type: ignore self, dim: int, index: Any, ) -> Union['HashTensor', Tensor]: """""" # noqa: D419 return torch.select(self, dim, index) def share_memory_(self) -> 'HashTensor': """""" # noqa: D419 if isinstance(self._map, Tensor): self._map.share_memory_() if self._value is not None: self._value.share_memory_() self._min_key.share_memory_() self._max_key.share_memory_() return self def is_shared(self) -> bool: """""" # noqa: D419 return self._min_key.is_shared() def detach_(self) -> 'HashTensor': """""" # noqa: D419 if self._value is not None: self._value.detach_() return super().detach_() # type: ignore def __getitem__(self, indices: Any) -> Union['HashTensor', Tensor]: if not isinstance(indices, tuple): indices = (indices, ) assert len(indices) > 0 # We convert any index tensor in the first dimension into a tensor. # This means that downstream handling (i.e. in `aten.index.Tensor`) # needs to take this pre-conversion into account. However, detecting # whether the first dimension is indexed can be tricky at times: # * We need to take into account `Ellipsis` # * We need to take any unsqueezing into account if indices[0] is Ellipsis and len(indices) > 1: nonempty_indices = [i for i in indices[1:] if i is not None] if len(nonempty_indices) == self.dim(): indices = indices[1:] if isinstance(indices[0], (int, bool)): index: Union[int, Tensor] = int(as_key_tensor([indices[0]])) indices = (index, ) + indices[1:] elif isinstance(indices[0], (Tensor, list, np.ndarray)): index = as_key_tensor(indices[0], device=self.device) indices = (index, ) + indices[1:] indices = indices[0] if len(indices) == 1 else indices return super().__getitem__(indices) @implements(aten.alias.default) def _alias(tensor: HashTensor) -> HashTensor: return tensor._shallow_copy() @implements(aten.clone.default) def _clone( tensor: HashTensor, *, memory_format: torch.memory_format = torch.preserve_format, ) -> HashTensor: value = tensor._value if value is not None: value = aten.clone.default(value, memory_format=memory_format) return tensor._from_data( tensor._map, # NOTE No reason to do clone since it is read-only. value, tensor._min_key, # NOTE No reason to do clone since it is read-only. tensor._max_key, # NOTE No reason to do clone since it is read-only. num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.detach.default) def _detach(tensor: HashTensor) -> HashTensor: value = tensor._value if value is not None: value = aten.detach.default(value) return tensor._from_data( tensor._map, value, tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten._to_copy.default) def _to_copy( tensor: HashTensor, *, dtype: Optional[torch.dtype] = None, layout: Optional[torch.layout] = None, device: Optional[torch.device] = None, pin_memory: bool = False, non_blocking: bool = False, memory_format: Optional[torch.memory_format] = None, ) -> HashTensor: value = tensor._value if value is not None: value = aten._to_copy.default( value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, non_blocking=non_blocking, memory_format=memory_format, ) min_key = aten._to_copy.default(tensor._min_key, device=device) max_key = aten._to_copy.default(tensor._max_key, device=device) _map = tensor._map if isinstance(_map, Tensor): _map = aten._to_copy.default(_map, device=device) # Only convert `_map` in case `CUDAHashMap` exists - otherwise we use # CPU-based mapping anyway and there is no need for a copy. elif (torch_geometric.typing.WITH_CUDA_HASH_MAP and tensor.is_cuda and tensor.device != min_key.device): key = _map.keys() key = aten._to_copy.default(key, device=device) _map = get_hash_map(key) return tensor._from_data( _map, value, min_key, max_key, num_keys=tensor.size(0), dtype=dtype or tensor.dtype, ) @implements(aten._pin_memory.default) def _pin_memory(tensor: HashTensor) -> HashTensor: _map = tensor._map if isinstance(_map, Tensor): _map = aten._pin_memory.default(_map) value = tensor._value if value is not None: value = aten._pin_memory.default(value) return tensor._from_data( _map, value, aten._pin_memory.default(tensor._min_key), aten._pin_memory.default(tensor._max_key), num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.unsqueeze.default) def _unsqueeze(tensor: HashTensor, dim: int) -> HashTensor: if dim == 0 or dim == -(tensor.dim() + 1): raise IndexError(f"Cannot unsqueeze '{tensor.__class__.__name__}' in " f"the first dimension. Please call `as_tensor()` " f"beforehand") return tensor._from_data( tensor._map, aten.unsqueeze.default(tensor.as_tensor(), dim), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.squeeze.default) def _squeeze_default(tensor: HashTensor) -> HashTensor: if tensor._value is None: return tensor._shallow_copy() value = tensor.as_tensor() for d in range(tensor.dim() - 1, 0, -1): value = value.squeeze(d) return tensor._from_data( tensor._map, value, tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.squeeze.dim) @implements(getattr(aten.squeeze, 'dims', aten.squeeze.dim)) def _squeeze_dim( tensor: HashTensor, dim: Union[int, List[int]], ) -> HashTensor: if isinstance(dim, int): dim = [dim] for d in dim: if d < -tensor.dim() or d >= tensor.dim(): raise IndexError(f"Dimension out of range (expected to be in " f"range of [{-tensor.dim()}, {tensor.dim()-1}], " f"but got {d})") if tensor._value is None: return tensor._shallow_copy() value = tensor.as_tensor() for d in dim[::-1]: if d != 0 and d != -tensor.dim(): value = value.squeeze(d) return tensor._from_data( tensor._map, value, tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.slice.Tensor) def _slice( tensor: HashTensor, dim: int, start: Optional[int] = None, end: Optional[int] = None, step: int = 1, ) -> HashTensor: if dim == 0 or dim == -tensor.dim(): copy = start is None or start == 0 or start <= -tensor.size(0) copy &= end is None or end > tensor.size(0) copy &= step == 1 if copy: return tensor._shallow_copy() key = aten.slice.Tensor(tensor._key, 0, start, end, step) value = aten.slice.Tensor(tensor.as_tensor(), 0, start, end, step) return tensor.__class__(key, value) return tensor._from_data( tensor._map, aten.slice.Tensor(tensor.as_tensor(), dim, start, end, step), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) # Since PyTorch does only allow PyTorch tensors as indices in `index_select`, # we need to create a wrapper function and monkey patch `index_select` :( _old_index_select = torch.index_select def _new_index_select( input: Tensor, dim: Union[int, str], index: Tensor, out: Optional[Tensor] = None, ) -> Tensor: if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()): raise IndexError(f"Dimension out of range (expected to be in range of " f"[{-input.dim()}, {input.dim()-1}], but got {dim})") # We convert any index tensor in the first dimension into a tensor. This # means that downstream handling (i.e. in `aten.index_select.default`) # needs to take this pre-conversion into account. if (not torch.jit.is_scripting() and isinstance(input, HashTensor) and isinstance(dim, int) and (dim == 0 or dim == -input.dim())): index = as_key_tensor(index, device=input.device) if isinstance(dim, int): # Type narrowing... if out is None: return _old_index_select(input, dim, index) else: return _old_index_select(input, dim, index, out=out) else: if out is None: return _old_index_select(input, dim, index) else: return _old_index_select(input, dim, index, out=out) torch.index_select = _new_index_select # type: ignore @implements(aten.index_select.default) def _index_select( tensor: HashTensor, dim: int, index: Tensor, ) -> Union[HashTensor, Tensor]: if dim == 0 or dim == -tensor.dim(): return tensor._get(index) return tensor._from_data( tensor._map, aten.index_select.default(tensor.as_tensor(), dim, index), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) # Since PyTorch does only allow PyTorch tensors as indices in `select`, we need # to create a wrapper function and monkey patch `select` :( _old_select = torch.select def _new_select( input: Tensor, dim: Union[int, str], index: int, ) -> Tensor: if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()): raise IndexError(f"Dimension out of range (expected to be in range of " f"[{-input.dim()}, {input.dim()-1}], but got {dim})") # We convert any index in the first dimension into an integer. This means # that downstream handling (i.e. in `aten.select.int`) needs to take this # pre-conversion into account. if (not torch.jit.is_scripting() and isinstance(input, HashTensor) and isinstance(dim, int) and (dim == 0 or dim == -input.dim())): index = int(as_key_tensor([index])) if isinstance(dim, int): # Type narrowing... return _old_select(input, dim, index) else: return _old_select(input, dim, index) torch.select = _new_select # type: ignore @implements(aten.select.int) def _select( tensor: HashTensor, dim: int, index: int, ) -> Union[HashTensor, Tensor]: if dim == 0 or dim == -tensor.dim(): key = torch.tensor( [index], dtype=tensor._min_key.dtype, device=tensor.device, ) return tensor._get(key).squeeze(0) return tensor._from_data( tensor._map, aten.select.int(tensor.as_tensor(), dim, index), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) @implements(aten.index.Tensor) def _index( tensor: HashTensor, indices: List[Optional[Tensor]], ) -> Union[HashTensor, Tensor]: assert len(indices) > 0 if indices[0] is not None: out = tensor._get(indices[0]) if len(indices) > 1: out = aten.index.Tensor(out, [None] + indices[1:]) return out return tensor._from_data( tensor._map, aten.index.Tensor(tensor.as_tensor(), indices), tensor._min_key, tensor._max_key, num_keys=tensor.size(0), dtype=tensor.dtype, ) ================================================ FILE: torch_geometric/home.py ================================================ import os import os.path as osp from typing import Optional ENV_PYG_HOME = 'PYG_HOME' DEFAULT_CACHE_DIR = osp.join('~', '.cache', 'pyg') _home_dir: Optional[str] = None def get_home_dir() -> str: r"""Get the cache directory used for storing all :pyg:`PyG`-related data. If :meth:`set_home_dir` is not called, the path is given by the environment variable :obj:`$PYG_HOME` which defaults to :obj:`"~/.cache/pyg"`. """ if _home_dir is not None: return _home_dir return osp.expanduser(os.getenv(ENV_PYG_HOME, DEFAULT_CACHE_DIR)) def set_home_dir(path: str) -> None: r"""Set the cache directory used for storing all :pyg:`PyG`-related data. Args: path (str): The path to a local folder. """ global _home_dir _home_dir = path ================================================ FILE: torch_geometric/index.py ================================================ import functools from typing import ( Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Type, Union, ) import numpy as np import torch import torch.utils._pytree as pytree from torch import Tensor from torch_geometric.typing import INDEX_DTYPES aten = torch.ops.aten HANDLED_FUNCTIONS: Dict[Callable, Callable] = {} def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor: index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device) return index.repeat_interleave(ptr.diff(), output_size=output_size) def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor: if size is None: size = int(index.max()) + 1 if index.numel() > 0 else 0 return torch._convert_indices_from_coo_to_csr( index, size, out_int32=index.dtype != torch.int64) class CatMetadata(NamedTuple): nnz: List[int] dim_size: List[Optional[int]] is_sorted: List[bool] def implements(torch_function: Callable) -> Callable: r"""Registers a :pytorch:`PyTorch` function override.""" @functools.wraps(torch_function) def decorator(my_function: Callable) -> Callable: HANDLED_FUNCTIONS[torch_function] = my_function return my_function return decorator def assert_valid_dtype(tensor: Tensor) -> None: if tensor.dtype not in INDEX_DTYPES: raise ValueError(f"'Index' holds an unsupported data type " f"(got '{tensor.dtype}', but expected one of " f"{INDEX_DTYPES})") def assert_one_dimensional(tensor: Tensor) -> None: if tensor.dim() != 1: raise ValueError(f"'Index' needs to be one-dimensional " f"(got {tensor.dim()} dimensions)") def assert_contiguous(tensor: Tensor) -> None: if not tensor.is_contiguous(): raise ValueError("'Index' needs to be contiguous. Please call " "`index.contiguous()` before proceeding.") def assert_sorted(func: Callable) -> Callable: @functools.wraps(func) def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any: if not self.is_sorted: cls_name = self.__class__.__name__ raise ValueError( f"Cannot call '{func.__name__}' since '{cls_name}' is not " f"sorted. Please call `{cls_name}.sort()` first.") return func(self, *args, **kwargs) return wrapper class Index(Tensor): r"""A one-dimensional :obj:`index` tensor with additional (meta)data attached. :class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds indices of shape :obj:`[num_indices]`. While :class:`Index` sub-classes a general :pytorch:`null` :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*: * :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*, the size of a dimension that can be indexed via :obj:`index`. By default, it is inferred as :obj:`dim_size=index.max() + 1`. * :obj:`is_sorted`: Whether indices are sorted in ascending order. Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR conversion in case its representation is sorted. Caches are filled based on demand (*e.g.*, when calling :meth:`Index.get_indptr`), or when explicitly requested via :meth:`Index.fill_cache_`, and are maintained and adjusted over its lifespan. This representation ensures optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG` workflows. .. code-block:: python from torch_geometric import Index index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True) assert index.dim_size == 3 assert index.is_sorted # Flipping order: index.flip(0) >>> Index([[2, 1, 1, 0], dim_size=3) assert not index.is_sorted # Filtering: mask = torch.tensor([True, True, True, False]) index[:, mask] >>> Index([[0, 1, 1], dim_size=3, is_sorted=True) assert index.is_sorted """ # See "https://pytorch.org/docs/stable/notes/extending.html" # for a basic tutorial on how to subclass `torch.Tensor`. # The underlying tensor representation: _data: Tensor # The size of the underlying sparse vector, e.g. `_data.max() + 1` : _dim_size: Optional[int] = None # Whether the `index` representation is sorted: _is_sorted: bool = False # A cache for its compressed representation: _indptr: Optional[Tensor] = None # Whenever we perform a concatenation of indices, we cache the original # metadata to be able to reconstruct individual indices: _cat_metadata: Optional[CatMetadata] = None @staticmethod def __new__( cls: Type, data: Any, *args: Any, dim_size: Optional[int] = None, is_sorted: bool = False, **kwargs: Any, ) -> 'Index': if not isinstance(data, Tensor): data = torch.tensor(data, *args, **kwargs) elif len(args) > 0: raise TypeError( f"new() received an invalid combination of arguments - got " f"(Tensor, {', '.join(str(type(arg)) for arg in args)})") elif len(kwargs) > 0: raise TypeError(f"new() received invalid keyword arguments - got " f"{set(kwargs.keys())})") assert isinstance(data, Tensor) indptr: Optional[Tensor] = None if isinstance(data, cls): # If passed `Index`, inherit metadata: indptr = data._indptr dim_size = dim_size or data.dim_size is_sorted = is_sorted or data.is_sorted assert_valid_dtype(data) assert_one_dimensional(data) assert_contiguous(data) out = Tensor._make_wrapper_subclass( cls, size=data.size(), strides=data.stride(), dtype=data.dtype, device=data.device, layout=data.layout, requires_grad=False, ) assert isinstance(out, Index) # Attach metadata: out._data = data out._dim_size = dim_size out._is_sorted = is_sorted out._indptr = indptr if isinstance(data, cls): out._data = data._data # Reset metadata if cache is invalidated: if dim_size is not None and dim_size != data.dim_size: out._indptr = None return out # Validation ############################################################## def validate(self) -> 'Index': r"""Validates the :class:`Index` representation. In particular, it ensures that * it only holds valid indices. * the sort order is correctly set. """ assert_valid_dtype(self._data) assert_one_dimensional(self._data) assert_contiguous(self._data) if self.numel() > 0 and self._data.min() < 0: raise ValueError(f"'{self.__class__.__name__}' contains negative " f"indices (got {int(self.min())})") if (self.numel() > 0 and self.dim_size is not None and self._data.max() >= self.dim_size): raise ValueError(f"'{self.__class__.__name__}' contains larger " f"indices than its registered size " f"(got {int(self._data.max())}, but expected " f"values smaller than {self.dim_size})") if self.is_sorted and (self._data.diff() < 0).any(): raise ValueError(f"'{self.__class__.__name__}' is not sorted") return self # Properties ############################################################## @property def dim_size(self) -> Optional[int]: r"""The size of the underlying sparse vector.""" return self._dim_size @property def is_sorted(self) -> bool: r"""Returns whether indices are sorted in ascending order.""" return self._is_sorted @property def dtype(self) -> torch.dtype: # type: ignore # TODO Remove once PyTorch does not override `dtype` in `DataLoader`. return self._data.dtype # Cache Interface ######################################################### def get_dim_size(self) -> int: r"""The size of the underlying sparse vector. Automatically computed and cached when not explicitly set. """ if self._dim_size is None: dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0 self._dim_size = dim_size assert isinstance(self._dim_size, int) return self._dim_size def dim_resize_(self, dim_size: Optional[int]) -> 'Index': r"""Assigns or re-assigns the size of the underlying sparse vector.""" if self.is_sorted and self._indptr is not None: if dim_size is None: self._indptr = None elif self._indptr.numel() - 1 >= dim_size: self._indptr = self._indptr[:dim_size + 1] else: fill_value = self._indptr.new_full( (dim_size - self._indptr.numel() + 1, ), fill_value=self._indptr[-1], # type: ignore ) self._indptr = torch.cat([self._indptr, fill_value], dim=0) self._dim_size = dim_size return self @assert_sorted def get_indptr(self) -> Tensor: r"""Returns the compressed index representation in case :class:`Index` is sorted. """ if self._indptr is None: self._indptr = index2ptr(self._data, self.get_dim_size()) assert isinstance(self._indptr, Tensor) return self._indptr def fill_cache_(self) -> 'Index': r"""Fills the cache with (meta)data information.""" self.get_dim_size() if self.is_sorted: self.get_indptr() return self # Methods ################################################################# def share_memory_(self) -> 'Index': """""" # noqa: D419 self._data.share_memory_() if self._indptr is not None: self._indptr.share_memory_() return self def is_shared(self) -> bool: """""" # noqa: D419 return self._data.is_shared() def as_tensor(self) -> Tensor: r"""Zero-copies the :class:`Index` representation back to a :class:`torch.Tensor` representation. """ return self._data # PyTorch/Python builtins ################################################# def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: attrs = ['_data'] if self._indptr is not None: attrs.append('_indptr') ctx = ( self._dim_size, self._is_sorted, self._cat_metadata, ) return attrs, ctx @staticmethod def __tensor_unflatten__( inner_tensors: Dict[str, Any], ctx: Tuple[Any, ...], outer_size: Tuple[int, ...], outer_stride: Tuple[int, ...], ) -> 'Index': index = Index( inner_tensors['_data'], dim_size=ctx[0], is_sorted=ctx[1], ) index._indptr = inner_tensors.get('_indptr', None) index._cat_metadata = ctx[2] return index # Prevent auto-wrapping outputs back into the proper subclass type: __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore @classmethod def __torch_dispatch__( # type: ignore cls: Type, func: Callable[..., Any], types: Iterable[Type[Any]], args: Iterable[Tuple[Any, ...]] = (), kwargs: Optional[Dict[Any, Any]] = None, ) -> Any: # `Index` should be treated as a regular PyTorch tensor for all # standard PyTorch functionalities. However, # * some of its metadata can be transferred to new functions, e.g., # `torch.narrow()` can inherit the `is_sorted` property. # * not all operations lead to valid `Index` tensors again, e.g., # `torch.sum()` does not yield a `Index` as its output, or # `torch.stack() violates the [*] shape assumption. # To account for this, we hold a number of `HANDLED_FUNCTIONS` that # implement specific functions for valid `Index` routines. if func in HANDLED_FUNCTIONS: return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) # For all other PyTorch functions, we treat them as vanilla tensors. args = pytree.tree_map_only(Index, lambda x: x._data, args) if kwargs is not None: kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs) return func(*args, **(kwargs or {})) def __repr__(self) -> str: # type: ignore prefix = f'{self.__class__.__name__}(' indent = len(prefix) tensor_str = torch._tensor_str._tensor_str(self._data, indent) suffixes = [] if self.dim_size is not None: suffixes.append(f'dim_size={self.dim_size}') if (self.device.type != torch._C._get_default_device() or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index) or (self.device.type == 'mps')): suffixes.append(f"device='{self.device}'") if self.dtype != torch.int64: suffixes.append(f'dtype={self.dtype}') if self.is_sorted: suffixes.append('is_sorted=True') return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=False) def tolist(self) -> List[Any]: """""" # noqa: D419 return self._data.tolist() def numpy(self, *, force: bool = False) -> np.ndarray: """""" # noqa: D419 return self._data.numpy(force=force) # Helpers ################################################################# def _shallow_copy(self) -> 'Index': out = Index(self._data) out._dim_size = self._dim_size out._is_sorted = self._is_sorted out._indptr = self._indptr out._cat_metadata = self._cat_metadata return out def _clear_metadata(self) -> 'Index': self._dim_size = None self._is_sorted = False self._indptr = None self._cat_metadata = None return self def apply_( tensor: Index, fn: Callable, *args: Any, **kwargs: Any, ) -> Union[Index, Tensor]: data = fn(tensor._data, *args, **kwargs) if data.dtype not in INDEX_DTYPES: return data if tensor._data.data_ptr() != data.data_ptr(): out = Index(data) else: # In-place: tensor._data = data out = tensor # Copy metadata: out._dim_size = tensor._dim_size out._is_sorted = tensor._is_sorted out._cat_metadata = tensor._cat_metadata # Convert cache: if tensor._indptr is not None: out._indptr = fn(tensor._indptr, *args, **kwargs) return out @implements(aten.clone.default) def _clone( tensor: Index, *, memory_format: torch.memory_format = torch.preserve_format, ) -> Index: out = apply_(tensor, aten.clone.default, memory_format=memory_format) assert isinstance(out, Index) return out @implements(aten._to_copy.default) def _to_copy( tensor: Index, *, dtype: Optional[torch.dtype] = None, layout: Optional[torch.layout] = None, device: Optional[torch.device] = None, pin_memory: bool = False, non_blocking: bool = False, memory_format: Optional[torch.memory_format] = None, ) -> Union[Index, Tensor]: return apply_( tensor, aten._to_copy.default, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, non_blocking=non_blocking, memory_format=memory_format, ) @implements(aten.alias.default) def _alias(tensor: Index) -> Index: return tensor._shallow_copy() @implements(aten._pin_memory.default) def _pin_memory(tensor: Index) -> Index: out = apply_(tensor, aten._pin_memory.default) assert isinstance(out, Index) return out @implements(aten.sort.default) def _sort( tensor: Index, dim: int = -1, descending: bool = False, ) -> Tuple[Index, Tensor]: if tensor.is_sorted and not descending: return tensor, torch.arange(tensor._data.numel(), device=tensor._data.device) data, perm = aten.sort.default(tensor._data, dim, descending) out = Index(data) out._dim_size = tensor._dim_size if not descending: out._is_sorted = True return out, perm @implements(aten.sort.stable) def _sort_stable( tensor: Index, *, stable: bool = False, dim: int = -1, descending: bool = False, ) -> Tuple[Index, Tensor]: if tensor.is_sorted and not descending: return tensor, torch.arange(tensor._data.numel(), device=tensor._data.device) data, perm = aten.sort.stable(tensor._data, stable=stable, dim=dim, descending=descending) out = Index(data) out._dim_size = tensor._dim_size if not descending: out._is_sorted = True return out, perm @implements(aten.cat.default) def _cat( tensors: List[Union[Index, Tensor]], dim: int = 0, ) -> Union[Index, Tensor]: data_list = pytree.tree_map_only(Index, lambda x: x._data, tensors) data = aten.cat.default(data_list, dim=dim) if any([not isinstance(tensor, Index) for tensor in tensors]): return data out = Index(data) nnz_list = [t.numel() for t in tensors] dim_size_list = [t.dim_size for t in tensors] # type: ignore is_sorted_list = [t.is_sorted for t in tensors] # type: ignore # Post-process `dim_size`: total_dim_size: Optional[int] = 0 for dim_size in dim_size_list: if dim_size is None: total_dim_size = None break assert isinstance(total_dim_size, int) total_dim_size = max(dim_size, total_dim_size) out._dim_size = total_dim_size out._cat_metadata = CatMetadata( nnz=nnz_list, dim_size=dim_size_list, is_sorted=is_sorted_list, ) return out @implements(aten.flip.default) def _flip( input: Index, dims: Union[List[int], Tuple[int, ...]], ) -> Index: data = aten.flip.default(input._data, dims) out = Index(data) out._dim_size = input.dim_size return out @implements(aten.index_select.default) def _index_select( input: Union[Index, Tensor], dim: int, index: Union[Index, Tensor], ) -> Union[Index, Tensor]: out = aten.index_select.default( input._data if isinstance(input, Index) else input, dim, index._data if isinstance(index, Index) else index, ) if isinstance(input, Index): out = Index(out) out._dim_size = input.dim_size return out @implements(aten.slice.Tensor) def _slice( input: Index, dim: int, start: Optional[int] = None, end: Optional[int] = None, step: int = 1, ) -> Index: if ((start is None or start <= 0 or start <= -input.size(dim)) and (end is None or end > input.size(dim)) and step == 1): return input._shallow_copy() # No-op. data = aten.slice.Tensor(input._data, dim, start, end, step) if step != 1: data = data.contiguous() out = Index(data) out._dim_size = input.dim_size # NOTE We could potentially maintain the `indptr` attribute here, # but it is not really clear if this is worth it. The most important # information `is_sorted` needs to be maintained though: if step >= 0: out._is_sorted = input.is_sorted return out @implements(aten.index.Tensor) def _index( input: Union[Index, Tensor], indices: List[Optional[Union[Tensor, Index]]], ) -> Union[Index, Tensor]: if not isinstance(input, Index): indices = pytree.tree_map_only(Index, lambda x: x._data, indices) return aten.index.Tensor(input, indices) data = aten.index.Tensor(input._data, indices) if data.dim() != 1: return data assert len(indices) == 1 index = indices[0] assert index is not None out = Index(data) if index.dtype in (torch.bool, torch.uint8): # 1. `index[mask]`. out._dim_size = input.dim_size out._is_sorted = input.is_sorted else: # 2. `index[index]`. out._dim_size = input.dim_size return out @implements(aten.add.Tensor) def _add( input: Union[int, Tensor, Index], other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Union[Index, Tensor]: data = aten.add.Tensor( input._data if isinstance(input, Index) else input, other._data if isinstance(other, Index) else other, alpha=alpha, ) if data.dtype not in INDEX_DTYPES: return data if data.dim() != 1: return data out = Index(data) if isinstance(input, Tensor) and input.numel() <= 1: input = int(input) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): assert isinstance(input, Index) if input.dim_size is not None: out._dim_size = input.dim_size + alpha * other out._is_sorted = input.is_sorted elif isinstance(input, int): assert isinstance(other, Index) if other.dim_size is not None: out._dim_size = input + alpha * other.dim_size out._is_sorted = other.is_sorted elif isinstance(input, Index) and isinstance(other, Index): if input.dim_size is not None and other.dim_size is not None: out._dim_size = input.dim_size + alpha * other.dim_size return out @implements(aten.add_.Tensor) def add_( input: Index, other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Index: dim_size = input.dim_size is_sorted = input.is_sorted input._clear_metadata() aten.add_.Tensor( input._data, other._data if isinstance(other, Index) else other, alpha=alpha, ) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): if dim_size is not None: input._dim_size = dim_size + alpha * other input._is_sorted = is_sorted elif isinstance(other, Index): if dim_size is not None and other.dim_size is not None: input._dim_size = dim_size + alpha * other.dim_size return input @implements(aten.sub.Tensor) def _sub( input: Union[int, Tensor, Index], other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Union[Index, Tensor]: data = aten.sub.Tensor( input._data if isinstance(input, Index) else input, other._data if isinstance(other, Index) else other, alpha=alpha, ) if data.dtype not in INDEX_DTYPES: return data if data.dim() != 1: return data out = Index(data) if not isinstance(input, Index): return out if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): if input.dim_size is not None: out._dim_size = input.dim_size - alpha * other out._is_sorted = input.is_sorted return out @implements(aten.sub_.Tensor) def sub_( input: Index, other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Index: dim_size = input.dim_size is_sorted = input.is_sorted input._clear_metadata() aten.sub_.Tensor( input._data, other._data if isinstance(other, Index) else other, alpha=alpha, ) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): if dim_size is not None: input._dim_size = dim_size - alpha * other input._is_sorted = is_sorted return input ================================================ FILE: torch_geometric/inspector.py ================================================ import inspect import re import sys import typing from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union import torch from torch import Tensor class Parameter(NamedTuple): name: str type: Type type_repr: str default: Any class Signature(NamedTuple): param_dict: Dict[str, Parameter] return_type: Type return_type_repr: str class Inspector: r"""Inspects a given class and collects information about its instance methods. Args: cls (Type): The class to inspect. """ def __init__(self, cls: Type): self._cls = cls self._signature_dict: Dict[str, Signature] = {} self._source_dict: Dict[str, str] = {} def _get_modules(self, cls: Type) -> List[str]: from torch_geometric.nn import MessagePassing modules: List[str] = [] for base_cls in cls.__bases__: if base_cls not in {object, torch.nn.Module, MessagePassing}: modules.extend(self._get_modules(base_cls)) modules.append(cls.__module__) return modules @property def _modules(self) -> List[str]: return self._get_modules(self._cls) @property def _globals(self) -> Dict[str, Any]: out: Dict[str, Any] = {} for module in self._modules: out.update(sys.modules[module].__dict__) return out def __repr__(self) -> str: return f'{self.__class__.__name__}({self._cls.__name__})' def eval_type(self, value: Any) -> Type: r"""Returns the type hint of a string.""" return eval_type(value, self._globals) def type_repr(self, obj: Any) -> str: r"""Returns the type hint representation of an object.""" return type_repr(obj, self._globals) def implements(self, func_name: str) -> bool: r"""Returns :obj:`True` in case the inspected class implements the :obj:`func_name` method. Args: func_name (str): The function name to check for existence. """ func = getattr(self._cls, func_name, None) if not callable(func): return False return not getattr(func, '__isabstractmethod__', False) # Inspecting Method Signatures ############################################ def inspect_signature( self, func: Union[Callable, str], exclude: Optional[List[Union[str, int]]] = None, ) -> Signature: r"""Inspects the function signature of :obj:`func` and returns a tuple of parameter types and return type. Args: func (callabel or str): The function. exclude (list[int or str]): A list of parameters to exclude, either given by their name or index. (default: :obj:`None`) """ if isinstance(func, str): func = getattr(self._cls, func) assert callable(func) if func.__name__ in self._signature_dict: return self._signature_dict[func.__name__] signature = inspect.signature(func) params = [p for p in signature.parameters.values() if p.name != 'self'] param_dict: Dict[str, Parameter] = {} for i, param in enumerate(params): if exclude is not None and (i in exclude or param.name in exclude): continue param_type = param.annotation # Mimic TorchScript to auto-infer `Tensor` on non-present types: param_type = Tensor if param_type is inspect._empty else param_type param_dict[param.name] = Parameter( name=param.name, type=self.eval_type(param_type), type_repr=self.type_repr(param_type), default=param.default, ) return_type = signature.return_annotation # Mimic TorchScript to auto-infer `Tensor` on non-present types: return_type = Tensor if return_type is inspect._empty else return_type self._signature_dict[func.__name__] = Signature( param_dict=param_dict, return_type=self.eval_type(return_type), return_type_repr=self.type_repr(return_type), ) return self._signature_dict[func.__name__] def get_signature( self, func: Union[Callable, str], exclude: Optional[List[str]] = None, ) -> Signature: r"""Returns the function signature of the inspected function :obj:`func`. Args: func (callabel or str): The function. exclude (list[str], optional): The parameter names to exclude. (default: :obj:`None`) """ func_name = func if isinstance(func, str) else func.__name__ signature = self._signature_dict.get(func_name) if signature is None: raise IndexError(f"Could not access signature for function " f"'{func_name}'. Did you forget to inspect it?") if exclude is None: return signature param_dict = { name: param for name, param in signature.param_dict.items() if name not in exclude } return Signature( param_dict=param_dict, return_type=signature.return_type, return_type_repr=signature.return_type_repr, ) def remove_signature( self, func: Union[Callable, str], ) -> Optional[Signature]: r"""Removes the inspected function signature :obj:`func`. Args: func (callabel or str): The function. """ func_name = func if isinstance(func, str) else func.__name__ return self._signature_dict.pop(func_name, None) def get_param_dict( self, func: Union[Callable, str], exclude: Optional[List[str]] = None, ) -> Dict[str, Parameter]: r"""Returns the parameters of the inspected function :obj:`func`. Args: func (str or callable): The function. exclude (list[str], optional): The parameter names to exclude. (default: :obj:`None`) """ return self.get_signature(func, exclude).param_dict def get_params( self, func: Union[Callable, str], exclude: Optional[List[str]] = None, ) -> List[Parameter]: r"""Returns the parameters of the inspected function :obj:`func`. Args: func (str or callable): The function. exclude (list[str], optional): The parameter names to exclude. (default: :obj:`None`) """ return list(self.get_param_dict(func, exclude).values()) def get_flat_param_dict( self, funcs: List[Union[Callable, str]], exclude: Optional[List[str]] = None, ) -> Dict[str, Parameter]: r"""Returns the union of parameters of all inspected functions in :obj:`funcs`. Args: funcs (list[str or callable]): The functions. exclude (list[str], optional): The parameter names to exclude. (default: :obj:`None`) """ param_dict: Dict[str, Parameter] = {} for func in funcs: params = self.get_params(func, exclude) for param in params: expected = param_dict.get(param.name) if expected is not None and param.type != expected.type: raise ValueError(f"Found inconsistent types for argument " f"'{param.name}'. Expected type " f"'{expected.type}' but found type " f"'{param.type}'.") if expected is not None and param.default != expected.default: if (param.default is not inspect._empty and expected.default is not inspect._empty): raise ValueError(f"Found inconsistent defaults for " f"argument '{param.name}'. Expected " f"'{expected.default}' but found " f"'{param.default}'.") default = expected.default if default is inspect._empty: default = param.default param_dict[param.name] = Parameter( name=param.name, type=param.type, type_repr=param.type_repr, default=default, ) if expected is None: param_dict[param.name] = param return param_dict def get_flat_params( self, funcs: List[Union[Callable, str]], exclude: Optional[List[str]] = None, ) -> List[Parameter]: r"""Returns the union of parameters of all inspected functions in :obj:`funcs`. Args: funcs (list[str or callable]): The functions. exclude (list[str], optional): The parameter names to exclude. (default: :obj:`None`) """ return list(self.get_flat_param_dict(funcs, exclude).values()) def get_param_names( self, func: Union[Callable, str], exclude: Optional[List[str]] = None, ) -> List[str]: r"""Returns the parameter names of the inspected function :obj:`func`. Args: func (str or callable): The function. exclude (list[str], optional): The parameter names to exclude. (default: :obj:`None`) """ return list(self.get_param_dict(func, exclude).keys()) def get_flat_param_names( self, funcs: List[Union[Callable, str]], exclude: Optional[List[str]] = None, ) -> List[str]: r"""Returns the union of parameter names of all inspected functions in :obj:`funcs`. Args: funcs (list[str or callable]): The functions. exclude (list[str], optional): The parameter names to exclude. (default: :obj:`None`) """ return list(self.get_flat_param_dict(funcs, exclude).keys()) def collect_param_data( self, func: Union[Callable, str], kwargs: Dict[str, Any], ) -> Dict[str, Any]: r"""Collects the input data of the inspected function :obj:`func` according to its function signature from a data blob. Args: func (callable or str): The function. kwargs (dict[str, Any]): The data blob which may serve as inputs. """ out_dict: Dict[str, Any] = {} for param in self.get_params(func): if param.name not in kwargs: if param.default is inspect._empty: raise TypeError(f"Parameter '{param.name}' is required") out_dict[param.name] = param.default else: out_dict[param.name] = kwargs[param.name] return out_dict # Inspecting Method Bodies ################################################ def get_source(self, cls: Optional[Type] = None) -> str: r"""Returns the source code of :obj:`cls`.""" from torch_geometric.nn import MessagePassing cls = cls or self._cls if cls.__name__ in self._source_dict: return self._source_dict[cls.__name__] if cls in {object, torch.nn.Module, MessagePassing}: return '' source = inspect.getsource(cls) self._source_dict[cls.__name__] = source return source def get_params_from_method_call( self, func: Union[Callable, str], exclude: Optional[List[Union[int, str]]] = None, ) -> Dict[str, Parameter]: r"""Parses a method call of :obj:`func` and returns its keyword arguments. .. note:: The method is required to be called via keyword arguments in case type annotations are not found. Args: func (callable or str): The function. exclude (list[int or str]): A list of parameters to exclude, either given by their name or index. (default: :obj:`None`) """ func_name = func if isinstance(func, str) else func.__name__ param_dict: Dict[str, Parameter] = {} # Three ways to specify the parameters of an unknown function header: # 1. Defined as class attributes in `{func_name}_type`. # 2. Defined via type annotations in `# {func_name}_type: (...)`. # 3. Defined via parsing of the function call. # (1) Find class attribute: if hasattr(self._cls, f'{func_name}_type'): type_dict = getattr(self._cls, f'{func_name}_type') if not isinstance(type_dict, dict): raise ValueError(f"'{func_name}_type' is expected to be a " f"dictionary (got '{type(type_dict)}')") for name, param_type in type_dict.items(): param_dict[name] = Parameter( name=name, type=self.eval_type(param_type), type_repr=self.type_repr(param_type), default=inspect._empty, ) return param_dict # (2) Find type annotation: for cls in self._cls.__mro__: source = self.get_source(cls) match = find_parenthesis_content(source, f'{func_name}_type:') if match is not None: for arg in split(match, sep=','): name_and_type_repr = re.split(r'\s*:\s*', arg) if len(name_and_type_repr) != 2: raise ValueError(f"Could not parse argument '{arg}' " f"of '{func_name}_type' annotation") name, type_repr = name_and_type_repr param_dict[name] = Parameter( name=name, type=self.eval_type(type_repr), type_repr=type_repr, default=inspect._empty, ) return param_dict # (3) Parse the function call: for cls in self._cls.__mro__: source = self.get_source(cls) source = remove_comments(source) match = find_parenthesis_content(source, f'self.{func_name}') if match is not None: for i, kwarg in enumerate(split(match, sep=',')): if ('=' not in kwarg and exclude is not None and i in exclude): continue name_and_content = re.split(r'\s*=\s*', kwarg) if len(name_and_content) != 2: raise ValueError(f"Could not parse keyword argument " f"'{kwarg}' in 'self.{func_name}()'") name, _ = name_and_content if exclude is not None and name in exclude: continue param_dict[name] = Parameter( name=name, type=Tensor, type_repr=self.type_repr(Tensor), default=inspect._empty, ) return param_dict return {} # (4) No function call found: def eval_type(value: Any, _globals: Dict[str, Any]) -> Type: r"""Returns the type hint of a string.""" if isinstance(value, str): value = typing.ForwardRef(value) return typing._eval_type(value, _globals, None) # type: ignore def type_repr(obj: Any, _globals: Dict[str, Any]) -> str: r"""Returns the type hint representation of an object.""" def _get_name(name: str, module: str) -> str: return name if name in _globals else f'{module}.{name}' if isinstance(obj, str): return obj if obj is type(None): return 'None' if obj is ...: return '...' if obj.__module__ == 'typing': # Special logic for `typing.*` types: if not hasattr(obj, '_name'): return repr(obj) name = obj._name if name is None: # In some cases, `_name` is not populated. name = str(obj.__origin__).split('.')[-1] args = getattr(obj, '__args__', None) if args is None or len(args) == 0: return _get_name(name, obj.__module__) if all(isinstance(arg, typing.TypeVar) for arg in args): return _get_name(name, obj.__module__) # Convert `Union[*, None]` to `Optional[*]`. # This is only necessary for old Python versions, e.g. 3.8. # TODO Only convert to `Optional` if `Optional` is importable. if (name == 'Union' and len(args) == 2 and any([arg is type(None) for arg in args])): name = 'Optional' if name == 'Optional': # Remove `None` from `Optional` arguments: args = [arg for arg in obj.__args__ if arg is not type(None)] args_repr = ', '.join([type_repr(arg, _globals) for arg in args]) return f'{_get_name(name, obj.__module__)}[{args_repr}]' if obj.__module__ == 'builtins': return obj.__qualname__ return _get_name(obj.__qualname__, obj.__module__) def find_parenthesis_content(source: str, prefix: str) -> Optional[str]: r"""Returns the content of :obj:`{prefix}.*(...)` within :obj:`source`.""" match = re.search(prefix, source) if match is None: return None offset = source[match.start():].find('(') if offset < 0: return None source = source[match.start() + offset:] depth = 0 for end, char in enumerate(source): if char == '(': depth += 1 if char == ')': depth -= 1 if depth == 0: content = source[1:end] # Properly handle line breaks and multiple white-spaces: content = content.replace('\n', ' ') content = content.replace('#', ' ') content = re.sub(' +', ' ', content) content = content.strip() return content return None def split(content: str, sep: str) -> List[str]: r"""Splits :obj:`content` based on :obj:`sep`. :obj:`sep` inside parentheses or square brackets are ignored. """ assert len(sep) == 1 outs: List[str] = [] start = depth = 0 for end, char in enumerate(content): if char == '[' or char == '(': depth += 1 elif char == ']' or char == ')': depth -= 1 elif char == sep and depth == 0: outs.append(content[start:end].strip()) start = end + 1 if start != len(content): # Respect dangling `sep`: outs.append(content[start:].strip()) return outs def remove_comments(content: str) -> str: content = re.sub(r'\s*#.*', '', content) content = re.sub(re.compile(r'r"""(.*?)"""', re.DOTALL), '', content) content = re.sub(re.compile(r'"""(.*?)"""', re.DOTALL), '', content) content = re.sub(re.compile(r"r'''(.*?)'''", re.DOTALL), '', content) content = re.sub(re.compile(r"'''(.*?)'''", re.DOTALL), '', content) return content ================================================ FILE: torch_geometric/io/__init__.py ================================================ from .txt_array import parse_txt_array, read_txt_array from .tu import read_tu_data from .planetoid import read_planetoid_data from .ply import read_ply from .obj import read_obj from .sdf import read_sdf, parse_sdf from .off import read_off, write_off from .npz import read_npz, parse_npz __all__ = [ 'read_off', 'write_off', 'parse_txt_array', 'read_txt_array', 'read_tu_data', 'read_planetoid_data', 'read_ply', 'read_obj', 'read_sdf', 'parse_sdf', 'read_npz', 'parse_npz', ] ================================================ FILE: torch_geometric/io/fs.py ================================================ import io import os import os.path as osp import pickle import re import sys import warnings from typing import Any, Dict, List, Literal, Optional, Union, overload from uuid import uuid4 import fsspec import torch import torch_geometric DEFAULT_CACHE_PATH = '/tmp/pyg_simplecache' def get_fs(path: str) -> fsspec.AbstractFileSystem: r"""Get filesystem backend given a path URI to the resource. Here are some common example paths and dispatch result: * :obj:`"/home/file"` -> :class:`fsspec.implementations.local.LocalFileSystem` * :obj:`"memory://home/file"` -> :class:`fsspec.implementations.memory.MemoryFileSystem` * :obj:`"https://home/file"` -> :class:`fsspec.implementations.http.HTTPFileSystem` * :obj:`"gs://home/file"` -> :class:`gcsfs.GCSFileSystem` * :obj:`"s3://home/file"` -> :class:`s3fs.S3FileSystem` A full list of supported backend implementations of :class:`fsspec` can be found `here `_. The backend dispatch logic can be updated with custom backends following `this tutorial `_. Args: path (str): The URI to the filesystem location, *e.g.*, :obj:`"gs://home/me/file"`, :obj:`"s3://..."`. """ return fsspec.core.url_to_fs(path)[0] def normpath(path: str) -> str: if isdisk(path): return osp.normpath(path) return path def exists(path: str) -> bool: return get_fs(path).exists(path) def makedirs(path: str, exist_ok: bool = True) -> None: return get_fs(path).makedirs(path, exist_ok) def isdir(path: str) -> bool: return get_fs(path).isdir(path) def isfile(path: str) -> bool: return get_fs(path).isfile(path) def isdisk(path: str) -> bool: return 'file' in get_fs(path).protocol def islocal(path: str) -> bool: return isdisk(path) or 'memory' in get_fs(path).protocol @overload def ls(path: str, detail: Literal[False] = False) -> List[str]: pass @overload def ls(path: str, detail: Literal[True]) -> List[Dict[str, Any]]: pass def ls( path: str, detail: bool = False, ) -> Union[List[str], List[Dict[str, Any]]]: fs = get_fs(path) outputs = fs.ls(path, detail=detail) if not isdisk(path): if detail: for output in outputs: output['name'] = fs.unstrip_protocol(output['name']) else: outputs = [fs.unstrip_protocol(output) for output in outputs] return outputs def cp( path1: str, path2: str, extract: bool = False, log: bool = True, use_cache: bool = True, clear_cache: bool = True, ) -> None: kwargs: Dict[str, Any] = {} is_path1_dir = isdir(path1) is_path2_dir = isdir(path2) # Cache result if the protocol is not local: cache_dir: Optional[str] = None if not islocal(path1): if log and 'PYTEST_CURRENT_TEST' not in os.environ: print(f'Downloading {path1}', file=sys.stderr) if extract and use_cache: # Cache seems to confuse the gcs filesystem. home_dir = torch_geometric.get_home_dir() cache_dir = osp.join(home_dir, 'simplecache', uuid4().hex) kwargs.setdefault('simplecache', dict(cache_storage=cache_dir)) path1 = f'simplecache::{path1}' # Handle automatic extraction: multiple_files = False if extract and path1.endswith('.tar.gz'): kwargs.setdefault('tar', dict(compression='gzip')) path1 = f'tar://**::{path1}' multiple_files = True elif extract and path1.endswith('.zip'): path1 = f'zip://**::{path1}' multiple_files = True elif extract and path1.endswith('.gz'): kwargs.setdefault('compression', 'infer') elif extract: raise NotImplementedError( f"Automatic extraction of '{path1}' not yet supported") # If the source path points to a directory, we need to make sure to # recursively copy all files within this directory. Additionally, if the # destination folder does not yet exist, we inherit the basename from the # source folder. if is_path1_dir: if exists(path2): path2 = osp.join(path2, osp.basename(path1)) path1 = osp.join(path1, '**') multiple_files = True # Perform the copy: for open_file in fsspec.open_files(path1, **kwargs): with open_file as f_from: if not multiple_files: if is_path2_dir: basename = osp.basename(path1) if extract and path1.endswith('.gz'): basename = '.'.join(basename.split('.')[:-1]) to_path = osp.join(path2, basename) else: to_path = path2 else: # Open file has protocol stripped. common_path = osp.commonprefix( [fsspec.core.strip_protocol(path1), open_file.path]) to_path = osp.join(path2, open_file.path[len(common_path):]) with fsspec.open(to_path, 'wb') as f_to: while True: chunk = f_from.read(10 * 1024 * 1024) if not chunk: break f_to.write(chunk) if use_cache and clear_cache and cache_dir is not None: try: rm(cache_dir) except Exception: # FIXME # Windows test yield "PermissionError: The process cannot access # the file because it is being used by another process". # Users may also observe "OSError: Directory not empty". # This is a quick workaround until we figure out the deeper issue. pass def rm(path: str, recursive: bool = True) -> None: get_fs(path).rm(path, recursive) def mv(path1: str, path2: str) -> None: fs1 = get_fs(path1) fs2 = get_fs(path2) assert fs1.protocol == fs2.protocol fs1.mv(path1, path2) def glob(path: str) -> List[str]: fs = get_fs(path) paths = fs.glob(path) if not isdisk(path): paths = [fs.unstrip_protocol(path) for path in paths] return paths def torch_save(data: Any, path: str) -> None: buffer = io.BytesIO() torch.save(data, buffer) with fsspec.open(path, 'wb') as f: f.write(buffer.getvalue()) def torch_load(path: str, map_location: Any = None) -> Any: if torch_geometric.typing.WITH_PT24: try: with fsspec.open(path, 'rb') as f: return torch.load(f, map_location, weights_only=True) except pickle.UnpicklingError as e: error_msg = str(e) if "add_safe_globals" in error_msg: warn_msg = ("Weights only load failed. Please file an issue " "to make `torch.load(weights_only=True)` " "compatible in your case.") match = re.search(r'add_safe_globals\(.*?\)', error_msg) if match is not None: warnings.warn( f"{warn_msg} Please use " f"`torch.serialization.{match.group()}` to " f"allowlist this global.", stacklevel=2) else: warnings.warn(warn_msg, stacklevel=2) with fsspec.open(path, 'rb') as f: return torch.load(f, map_location, weights_only=False) else: raise e with fsspec.open(path, 'rb') as f: return torch.load(f, map_location) ================================================ FILE: torch_geometric/io/npz.py ================================================ from typing import Any, Dict import numpy as np import torch from torch_geometric.data import Data from torch_geometric.utils import remove_self_loops from torch_geometric.utils import to_undirected as to_undirected_fn def read_npz(path: str, to_undirected: bool = True) -> Data: with np.load(path) as f: return parse_npz(f, to_undirected=to_undirected) def parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Data: import scipy.sparse as sp x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']), f['attr_shape']).todense() x = torch.from_numpy(x).to(torch.float) x[x > 0] = 1 adj = sp.csr_matrix((f['adj_data'], f['adj_indices'], f['adj_indptr']), f['adj_shape']).tocoo() row = torch.from_numpy(adj.row).to(torch.long) col = torch.from_numpy(adj.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) edge_index, _ = remove_self_loops(edge_index) if to_undirected: edge_index = to_undirected_fn(edge_index, num_nodes=x.size(0)) y = torch.from_numpy(f['labels']).to(torch.long) return Data(x=x, edge_index=edge_index, y=y) ================================================ FILE: torch_geometric/io/obj.py ================================================ from typing import Iterator, List, Optional, Tuple, Union import torch from torch_geometric.data import Data def yield_file(in_file: str) -> Iterator[Tuple[str, List[Union[int, float]]]]: f = open(in_file) buf = f.read() f.close() for b in buf.split('\n'): if b.startswith('v '): yield 'v', [float(x) for x in b.split(" ")[1:]] elif b.startswith('f '): triangles = b.split(' ')[1:] # -1 as .obj is base 1 but the Data class expects base 0 indices yield 'f', [int(t.split("/")[0]) - 1 for t in triangles] else: yield '', [] def read_obj(in_file: str) -> Optional[Data]: vertices = [] faces = [] for k, v in yield_file(in_file): if k == 'v': vertices.append(v) elif k == 'f': faces.append(v) if not len(faces) or not len(vertices): return None pos = torch.tensor(vertices, dtype=torch.float) face = torch.tensor(faces, dtype=torch.long).t().contiguous() data = Data(pos=pos, face=face) return data ================================================ FILE: torch_geometric/io/off.py ================================================ import re from typing import List import torch from torch import Tensor from torch._tensor_str import PRINT_OPTS, _tensor_str from torch_geometric.data import Data from torch_geometric.io import parse_txt_array def parse_off(src: List[str]) -> Data: # Some files may contain a bug and do not have a carriage return after OFF. if src[0] == 'OFF': src = src[1:] else: src[0] = src[0][3:] num_nodes, num_faces = (int(item) for item in src[0].split()[:2]) pos = parse_txt_array(src[1:1 + num_nodes]) face = face_to_tri(src[1 + num_nodes:1 + num_nodes + num_faces]) data = Data(pos=pos) data.face = face return data def face_to_tri(face: List[str]) -> Tensor: face_index = [[int(x) for x in line.strip().split()] for line in face] triangle = torch.tensor([line[1:] for line in face_index if line[0] == 3]) triangle = triangle.to(torch.int64) rect = torch.tensor([line[1:] for line in face_index if line[0] == 4]) rect = rect.to(torch.int64) if rect.numel() > 0: first, second = rect[:, [0, 1, 2]], rect[:, [0, 2, 3]] return torch.cat([triangle, first, second], dim=0).t().contiguous() return triangle.t().contiguous() def read_off(path: str) -> Data: r"""Reads an OFF (Object File Format) file, returning both the position of nodes and their connectivity in a :class:`torch_geometric.data.Data` object. Args: path (str): The path to the file. """ with open(path) as f: src = f.read().split('\n')[:-1] return parse_off(src) def write_off(data: Data, path: str) -> None: r"""Writes a :class:`torch_geometric.data.Data` object to an OFF (Object File Format) file. Args: data (:class:`torch_geometric.data.Data`): The data object. path (str): The path to the file. """ assert data.pos is not None assert data.face is not None num_nodes, num_faces = data.pos.size(0), data.face.size(1) pos = data.pos.to(torch.float) face = data.face.t() num_vertices = torch.full((num_faces, 1), face.size(1), dtype=torch.long) face = torch.cat([num_vertices, face], dim=-1) threshold = PRINT_OPTS.threshold torch.set_printoptions(threshold=float('inf')) pos_repr = re.sub(',', '', _tensor_str(pos, indent=0)) pos_repr = '\n'.join([x[2:-1] for x in pos_repr.split('\n')])[:-1] face_repr = re.sub(',', '', _tensor_str(face, indent=0)) face_repr = '\n'.join([x[2:-1] for x in face_repr.split('\n')])[:-1] with open(path, 'w') as f: f.write(f'OFF\n{num_nodes} {num_faces} 0\n') f.write(pos_repr) f.write('\n') f.write(face_repr) f.write('\n') torch.set_printoptions(threshold=threshold) ================================================ FILE: torch_geometric/io/planetoid.py ================================================ import os.path as osp import warnings from itertools import repeat from typing import Dict, List, Optional import fsspec import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.io import read_txt_array from torch_geometric.utils import ( coalesce, index_to_mask, remove_self_loops, to_torch_csr_tensor, ) try: import cPickle as pickle except ImportError: import pickle def read_planetoid_data(folder: str, prefix: str) -> Data: names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] items = [read_file(folder, prefix, name) for name in names] x, tx, allx, y, ty, ally, graph, test_index = items train_index = torch.arange(y.size(0), dtype=torch.long) val_index = torch.arange(y.size(0), y.size(0) + 500, dtype=torch.long) sorted_test_index = test_index.sort()[0] if prefix.lower() == 'citeseer': # There are some isolated nodes in the Citeseer graph, resulting in # none consecutive test indices. We need to identify them and add them # as zero vectors to `tx` and `ty`. len_test_indices = int(test_index.max() - test_index.min()) + 1 tx_ext = torch.zeros(len_test_indices, tx.size(1), dtype=tx.dtype) tx_ext[sorted_test_index - test_index.min(), :] = tx ty_ext = torch.zeros(len_test_indices, ty.size(1), dtype=ty.dtype) ty_ext[sorted_test_index - test_index.min(), :] = ty tx, ty = tx_ext, ty_ext if prefix.lower() == 'nell.0.001': tx_ext = torch.zeros(len(graph) - allx.size(0), x.size(1)) tx_ext[sorted_test_index - allx.size(0)] = tx ty_ext = torch.zeros(len(graph) - ally.size(0), y.size(1)) ty_ext[sorted_test_index - ally.size(0)] = ty tx, ty = tx_ext, ty_ext x = torch.cat([allx, tx], dim=0) x[test_index] = x[sorted_test_index] # Creating feature vectors for relations. row, col = x.nonzero(as_tuple=True) value = x[row, col] mask = ~index_to_mask(test_index, size=len(graph)) mask[:allx.size(0)] = False isolated_idx = mask.nonzero().view(-1) row = torch.cat([row, isolated_idx]) col = torch.cat([col, torch.arange(isolated_idx.size(0)) + x.size(1)]) value = torch.cat([value, value.new_ones(isolated_idx.size(0))]) x = to_torch_csr_tensor( edge_index=torch.stack([row, col], dim=0), edge_attr=value, size=(x.size(0), isolated_idx.size(0) + x.size(1)), ) else: x = torch.cat([allx, tx], dim=0) x[test_index] = x[sorted_test_index] y = torch.cat([ally, ty], dim=0).max(dim=1)[1] y[test_index] = y[sorted_test_index] train_mask = index_to_mask(train_index, size=y.size(0)) val_mask = index_to_mask(val_index, size=y.size(0)) test_mask = index_to_mask(test_index, size=y.size(0)) edge_index = edge_index_from_dict( graph_dict=graph, # type: ignore num_nodes=y.size(0), ) data = Data(x=x, edge_index=edge_index, y=y) data.train_mask = train_mask data.val_mask = val_mask data.test_mask = test_mask return data def read_file(folder: str, prefix: str, name: str) -> Tensor: path = osp.join(folder, f'ind.{prefix.lower()}.{name}') if name == 'test.index': return read_txt_array(path, dtype=torch.long) with fsspec.open(path, 'rb') as f: warnings.filterwarnings('ignore', '.*`scipy.sparse.csr` name.*') out = pickle.load(f, encoding='latin1') if name == 'graph': return out out = out.todense() if hasattr(out, 'todense') else out out = torch.from_numpy(out).to(torch.float) return out def edge_index_from_dict( graph_dict: Dict[int, List[int]], num_nodes: Optional[int] = None, ) -> Tensor: rows: List[int] = [] cols: List[int] = [] for key, value in graph_dict.items(): rows += repeat(key, len(value)) cols += value row = torch.tensor(rows) col = torch.tensor(cols) edge_index = torch.stack([row, col], dim=0) # `torch.compile` is not yet ready for `EdgeIndex` :( # from torch_geometric import EdgeIndex # edge_index: Union[EdgeIndex, Tensor] = EdgeIndex( # torch.stack([row, col], dim=0), # is_undirected=True, # sparse_size=(num_nodes, num_nodes), # ) # NOTE: There are some duplicated edges and self loops in the datasets. # Other implementations do not remove them! edge_index, _ = remove_self_loops(edge_index) edge_index = coalesce(edge_index, num_nodes=num_nodes, sort_by_row=False) return edge_index ================================================ FILE: torch_geometric/io/ply.py ================================================ import torch from torch_geometric.data import Data try: import openmesh except ImportError: openmesh = None def read_ply(path: str) -> Data: if openmesh is None: raise ImportError('`read_ply` requires the `openmesh` package.') mesh = openmesh.read_trimesh(path) pos = torch.from_numpy(mesh.points()).to(torch.float) face = torch.from_numpy(mesh.face_vertex_indices()) face = face.t().to(torch.long).contiguous() return Data(pos=pos, face=face) ================================================ FILE: torch_geometric/io/sdf.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.io import parse_txt_array from torch_geometric.utils import coalesce, one_hot elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} def parse_sdf(src: str) -> Data: lines = src.split('\n')[3:] num_atoms, num_bonds = (int(item) for item in lines[0].split()[:2]) atom_block = lines[1:num_atoms + 1] pos = parse_txt_array(atom_block, end=3) x = torch.tensor([elems[item.split()[3]] for item in atom_block]) x = one_hot(x, num_classes=len(elems)) bond_block = lines[1 + num_atoms:1 + num_atoms + num_bonds] row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1 row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) edge_index = torch.stack([row, col], dim=0) edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 edge_attr = torch.cat([edge_attr, edge_attr], dim=0) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) def read_sdf(path: str) -> Data: with open(path) as f: return parse_sdf(f.read()) ================================================ FILE: torch_geometric/io/tu.py ================================================ import os.path as osp from typing import Dict, List, Optional, Tuple import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.io import fs, read_txt_array from torch_geometric.utils import coalesce, cumsum, one_hot, remove_self_loops names = [ 'A', 'graph_indicator', 'node_labels', 'node_attributes' 'edge_labels', 'edge_attributes', 'graph_labels', 'graph_attributes' ] def read_tu_data( folder: str, prefix: str, ) -> Tuple[Data, Dict[str, Tensor], Dict[str, int]]: files = fs.glob(osp.join(folder, f'{prefix}_*.txt')) names = [osp.basename(f)[len(prefix) + 1:-4] for f in files] edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1 batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1 node_attribute = torch.empty((batch.size(0), 0)) if 'node_attributes' in names: node_attribute = read_file(folder, prefix, 'node_attributes') if node_attribute.dim() == 1: node_attribute = node_attribute.unsqueeze(-1) node_label = torch.empty((batch.size(0), 0)) if 'node_labels' in names: node_label = read_file(folder, prefix, 'node_labels', torch.long) if node_label.dim() == 1: node_label = node_label.unsqueeze(-1) node_label = node_label - node_label.min(dim=0)[0] node_labels = list(node_label.unbind(dim=-1)) node_labels = [one_hot(x) for x in node_labels] if len(node_labels) == 1: node_label = node_labels[0] else: node_label = torch.cat(node_labels, dim=-1) edge_attribute = torch.empty((edge_index.size(1), 0)) if 'edge_attributes' in names: edge_attribute = read_file(folder, prefix, 'edge_attributes') if edge_attribute.dim() == 1: edge_attribute = edge_attribute.unsqueeze(-1) edge_label = torch.empty((edge_index.size(1), 0)) if 'edge_labels' in names: edge_label = read_file(folder, prefix, 'edge_labels', torch.long) if edge_label.dim() == 1: edge_label = edge_label.unsqueeze(-1) edge_label = edge_label - edge_label.min(dim=0)[0] edge_labels = list(edge_label.unbind(dim=-1)) edge_labels = [one_hot(e) for e in edge_labels] if len(edge_labels) == 1: edge_label = edge_labels[0] else: edge_label = torch.cat(edge_labels, dim=-1) x = cat([node_attribute, node_label]) edge_attr = cat([edge_attribute, edge_label]) y = None if 'graph_attributes' in names: # Regression problem. y = read_file(folder, prefix, 'graph_attributes') elif 'graph_labels' in names: # Classification problem. y = read_file(folder, prefix, 'graph_labels', torch.long) _, y = y.unique(sorted=True, return_inverse=True) num_nodes = int(edge_index.max()) + 1 if x is None else x.size(0) edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) data, slices = split(data, batch) sizes = { 'num_node_attributes': node_attribute.size(-1), 'num_node_labels': node_label.size(-1), 'num_edge_attributes': edge_attribute.size(-1), 'num_edge_labels': edge_label.size(-1), } return data, slices, sizes def read_file( folder: str, prefix: str, name: str, dtype: Optional[torch.dtype] = None, ) -> Tensor: path = osp.join(folder, f'{prefix}_{name}.txt') return read_txt_array(path, sep=',', dtype=dtype) def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]: values = [v for v in seq if v is not None] values = [v for v in values if v.numel() > 0] values = [v.unsqueeze(-1) if v.dim() == 1 else v for v in values] return torch.cat(values, dim=-1) if len(values) > 0 else None def split(data: Data, batch: Tensor) -> Tuple[Data, Dict[str, Tensor]]: node_slice = cumsum(torch.bincount(batch)) assert data.edge_index is not None row, _ = data.edge_index edge_slice = cumsum(torch.bincount(batch[row])) # Edge indices should start at zero for every graph. data.edge_index -= node_slice[batch[row]].unsqueeze(0) slices = {'edge_index': edge_slice} if data.x is not None: slices['x'] = node_slice else: # Imitate `collate` functionality: data._num_nodes = torch.bincount(batch).tolist() data.num_nodes = batch.numel() if data.edge_attr is not None: slices['edge_attr'] = edge_slice if data.y is not None: assert isinstance(data.y, Tensor) if data.y.size(0) == batch.size(0): slices['y'] = node_slice else: slices['y'] = torch.arange(0, int(batch[-1]) + 2, dtype=torch.long) return data, slices ================================================ FILE: torch_geometric/io/txt_array.py ================================================ from typing import List, Optional import fsspec import torch from torch import Tensor def parse_txt_array( src: List[str], sep: Optional[str] = None, start: int = 0, end: Optional[int] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tensor: empty = torch.empty(0, dtype=dtype) to_number = float if empty.is_floating_point() else int return torch.tensor([[to_number(x) for x in line.split(sep)[start:end]] for line in src], dtype=dtype).squeeze() def read_txt_array( path: str, sep: Optional[str] = None, start: int = 0, end: Optional[int] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tensor: with fsspec.open(path, 'r') as f: src = f.read().split('\n')[:-1] return parse_txt_array(src, sep, start, end, dtype, device) ================================================ FILE: torch_geometric/isinstance.py ================================================ from typing import Any, Tuple, Type, Union import torch import torch_geometric.typing if torch_geometric.typing.WITH_PT20: import torch._dynamo def is_torch_instance(obj: Any, cls: Union[Type, Tuple[Type]]) -> bool: r"""Checks if the :obj:`obj` is an instance of a :obj:`cls`. This function extends :meth:`isinstance` to be applicable during :meth:`torch.compile` usage by checking against the original class of compiled models. """ # `torch.compile` removes the model inheritance and converts the model to # a `torch._dynamo.OptimizedModule` instance, leading to `isinstance` being # unable to check the model's inheritance. This function unwraps the # compiled model before evaluating via `isinstance`. if (torch_geometric.typing.WITH_PT20 and isinstance(obj, torch._dynamo.OptimizedModule)): return isinstance(obj._orig_mod, cls) return isinstance(obj, cls) ================================================ FILE: torch_geometric/lazy_loader.py ================================================ from importlib import import_module from types import ModuleType from typing import Any, Dict, List # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/ # python/util/lazy_loader.py class LazyLoader(ModuleType): def __init__( self, local_name: str, parent_module_globals: Dict[str, Any], name: str, ) -> None: self._local_name = local_name self._parent_module_globals = parent_module_globals super().__init__(name) def _load(self) -> Any: module = import_module(self.__name__) self._parent_module_globals[self._local_name] = module self.__dict__.update(module.__dict__) return module def __getattr__(self, item: str) -> Any: module = self._load() return getattr(module, item) def __dir__(self) -> List[str]: module = self._load() return dir(module) ================================================ FILE: torch_geometric/llm/__init__.py ================================================ from .large_graph_indexer import LargeGraphIndexer from .rag_loader import RAGQueryLoader from .utils import * # noqa from .models import * # noqa __all__ = classes = [ 'LargeGraphIndexer', 'RAGQueryLoader', ] ================================================ FILE: torch_geometric/llm/large_graph_indexer.py ================================================ import os import pickle as pkl import shutil from dataclasses import dataclass from itertools import chain, islice, tee from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union, ) import torch from torch import Tensor from tqdm import tqdm from torch_geometric.data import Data from torch_geometric.io import fs from torch_geometric.typing import WITH_PT24 # Could be any hashable type TripletLike = Tuple[str, str, str] KnowledgeGraphLike = Iterable[TripletLike] def ordered_set(values: Iterable[str]) -> List[str]: return list(dict.fromkeys(values)) # TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum? NODE_PID = "pid" # Encodes node id NODE_KEYS = {NODE_PID} EDGE_PID = "e_pid" # Encodes source node, relation, destination node EDGE_HEAD = "h" # Encodes source node EDGE_RELATION = "r" # Encodes relation EDGE_TAIL = "t" # Encodes destination node EDGE_INDEX = "edge_idx" # Encodes source node, destination node EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX} FeatureValueType = Union[Sequence[Any], Tensor] @dataclass class MappedFeature: name: str values: FeatureValueType def __eq__(self, value: "MappedFeature") -> bool: eq = self.name == value.name if isinstance(self.values, torch.Tensor): eq &= torch.equal(self.values, value.values) else: eq &= self.values == value.values return eq if WITH_PT24: torch.serialization.add_safe_globals([MappedFeature]) class LargeGraphIndexer: """For a dataset that consists of multiple subgraphs that are assumed to be part of a much larger graph, collate the values into a large graph store to save resources. """ def __init__( self, nodes: Iterable[str], edges: KnowledgeGraphLike, node_attr: Optional[Dict[str, List[Any]]] = None, edge_attr: Optional[Dict[str, List[Any]]] = None, ) -> None: r"""Constructs a new index that uniquely catalogs each node and edge by id. Not meant to be used directly. Args: nodes (Iterable[str]): Node ids in the graph. edges (KnowledgeGraphLike): Edge ids in the graph. Example: [("cats", "eat", "dogs")] node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node attribute name and list of their values in order of unique node ids. Defaults to None. edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge attribute name and list of their values in order of unique edge ids. Defaults to None. """ self._nodes: Dict[str, int] = dict() self._edges: Dict[TripletLike, int] = dict() self._mapped_node_features: Set[str] = set() self._mapped_edge_features: Set[str] = set() if len(nodes) != len(set(nodes)): raise AttributeError("Nodes need to be unique") if len(edges) != len(set(edges)): raise AttributeError("Edges need to be unique") if node_attr is not None: # TODO: Validity checks btw nodes and node_attr self.node_attr = node_attr if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS: raise AttributeError( "Invalid node_attr object. Missing " + f"{NODE_KEYS - set(self.node_attr.keys())}") elif self.node_attr[NODE_PID] != nodes: raise AttributeError( "Nodes provided do not match those in node_attr") else: self.node_attr = dict() self.node_attr[NODE_PID] = nodes for i, node in enumerate(self.node_attr[NODE_PID]): self._nodes[node] = i if edge_attr is not None: # TODO: Validity checks btw edges and edge_attr self.edge_attr = edge_attr if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS: raise AttributeError( "Invalid edge_attr object. Missing " + f"{EDGE_KEYS - set(self.edge_attr.keys())}") elif self.node_attr[EDGE_PID] != edges: raise AttributeError( "Edges provided do not match those in edge_attr") else: self.edge_attr = dict() for default_key in EDGE_KEYS: self.edge_attr[default_key] = list() self.edge_attr[EDGE_PID] = edges for tup in edges: h, r, t = tup self.edge_attr[EDGE_HEAD].append(h) self.edge_attr[EDGE_RELATION].append(r) self.edge_attr[EDGE_TAIL].append(t) self.edge_attr[EDGE_INDEX].append( (self._nodes[h], self._nodes[t])) for i, tup in enumerate(edges): self._edges[tup] = i @classmethod def from_triplets( cls, triplets: KnowledgeGraphLike, pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, ) -> "LargeGraphIndexer": r"""Generate a new index from a series of triplets that represent edge relations between nodes. Formatted like (source_node, edge, dest_node). Args: triplets (KnowledgeGraphLike): Series of triplets representing knowledge graph relations. Example: [("cats", "eat", dogs")]. Note: Please ensure triplets are unique. pre_transform (Optional[Callable[[TripletLike], TripletLike]]): Optional preprocessing function to apply to triplets. Defaults to None. Returns: LargeGraphIndexer: Index of unique nodes and edges. """ # NOTE: Right now assumes that all trips can be loaded into memory nodes = [] edges = [] if pre_transform is not None: def apply_transform( trips: KnowledgeGraphLike) -> Iterator[TripletLike]: for trip in trips: yield pre_transform(trip) triplets = list(apply_transform(triplets)) for h, r, t in triplets: for node in (h, t): nodes.append(node) edge_idx = (h, r, t) edges.append(edge_idx) nodes = ordered_set(nodes) edges = ordered_set(edges) return cls(list(nodes), list(edges)) @classmethod def collate(cls, graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer": r"""Combines a series of large graph indexes into a single large graph index. Args: graphs (Iterable[LargeGraphIndexer]): Indices to be combined. Returns: LargeGraphIndexer: Singular unique index for all nodes and edges in input indices. """ # FIXME Needs to merge node attrs and edge attrs? trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) return cls.from_triplets(trips) def get_unique_node_features(self, feature_name: str = NODE_PID) -> List[str]: r"""Get all the unique values for a specific node attribute. Args: feature_name (str, optional): Name of feature to get. Defaults to NODE_PID. Returns: List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_node_features: raise IndexError( "Only non-mapped features can be retrieved uniquely.") return ordered_set(self.get_node_features(feature_name)) except KeyError as e: raise AttributeError( f"Nodes do not have a feature called {feature_name}") from e def add_node_feature( self, new_feature_name: str, new_feature_vals: FeatureValueType, map_from_feature: str = NODE_PID, ) -> None: r"""Adds a new feature that corresponds to each unique node in the graph. Args: new_feature_name (str): Name to call the new feature. new_feature_vals (FeatureValueType): Values to map for that new feature. map_from_feature (str, optional): Key of feature to map from. Size must match the number of feature values. Defaults to NODE_PID. """ if new_feature_name in self.node_attr: raise AttributeError("Features cannot be overridden once created") if map_from_feature in self._mapped_node_features: raise AttributeError( f"{map_from_feature} is already a feature mapping.") feature_keys = self.get_unique_node_features(map_from_feature) if len(feature_keys) != len(new_feature_vals): raise AttributeError( "Expected encodings for {len(feature_keys)} unique features," + f" but got {len(new_feature_vals)} encodings.") if map_from_feature == NODE_PID: self.node_attr[new_feature_name] = new_feature_vals else: self.node_attr[new_feature_name] = MappedFeature( name=map_from_feature, values=new_feature_vals) self._mapped_node_features.add(new_feature_name) def get_node_features( self, feature_name: str = NODE_PID, pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get node feature values for a given set of unique node ids. Returned values are not necessarily unique. Args: feature_name (str, optional): Name of feature to fetch. Defaults to NODE_PID. pids (Optional[Iterable[str]], optional): Node ids to fetch for. Defaults to None, which fetches all nodes. Returns: List[Any]: Node features corresponding to the specified ids. """ if feature_name in self._mapped_node_features: values = self.node_attr[feature_name].values else: values = self.node_attr[feature_name] # TODO: torch_geometric.utils.select if isinstance(values, torch.Tensor): idxs = list( self.get_node_features_iter(feature_name, pids, index_only=True)) return values[torch.tensor(idxs).long()] return list(self.get_node_features_iter(feature_name, pids)) def get_node_features_iter( self, feature_name: str = NODE_PID, pids: Optional[Iterable[str]] = None, index_only: bool = False, ) -> Iterator[Any]: """Iterator version of get_node_features. If index_only is True, yields indices instead of values. """ if pids is None: pids = self.node_attr[NODE_PID] if feature_name in self._mapped_node_features: feature_map_info = self.node_attr[feature_name] from_feature_name, to_feature_vals = ( feature_map_info.name, feature_map_info.values, ) from_feature_vals = self.get_unique_node_features( from_feature_name) feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} for pid in pids: idx = self._nodes[pid] from_feature_val = self.node_attr[from_feature_name][idx] to_feature_idx = feature_mapping[from_feature_val] if index_only: yield to_feature_idx else: yield to_feature_vals[to_feature_idx] else: for pid in pids: idx = self._nodes[pid] if index_only: yield idx else: yield self.node_attr[feature_name][idx] def get_unique_edge_features(self, feature_name: str = EDGE_PID) -> List[str]: r"""Get all the unique values for a specific edge attribute. Args: feature_name (str, optional): Name of feature to get. Defaults to EDGE_PID. Returns: List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_edge_features: raise IndexError( "Only non-mapped features can be retrieved uniquely.") return ordered_set(self.get_edge_features(feature_name)) except KeyError as e: raise AttributeError( f"Edges do not have a feature called {feature_name}") from e def add_edge_feature( self, new_feature_name: str, new_feature_vals: FeatureValueType, map_from_feature: str = EDGE_PID, ) -> None: r"""Adds a new feature that corresponds to each unique edge in the graph. Args: new_feature_name (str): Name to call the new feature. new_feature_vals (FeatureValueType): Values to map for that new feature. map_from_feature (str, optional): Key of feature to map from. Size must match the number of feature values. Defaults to EDGE_PID. """ if new_feature_name in self.edge_attr: raise AttributeError("Features cannot be overridden once created") if map_from_feature in self._mapped_edge_features: raise AttributeError( f"{map_from_feature} is already a feature mapping.") feature_keys = self.get_unique_edge_features(map_from_feature) if len(feature_keys) != len(new_feature_vals): raise AttributeError( f"Expected encodings for {len(feature_keys)} unique features, " + f"but got {len(new_feature_vals)} encodings.") if map_from_feature == EDGE_PID: self.edge_attr[new_feature_name] = new_feature_vals else: self.edge_attr[new_feature_name] = MappedFeature( name=map_from_feature, values=new_feature_vals) self._mapped_edge_features.add(new_feature_name) def get_edge_features( self, feature_name: str = EDGE_PID, pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get edge feature values for a given set of unique edge ids. Returned values are not necessarily unique. Args: feature_name (str, optional): Name of feature to fetch. Defaults to EDGE_PID. pids (Optional[Iterable[str]], optional): Edge ids to fetch for. Defaults to None, which fetches all edges. Returns: List[Any]: Node features corresponding to the specified ids. """ if feature_name in self._mapped_edge_features: values = self.edge_attr[feature_name].values else: values = self.edge_attr[feature_name] # TODO: torch_geometric.utils.select if isinstance(values, torch.Tensor): idxs = list( self.get_edge_features_iter(feature_name, pids, index_only=True)) return values[torch.tensor(idxs).long()] return list(self.get_edge_features_iter(feature_name, pids)) def get_edge_features_iter( self, feature_name: str = EDGE_PID, pids: Optional[KnowledgeGraphLike] = None, index_only: bool = False, ) -> Iterator[Any]: """Iterator version of get_edge_features. If index_only is True, yields indices instead of values. """ if pids is None: pids = self.edge_attr[EDGE_PID] if feature_name in self._mapped_edge_features: feature_map_info = self.edge_attr[feature_name] from_feature_name, to_feature_vals = ( feature_map_info.name, feature_map_info.values, ) from_feature_vals = self.get_unique_edge_features( from_feature_name) feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} for pid in pids: idx = self._edges[pid] from_feature_val = self.edge_attr[from_feature_name][idx] to_feature_idx = feature_mapping[from_feature_val] if index_only: yield to_feature_idx else: yield to_feature_vals[to_feature_idx] else: for pid in pids: idx = self._edges[pid] if index_only: yield idx else: yield self.edge_attr[feature_name][idx] def to_triplets(self) -> Iterator[TripletLike]: return iter(self.edge_attr[EDGE_PID]) def save(self, path: str) -> None: if os.path.exists(path): shutil.rmtree(path) os.makedirs(path, exist_ok=True) with open(path + "/edges", "wb") as f: pkl.dump(self._edges, f) with open(path + "/nodes", "wb") as f: pkl.dump(self._nodes, f) with open(path + "/mapped_edges", "wb") as f: pkl.dump(self._mapped_edge_features, f) with open(path + "/mapped_nodes", "wb") as f: pkl.dump(self._mapped_node_features, f) node_attr_path = path + "/node_attr" os.makedirs(node_attr_path, exist_ok=True) for attr_name, vals in self.node_attr.items(): torch.save(vals, node_attr_path + f"/{attr_name}.pt") edge_attr_path = path + "/edge_attr" os.makedirs(edge_attr_path, exist_ok=True) for attr_name, vals in self.edge_attr.items(): torch.save(vals, edge_attr_path + f"/{attr_name}.pt") @classmethod def from_disk(cls, path: str) -> "LargeGraphIndexer": indexer = cls(list(), list()) with open(path + "/edges", "rb") as f: indexer._edges = pkl.load(f) with open(path + "/nodes", "rb") as f: indexer._nodes = pkl.load(f) with open(path + "/mapped_edges", "rb") as f: indexer._mapped_edge_features = pkl.load(f) with open(path + "/mapped_nodes", "rb") as f: indexer._mapped_node_features = pkl.load(f) node_attr_path = path + "/node_attr" for fname in os.listdir(node_attr_path): full_fname = f"{node_attr_path}/{fname}" key = fname.split(".")[0] indexer.node_attr[key] = fs.torch_load(full_fname) edge_attr_path = path + "/edge_attr" for fname in os.listdir(edge_attr_path): full_fname = f"{edge_attr_path}/{fname}" key = fname.split(".")[0] indexer.edge_attr[key] = fs.torch_load(full_fname) return indexer def to_data(self, node_feature_name: str, edge_feature_name: Optional[str] = None) -> Data: """Return a Data object containing all the specified node and edge features and the graph. Args: node_feature_name (str): Feature to use for nodes edge_feature_name (Optional[str], optional): Feature to use for edges. Defaults to None. Returns: Data: Data object containing the specified node and edge features and the graph. """ x = torch.Tensor(self.get_node_features(node_feature_name)) node_id = torch.LongTensor(range(len(x))) edge_index = torch.t( torch.LongTensor(self.get_edge_features(EDGE_INDEX))) edge_attr = (self.get_edge_features(edge_feature_name) if edge_feature_name is not None else None) edge_id = torch.LongTensor(range(len(edge_attr))) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, edge_id=edge_id, node_id=node_id) def __eq__(self, value: "LargeGraphIndexer") -> bool: eq = True eq &= self._nodes == value._nodes eq &= self._edges == value._edges eq &= self.node_attr.keys() == value.node_attr.keys() eq &= self.edge_attr.keys() == value.edge_attr.keys() eq &= self._mapped_node_features == value._mapped_node_features eq &= self._mapped_edge_features == value._mapped_edge_features for k in self.node_attr: eq &= isinstance(self.node_attr[k], type(value.node_attr[k])) if isinstance(self.node_attr[k], torch.Tensor): eq &= torch.equal(self.node_attr[k], value.node_attr[k]) else: eq &= self.node_attr[k] == value.node_attr[k] for k in self.edge_attr: eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k])) if isinstance(self.edge_attr[k], torch.Tensor): eq &= torch.equal(self.edge_attr[k], value.edge_attr[k]) else: eq &= self.edge_attr[k] == value.edge_attr[k] return eq def get_features_for_triplets_groups( indexer: LargeGraphIndexer, triplet_groups: Iterable[KnowledgeGraphLike], node_feature_name: str = "x", edge_feature_name: str = "edge_attr", pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip, verbose: bool = False, max_batch_size: int = 250, num_workers: Optional[int] = None, ) -> Iterator[Data]: """Given an indexer and a series of triplet groups (like a dataset), retrieve the specified node and edge features for each triplet from the index. Args: indexer (LargeGraphIndexer): Indexer containing desired features triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of triplets to fetch features for node_feature_name (str, optional): Node feature to fetch. Defaults to "x". edge_feature_name (str, optional): edge feature to fetch. Defaults to "edge_attr". pre_transform (Callable[[TripletLike], TripletLike]): Optional preprocessing to perform on triplets. Defaults to None. verbose (bool, optional): Whether to print progress. Defaults to False. max_batch_size (int, optional): Maximum batch size for fetching features. Defaults to 250. num_workers (int, optional): Number of workers to use for fetching features. Defaults to None (all available). Yields: Iterator[Data]: For each triplet group, yield a data object containing the unique graph and features from the index. """ def apply_transform(trips: Iterable[TripletLike]) -> Iterator[TripletLike]: for trip in trips: yield pre_transform(tuple(trip)) # Carefully trying to avoid loading all triplets into memory at once # While also still tracking the number of elements for tqdm triplet_groups: List[Iterator[TripletLike]] = [ apply_transform(triplets) for triplets in triplet_groups ] node_keys = [] edge_keys = [] edge_index = [] """ For each KG, we gather the node_indices, edge_keys, and edge_indices needed to construct each Data object """ for kg_triplets in tqdm(triplet_groups, disable=not verbose): kg_triplets_nodes, kg_triplets_edge_keys, kg_triplets_edge_index = tee( kg_triplets, 3) """ Don't apply pre_transform here, because it has already been applied on the triplet groups/ """ small_graph_indexer = LargeGraphIndexer.from_triplets( kg_triplets_nodes) node_keys.append(small_graph_indexer.get_node_features()) edge_keys.append( small_graph_indexer.get_edge_features(pids=kg_triplets_edge_keys)) edge_index.append( small_graph_indexer.get_edge_features( EDGE_INDEX, kg_triplets_edge_index, )) """ We get the embeddings for each node and edge key in the KG, but we need to do so in batches. Batches that are too small waste compute time, as each call to get features has an upfront cost. Batches that are too large waste memory, as we need to store all the result embeddings in memory. """ def _fetch_feature_batch(batches): node_key_batch, edge_key_batch, edge_index_batch = batches node_feats = indexer.get_node_features( feature_name=node_feature_name, pids=chain.from_iterable(node_key_batch)) edge_feats = indexer.get_edge_features( feature_name=edge_feature_name, pids=chain.from_iterable(edge_key_batch)) last_node_idx, last_edge_idx = 0, 0 for (nkeys, ekeys, eidx) in zip(node_key_batch, edge_key_batch, edge_index_batch): nlen, elen = len(nkeys), len(ekeys) x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen]) last_node_idx += len(nkeys) edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx + elen]) last_edge_idx += len(ekeys) edge_idx = torch.LongTensor(eidx).T data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx) data_obj[NODE_PID] = node_keys data_obj[EDGE_PID] = edge_keys data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys] data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys] yield data_obj # NOTE: Backport of itertools.batched from Python 3.12 def batched(iterable, n, *, strict=False): # batched('ABCDEFG', 3) → ABC DEF G if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(islice(iterator, n)): if strict and len(batch) != n: raise ValueError('batched(): incomplete batch') yield batch import multiprocessing as mp import multiprocessing.pool as mpp num_workers = num_workers if num_workers is not None else mp.cpu_count() ideal_batch_size = min(max_batch_size, max(1, len(triplet_groups) // num_workers)) node_key_batches = batched(node_keys, ideal_batch_size) edge_key_batches = batched(edge_keys, ideal_batch_size) edge_index_batches = batched(edge_index, ideal_batch_size) batches = zip(node_key_batches, edge_key_batches, edge_index_batches) with mpp.ThreadPool() as pool: result = pool.map(_fetch_feature_batch, batches) yield from chain.from_iterable(result) def get_features_for_triplets( indexer: LargeGraphIndexer, triplets: KnowledgeGraphLike, node_feature_name: str = "x", edge_feature_name: str = "edge_attr", pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip, verbose: bool = False, ) -> Data: """For a given set of triplets retrieve a Data object containing the unique graph and features from the index. Args: indexer (LargeGraphIndexer): Indexer containing desired features triplets (KnowledgeGraphLike): Triplets to fetch features for node_feature_name (str, optional): Feature to use for node features. Defaults to "x". edge_feature_name (str, optional): Feature to use for edge features. Defaults to "edge_attr". pre_transform (Callable[[TripletLike], TripletLike]): Optional preprocessing function for triplets. Defaults to None. verbose (bool, optional): Whether to print progress. Defaults to False. Returns: Data: Data object containing the unique graph and features from the index for the given triplets. """ gen = get_features_for_triplets_groups(indexer, [triplets], node_feature_name, edge_feature_name, pre_transform, verbose, max_batch_size=1) return next(gen) ================================================ FILE: torch_geometric/llm/models/__init__.py ================================================ from .sentence_transformer import SentenceTransformer from .vision_transformer import VisionTransformer from .llm import LLM from .txt2kg import TXT2KG from .llm_judge import LLMJudge from .g_retriever import GRetriever from .molecule_gpt import MoleculeGPT from .glem import GLEM from .protein_mpnn import ProteinMPNN from .git_mol import GITMol __all__ = classes = [ 'SentenceTransformer', 'VisionTransformer', 'LLM', 'LLMJudge', 'TXT2KG', 'GRetriever', 'MoleculeGPT', 'GLEM', 'ProteinMPNN', 'GITMol', ] ================================================ FILE: torch_geometric/llm/models/g_retriever.py ================================================ from typing import List, Optional import torch from torch import Tensor from torch_geometric.llm.models.llm import LLM, MAX_NEW_TOKENS from torch_geometric.utils import scatter class GRetriever(torch.nn.Module): r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and Question Answering" `_ paper. Args: llm (LLM): The LLM to use. gnn (torch.nn.Module): The GNN to use. use_lora (bool, optional): If set to :obj:`True`, will use LORA from :obj:`peft` for training the LLM, see `here `_ for details. (default: :obj:`False`) mlp_out_tokens (int, optional): Number of LLM prefix tokens to reserve for GNN output. (default: :obj:`1`) .. warning:: This module has been tested with the following HuggingFace models * :obj:`llm_to_use="meta-llama/Meta-Llama-3.1-8B-Instruct"` * :obj:`llm_to_use="Qwen/Qwen3-0.6B"` This module should work with any HuggingFace model. See other models at `HuggingFace Models `_ and let us know if you encounter any issues. .. note:: For an example of using :class:`GRetriever`, see `examples/llm/g_retriever.py `_. """ def __init__( self, llm: LLM, gnn: torch.nn.Module = None, use_lora: bool = False, mlp_out_tokens: int = 1, ) -> None: super().__init__() self.llm = llm self.gnn = gnn.to(self.llm.device) if gnn is not None else None self.word_embedding = self.llm.word_embedding self.llm_generator = self.llm.llm if use_lora: from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, ) self.llm_generator = prepare_model_for_kbit_training( self.llm_generator) lora_r: int = 8 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_target_modules = ['q_proj', 'v_proj'] config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=lora_target_modules, lora_dropout=lora_dropout, bias='none', task_type='CAUSAL_LM', ) self.llm_generator = get_peft_model(self.llm_generator, config) if self.gnn is not None: mlp_out_channels = llm.word_embedding.embedding_dim mlp_hidden_channels = self.gnn.out_channels self.projector = torch.nn.Sequential( torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), torch.nn.Sigmoid(), torch.nn.Linear(mlp_hidden_channels, mlp_out_channels * mlp_out_tokens), torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)), ).to(self.llm.device) self.seq_length_stats = [] def encode( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor], ) -> Tensor: x = x.to(self.llm.device) edge_index = edge_index.to(self.llm.device) if edge_attr is not None: edge_attr = edge_attr.to(self.llm.device) batch = batch.to(self.llm.device) model_specific_kwargs = {} # duck typing for SGFormer to get around circular import if (hasattr(self.gnn, 'trans_conv') and hasattr(self.gnn, 'graph_conv')): model_specific_kwargs['batch'] = batch else: model_specific_kwargs['edge_attr'] = edge_attr out = self.gnn(x, edge_index, **model_specific_kwargs) return scatter(out, batch, dim=0, reduce='mean') def forward( self, question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, label: List[str], edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None, ): r"""The forward pass. Args: question (List[str]): The questions/prompts. x (torch.Tensor): The input node features. edge_index (torch.Tensor): The edge indices. batch (torch.Tensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. label (List[str]): The answers/labels. edge_attr (torch.Tensor, optional): The edge features (if supported by the GNN). (default: :obj:`None`) additional_text_context (List[str], optional): Additional context to give to the LLM, such as textified knowledge graphs. (default: :obj:`None`) """ xs = None if self.gnn is not None: x = self.encode(x, edge_index, batch, edge_attr) x = self.projector(x) x = self._align_dtype(x, self.llm_generator) xs = x.split(1, dim=0) # Handle case where theres more than one embedding for each sample xs = [x.squeeze(0) for x in xs] # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) if len(batch_unique) < batch_size: xs = [ xs[i] if i in batch_unique else None for i in range(batch_size) ] ( inputs_embeds, attention_mask, label_input_ids, ) = self.llm._get_embeds(question, additional_text_context, xs, label) max_seq_len = inputs_embeds.size(1) self.seq_length_stats.append(max_seq_len) with self.llm.autocast_context: outputs = self.llm_generator( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=label_input_ids, ) return outputs.loss @torch.no_grad() def inference( self, question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None, max_out_tokens: Optional[int] = MAX_NEW_TOKENS, ): r"""The inference pass. Args: question (List[str]): The questions/prompts. x (torch.Tensor): The input node features. edge_index (torch.Tensor): The edge indices. batch (torch.Tensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. edge_attr (torch.Tensor, optional): The edge features (if supported by the GNN). (default: :obj:`None`) additional_text_context (List[str], optional): Additional context to give to the LLM, such as textified knowledge graphs. (default: :obj:`None`) max_out_tokens (int, optional): How many tokens for the LLM to generate. (default: :obj:`32`) """ xs = None if self.gnn is not None: x = self.encode(x, edge_index, batch, edge_attr) x = self.projector(x) xs = x.split(1, dim=0) # Handle case where theres more than one embedding for each sample xs = [x.squeeze(0) for x in xs] # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) if len(batch_unique) < batch_size: xs = [ xs[i] if i in batch_unique else None for i in range(batch_size) ] inputs_embeds, attention_mask, _ = self.llm._get_embeds( question, additional_text_context, xs) # bos_token = self.llm.tokenizer( # self.llm.tokenizer.bos_token_id, # add_special_tokens=False, # ).input_ids[0] with self.llm.autocast_context: outputs = self.llm_generator.generate( inputs_embeds=inputs_embeds, max_new_tokens=max_out_tokens, attention_mask=attention_mask, bos_token_id=self.llm.tokenizer.bos_token_id, pad_token_id=self.llm.tokenizer.eos_token_id, use_cache=True # Important to set! ) return self.llm.tokenizer.batch_decode( outputs, skip_special_tokens=True, ) def __repr__(self) -> str: return (f'{self.__class__.__name__}(\n' f' llm={self.llm},\n' f' gnn={self.gnn},\n' f')') def _align_dtype( self, x: torch.Tensor, llm_generator: torch.nn.Module, ) -> torch.Tensor: llm_dtype = next(iter(llm_generator.parameters())).dtype if x.dtype != llm_dtype: x = x.to(llm_dtype) return x ================================================ FILE: torch_geometric/llm/models/git_mol.py ================================================ from typing import List, Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential from torch_geometric.llm.models import SentenceTransformer, VisionTransformer from torch_geometric.nn import GINEConv from torch_geometric.utils import add_self_loops, to_dense_batch class GraphEncoder(torch.nn.Module): def __init__( self, num_layers: int, in_channels: int, dropout: float = 0., num_atom_type: int = 120, num_chirality_tag: int = 3, num_bond_type: int = 6, num_bond_direction: int = 3, ) -> None: super().__init__() self.num_layers = num_layers self.dropout = dropout self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels) self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels) self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels) self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels) self.gnns = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for _ in range(num_layers): self.gnns.append( GINEConv( nn=Sequential( Linear(in_channels, in_channels * 2), ReLU(), Linear(in_channels * 2, in_channels), ), train_eps=True, edge_dim=in_channels, )) self.batch_norms.append(BatchNorm1d(in_channels)) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.x_embed1.weight.data) torch.nn.init.xavier_uniform_(self.x_embed2.weight.data) torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data) torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data) def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Tensor, ) -> Tensor: x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long()) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=0, num_nodes=x.size(0), ) edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2( edge_attr[:, 1]) for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)): x = gnn(x, edge_index, edge_attr) x = bn(x) if i < self.num_layers - 1: x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) x, mask = to_dense_batch(x, batch) return x, mask class GITFormer(torch.nn.Module): def __init__( self, num_query_token: int, vision_graph_width: int, cross_attention_freq: int = 2, ): super().__init__() from transformers import AutoConfig, AutoModel config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased") config.encoder_width = vision_graph_width # insert cross-attention layer every other block config.add_cross_attention = True config.is_decoder = True config.cross_attention_freq = cross_attention_freq config.query_length = num_query_token self.Qformer = AutoModel.from_pretrained( "allenai/scibert_scivocab_uncased", config=config) self.query_tokens = torch.nn.Parameter( torch.zeros(1, num_query_token, config.hidden_size)) self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range) class GITMol(torch.nn.Module): r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text" `_ paper. .. note:: For an example of using :class:`GITMol`, see `examples/llm/git_mol.py `_. """ def __init__(self) -> None: super().__init__() # graph self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16) self.graph_proj = Linear(16, 768) self.ln_graph = LayerNorm(768) # text self.text_encoder = SentenceTransformer( model_name='allenai/scibert_scivocab_uncased', pooling_strategy='last_hidden_state', ) self.text_proj = Linear(768, 768) self.ln_text = LayerNorm(768) # vision self.vision_encoder = VisionTransformer( model_name='microsoft/swin-base-patch4-window7-224', ) self.vision_proj = Linear(1024, 768) self.ln_vision = LayerNorm(768) # cross-attention self.gitformer = GITFormer(384, 768) self.xtm_head = torch.nn.ModuleDict({ 'image': Linear(self.gitformer.Qformer.config.hidden_size, 2), 'graph': Linear(self.gitformer.Qformer.config.hidden_size, 2), 'cs_text': Linear(self.gitformer.Qformer.config.hidden_size, 2), }) self.xtc_proj = torch.nn.ModuleDict({ 'image': Linear(self.gitformer.Qformer.config.hidden_size, 768), 'graph': Linear(self.gitformer.Qformer.config.hidden_size, 768), 'cs_text': Linear(self.gitformer.Qformer.config.hidden_size, 768), }) self.temp = torch.nn.Parameter(0.07 * torch.ones([])) self.model_freeze() def model_freeze(self) -> None: for param in self.graph_encoder.parameters(): param.requires_grad = False for param in self.vision_encoder.parameters(): param.requires_grad = False def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor], smiles: List[str], images: Tensor, captions: List[str], ) -> Tensor: batch_size = len(smiles) x_vision = self.vision_encoder(images) x_vision = self.vision_proj(x_vision) x_vision = self.ln_vision(x_vision) # [bs, patch_len, d] vision_atts = torch.ones(x_vision.size()[:-1], dtype=torch.long).to(x_vision.device) vision_targets = torch.arange(batch_size).to(x_vision.device) x_graph, graph_atts = self.graph_encoder(x, edge_index, batch, edge_attr) x_graph = self.graph_proj(x_graph) x_graph = self.ln_graph(x_graph) # [bs, node_len, d] graph_targets = torch.arange(batch_size).to(x_graph.device) x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d] smiles_atts = torch.ones(x_smiles.size()[:-1], dtype=torch.long).to(x_smiles.device) smiles_targets = torch.arange(batch_size).to(x_smiles.device) caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501 captions) text_output = self.gitformer.Qformer( caption_input_ids, attention_mask=caption_attention_masks, return_dict=True, ) text_feat = F.normalize( self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1) loss = 0 for x_embed, x_atts, x_targets, modal in zip( [x_graph, x_smiles, x_vision], [graph_atts, smiles_atts, vision_atts], [graph_targets, smiles_targets, vision_targets], ['graph', 'cs_text', 'image'], ): loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat, modal) loss += self._calc_xtm_loss(x_embed, caption_input_ids, caption_attention_masks, modal) return loss / 6 def _calc_xtm_loss( self, x_embeds: Tensor, input_ids: Tensor, attention_mask: Tensor, modal: str, ) -> Tensor: # Initializing lists to hold the original and negative samples x_embeds_list = [] text_input_ids_list = [] text_attention_mask_list = [] batch_size = x_embeds.size(0) for i in range(batch_size): # Original samples x_embeds_list.append(x_embeds[i]) text_input_ids_list.append(input_ids[i, :]) text_attention_mask_list.append(attention_mask[i, :]) if batch_size > 1: # Negative samples (neg_text_input_ids corresponds to x_embeds) neg_text_input_ids = input_ids[i - 1 if i == batch_size - 1 else i + 1, :] neg_text_attention_mask = attention_mask[i - 1 if i == batch_size - 1 else i + 1, :] text_input_ids_list.append(neg_text_input_ids) text_attention_mask_list.append(neg_text_attention_mask) x_embeds_list.append(x_embeds[i, :]) # Negative samples (text_input_ids corresponds to neg_x_embeds) neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i + 1, :] x_embeds_list.append(neg_x_embeds) text_input_ids_list.append(input_ids[i, :]) text_attention_mask_list.append(attention_mask[i, :]) # Stack all samples into two large tensors x_embeds_all = torch.stack(x_embeds_list, dim=1) \ .reshape(-1, x_embeds.size(1), x_embeds.size(2)) text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \ .reshape(-1, input_ids.size(1)) # Create image attention masks for the concatenated tensor image_attns_all = torch.ones(x_embeds_all.size()[:-1], dtype=torch.long).to(x_embeds_all.device) query_tokens_xtm = self.gitformer.query_tokens.expand( text_input_ids_all.shape[0], -1, -1) query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1], dtype=torch.long).to(x_embeds_all.device) output_xtm = self.gitformer.Qformer( inputs_embeds=query_tokens_xtm, attention_mask=query_attns_xtm, encoder_hidden_states=x_embeds_all, encoder_attention_mask=image_attns_all, return_dict=True, ).last_hidden_state xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :] xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1) # Create labels: 1 for the original samples, 0 for the negative samples if batch_size > 1: labels = torch.cat( [torch.ones(batch_size), torch.zeros(batch_size * 2)], dim=0) else: labels = torch.ones(batch_size) labels = labels.long().to(xtm_logit.device) # Calculate cross entropy loss return F.cross_entropy(xtm_logit, labels) def _calc_xtc_loss( self, x_embeds: Tensor, x_atts: Tensor, x_targets: Tensor, text_feat: Tensor, modal: str, ) -> Tensor: query_tokens = self.gitformer.query_tokens.expand( x_embeds.shape[0], -1, -1) query_output = self.gitformer.Qformer( inputs_embeds=query_tokens, encoder_hidden_states=x_embeds, encoder_attention_mask=x_atts, return_dict=True, ).last_hidden_state x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1) sim_q2t = torch.matmul( x_feats.unsqueeze(1), text_feat.unsqueeze(-1), ).squeeze(-1) # modal-text similarity: aggregate across all query tokens sim_x2t, _ = sim_q2t.max(-1) sim_x2t = sim_x2t / self.temp # text-query similarity sim_t2q = torch.matmul( text_feat.unsqueeze(1).unsqueeze(1), x_feats.permute(0, 2, 1), ).squeeze(-2) # text-modal similarity: aggregate across all query tokens sim_t2x, _ = sim_t2q.max(-1) sim_t2x = sim_t2x / self.temp loss_itc = ( F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) + F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2 return loss_itc ================================================ FILE: torch_geometric/llm/models/glem.py ================================================ from typing import List, Optional, Union import torch import torch.nn as nn from tqdm import tqdm from torch_geometric.loader import DataLoader, NeighborLoader from torch_geometric.nn.models import GraphSAGE, basic_gnn def deal_nan(x): if isinstance(x, torch.Tensor): x = x.clone() x[torch.isnan(x)] = 0.0 return x class GLEM(torch.nn.Module): r"""This GNN+LM co-training model is based on GLEM from the `"Learning on Large-scale Text-attributed Graphs via Variational Inference" `_ paper. Args: lm_to_use (str): A TextEncoder from huggingface model repo with a classifier(default: TinyBERT) gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE) out_channels (int): output channels for LM and GNN, should be same num_gnn_heads Optional[int]: Number of heads for attention, if needed num_gnn_layers (int): number of gnn layers gnn_loss: loss function for gnn, (default: CrossEntropyLoss) lm_loss: loss function for Language Model, (default: CrossEntropyLoss) alpha (float): pseudo label weight of E-step, LM optimization, (default: 0.5) beta (float): pseudo label weight of M-step, GNN optimization, (default: 0.5) lm_dtype (torch.dtype): the data type once you load LM into memory, (default: torch.bfloat16) lm_use_lora (bool): choose if LM use Lora peft for fine tune, (default: True) lora_target_modules: The names of the target modules to apply the lora adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None) .. note:: See `examples/llm_plus_gnn/glem.py` for example usage. """ def __init__( self, lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: basic_gnn = GraphSAGE, out_channels: int = 47, gnn_loss: Optional[nn.Module] = None, lm_loss: Optional[nn.Module] = None, alpha: float = 0.5, beta: float = 0.5, lm_dtype: torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: Optional[Union[List[str], str]] = None, device: Optional[Union[str, torch.device]] = None, ): super().__init__() if gnn_loss is None: gnn_loss = nn.CrossEntropyLoss(reduction='mean') if lm_loss is None: lm_loss = nn.CrossEntropyLoss(reduction='mean') if device is None: device = torch.device('cpu') self.device = device self.lm_loss = lm_loss self.gnn = gnn_to_use self.gnn_loss = gnn_loss self.alpha = alpha self.beta = beta self.gnn_loss = gnn_loss self.lm = lm_to_use from transformers import AutoModelForSequenceClassification self.lm = AutoModelForSequenceClassification.from_pretrained( lm_to_use, num_labels=out_channels, dtype=lm_dtype, offload_folder="offload", trust_remote_code=True) if lm_use_lora: from peft import ( LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, ) print("Training LM with LORA!") self.lm = prepare_model_for_kbit_training(self.lm) config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16, lora_alpha=16, lora_dropout=0.05, bias="none", target_modules=lora_target_modules) self.lm = get_peft_model(self.lm, config) self.lm.print_trainable_parameters() self.lm.config.pad_token_id = self.lm.config.eos_token_id self.lm_device = self.lm.device if self.lm.num_labels != self.gnn.out_channels: raise ValueError('''The output channel of language model \ and gnn should be the same''') def pre_train_gnn(self, train_loader: NeighborLoader, optimizer: torch.optim.Optimizer, num_epochs: int, patience: int, ext_pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): # Pretrain GNN, optional steps if you do not have pseudo labels. best_acc = 0 early_stopping = 0 # training only based on gold data for epoch in range(0, num_epochs): acc, loss = self.train_gnn(train_loader, optimizer, epoch, ext_pseudo_labels, is_augmented, verbose) if acc < best_acc: early_stopping += 1 if early_stopping > patience: print(f'Early stopped by Epoch: {epoch}, ' f'Best acc: {best_acc}') break best_acc = max(best_acc, acc) def pre_train_lm(self, train_loader: DataLoader, optimizer: torch.optim.Optimizer, num_epochs: int, patience: int, ext_pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): # Pretrain language model best_acc = 0 early_stopping = 0 for epoch in range(1, num_epochs + 1): acc, loss = self.train_lm(train_loader, optimizer, epoch, ext_pseudo_labels, is_augmented, verbose) if acc < best_acc: early_stopping += 1 if early_stopping > patience: print(f'Early stopped by Epoch: {epoch}, ' f'Best acc: {best_acc}') break best_acc = max(best_acc, acc) def train(self, em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False): r"""GLEM training step, EM steps. Args: em_phase(str): 'gnn' or 'lm' choose which phase you are training on train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for lm training, include tokenized data, labels is_gold mask. use NeighborLoader for gnn training, include x, edge_index. optimizer (torch.optim.Optimizer): optimizer for training pseudo_labels(torch.Tensor): the predicted labels used as pseudo labels epoch (int): current epoch is_augmented (bool): will use pseudo_labels or not verbose (bool): print training progress bar or not Returns: acc (float): training accuracy loss (float): loss value """ if pseudo_labels is not None: pseudo_labels = pseudo_labels.to(self.device) if em_phase == 'gnn': acc, loss = self.train_gnn(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose) if em_phase == 'lm': acc, loss = self.train_lm(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose) return acc, loss def train_lm(self, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): r"""Language model Training in every epoch. Args: train_loader (loader.dataloader.DataLoader): text token dataloader optimizer (torch.optim.Optimizer): model optimizer epoch (int): current train epoch pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn is_augmented (bool): train with pseudo labels or not verbose (bool): print training progress bar or not Returns: approx_acc (torch.tensor): training accuracy loss (torch.float): loss value """ all_out = [] total_loss = total_correct = 0 num_nodes = train_loader.dataset.indices.size(0) self.lm.train() if verbose: pbar = tqdm(total=num_nodes) pbar.set_description(f'Epoch {epoch:02d}') for batch in train_loader: inputs = {k: v.to(self.device) for k, v in batch['input'].items()} out = self.lm(**inputs).logits labels = batch['labels'].to(self.device).squeeze() # training with pseudo labels or not if is_augmented: pl_batch = pseudo_labels[batch['n_id']].to(self.device) else: pl_batch = None loss = self.loss(out, labels, self.lm_loss, batch['is_gold'].to(self.device), pl_batch, self.alpha, is_augmented) loss.backward() optimizer.step() optimizer.zero_grad() all_out.append(out) total_correct += int(out.argmax(dim=-1).eq(labels).sum()) total_loss += float(loss.detach()) if verbose: pbar.update(batch['n_id'].size(0)) all_out = torch.cat(all_out, dim=0) approx_acc = total_correct / num_nodes loss = total_loss / len(train_loader) if verbose: pbar.close() print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' f'Approx. Train: {approx_acc:.4f}') return approx_acc, loss def train_gnn(self, train_loader: NeighborLoader, optimizer: torch.optim.Optimizer, epoch: int, pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): r"""GNN training step in every epoch. Args: train_loader (loader.NeighborLoader): gnn Neighbor node loader optimizer (torch.optim.Optimizer): model optimizer epoch (int): current train epoch pseudo_labels(torch.tensor): 1-D tensor, predictions from lm is_augmented(bool): use pseudo labeled node or not verbose (bool): print training progress or not Returns: approx_acc (torch.tensor): training accuracy loss (torch.float): loss value """ self.gnn.train() num_nodes = train_loader.input_nodes.size(0) if verbose: pbar = tqdm(total=num_nodes) pbar.set_description(f'Epoch {epoch:02d}') total_loss = total_correct = 0 all_out = [] for batch in train_loader: batch = batch.to(self.device) out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] all_out.append(out) labels = batch.y[:batch.batch_size].squeeze() is_gold_batch = batch.is_gold[:batch.batch_size].squeeze() # training with pseudo labels or not if is_augmented and pseudo_labels is not None: pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]] else: pl_batch = None loss = self.loss(out, labels, self.gnn_loss, is_gold_batch, pl_batch, self.beta, is_augmented) loss.backward() optimizer.step() optimizer.zero_grad() total_loss += float(loss.detach()) total_correct += int(out.argmax(dim=-1).eq(labels).sum()) if verbose: pbar.update(batch.batch_size) all_out = torch.cat(all_out, dim=0) loss = total_loss / len(train_loader) approx_acc = total_correct / num_nodes if verbose: pbar.close() print(f'Epoch: {epoch:02d} Loss: {loss:.4f} ' f'Approx. Train: {approx_acc:.4f}') return approx_acc, loss @torch.no_grad() def inference(self, em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False): r"""GLEM inference step. Args: em_phase(str): 'gnn' or 'lm' data_loader(dataloader or Neighborloader): dataloader: for lm training, include tokenized data nodeloader: for gnn training, include x, edge_index verbose(bool): print inference progress or not Returns: out (torch.Tensor): n * m tensor, m is number of classes, n is number of nodes """ out = None if em_phase == 'gnn': self.gnn.eval() out = self.inference_gnn(data_loader, verbose) elif em_phase == 'lm': self.lm.eval() out = self.inference_lm(data_loader, verbose) return out @torch.no_grad() def inference_lm(self, data_loader: DataLoader, verbose: bool = True): r"""LM inference step. Args: data_loader (Dataloader): include token, labels, and gold mask verbose (bool): print progress bar or not Returns: preds (tensor): prediction from GNN, convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) """ if verbose: pbar = tqdm(total=data_loader.dataset._data.num_nodes) pbar.set_description('LM inference stage') self.lm.eval() preds = [] for batch in data_loader: inputs = {k: v.to(self.device) for k, v in batch['input'].items()} logits = self.lm(**inputs).logits preds.append(logits) if verbose: pbar.update(batch['n_id'].size(0)) if verbose: pbar.close() preds = torch.cat(preds) return preds @torch.no_grad() def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): r"""GNN inference step. Args: data_loader(NeighborLoader): include x, edge_index, verbose (bool): print progress bar or not Returns: preds (tensor): prediction from GNN, convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) """ if verbose: pbar = tqdm(total=data_loader.data.num_nodes) pbar.set_description('GNN inference stage') preds = [] self.gnn.eval() for batch in data_loader: batch = batch.to(self.device) out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] preds.append(out) if verbose: pbar.update(batch.batch_size) if verbose: pbar.close() preds = torch.cat(preds, dim=0) return preds def loss(self, logits: torch.Tensor, labels: torch.Tensor, loss_func: torch.nn.functional, is_gold: torch.Tensor, pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5, is_augmented: bool = True): r"""Core function of variational EM inference, this function is aming on combining loss value on gold(original train) and loss value on pseudo labels. Reference: # noqa Args: logits(torch.tensor): predict results from LM or GNN labels(torch.tensor): combined node labels from ground truth and pseudo labels(if provided) loss_func(torch.nn.modules.loss): loss function for classification is_gold(tensor): a tensor with bool value that mask ground truth label and during training, thus ~is_gold mask pseudo labels pseudo_labels(torch.tensor): predictions from other model pl_weight: the pseudo labels used in E-step and M-step optimization alpha in E-step, beta in M-step respectively is_augmented: use EM or just train GNN and LM with gold data """ if is_augmented and (sum(~is_gold) > 0): mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold])) # all other labels beside from ground truth(gold labels) pseudo_label_loss = deal_nan( loss_func(logits[~is_gold], pseudo_labels[~is_gold])) loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss else: loss = loss_func(logits, labels) return loss ================================================ FILE: torch_geometric/llm/models/llm.py ================================================ import warnings from contextlib import nullcontext from typing import Any, Dict, List, Optional import torch from torch import Tensor try: from transformers.tokenization_utils_base import BatchEncoding except ImportError: BatchEncoding = Dict IGNORE_INDEX = -100 MAX_TXT_LEN = 512 MAX_NEW_TOKENS = 128 PAD_TOKEN_ID = 0 PADDING_SIDE = 'left' # legacy constants - used for Llama 2 style prompting BOS = '[INST]' EOS_USER = '[/INST]' EOS = '[/s]' def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]: torch.cuda.empty_cache() gpu_memory: List[int] = [] for i in range(torch.cuda.device_count()): gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3) # Use the minimum number of GPUs to fit the LLM on. if sum(gpu_memory) >= required_memory: break if sum(gpu_memory) < required_memory: gpu_memory = [] # If not enough VRAM, use pure CPU. kwargs = dict(revision='main') if len(gpu_memory) > 0: kwargs['max_memory'] = { i: f'{memory}GiB' for i, memory in enumerate(gpu_memory) } kwargs['low_cpu_mem_usage'] = True kwargs['device_map'] = 'auto' kwargs['dtype'] = dtype return kwargs class LLM(torch.nn.Module): r"""A wrapper around a Large Language Model (LLM) from HuggingFace. Args: model_name (str): The HuggingFace model name num_params (float, optional): An integer representing how many params the HuggingFace model has, in billions. This is used to automatically allocate the correct number of GPUs needed (using a rough heuristic), given the available GPU memory of your GPUs. If not specified, the number of parameters is determined using the `huggingface_hub` module. n_gpus (int, optional): Number of GPUs to use. Designed for advanced users to select how many GPU's they want to set this manually and override the automatic set up mechanism. dtype (torch.dtype, optional): The data type to use for the LLM. (default :obj: `torch.bfloat16`) sys_prompt (str, optional): A system prompt to use for the LLM. (default: :obj: `None`) """ def __init__( self, model_name: str, num_params: Optional[float] = None, n_gpus: Optional[int] = None, dtype: Optional[torch.dtype] = torch.bfloat16, sys_prompt: Optional[str] = None, ) -> None: super().__init__() self.model_name = model_name from transformers import AutoModelForCausalLM, AutoTokenizer if n_gpus is None: if num_params is None: from huggingface_hub import get_safetensors_metadata safetensors_metadata = get_safetensors_metadata(model_name) param_count = safetensors_metadata.parameter_count num_params = float(list(param_count.values())[0] // 10**9) # A rough heuristic on GPU memory requirements, e.g., we found that # LLAMA3 (8B parameters) fits on a 96GB GPU. required_memory = 96.0 * num_params / 8.0 kwargs = get_llm_kwargs(required_memory, dtype) else: gpu_memory: List[int] = [] for i in range(n_gpus): gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3) kwargs = dict(revision='main') kwargs['max_memory'] = { i: f'{memory}GiB' for i, memory in enumerate(gpu_memory) } kwargs['low_cpu_mem_usage'] = True kwargs['device_map'] = 'auto' kwargs['dtype'] = dtype print(f"Setting up '{model_name}' with configuration: {kwargs}") self.tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=False, ) if self.tokenizer.chat_template and self.tokenizer.bos_token is None: dummy_convo = [ { "role": "system", "content": "dummy" }, { "role": "user", "content": "convo" }, ] text = self.tokenizer.apply_chat_template( dummy_convo, tokenize=True, ) self.tokenizer.bos_token = self._safe_decode(self.tokenizer, text) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = PAD_TOKEN_ID if self.tokenizer.padding_side is None: self.tokenizer.padding_side = PADDING_SIDE self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) self.llm = self.llm.to(dtype) self.word_embedding = self.llm.model.get_input_embeddings() if sys_prompt is not None: self.sys_prompt = sys_prompt else: self.sys_prompt = "" if 'max_memory' not in kwargs: # Pure CPU: warnings.warn( "LLM is being used on CPU, which may be slow. This decision " "was made by a rough hueristic that assumes your GPU set up " "does not have enough GPU RAM. This is done to avoid GPU OOM " "errors. If you think this is a mistake, please initialize " "your LLM with the n_gpus param to dictate how many gpus to " "use for the LLM.", stacklevel=2) self.device = torch.device('cpu') self.autocast_context = nullcontext() else: self.device = self.llm.device if dtype == torch.float32: self.autocast_context = nullcontext() else: self.autocast_context = torch.amp.autocast('cuda', dtype=dtype) @staticmethod def _safe_decode(tokenizer, tokens) -> str: """Decode token IDs from various Hugging Face tokenizer outputs. Supports: - list[int] - list[list[int]] - BatchEncoding - tokenizers.Encoding """ if isinstance(tokens, dict): tokens = tokens.get("input_ids", tokens) if hasattr(tokens, "ids"): tokens = tokens.ids if isinstance(tokens, list) and tokens and isinstance(tokens[0], list): tokens = tokens[0] return tokenizer.decode(tokens) # legacy function - used for Llama 2 style prompting def _encode_inputs( self, question: List[str], context: Optional[List[str]] = None, ) -> tuple: batch_size = len(question) questions = self.tokenizer(question, add_special_tokens=False) if context is not None: context = self.tokenizer(context, add_special_tokens=False) eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False) bos_token = self.tokenizer( BOS, add_special_tokens=False, return_tensors='pt', ).input_ids[0].to(self.device) bos_embeds = self.word_embedding(bos_token) pad_token = torch.tensor(self.tokenizer.pad_token_id, device=self.device) pad_embeds = self.word_embedding(pad_token).unsqueeze(0) return (batch_size, questions, context, eos_user_tokens, bos_embeds, pad_embeds) def _label_input_ids( self, i: int, label: BatchEncoding, eos_tokens: BatchEncoding, ) -> List[int]: label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS] label_input_ids = label_input_ids + eos_tokens.input_ids return label_input_ids # legacy function - used for Llama 2 style prompting def _input_ids( self, i: int, context: BatchEncoding, question: BatchEncoding, eos_user_tokens: BatchEncoding, ) -> List[int]: input_ids: List[int] = [] if context is not None: input_ids += context.input_ids[i][:MAX_TXT_LEN] input_ids += question.input_ids[i] input_ids += eos_user_tokens.input_ids return input_ids # legacy function - used for Llama 2 style prompting def _inputs_embeds( self, i: int, input_ids: List[int], bos_embeds: Tensor, embedding: Optional[List[Tensor]] = None, ) -> Tensor: inputs_embeds = self.word_embedding( torch.tensor(input_ids, device=self.device)) to_cat = [bos_embeds] if embedding is not None and embedding[i] is not None: to_cat.append(embedding[i]) to_cat.append(inputs_embeds) return torch.cat(to_cat, dim=0).to(self.device) def _append_embeds( self, inputs_embeds: Tensor, batch_inputs_embeds: List[Tensor], batch_attention_mask: List[List[int]], label_input_ids: List[int] = None, batch_label_input_ids: Optional[List[List[int]]] = None, ) -> tuple: batch_inputs_embeds.append(inputs_embeds) batch_attention_mask.append([1] * inputs_embeds.size(0)) if label_input_ids is not None: pad = inputs_embeds.size(0) - len(label_input_ids) label_input_ids = [IGNORE_INDEX] * pad + label_input_ids batch_label_input_ids.append(label_input_ids) return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids def _pad_embeds( self, pad_embeds: Tensor, batch_inputs_embeds: List[Tensor], batch_attention_mask: List[List[int]], batch_label_input_ids: Optional[List[List[int]]] = None, ) -> tuple: max_length = max([x.size(0) for x in batch_inputs_embeds]) batch_size = len(batch_inputs_embeds) for i in range(batch_size): pad = max_length - batch_inputs_embeds[i].size(0) batch_inputs_embeds[i] = torch.cat([ pad_embeds.repeat(pad, 1), batch_inputs_embeds[i], ]) batch_attention_mask[i] = [0] * pad + batch_attention_mask[i] if batch_label_input_ids is not None: tmp = [IGNORE_INDEX] * pad + batch_label_input_ids[i] batch_label_input_ids[i] = tmp inputs_embeds = torch.stack(batch_inputs_embeds, dim=0) attention_mask = torch.tensor(batch_attention_mask, device=self.device) label_input_ids = None if batch_label_input_ids is not None: label_input_ids = torch.tensor(batch_label_input_ids, device=self.device) return inputs_embeds, attention_mask, label_input_ids # legacy function - used for Llama 2 style prompting def _get_embeds_old( self, question: List[str], context: Optional[List[str]] = None, embedding: Optional[List[Tensor]] = None, answer: Optional[List[str]] = None, ) -> tuple: (batch_size, question, context, eos_user_tokens, bos_embeds, pad_embeds) = self._encode_inputs(question, context) batch_label_input_ids = None if answer is not None: label = self.tokenizer(answer, add_special_tokens=False) eos_tokens = self.tokenizer(EOS, add_special_tokens=False) batch_label_input_ids = [] batch_inputs_embeds = [] batch_attention_mask = [] for i in range(batch_size): input_ids = self._input_ids(i, context, question, eos_user_tokens) if answer is not None: label_input_ids = self._label_input_ids(i, label, eos_tokens) input_ids += label_input_ids else: label_input_ids = None inputs_embeds = self._inputs_embeds(i, input_ids, bos_embeds, embedding) ( batch_inputs_embeds, batch_attention_mask, batch_label_input_ids, ) = self._append_embeds( inputs_embeds, batch_inputs_embeds, batch_attention_mask, label_input_ids, batch_label_input_ids, ) inputs_embeds, attention_mask, label_input_ids = self._pad_embeds( pad_embeds, batch_inputs_embeds, batch_attention_mask, batch_label_input_ids) return inputs_embeds, attention_mask, label_input_ids def _get_embeds( self, question: List[str], context: Optional[List[str]] = None, embedding: Optional[List[Tensor]] = None, answer: Optional[List[str]] = None, ) -> tuple: if not self.tokenizer.chat_template or not self.sys_prompt: warnings.warn( f"HuggingFace model {self.model_name} is not using a " "chat template, using Llama 2 style prompting. Please " "consider using a more recent model and initialize the " "LLM with `sys_prompt`.", stacklevel=2) return self._get_embeds_old(question, context, embedding, answer) batch_label_input_ids = None if answer is not None: label = self.tokenizer(answer, add_special_tokens=False) eos_tokens = self.tokenizer(self.tokenizer.eos_token, add_special_tokens=False) batch_label_input_ids = [] batch_inputs_embeds = [] batch_attention_mask = [] for i in range(len(question)): ctx = f"{context[i]} - " if context else "" messages = [ { "role": "system", "content": self.sys_prompt }, { "role": "user", "content": f"{ctx} - {question[i]}" }, ] text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=True, ) text = text[len(self.tokenizer.bos_token):] input_ids = self.tokenizer(text, add_special_tokens=False).input_ids if answer is not None: label_input_ids = self._label_input_ids(i, label, eos_tokens) input_ids += label_input_ids else: label_input_ids = None bos_token = self.tokenizer( self.tokenizer.bos_token, add_special_tokens=False, return_tensors='pt', ).input_ids[0].to(self.device) bos_embeds = self.word_embedding(bos_token) inputs_embeds = self.word_embedding( torch.tensor(input_ids, device=self.device)) to_cat = [bos_embeds] if embedding is not None and embedding[i] is not None: to_cat.append(embedding[i]) to_cat.append(inputs_embeds) inputs_embeds = torch.cat(to_cat, dim=0).to(self.device) ( batch_inputs_embeds, batch_attention_mask, batch_label_input_ids, ) = self._append_embeds( inputs_embeds, batch_inputs_embeds, batch_attention_mask, label_input_ids, batch_label_input_ids, ) pad_token = torch.tensor(self.tokenizer.pad_token_id, device=self.device) pad_embeds = self.word_embedding(pad_token).unsqueeze(0) inputs_embeds, attention_mask, label_input_ids = self._pad_embeds( pad_embeds, batch_inputs_embeds, batch_attention_mask, batch_label_input_ids) return inputs_embeds, attention_mask, label_input_ids def forward( self, question: List[str], answer: List[str], context: Optional[List[str]] = None, embedding: Optional[List[Tensor]] = None, ) -> Tensor: r"""The forward pass. Args: question (list[str]): The questions/prompts. answer (list[str]): The answers/labels. context (list[str], optional): Additional context to give to the LLM, such as textified knowledge graphs. (default: :obj:`None`) embedding (list[torch.Tensor], optional): RAG embedding tensors, *i.e.* the embedded form of :obj:`context`. Either :obj:`context` or :obj:`embedding` should be used, not both. (default: :obj:`None`) """ inputs_embeds, attention_mask, label_input_ids = self._get_embeds( question, context, embedding, answer) with self.autocast_context: outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=label_input_ids, ) return outputs.loss @torch.no_grad() def inference( self, question: List[str], context: Optional[List[str]] = None, embedding: Optional[List[Tensor]] = None, max_tokens: Optional[int] = MAX_NEW_TOKENS, ) -> List[str]: r"""The inference pass. Args: question (list[str]): The questions/prompts. answer (list[str]): The answers/labels. context (list[str], optional): Additional context to give to the LLM, such as textified knowledge graphs. (default: :obj:`None`) embedding (list[torch.Tensor], optional): RAG embedding tensors, *i.e.* the embedded form of :obj:`context`. Either :obj:`context` or :obj:`embedding` should be used, not both. (default: :obj:`None`) max_tokens (int, optional): How many tokens for the LLM to generate. (default: :obj:`32`) """ inputs_embeds, attention_mask, _ = self._get_embeds( question, context, embedding) with self.autocast_context: outputs = self.llm.generate( inputs_embeds=inputs_embeds, bos_token_id=self.tokenizer.bos_token_id, max_new_tokens=max_tokens, attention_mask=attention_mask, pad_token_id=self.tokenizer.eos_token_id, use_cache=True, ) return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.model_name})' ================================================ FILE: torch_geometric/llm/models/llm_judge.py ================================================ from math import isnan from typing import Optional from torch_geometric.llm.models.txt2kg import \ _chunk_to_triples_str_cloud as call_NIM # Credit for original "Marlin Accuracy" system goes to: # Gilberto Titericz (NVIDIA) # This work is an adaptation of his for PyG SYSTEM_PROMPT_1 = ( "Instruction: You are a world class state of the art " + "assistant for rating " + "a User Answer given a Question. The Question is completely" + " answered by the Reference Answer.\n" + "Say 4, if User Answer is full contained and equivalent to" + " Reference Answer" + "in all terms, topics, numbers, metrics, dates and units.\n" + "Say 2, if User Answer is partially contained and almost " + "equivalent to Reference Answer" + "in all terms, topics, numbers, metrics, dates and units.\n" + "Say 0, if User Answer is not contained in Reference Answer" + " or not accurate in all terms, topics," + "numbers, metrics, dates and units or the User Answer do not" + " answer the question.\n" + "Do not explain or justify your rating. Your rating must be " + "only 4, 2 or 0 according to the instructions above.\n" + "### Question: \"{question}\"\n" + "### User Answer: \"{model_pred}\"\n" + "### Reference Answer: \"{correct_answer}\"\n" + "The rating is:\n") SYSTEM_PROMPT_2 = ( "I will rate the User Answer in comparison to the Reference " + "Answer for a given Question.\n" + "A rating of 4 indicates that the User Answer is entirely " + "consistent with the Reference Answer, covering all aspects," + " topics, numbers, metrics, dates, and units.\n" + "A rating of 2 signifies that the User Answer is mostly " + "aligned with the Reference Answer, with minor discrepancies" + " in some areas.\n" + "A rating of 0 means that the User Answer is either " + "inaccurate, incomplete, or unrelated to the Reference " + "Answer, or it fails to address the Question.\n" + "I will provide the rating without any explanation or " + "justification, adhering to the following scale: " + "0 (no match), 2 (partial match), 4 (exact match).\n" + "Do not explain or justify my rating. My rating must" + " be only 4, 2 or 0 only.\n\n" + "Question: \"{question}\"\n\n" + "Reference Answer: \"{model_pred}\"\n\n" + "User Answer: \"{correct_answer}\"\n\n" + "Rating: ") # TODO: add support for Local LM # TODO: add multiproc support like txt2kg class LLMJudge(): """Uses NIMs to score a triple of (question, model_pred, correct_answer) This whole class is an adaptation of Gilberto's work for PyG. Args: NVIDIA_NIM_MODEL : (str, optional) The name of the NVIDIA NIM model to use. (default: "nvidia/llama-3.1-nemotron-70b-instruct"). NVIDIA_API_KEY : (str, optional) The API key for accessing NVIDIA's NIM models. (default: ""). ENDPOINT_URL : (str, optional) The URL hosting your model, in case you are not using the public NIM. (default: "https://integrate.api.nvidia.com/v1"). """ def __init__( self, NVIDIA_NIM_MODEL: Optional[ str] = "nvidia/llama-3.1-nemotron-70b-instruct", NVIDIA_API_KEY: Optional[str] = "", ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1", ) -> None: self.NVIDIA_API_KEY = NVIDIA_API_KEY self.NIM_MODEL = NVIDIA_NIM_MODEL self.ENDPOINT_URL = ENDPOINT_URL def _process_score(self, response: str) -> float: """Uses 3 and 1 even though prompt says only 0, 2, 4. This is because LLMs don't always follow instructions. Credit to Gilberto. """ for i in [4, 3, 2, 1, 0]: if str(i) in response: return i / 4 return float("nan") def _average_scores(self, score0: float, score1: float): """Take the average of score0 and score1. Sometimes the LLM fail to respond or have no score in the response. In those cases the failed score is discarded. Credit to Gilberto. Args: score0 (float): judge accuracy score. score1 (float): judge accuracy score by permuting agent answer and ground truth. Returns: (float) average of score0 and score1 of both contains scores, otherwise pick the max. """ score = float("nan") if score0 >= 0 and score1 >= 0: score = (score0 + score1) / 2 else: score = max(score0, score1) return score def score( self, question: str, model_pred: str, correct_answer: str, ) -> float: """Args: question (str): The original question asked to the model. model_pred (str): The prediction made by the model. correct_answer (str): The actual correct answer to the question. Returns: score (float): score of 0-1, may be nan due to LLM judge failure. Evals should skip nan's when aggregating score. """ prompt1 = SYSTEM_PROMPT_1.format(question=question, model_pred=model_pred, correct_answer=correct_answer) prompt2 = SYSTEM_PROMPT_2.format(question=question, model_pred=model_pred, correct_answer=correct_answer) score1 = float("nan") score2 = float("nan") for _retry in range(200): try: score1 = self._process_score( call_NIM(prompt1, self.NVIDIA_API_KEY, self.NIM_MODEL, self.ENDPOINT_URL, post_text="")) if not isnan(score1): break except ImportError: raise except: # noqa pass for _retry in range(20): try: score2 = self._process_score( call_NIM(prompt2, self.NVIDIA_API_KEY, self.NIM_MODEL, self.ENDPOINT_URL, post_text="")) if not isnan(score2): break except ImportError: raise except: # noqa pass return self._average_scores(score1, score2) ================================================ FILE: torch_geometric/llm/models/molecule_gpt.py ================================================ from typing import List, Optional import torch from torch import Tensor from torch_geometric.llm.models.llm import BOS, LLM, MAX_NEW_TOKENS from torch_geometric.nn.attention import QFormer from torch_geometric.utils import to_dense_batch def pad_or_truncate(embeddings: Tensor, max_seq_len: int, padding_value: int = 0) -> Tensor: batch_size, current_seq_len, d = embeddings.size() if current_seq_len > max_seq_len: return embeddings[:, :max_seq_len, :] elif current_seq_len < max_seq_len: pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d), padding_value, dtype=embeddings.dtype, device=embeddings.device) return torch.cat([embeddings, pad_tensor], dim=1) else: return embeddings class MoleculeGPT(torch.nn.Module): r"""The MoleculeGPT model from the `"MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction" `_ paper. Args: llm (LLM): The LLM to use. graph_encoder (torch.nn.Module): Encode 2D molecule graph. smiles_encoder (torch.nn.Module): Encode 1D SMILES. mlp_out_channels (int, optional): The size of each embedding after qformer encoding. (default: :obj:`32`) max_tokens (int, optional): Max output tokens of 1D/2D encoder. (default: :obj:`20`) .. warning:: This module has been tested with the following HuggingFace models * :obj:`llm_to_use="lmsys/vicuna-7b-v1.5"` and may not work with other models. See other models at `HuggingFace Models `_ and let us know if you encounter any issues. .. note:: For an example of using :class:`MoleculeGPT`, see `examples/llm/molecule_gpt.py `_. """ def __init__( self, llm: LLM, graph_encoder: torch.nn.Module, smiles_encoder: torch.nn.Module, mlp_out_channels: int = 32, max_tokens: Optional[int] = 20, ) -> None: super().__init__() self.llm = llm self.graph_encoder = graph_encoder.to(self.llm.device) self.smiles_encoder = smiles_encoder.to(self.llm.device) self.graph_qformer = QFormer( input_dim=self.graph_encoder.nn[-1].out_features, hidden_dim=mlp_out_channels, output_dim=mlp_out_channels, num_heads=4, num_layers=2, ).to(self.llm.device) self.smiles_qformer = QFormer( input_dim=self.smiles_encoder.model.pooler.dense.out_features, hidden_dim=mlp_out_channels, output_dim=mlp_out_channels, num_heads=4, num_layers=2, ).to(self.llm.device) self.max_tokens = max_tokens self.word_embedding = self.llm.word_embedding self.llm_generator = self.llm.llm # LLMs in_dim = 2 * mlp_out_channels * max_tokens out_dim = self.llm.llm.model.embed_tokens.embedding_dim self.projector = torch.nn.Sequential( torch.nn.Linear(in_dim, in_dim), torch.nn.Sigmoid(), torch.nn.Linear(in_dim, out_dim), ).to(self.llm.device) def encode( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor], smiles: List[str], ) -> Tensor: batch_size = len(smiles) # 2D Graph Branch: [bs, node_len, d] x = x.to(self.llm.device) edge_index = edge_index.to(self.llm.device) if edge_attr is not None: edge_attr = edge_attr.to(self.llm.device) batch = batch.to(self.llm.device) x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr) x_graph = to_dense_batch(x_graph, batch)[0] out_graph = self.graph_qformer(x_graph) out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens, padding_value=0) out_graph = out_graph.view(batch_size, -1) # 1D SMILES Branch: [bs, seq_len, d] x_smiles = self.smiles_encoder.encode(smiles, output_device=self.llm.device) out_smiles = self.smiles_qformer(x_smiles) out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens, padding_value=0) out_smiles = out_smiles.view(batch_size, -1) # Merge into LLMs x_cat = torch.cat([out_graph, out_smiles], dim=1) return x_cat def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor], smiles: List[str], instructions: List[str], label: List[str], additional_text_context: Optional[List[str]] = None, ): x = self.encode(x, edge_index, batch, edge_attr, smiles) x = self.projector(x) xs = x.split(1, dim=0) batch_unique = batch.unique() batch_size = len(instructions) if len(batch_unique) < batch_size: xs = [ xs[i] if i in batch_unique else None for i in range(batch_size) ] ( inputs_embeds, attention_mask, label_input_ids, ) = self.llm._get_embeds(instructions, additional_text_context, xs, label) with self.llm.autocast_context: outputs = self.llm_generator( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=label_input_ids, ) return outputs.loss @torch.no_grad() def inference( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor], smiles: List[str], instructions: List[str], additional_text_context: Optional[List[str]] = None, max_out_tokens: Optional[int] = MAX_NEW_TOKENS, ): x = self.encode(x, edge_index, batch, edge_attr, smiles) x = self.projector(x) xs = x.split(1, dim=0) # Handle questions without node features: batch_unique = batch.unique() batch_size = len(instructions) if len(batch_unique) < batch_size: xs = [ xs[i] if i in batch_unique else None for i in range(batch_size) ] inputs_embeds, attention_mask, _ = self.llm._get_embeds( instructions, additional_text_context, xs) bos_token = self.llm.tokenizer( BOS, add_special_tokens=False, ).input_ids[0] with self.llm.autocast_context: outputs = self.llm_generator.generate( inputs_embeds=inputs_embeds, max_new_tokens=max_out_tokens, attention_mask=attention_mask, bos_token_id=bos_token, use_cache=True # Important to set! ) return self.llm.tokenizer.batch_decode( outputs, skip_special_tokens=True, ) def __repr__(self) -> str: return (f'{self.__class__.__name__}(\n' f' llm={self.llm},\n' f' graph={self.graph_encoder.__class__.__name__},\n' f' smiles={self.smiles_encoder},\n' f')') ================================================ FILE: torch_geometric/llm/models/protein_mpnn.py ================================================ from itertools import product from typing import Tuple import torch import torch.nn.functional as F from torch_geometric.nn import knn_graph from torch_geometric.nn.conv import MessagePassing from torch_geometric.utils import to_dense_adj, to_dense_batch class PositionWiseFeedForward(torch.nn.Module): def __init__(self, in_channels: int, hidden_channels: int) -> None: super().__init__() self.out = torch.nn.Sequential( torch.nn.Linear(in_channels, hidden_channels), torch.nn.GELU(), torch.nn.Linear(hidden_channels, in_channels), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out(x) class PositionalEncoding(torch.nn.Module): def __init__(self, hidden_channels: int, max_relative_feature: int = 32) -> None: super().__init__() self.max_relative_feature = max_relative_feature self.emb = torch.nn.Embedding(2 * max_relative_feature + 2, hidden_channels) def forward(self, offset, mask) -> torch.Tensor: d = torch.clip(offset + self.max_relative_feature, 0, 2 * self.max_relative_feature) * mask + (1 - mask) * ( 2 * self.max_relative_feature + 1) # noqa: E501 return self.emb(d.long()) class Encoder(MessagePassing): def __init__( self, in_channels: int, hidden_channels: int, dropout: float = 0.1, scale: float = 30, ) -> None: super().__init__() self.out_v = torch.nn.Sequential( torch.nn.Linear(in_channels, hidden_channels), torch.nn.GELU(), torch.nn.Linear(hidden_channels, hidden_channels), torch.nn.GELU(), torch.nn.Linear(hidden_channels, hidden_channels), ) self.out_e = torch.nn.Sequential( torch.nn.Linear(in_channels, hidden_channels), torch.nn.GELU(), torch.nn.Linear(hidden_channels, hidden_channels), torch.nn.GELU(), torch.nn.Linear(hidden_channels, hidden_channels), ) self.dropout1 = torch.nn.Dropout(dropout) self.dropout2 = torch.nn.Dropout(dropout) self.dropout3 = torch.nn.Dropout(dropout) self.norm1 = torch.nn.LayerNorm(hidden_channels) self.norm2 = torch.nn.LayerNorm(hidden_channels) self.norm3 = torch.nn.LayerNorm(hidden_channels) self.scale = scale self.dense = PositionWiseFeedForward(hidden_channels, hidden_channels * 4) def forward( self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, ) -> torch.Tensor: # x: [N, d_v] # edge_index: [2, E] # edge_attr: [E, d_e] # update node features h_message = self.propagate(x=x, edge_index=edge_index, edge_attr=edge_attr) dh = h_message / self.scale x = self.norm1(x + self.dropout1(dh)) dh = self.dense(x) x = self.norm2(x + self.dropout2(dh)) # update edge features row, col = edge_index x_i, x_j = x[row], x[col] h_e = torch.cat([x_i, x_j, edge_attr], dim=-1) h_e = self.out_e(h_e) edge_attr = self.norm3(edge_attr + self.dropout3(h_e)) return x, edge_attr def message(self, x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: h = torch.cat([x_i, x_j, edge_attr], dim=-1) # [E, 2*d_v + d_e] h = self.out_e(h) # [E, d_e] return h class Decoder(MessagePassing): def __init__( self, in_channels: int, hidden_channels: int, dropout: float = 0.1, scale: float = 30, ) -> None: super().__init__() self.out_v = torch.nn.Sequential( torch.nn.Linear(in_channels, hidden_channels), torch.nn.GELU(), torch.nn.Linear(hidden_channels, hidden_channels), torch.nn.GELU(), torch.nn.Linear(hidden_channels, hidden_channels), ) self.dropout1 = torch.nn.Dropout(dropout) self.dropout2 = torch.nn.Dropout(dropout) self.norm1 = torch.nn.LayerNorm(hidden_channels) self.norm2 = torch.nn.LayerNorm(hidden_channels) self.scale = scale self.dense = PositionWiseFeedForward(hidden_channels, hidden_channels * 4) def forward( self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, x_label: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: # x: [N, d_v] # edge_index: [2, E] # edge_attr: [E, d_e] h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index, edge_attr=edge_attr, mask=mask) dh = h_message / self.scale x = self.norm1(x + self.dropout1(dh)) dh = self.dense(x) x = self.norm2(x + self.dropout2(dh)) return x def message(self, x_i: torch.Tensor, x_j: torch.Tensor, x_label_j: torch.Tensor, edge_attr: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1) h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1) h = h_1 * mask + h_0 * (1 - mask) h = torch.concat([x_i, h], dim=-1) h = self.out_v(h) return h class ProteinMPNN(torch.nn.Module): r"""The ProteinMPNN model from the `"Robust deep learning--based protein sequence design using ProteinMPNN" `_ paper. Args: hidden_dim (int): Hidden channels. (default: :obj:`128`) num_encoder_layers (int): Number of encode layers. (default: :obj:`3`) num_decoder_layers (int): Number of decode layers. (default: :obj:`3`) num_neighbors (int): Number of neighbors for each atom. (default: :obj:`30`) num_rbf (int): Number of radial basis functions. (default: :obj:`16`) dropout (float): Dropout rate. (default: :obj:`0.1`) augment_eps (float): Augmentation epsilon for input coordinates. (default: :obj:`0.2`) num_positional_embedding (int): Number of positional embeddings. (default: :obj:`16`) vocab_size (int): Number of vocabulary. (default: :obj:`21`) .. note:: For an example of using :class:`ProteinMPNN`, see `examples/llm/protein_mpnn.py `_. """ def __init__( self, hidden_dim: int = 128, num_encoder_layers: int = 3, num_decoder_layers: int = 3, num_neighbors: int = 30, num_rbf: int = 16, dropout: float = 0.1, augment_eps: float = 0.2, num_positional_embedding: int = 16, vocab_size: int = 21, ) -> None: super().__init__() self.augment_eps = augment_eps self.hidden_dim = hidden_dim self.num_neighbors = num_neighbors self.num_rbf = num_rbf self.embedding = PositionalEncoding(num_positional_embedding) self.edge_mlp = torch.nn.Sequential( torch.nn.Linear(num_positional_embedding + 400, hidden_dim), torch.nn.LayerNorm(hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim), ) self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim) self.encoder_layers = torch.nn.ModuleList([ Encoder(hidden_dim * 3, hidden_dim, dropout) for _ in range(num_encoder_layers) ]) self.decoder_layers = torch.nn.ModuleList([ Decoder(hidden_dim * 4, hidden_dim, dropout) for _ in range(num_decoder_layers) ]) self.output = torch.nn.Linear(hidden_dim, vocab_size) self.reset_parameters() def reset_parameters(self): for p in self.parameters(): if p.dim() > 1: torch.nn.init.xavier_uniform_(p) def _featurize( self, x: torch.Tensor, mask: torch.Tensor, batch: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: N, Ca, C, O = (x[:, i, :] for i in range(4)) # noqa: E741 b = Ca - N c = C - Ca a = torch.cross(b, c, dim=-1) Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca valid_mask = mask.bool() valid_Ca = Ca[valid_mask] valid_batch = batch[valid_mask] edge_index = knn_graph(valid_Ca, k=self.num_neighbors, batch=valid_batch, loop=True) row, col = edge_index original_indices = torch.arange(Ca.size(0), device=x.device)[valid_mask] edge_index_original = torch.stack( [original_indices[row], original_indices[col]], dim=0) row, col = edge_index_original rbf_all = [] for A, B in list(product([N, Ca, C, O, Cb], repeat=2)): distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6) rbf = self._rbf(distances) rbf_all.append(rbf) return edge_index_original, torch.cat(rbf_all, dim=-1) def _rbf(self, D: torch.Tensor) -> torch.Tensor: D_min, D_max, D_count = 2., 22., self.num_rbf D_mu = torch.linspace(D_min, D_max, D_count, device=D.device) D_mu = D_mu.view([1, -1]) D_sigma = (D_max - D_min) / D_count D_expand = torch.unsqueeze(D, -1) RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) return RBF def forward( self, x: torch.Tensor, chain_seq_label: torch.Tensor, mask: torch.Tensor, chain_mask_all: torch.Tensor, residue_idx: torch.Tensor, chain_encoding_all: torch.Tensor, batch: torch.Tensor, ) -> torch.Tensor: device = x.device if self.training and self.augment_eps > 0: x = x + self.augment_eps * torch.randn_like(x) edge_index, edge_attr = self._featurize(x, mask, batch) row, col = edge_index offset = residue_idx[row] - residue_idx[col] # find self vs non-self interaction e_chains = ((chain_encoding_all[row] - chain_encoding_all[col]) == 0).long() e_pos = self.embedding(offset, e_chains) h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1)) h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device) # encoder for encoder in self.encoder_layers: h_v, h_e = encoder(h_v, edge_index, h_e) # mask h_label = self.label_embedding(chain_seq_label) batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask, batch) # [B, N] # 0 - visible - encoder, 1 - masked - decoder decoding_order = torch.argsort( (batch_chain_mask_all + 1e-4) * (torch.abs( torch.randn(batch_chain_mask_all.shape, device=device)))) mask_size = batch_chain_mask_all.size(1) permutation_matrix_reverse = F.one_hot(decoding_order, num_classes=mask_size).float() order_mask_backward = torch.einsum( 'ij, biq, bjp->bqp', 1 - torch.triu(torch.ones(mask_size, mask_size, device=device)), permutation_matrix_reverse, permutation_matrix_reverse, ) adj = to_dense_adj(edge_index, batch) mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1) # decoder for decoder in self.decoder_layers: h_v = decoder( h_v, edge_index, h_e, h_label, mask_attend, ) logits = self.output(h_v) return F.log_softmax(logits, dim=-1) ================================================ FILE: torch_geometric/llm/models/sentence_transformer.py ================================================ from enum import Enum from typing import List, Optional, Union import torch import torch.nn.functional as F from torch import Tensor from tqdm import tqdm class PoolingStrategy(Enum): MEAN = 'mean' LAST = 'last' CLS = 'cls' LAST_HIDDEN_STATE = 'last_hidden_state' class SentenceTransformer(torch.nn.Module): r"""A wrapper around a Sentence-Transformer from HuggingFace. Args: model_name (str): The HuggingFace model name, *e.g.*, :obj:`"BERT"`. pooling_strategy (str, optional): The pooling strategy to use for generating node embeddings. (default: :obj:`"mean"`) """ def __init__( self, model_name: str, pooling_strategy: Union[PoolingStrategy, str] = 'mean', ) -> None: super().__init__() self.model_name = model_name self.pooling_strategy = PoolingStrategy(pooling_strategy) from transformers import AutoModel, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Maximum sequence length from the model configuration (e.g. 8192 for # models like ModernBERT) self.max_seq_length = self.model.config.max_position_embeddings """ Some models define a max sequence length in their configuration. Others only in the tokenizer. This is a hacky heuristic to find the max sequence length that works for the model. """ probe_tokens = self.tokenizer("hacky heuristic", padding='max_length', return_tensors='pt') self.max_seq_length = min(self.max_seq_length, probe_tokens.input_ids.shape[1]) def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: out = self.model(input_ids=input_ids, attention_mask=attention_mask) emb = out[0] # First element contains all token embeddings. if self.pooling_strategy == PoolingStrategy.MEAN: emb = mean_pooling(emb, attention_mask) elif self.pooling_strategy == PoolingStrategy.LAST: emb = last_pooling(emb, attention_mask) elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE: emb = out.last_hidden_state else: assert self.pooling_strategy == PoolingStrategy.CLS emb = emb[:, 0, :] emb = F.normalize(emb, p=2, dim=1) return emb def get_input_ids( self, text: List[str], batch_size: Optional[int] = None, output_device: Optional[Union[torch.device, str]] = None, ) -> Tensor: is_empty = len(text) == 0 text = ['dummy'] if is_empty else text batch_size = len(text) if batch_size is None else batch_size input_ids: List[Tensor] = [] attention_masks: List[Tensor] = [] for start in range(0, len(text), batch_size): token = self.tokenizer( text[start:start + batch_size], padding=True, truncation=True, return_tensors='pt', max_length=self.max_seq_length, ) input_ids.append(token.input_ids.to(self.device)) attention_masks.append(token.attention_mask.to(self.device)) def _out(x: List[Tensor]) -> Tensor: out = torch.cat(x, dim=0) if len(x) > 1 else x[0] out = out[:0] if is_empty else out return out.to(output_device) return _out(input_ids), _out(attention_masks) @property def device(self) -> torch.device: return next(iter(self.model.parameters())).device @torch.no_grad() def encode( self, text: List[str], batch_size: Optional[int] = None, output_device: Optional[Union[torch.device, str]] = None, verbose=False, ) -> Tensor: r"""Main function for users. Converts strings to embeddings. Args: text (List[str]): List of strings to embed. batch_size (int, optional): How many strings to process. Defaults to processing all at once, but this may lead to OOM errors. (default: obj:`None`) output_device (Union[torch.device, str], optional): By default outputs cpu pytorch tensor, but can choose to output to specific cuda devices. (default: obj:`None`) verbose (bool, optional): Controls the verbosity of outputs. (default: obj:`False`) """ is_empty = len(text) == 0 text = ['dummy'] if is_empty else text batch_size = len(text) if batch_size is None else batch_size embs: List[Tensor] = [] loader = range(0, len(text), batch_size) if verbose: loader = tqdm( loader, desc="Encoding " + str(len(text)) + " strings w/ SentenceTransformer") for start in loader: token = self.tokenizer( text[start:start + batch_size], padding=True, truncation=True, return_tensors='pt', max_length=self.max_seq_length, ) try: emb = self( input_ids=token.input_ids.to(self.device), attention_mask=token.attention_mask.to(self.device), ).to(output_device) embs.append(emb) except: # noqa # fallback to using CPU for huge strings that cause OOMs print("Sentence Transformer failed on cuda, trying w/ cpu...") previous_device = self.device self.model = self.model.to("cpu") emb = self( input_ids=token.input_ids.to(self.device), attention_mask=token.attention_mask.to(self.device), ).to(output_device) embs.append(emb) self.model = self.model.to(previous_device) out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0] out = out[:0] if is_empty else out return out def __repr__(self) -> str: return f'{self.__class__.__name__}(model_name={self.model_name})' def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor: mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype) return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor: # Check whether language model uses left padding, # which is always used for decoder LLMs left_padding = attention_mask[:, -1].sum() == attention_mask.size(0) if left_padding: return emb[:, -1] seq_indices = attention_mask.sum(dim=1) - 1 return emb[torch.arange(emb.size(0), device=emb.device), seq_indices] ================================================ FILE: torch_geometric/llm/models/txt2kg.py ================================================ import os import time from typing import List, Optional, Tuple, Union import torch import torch.multiprocessing as mp CLIENT_INITD = False CLIENT = None GLOBAL_NIM_KEY = "" SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Separate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph." # noqa MAX_OUTER_RETRIES = 5 # Maximum number of times the entire multiprocessing job is retried. # noqa RETRY_DELAY = 5 # Fixed sleep time (in seconds) between outer retries. MAX_NIM_RETRIES = 200 # Maximum number of attempts to call the NIM API inside one worker. # noqa BASE_DELAY = 0.5 # Initial wait time before retrying a failed network call. class TXT2KG(): """A class to convert text data into a Knowledge Graph (KG) format. Uses NVIDIA NIMs + Prompt engineering by default. Default model `nvidia/llama-3.1-nemotron-70b-instruct` is on par or better than GPT4o in benchmarks. We need a high quality model to ensure high quality KG. Otherwise we have garbage in garbage out for the rest of the GNN+LLM RAG pipeline. Use local_lm flag for local debugging/dev. You still need to be able to inference a 14B param LLM, 'VAGOsolutions/SauerkrautLM-v2-14b-DPO'. Smaller LLMs did not work at all in testing. Note this 14B model requires a considerable amount of GPU memory. See examples/llm/txt2kg_rag.py for an example. Args: NVIDIA_NIM_MODEL : str, optional The name of the NVIDIA NIM model to use. (default: "nvidia/llama-3.1-nemotron-70b-instruct"). NVIDIA_API_KEY : str, optional The API key for accessing NVIDIA's NIM models (default: ""). ENDPOINT_URL : str, optional The URL hosting your model, in case you are not using the public NIM. (default: "https://integrate.api.nvidia.com/v1"). local_LM : bool, optional A flag indicating whether a local Language Model (LM) should be used. This uses HuggingFace and will be slower than deploying your own private NIM endpoint. This flag is mainly recommended for dev/debug. (default: False). chunk_size : int, optional The size of the chunks in which the text data is processed (default: 512). """ def __init__( self, NVIDIA_NIM_MODEL: Optional[ str] = "nvidia/llama-3.1-nemotron-70b-instruct", NVIDIA_API_KEY: Optional[str] = "", ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1", local_LM: bool = False, chunk_size: int = 512, ) -> None: self.local_LM = local_LM # Initialize the local LM flag and the NIM model info accordingly if self.local_LM: # If using a local LM, set the initd_LM flag to False self.initd_LM = False else: # If not using a local LM, store the provided NIM model info self.NVIDIA_API_KEY = NVIDIA_API_KEY self.NIM_MODEL = NVIDIA_NIM_MODEL self.ENDPOINT_URL = ENDPOINT_URL # Set the chunk size for processing text data self.chunk_size = chunk_size # Initialize counters and storage for parsing results self.doc_id_counter = 0 self.relevant_triples = {} self.total_chars_parsed = 0 self.time_to_parse = 0.0 def save_kg(self, path: str) -> None: """Saves the relevant triples in the knowledge graph (KG) to a file. Args: path (str): The file path where the KG will be saved. Returns: None """ torch.save(self.relevant_triples, path) def _chunk_to_triples_str_local(self, txt: str) -> str: # call LLM on text chunk_start_time = time.time() if not self.initd_LM: from torch_geometric.llm.models import LLM LM_name = "VAGOsolutions/SauerkrautLM-v2-14b-DPO" self.model = LLM(LM_name).eval() self.initd_LM = True out_str = self.model.inference(question=[txt + '\n' + SYSTEM_PROMPT], max_tokens=self.chunk_size)[0] # for debug self.total_chars_parsed += len(txt) self.time_to_parse += round(time.time() - chunk_start_time, 2) self.avg_chars_parsed_per_sec = self.total_chars_parsed / ( self.time_to_parse + 1e-6) # noqa return out_str def add_doc_2_KG( self, txt: str, QA_pair: Optional[Tuple[str, str]] = None, ) -> None: """Add a document to the Knowledge Graph (KG). Args: txt (str): The text to extract triples from. QA_pair (Tuple[str, str]], optional): A QA pair to associate with the extracted triples. Useful for downstream evaluation. Returns: - None """ if not self.local_LM: # Ensure NVIDIA_API_KEY is set before proceeding assert self.NVIDIA_API_KEY != '', \ "Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm=True" if QA_pair: # QA_pairs should be unique keys, check if already exists in KG if QA_pair in self.relevant_triples.keys(): print("Warning: QA_Pair was already added to the set") print("Q=", QA_pair[0]) print("A=", QA_pair[1]) print("Previously parsed triples=", self.relevant_triples[QA_pair]) print("Skipping...") key = QA_pair else: # If no QA_pair, use the current doc_id_counter as the key key = self.doc_id_counter self.relevant_triples[key] = self._extract_relevant_triples(txt) # Increment the doc_id_counter for the next document self.doc_id_counter += 1 def _extract_relevant_triples( self, txt: str, max_retries: int = MAX_OUTER_RETRIES, retry_delay: float = RETRY_DELAY, ) -> List[Tuple[str, str, str]]: # Handle empty text (context-less QA pairs) if txt == "": return [] # Chunk the text into smaller pieces for processing chunks = _chunk_text(txt, chunk_size=self.chunk_size) if self.local_LM: # For debugging purposes... # process chunks sequentially on the local LM return _llm_then_python_parse(chunks, _parse_n_check_triples, self._chunk_to_triples_str_local) # Create deterministic chunk assignment import math num_procs = min(len(chunks), _get_num_procs()) chunk_size = math.ceil(len(chunks) / num_procs) in_chunks_per_proc = [ chunks[j * chunk_size:min((j + 1) * chunk_size, len(chunks))] for j in range(num_procs) ] # Run workers via starmap for deterministic ordering worker_args = [( rank, in_chunks_per_proc[rank], _parse_n_check_triples, _chunk_to_triples_str_cloud, self.NVIDIA_API_KEY, self.NIM_MODEL, self.ENDPOINT_URL, ) for rank in range(num_procs)] for attempt in range(max_retries): try: with mp.get_context("spawn").Pool(num_procs) as pool: results = pool.starmap(_multiproc_helper, worker_args) break # success except Exception as e: if attempt == max_retries - 1: raise # re-raise on final failure print(f"[Retry {attempt+1}/{max_retries}] " f"Multiprocessing failed: {e}") time.sleep(retry_delay) return _merge_triples_deterministically(results) known_reasoners = [ "llama-3.1-nemotron-ultra-253b-v1", "kimi-k2-instruct", "nemotron-super-49b-v1_5", "gpt-oss", ] def _chunk_to_triples_str_cloud( txt: str, GLOBAL_NIM_KEY='', NIM_MODEL="nvidia/llama-3.1-nemotron-ultra-253b-v1", ENDPOINT_URL="https://integrate.api.nvidia.com/v1", post_text=SYSTEM_PROMPT) -> str: global CLIENT_INITD if not CLIENT_INITD: # We use NIMs since most PyG users may not be able to run a 70B+ model try: from openai import OpenAI except ImportError: quit( "Failed to import `openai` package, please install it and rerun the script" # noqa ) global CLIENT CLIENT = OpenAI(base_url=ENDPOINT_URL, api_key=GLOBAL_NIM_KEY) CLIENT_INITD = True txt_input = txt if post_text != "": txt_input += '\n' + post_text messages = [] if any([model_name_str in NIM_MODEL for model_name_str in known_reasoners]): messages.append({"role": "system", "content": "detailed thinking on"}) messages.append({"role": "user", "content": txt_input}) completion = CLIENT.chat.completions.create(model=NIM_MODEL, messages=messages, temperature=0, top_p=1, max_tokens=1024, stream=True) out_str = "" for chunk in completion: if chunk.choices[0].delta.content is not None: out_str += chunk.choices[0].delta.content return out_str def _parse_n_check_triples(triples_str: str) -> List[Tuple[str, str, str]]: # use pythonic checks for triples processed = [] split_by_newline = triples_str.split("\n") # sometimes LLM fails to obey the prompt if len(split_by_newline) > 1: split_triples = split_by_newline llm_obeyed = True else: # handles form "(e, r, e) (e, r, e) ... (e, r, e)"" split_triples = triples_str[1:-1].split(") (") llm_obeyed = False for triple_str in split_triples: try: if llm_obeyed: # remove parenthesis and single quotes for parsing triple_str = triple_str.replace("(", "").replace(")", "").replace( "'", "") split_trip = triple_str.split(',') # remove blank space at beginning or end split_trip = [(i[1:] if i[0] == " " else i) for i in split_trip] split_trip = [(i[:-1].lower() if i[-1] == " " else i) for i in split_trip] potential_trip = tuple(split_trip) except: # noqa continue if 'tuple' in str(type(potential_trip)) and len( potential_trip ) == 3 and "note:" not in potential_trip[0].lower(): # additional check for empty node/edge attrs if potential_trip[0] != '' and potential_trip[ 1] != '' and potential_trip[2] != '': processed.append(potential_trip) return processed def _llm_then_python_parse(chunks, py_fn, llm_fn, **kwargs): relevant_triples = [] for chunk in chunks: relevant_triples += py_fn(llm_fn(chunk, **kwargs)) return relevant_triples def _multiproc_helper( rank, chunks_for_rank, py_fn, llm_fn, NIM_KEY, NIM_MODEL, ENDPOINT_URL, max_retries=MAX_NIM_RETRIES, base_delay=BASE_DELAY, ): for attempt in range(max_retries): try: return _llm_then_python_parse( chunks_for_rank, py_fn, llm_fn, GLOBAL_NIM_KEY=NIM_KEY, NIM_MODEL=NIM_MODEL, ENDPOINT_URL=ENDPOINT_URL, ) except Exception: # Optional: restrict to network-related exceptions only if attempt == max_retries - 1: raise # exponential backoff with jitter from random import uniform sleep_time = base_delay * (2**min(attempt, 6)) sleep_time += uniform(0, 0.1) time.sleep(sleep_time) def _get_num_procs(): num_proc = None if hasattr(os, "sched_getaffinity"): try: num_proc = len(os.sched_getaffinity(0)) / (2) except Exception: pass if num_proc is None: num_proc = os.cpu_count() / (2) return int(num_proc) def _chunk_text(text: str, chunk_size: int = 512) -> list[str]: """Function to chunk text into sentence-based segments. Co-authored with Claude AI. """ # If the input text is empty or None, return an empty list if not text: return [] # List of punctuation marks that typically end sentences sentence_endings = '.!?' # List to store the resulting chunks chunks = [] # Continue processing the entire text while text: # If the remaining text is shorter than chunk_size, add it and break if len(text) <= chunk_size: chunks.append(text.strip()) break # Start with the maximum possible chunk chunk = text[:chunk_size] # Try to find the last sentence ending within the chunk best_split = chunk_size for ending in sentence_endings: # Find the last occurrence of the ending punctuation last_ending = chunk.rfind(ending) if last_ending != -1: # Ensure we include the punctuation and any following space best_split = min( best_split, last_ending + 1 + (1 if last_ending + 1 < len(chunk) and chunk[last_ending + 1].isspace() else 0)) # Adjust to ensure we don't break words # If the next character is a letter, find the last space if best_split < len(text) and text[best_split].isalpha(): # Find the last space before the current split point space_split = text[:best_split].rfind(' ') if space_split != -1: best_split = space_split # Append the chunk, ensuring it's stripped chunks.append(text[:best_split].strip()) # Remove the processed part from the text text = text[best_split:].lstrip() return chunks Triple = Union[List[str], Tuple[str, ...]] def _merge_triples_deterministically( triples: List[List[Triple]]) -> List[Tuple[str, ...]]: """Flatten a list of lists of triples and return a deterministic, reproducible sorted list of tuples. Args: triples (List[List[Triple]]): A list of lists of triples, where each triple is a list or tuple of strings or other comparable values. Typically, each inner list comes from a worker. Returns: List[Tuple[str, ...]]: A flattened list of triples as tuples, sorted deterministically. Sorting is Unicode-safe and reproducible across Python versions using `str.casefold()`. Tuples are immutable to ensure hashability and stability in dicts/sets. """ # Flatten all sublists and convert inner lists to tuples flat_triples = [tuple(t) for sublist in triples for t in sublist] # Deterministic sort (Unicode-safe, casefold for strings) flat_triples.sort(key=lambda triple: tuple( s.casefold() if isinstance(s, str) else s for s in triple)) return flat_triples ================================================ FILE: torch_geometric/llm/models/vision_transformer.py ================================================ from typing import Optional, Union import torch from torch import Tensor class VisionTransformer(torch.nn.Module): r"""A wrapper around a Vision-Transformer from HuggingFace. Args: model_name (str): The HuggingFace model name, *e.g.*, :obj:`"ViT"`. """ def __init__( self, model_name: str, ) -> None: super().__init__() self.model_name = model_name from transformers import SwinConfig, SwinModel self.config = SwinConfig.from_pretrained(model_name) self.model = SwinModel(self.config) @torch.no_grad() def forward( self, images: Tensor, output_device: Optional[Union[torch.device, str]] = None, ) -> Tensor: return self.model(images).last_hidden_state.to(output_device) @property def device(self) -> torch.device: return next(iter(self.model.parameters())).device def __repr__(self) -> str: return f'{self.__class__.__name__}(model_name={self.model_name})' ================================================ FILE: torch_geometric/llm/rag_loader.py ================================================ from abc import abstractmethod from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union from torch_geometric.data import Data, FeatureStore, HeteroData from torch_geometric.llm.utils.vectorrag import VectorRetriever from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput from torch_geometric.typing import InputEdges, InputNodes class RAGFeatureStore(Protocol): """Feature store template for remote GNN RAG backend.""" @abstractmethod def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: """Makes a comparison between the query and all the nodes to get all the closest nodes. Return the indices of the nodes that are to be seeds for the RAG Sampler. """ ... @property @abstractmethod def config(self) -> Dict[str, Any]: """Get the config for the RAGFeatureStore.""" ... @config.setter @abstractmethod def config(self, config: Dict[str, Any]): """Set the config for the RAGFeatureStore.""" ... @abstractmethod def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges: """Makes a comparison between the query and all the edges to get all the closest nodes. Returns the edge indices that are to be the seeds for the RAG Sampler. """ ... @abstractmethod def load_subgraph( self, sample: Union[SamplerOutput, HeteroSamplerOutput] ) -> Union[Data, HeteroData]: """Combines sampled subgraph output with features in a Data object.""" ... class RAGGraphStore(Protocol): """Graph store template for remote GNN RAG backend.""" @abstractmethod def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: """Sample a subgraph using the seeded nodes and edges.""" ... @property @abstractmethod def config(self) -> Dict[str, Any]: """Get the config for the RAGGraphStore.""" ... @config.setter @abstractmethod def config(self, config: Dict[str, Any]): """Set the config for the RAGGraphStore.""" ... @abstractmethod def register_feature_store(self, feature_store: FeatureStore): """Register a feature store to be used with the sampler. Samplers need info from the feature store in order to work properly on HeteroGraphs. """ ... # TODO: Make compatible with Heterographs class RAGQueryLoader: """Loader meant for making RAG queries from a remote backend.""" def __init__(self, graph_data: Tuple[RAGFeatureStore, RAGGraphStore], subgraph_filter: Optional[Callable[[Data, Any], Data]] = None, augment_query: bool = False, vector_retriever: Optional[VectorRetriever] = None, config: Optional[Dict[str, Any]] = None): """Loader meant for making queries from a remote backend. Args: graph_data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore and GraphStore to load from. Assumed to conform to the protocols listed above. subgraph_filter (Optional[Callable[[Data, Any], Data]], optional): Optional local transform to apply to data after retrieval. Defaults to None. augment_query (bool, optional): Whether to augment the query with retrieved documents. Defaults to False. vector_retriever (Optional[VectorRetriever], optional): VectorRetriever to use for retrieving documents. Defaults to None. config (Optional[Dict[str, Any]], optional): Config to pass into the RAGQueryLoader. Defaults to None. """ fstore, gstore = graph_data self.vector_retriever = vector_retriever self.augment_query = augment_query self.feature_store = fstore self.graph_store = gstore self.graph_store.edge_index = self.graph_store.edge_index.contiguous() self.graph_store.register_feature_store(self.feature_store) self.subgraph_filter = subgraph_filter self.config = config def _propagate_config(self, config: Dict[str, Any]): """Propagate the config the relevant components.""" self.feature_store.config = config self.graph_store.config = config @property def config(self): """Get the config for the RAGQueryLoader.""" return self._config @config.setter def config(self, config: Dict[str, Any]): """Set the config for the RAGQueryLoader. Args: config (Dict[str, Any]): The config to set. """ self._propagate_config(config) self._config = config def query(self, query: Any) -> Data: """Retrieve a subgraph associated with the query with all its feature attributes. """ if self.vector_retriever: retrieved_docs = self.vector_retriever.query(query) if self.augment_query: query = [query] + retrieved_docs seed_nodes, query_enc = self.feature_store.retrieve_seed_nodes(query) subgraph_sample = self.graph_store.sample_subgraph(seed_nodes) data = self.feature_store.load_subgraph(sample=subgraph_sample) # apply local filter if self.subgraph_filter: data = self.subgraph_filter(data, query) if self.vector_retriever: data.text_context = retrieved_docs return data ================================================ FILE: torch_geometric/llm/utils/__init__.py ================================================ from .backend_utils import * # noqa from .feature_store import KNNRAGFeatureStore from .graph_store import NeighborSamplingRAGGraphStore from .vectorrag import DocumentRetriever __all__ = classes = [ 'KNNRAGFeatureStore', 'NeighborSamplingRAGGraphStore', 'DocumentRetriever', ] ================================================ FILE: torch_geometric/llm/utils/backend_utils.py ================================================ import os from dataclasses import dataclass from enum import Enum, auto from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Optional, Protocol, Tuple, Type, Union, no_type_check, runtime_checkable, ) import numpy as np import torch from torch import Tensor from torch.nn import Module from torch_geometric.data import Data, FeatureStore, GraphStore from torch_geometric.distributed import ( LocalFeatureStore, LocalGraphStore, Partitioner, ) from torch_geometric.llm.large_graph_indexer import ( EDGE_RELATION, LargeGraphIndexer, TripletLike, ) from torch_geometric.llm.models import SentenceTransformer from torch_geometric.typing import EdgeType, NodeType try: from pandas import DataFrame except ImportError: DataFrame = None RemoteGraphBackend = Tuple[FeatureStore, GraphStore] # TODO: Make everything compatible with Hetero graphs aswell def preprocess_triplet(triplet: TripletLike) -> TripletLike: h, r, t = triplet return str(h).lower(), str(r).lower(), str(t).lower() @no_type_check def retrieval_via_pcst( data: Data, q_emb: Tensor, textual_nodes: Any, textual_edges: Any, topk: int = 3, topk_e: int = 5, cost_e: float = 0.5, num_clusters: int = 1, ) -> Tuple[Data, str]: # skip PCST for bad graphs booly = data.edge_attr is None or data.edge_attr.numel() == 0 booly = booly or data.x is None or data.x.numel() == 0 booly = booly or data.edge_index is None or data.edge_index.numel() == 0 if not booly: c = 0.01 from pcst_fast import pcst_fast root = -1 pruning = 'gw' verbosity_level = 0 if topk > 0: n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x) topk = min(topk, data.num_nodes) _, topk_n_indices = torch.topk(n_prizes, topk, largest=True) n_prizes = torch.zeros_like(n_prizes) n_prizes[topk_n_indices] = torch.arange(topk, 0, -1, device=n_prizes.device, dtype=n_prizes.dtype) else: n_prizes = torch.zeros(data.num_nodes) if topk_e > 0: e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr) topk_e = min(topk_e, e_prizes.unique().size(0)) topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True) e_prizes[e_prizes < topk_e_values[-1]] = 0.0 last_topk_e_value = topk_e for k in range(topk_e): indices = e_prizes == topk_e_values[k] value = min((topk_e - k) / sum(indices), last_topk_e_value - c) e_prizes[indices] = value last_topk_e_value = value * (1 - c) # reduce the cost of the edges so that at least one edge is chosen cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2)) else: e_prizes = torch.zeros(data.num_edges) costs = [] edges = [] virtual_n_prizes = [] virtual_edges = [] virtual_costs = [] mapping_n = {} mapping_e = {} for i, (src, dst) in enumerate(data.edge_index.t().numpy()): prize_e = e_prizes[i] if prize_e <= cost_e: mapping_e[len(edges)] = i edges.append((src, dst)) costs.append(cost_e - prize_e) else: virtual_node_id = data.num_nodes + len(virtual_n_prizes) mapping_n[virtual_node_id] = i virtual_edges.append((src, virtual_node_id)) virtual_edges.append((virtual_node_id, dst)) virtual_costs.append(0) virtual_costs.append(0) virtual_n_prizes.append(prize_e - cost_e) prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)]) num_edges = len(edges) if len(virtual_costs) > 0: costs = np.array(costs + virtual_costs) edges = np.array(edges + virtual_edges) vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, pruning, verbosity_level) selected_nodes = vertices[vertices < data.num_nodes] selected_edges = [mapping_e[e] for e in edges if e < num_edges] virtual_vertices = vertices[vertices >= data.num_nodes] if len(virtual_vertices) > 0: virtual_vertices = vertices[vertices >= data.num_nodes] virtual_edges = [mapping_n[i] for i in virtual_vertices] selected_edges = np.array(selected_edges + virtual_edges) edge_index = data.edge_index[:, selected_edges] selected_nodes = np.unique( np.concatenate( [selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()])) n = textual_nodes.iloc[selected_nodes] e = textual_edges.iloc[selected_edges] else: n = textual_nodes e = textual_edges desc = n.to_csv(index=False) + '\n' + e.to_csv( index=False, columns=['src', 'edge_attr', 'dst']) if booly: return data, desc mapping = {n: i for i, n in enumerate(selected_nodes.tolist())} src = [mapping[i] for i in edge_index[0].tolist()] dst = [mapping[i] for i in edge_index[1].tolist()] # HACK Added so that the subset of nodes and edges selected can be tracked node_idx = np.array(data.node_idx)[selected_nodes] edge_idx = np.array(data.edge_idx)[selected_edges] data = Data( x=data.x[selected_nodes], edge_index=torch.tensor([src, dst]).to(torch.long), edge_attr=data.edge_attr[selected_edges], # HACK: track subset of selected nodes/edges node_idx=node_idx, edge_idx=edge_idx, ) return data, desc def batch_knn(query_enc: Tensor, embeds: Tensor, k: int) -> Iterator[Tuple[Tensor, Tensor]]: from torchmetrics.functional import pairwise_cosine_similarity prizes = pairwise_cosine_similarity(query_enc, embeds.to(query_enc.device)) topk = min(k, len(embeds)) for i, q in enumerate(prizes): _, indices = torch.topk(q, topk, largest=True) yield indices, query_enc[i].unsqueeze(0) # Adapted from LocalGraphStore @runtime_checkable class ConvertableGraphStore(Protocol): @classmethod def from_data( cls, edge_id: Tensor, edge_index: Tensor, num_nodes: int, is_sorted: bool = False, ) -> GraphStore: ... @classmethod def from_hetero_data( cls, edge_id_dict: Dict[EdgeType, Tensor], edge_index_dict: Dict[EdgeType, Tensor], num_nodes_dict: Dict[NodeType, int], is_sorted: bool = False, ) -> GraphStore: ... @classmethod def from_partition(cls, root: str, pid: int) -> GraphStore: ... # Adapted from LocalFeatureStore @runtime_checkable class ConvertableFeatureStore(Protocol): @classmethod def from_data( cls, node_id: Tensor, x: Optional[Tensor] = None, y: Optional[Tensor] = None, edge_id: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, ) -> FeatureStore: ... @classmethod def from_hetero_data( cls, node_id_dict: Dict[NodeType, Tensor], x_dict: Optional[Dict[NodeType, Tensor]] = None, y_dict: Optional[Dict[NodeType, Tensor]] = None, edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, ) -> FeatureStore: ... @classmethod def from_partition(cls, root: str, pid: int) -> FeatureStore: ... class RemoteDataType(Enum): DATA = auto() PARTITION = auto() @dataclass class RemoteGraphBackendLoader: """Utility class to load triplets into a RAG Backend.""" path: str datatype: RemoteDataType graph_store_type: Type[ConvertableGraphStore] feature_store_type: Type[ConvertableFeatureStore] def load(self, pid: Optional[int] = None) -> RemoteGraphBackend: if self.datatype == RemoteDataType.DATA: data_obj = torch.load(self.path, weights_only=False) # is_sorted=true since assume nodes come sorted from indexer graph_store = self.graph_store_type.from_data( edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index, num_nodes=data_obj.num_nodes, is_sorted=True) feature_store = self.feature_store_type.from_data( node_id=data_obj['node_id'], x=data_obj.x, edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr) elif self.datatype == RemoteDataType.PARTITION: if pid is None: assert pid is not None, \ "Partition ID must be defined for loading from a " \ + "partitioned store." graph_store = self.graph_store_type.from_partition(self.path, pid) feature_store = self.feature_store_type.from_partition( self.path, pid) else: raise NotImplementedError return (feature_store, graph_store) def __del__(self) -> None: if os.path.exists(self.path): os.remove(self.path) def create_graph_from_triples( triples: Iterable[TripletLike], embedding_model: Union[Module, Callable], embedding_method_kwargs: Optional[Dict[str, Any]] = None, pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, ) -> Data: """Utility function that can be used to create a graph from triples.""" # Resolve callable methods embedding_method_kwargs = embedding_method_kwargs \ if embedding_method_kwargs is not None else dict() indexer = LargeGraphIndexer.from_triplets(triples, pre_transform=pre_transform) node_feats = embedding_model(indexer.get_unique_node_features(), **embedding_method_kwargs) indexer.add_node_feature('x', node_feats) edge_feats = embedding_model( indexer.get_unique_edge_features(feature_name=EDGE_RELATION), **embedding_method_kwargs) indexer.add_edge_feature(new_feature_name="edge_attr", new_feature_vals=edge_feats, map_from_feature=EDGE_RELATION) data = indexer.to_data(node_feature_name='x', edge_feature_name='edge_attr') data = data.to("cpu") return data def create_remote_backend_from_graph_data( graph_data: Data, graph_db: Type[ConvertableGraphStore] = LocalGraphStore, feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore, path: str = '', n_parts: int = 1, ) -> RemoteGraphBackendLoader: """Utility function that can be used to create a RAG Backend from triples. Args: graph_data (Data): Graph data to load into the RAG Backend. graph_db (Type[ConvertableGraphStore], optional): GraphStore class to use. Defaults to LocalGraphStore. feature_db (Type[ConvertableFeatureStore], optional): FeatureStore class to use. Defaults to LocalFeatureStore. path (str, optional): path to save resulting stores. Defaults to ''. n_parts (int, optional): Number of partitons to store in. Defaults to 1. Returns: RemoteGraphBackendLoader: Loader to load RAG backend from disk or memory. """ # Will return attribute errors for missing attributes if not issubclass(graph_db, ConvertableGraphStore): _ = graph_db.from_data _ = graph_db.from_hetero_data _ = graph_db.from_partition elif not issubclass(feature_db, ConvertableFeatureStore): _ = feature_db.from_data _ = feature_db.from_hetero_data _ = feature_db.from_partition if n_parts == 1: torch.save(graph_data, path) return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db, feature_db) else: partitioner = Partitioner(data=graph_data, num_parts=n_parts, root=path) partitioner.generate_partition() return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION, graph_db, feature_db) def make_pcst_filter(triples: List[Tuple[str, str, str]], model: SentenceTransformer, topk: int = 5, topk_e: int = 5, cost_e: float = 0.5, num_clusters: int = 1) -> Callable[[Data, str], Data]: """Creates a PCST (Prize Collecting Tree) filter. :param triples: List of triples (head, relation, tail) representing KG data :param model: SentenceTransformer model for embedding text :param topk: Number of top-K results to return (default: 5) :param topk_e: Number of top-K entity results to return (default: 5) :param cost_e: Cost of edges (default: 0.5) :param num_clusters: Number of connected components in the PCST output. :return: PCST Filter function """ if DataFrame is None: raise Exception("PCST requires `pip install pandas`" ) # Check if pandas is installed # Remove duplicate triples to ensure unique set triples = list(dict.fromkeys(triples)) # Initialize empty list to store nodes (entities) from triples nodes = [] # Iterate over triples to extract unique nodes (entities) for h, _, t in triples: for node in (h, t): # Extract head and tail entities from each triple nodes.append(node) # Remove duplicates and create final list of unique nodes nodes = list(dict.fromkeys(nodes)) # Create full list of textual nodes (entities) for filtering full_textual_nodes = nodes def apply_retrieval_via_pcst( graph: Data, # Input graph data query: str, # Search query ) -> Data: """Applies PCST filtering for retrieval. :param graph: Input graph data :param query: Search query :return: Retrieved graph/query data """ # PCST relies on numpy and pcst_fast pypi libs, hence to("cpu") with torch.no_grad(): q_emb = model.encode([query]).to("cpu") textual_nodes = [(int(i), full_textual_nodes[i]) for i in graph["node_idx"]] textual_nodes = DataFrame(textual_nodes, columns=["node_id", "node_attr"]) textual_edges = [triples[i] for i in graph["edge_idx"]] textual_edges = DataFrame(textual_edges, columns=["src", "edge_attr", "dst"]) out_graph, desc = retrieval_via_pcst(graph.to(q_emb.device), q_emb, textual_nodes, textual_edges, topk=topk, topk_e=topk_e, cost_e=cost_e, num_clusters=num_clusters) out_graph["desc"] = desc where_trips_start = desc.find("src,edge_attr,dst") parsed_trips = [] for trip in desc[where_trips_start + 18:-1].split("\n"): parsed_trips.append(tuple(trip.split(","))) # Handle case where PCST returns an isolated node """ TODO find a better solution since these failed subgraphs severely hurt accuracy. """ if str(parsed_trips) == "[('',)]" or out_graph.edge_index.numel() == 0: out_graph["triples"] = [] else: out_graph["triples"] = parsed_trips out_graph["question"] = query return out_graph return apply_retrieval_via_pcst ================================================ FILE: torch_geometric/llm/utils/feature_store.py ================================================ import gc from collections.abc import Iterable, Iterator from typing import Any, Dict, List, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.distributed.local_feature_store import LocalFeatureStore from torch_geometric.llm.utils.backend_utils import batch_knn from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput from torch_geometric.typing import InputNodes # NOTE: Only compatible with Homogeneous graphs for now class KNNRAGFeatureStore(LocalFeatureStore): """A feature store that uses a KNN-based retrieval.""" def __init__(self) -> None: """Initializes the feature store.""" # to be set by the config self.encoder_model = None self.k_nodes = None self._config: Dict[str, Any] = {} super().__init__() @property def config(self) -> Dict[str, Any]: """Get the config for the feature store.""" return self._config def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None: """Set an attribute from the config. Args: config (Dict[str, Any]): Config dictionary attr_name (str): Name of attribute to set Raises: ValueError: If required attribute not found in config """ if attr_name not in config: raise ValueError( f"Required config parameter '{attr_name}' not found") setattr(self, attr_name, config[attr_name]) @config.setter # type: ignore def config(self, config: Dict[str, Any]) -> None: """Set the config for the feature store. Args: config (Dict[str, Any]): Config dictionary containing required parameters Raises: ValueError: If required parameters missing from config """ self._set_from_config(config, "k_nodes") self._set_from_config(config, "encoder_model") assert self.encoder_model is not None, \ "Need to define encoder model from config" self.encoder_model.eval() self._config = config @property def x(self) -> Tensor: """Returns the node features.""" return Tensor(self.get_tensor(group_name=None, attr_name='x')) @property def edge_attr(self) -> Tensor: """Returns the edge attributes.""" return Tensor( self.get_tensor(group_name=(None, None), attr_name='edge_attr')) def retrieve_seed_nodes( # noqa: D417 self, query: Union[str, List[str], Tuple[str]]) -> Tuple[InputNodes, Tensor]: """Retrieves the k_nodes most similar nodes to the given query. Args: query (Union[str, List[str], Tuple[str]]): The query or list of queries to search for. Returns: The indices of the most similar nodes and the encoded query """ if not isinstance(query, (list, tuple)): query = [query] assert self.k_nodes is not None, "please set k_nodes via config" if len(query) == 1: result, query_enc = next( self._retrieve_seed_nodes_batch(query, self.k_nodes)) gc.collect() torch.cuda.empty_cache() return result, query_enc else: out_dict = {} for i, out in enumerate( self._retrieve_seed_nodes_batch(query, self.k_nodes)): out_dict[query[i]] = out gc.collect() torch.cuda.empty_cache() return out_dict def _retrieve_seed_nodes_batch( # noqa: D417 self, query: Iterable[Any], k_nodes: int) -> Iterator[Tuple[InputNodes, Tensor]]: """Retrieves the k_nodes most similar nodes to each query in the batch. Args: - query (Iterable[Any]: The batch of queries to search for. - k_nodes (int): The number of nodes to retrieve. Yields: - The indices of the most similar nodes for each query. """ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): raise NotImplementedError assert self.encoder_model is not None, \ "Need to define encoder model from config" query_enc = self.encoder_model.encode(query) return batch_knn(query_enc, self.x, k_nodes) def load_subgraph( # noqa self, sample: Union[SamplerOutput, HeteroSamplerOutput], induced: bool = True, ) -> Union[Data, HeteroData]: """Loads a subgraph from the given sample. Args: sample: The sample to load the subgraph from. induced: Whether to return the induced subgraph. Resets node and edge ids. Returns: The loaded subgraph. """ if isinstance(sample, HeteroSamplerOutput): raise NotImplementedError """ NOTE: torch_geometric.loader.utils.filter_custom_store can be used here if it supported edge features. """ edge_id = sample.edge x = self.x[sample.node] edge_attr = self.edge_attr[edge_id] edge_idx = torch.stack( [sample.row, sample.col], dim=0) if induced else torch.stack( [sample.global_row, sample.global_col], dim=0) result = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx) # useful for tracking what subset of the graph was sampled result.node_idx = sample.node result.edge_idx = edge_id return result """ TODO: make class CuVSKNNRAGFeatureStore(KNNRAGFeatureStore) include a approximate knn flag for the CuVS. Connect this with a CuGraphGraphStore for enabling a accelerated boolean flag for RAGQueryLoader. On by default if CuGraph+CuVS avail. If not raise note mentioning its speedup. """ ================================================ FILE: torch_geometric/llm/utils/graph_store.py ================================================ from typing import Any, Dict, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import FeatureStore from torch_geometric.distributed.local_graph_store import LocalGraphStore from torch_geometric.sampler import ( BidirectionalNeighborSampler, NodeSamplerInput, SamplerOutput, ) from torch_geometric.utils import index_sort # A representation of an edge index, following the possible formats: # * default: Tensor, size = [2, num_edges] # * Tensor[0, :] == row, Tensor[1, :] == col # * COO: (row, col) # * CSC: (row, colptr) # * CSR: (rowptr, col) _EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]] class NeighborSamplingRAGGraphStore(LocalGraphStore): """Neighbor sampling based graph-store to store & retrieve graph data.""" def __init__( # type: ignore[no-untyped-def] self, feature_store: Optional[FeatureStore] = None, **kwargs, ): """Initializes the graph store. Optional feature store and neighbor sampling settings. Args: feature_store (optional): The feature store to use. None if not yet registered. **kwargs (optional): Additional keyword arguments for neighbor sampling. """ self.feature_store = feature_store self.sample_kwargs = kwargs self._sampler_is_initialized = False self._config: Dict[str, Any] = {} # to be set by the config self.num_neighbors = None super().__init__() @property def config(self) -> Dict[str, Any]: """Get the config for the feature store.""" return self._config def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None: """Set an attribute from the config. Args: config (Dict[str, Any]): Config dictionary attr_name (str): Name of attribute to set Raises: ValueError: If required attribute not found in config """ if attr_name not in config: raise ValueError( f"Required config parameter '{attr_name}' not found") setattr(self, attr_name, config[attr_name]) @config.setter # type: ignore def config(self, config: Dict[str, Any]) -> None: """Set the config for the feature store. Args: config (Dict[str, Any]): Config dictionary containing required parameters Raises: ValueError: If required parameters missing from config """ self._set_from_config(config, "num_neighbors") if hasattr(self, 'sampler'): self.sampler.num_neighbors = ( # type: ignore[has-type] self.num_neighbors) self._config = config def _init_sampler(self) -> None: """Initializes neighbor sampler with the registered feature store.""" if self.feature_store is None: raise AttributeError("Feature store not registered yet.") assert self.num_neighbors is not None, \ "Please set num_neighbors through config" self.sampler = BidirectionalNeighborSampler( data=(self.feature_store, self), num_neighbors=self.num_neighbors, **self.sample_kwargs) self._sampler_is_initialized = True def register_feature_store(self, feature_store: FeatureStore) -> None: """Registers a feature store with the graph store. :param feature_store: The feature store to register. """ self.feature_store = feature_store self._sampler_is_initialized = False def put_edge_id( # type: ignore[no-untyped-def] self, edge_id: Tensor, *args, **kwargs) -> bool: """Stores an edge ID in the graph store. :param edge_id: The edge ID to store. :return: Whether the operation was successful. """ ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs) self._sampler_is_initialized = False return ret @property def edge_index(self) -> _EdgeTensorType: """Gets the edge index of the graph. :return: The edge index as a tensor. """ return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs) def put_edge_index( # type: ignore[no-untyped-def] self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool: """Stores an edge index in the graph store. :param edge_index: The edge index to store. :return: Whether the operation was successful. """ ret = super().put_edge_index(edge_index, *args, **kwargs) # HACK self.edge_idx_args = args self.edge_idx_kwargs = kwargs self._sampler_is_initialized = False return ret # HACKY @edge_index.setter # type: ignore def edge_index(self, edge_index: _EdgeTensorType) -> None: """Sets the edge index of the graph. :param edge_index: The edge index to set. """ # correct since we make node list from triples if isinstance(edge_index, Tensor): num_nodes = int(edge_index.max()) + 1 else: assert isinstance(edge_index, tuple) \ and isinstance(edge_index[0], Tensor) \ and isinstance(edge_index[1], Tensor), \ "edge_index must be a Tensor of [2, num_edges] \ or a tuple of Tensors, (row, col)." num_nodes = int(edge_index[0].max()) + 1 attr = dict( edge_type=None, layout='coo', size=(num_nodes, num_nodes), is_sorted=False, ) # edge index needs to be sorted here and the perm saved for later col_sorted, self.perm = index_sort(edge_index[1], num_nodes, stable=True) row_sorted = edge_index[0][self.perm] edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0) self.put_edge_index(edge_index_sorted, **attr) def sample_subgraph( self, seed_nodes: Tensor, ) -> SamplerOutput: """Sample the graph starting from the given nodes using the in-built NeighborSampler. Args: seed_nodes (InputNodes): Seed nodes to start sampling from. num_neighbors (Optional[NumNeighborsType], optional): Parameters to determine how many hops and number of neighbors per hop. Defaults to None. Returns: Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput for the input. """ # TODO add support for Hetero if not self._sampler_is_initialized: self._init_sampler() seed_nodes = seed_nodes.unique().contiguous() node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes) out = self.sampler.sample_from_nodes( # type: ignore[has-type] node_sample_input) # edge ids need to be remapped to the original indices out.edge = self.perm[out.edge] return out ================================================ FILE: torch_geometric/llm/utils/vectorrag.py ================================================ # mypy: ignore-errors import os from abc import abstractmethod from typing import Any, Callable, Dict, List, Optional, Protocol, Union import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.llm.models import SentenceTransformer from torch_geometric.llm.utils.backend_utils import batch_knn class VectorRetriever(Protocol): """Protocol for VectorRAG.""" @abstractmethod def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data: """Retrieve a context for a given query.""" ... class DocumentRetriever(VectorRetriever): """Retrieve documents from a vector database.""" def __init__(self, raw_docs: List[str], embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2, model: Optional[Union[SentenceTransformer, torch.nn.Module, Callable]] = None, model_kwargs: Optional[Dict[str, Any]] = None): """Retrieve documents from a vector database. Args: raw_docs: List[str]: List of raw documents. embedded_docs: Optional[Tensor]: Embedded documents. k_for_docs: int: Number of documents to retrieve. model: Optional[Union[SentenceTransformer, torch.nn.Module]]: Model to use for encoding. model_kwargs: Optional[Dict[str, Any]]: Keyword arguments to pass to the model. """ self.raw_docs = raw_docs self.embedded_docs = embedded_docs self.k_for_docs = k_for_docs self.model = model if self.model is not None: self.encoder = self.model self.model_kwargs = model_kwargs if self.embedded_docs is None: assert self.model is not None, \ "Model must be provided if embedded_docs is not provided" self.model_kwargs = model_kwargs or {} self.embedded_docs = self.encoder(self.raw_docs, **self.model_kwargs) # we don't want to print the verbose output in `query` self.model_kwargs.pop("verbose", None) def query(self, query: Union[str, Tensor]) -> List[str]: """Retrieve documents from the vector database. Args: query: Union[str, Tensor]: Query to retrieve documents for. Returns: List[str]: Documents retrieved from the vector database. """ if isinstance(query, str): with torch.no_grad(): query_enc = self.encoder(query, **self.model_kwargs) else: query_enc = query selected_doc_idxs, _ = next( batch_knn(query_enc, self.embedded_docs, self.k_for_docs)) return [self.raw_docs[i] for i in selected_doc_idxs] def save(self, path: str) -> None: """Save the DocumentRetriever instance to disk. Args: path: str: Path where to save the retriever. """ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) # Prepare data to save save_dict = { 'raw_docs': self.raw_docs, 'embedded_docs': self.embedded_docs, 'k_for_docs': self.k_for_docs, } # We do not serialize the model torch.save(save_dict, path) @classmethod def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module, Callable], model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever: """Load a DocumentRetriever instance from disk. Args: path: str: Path to the saved retriever. model: Union[SentenceTransformer, torch.nn.Module, Callable]: Model to use for encoding. If None, the saved model will be used if available. model_kwargs: Optional[Dict[str, Any]] Key word args to be passed to model Returns: DocumentRetriever: The loaded retriever. """ if not os.path.exists(path): raise FileNotFoundError( f"No saved document retriever found at {path}") save_dict = torch.load(path, weights_only=False) if save_dict['embedded_docs'] is not None \ and isinstance(save_dict['embedded_docs'], Tensor)\ and model_kwargs is not None: model_kwargs.pop("verbose", None) # Create a new DocumentRetriever with the loaded data return cls(raw_docs=save_dict['raw_docs'], embedded_docs=save_dict['embedded_docs'], k_for_docs=save_dict['k_for_docs'], model=model, model_kwargs=model_kwargs) ================================================ FILE: torch_geometric/loader/__init__.py ================================================ from torch_geometric.deprecation import deprecated from .dataloader import DataLoader from .node_loader import NodeLoader from .link_loader import LinkLoader from .neighbor_loader import NeighborLoader from .link_neighbor_loader import LinkNeighborLoader from .hgt_loader import HGTLoader from .cluster import ClusterData, ClusterLoader from .graph_saint import (GraphSAINTSampler, GraphSAINTNodeSampler, GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler) from .shadow import ShaDowKHopSampler from .random_node_loader import RandomNodeLoader # from .ibmb_loader import IBMBBatchLoader, IBMBNodeLoader from .zip_loader import ZipLoader from .data_list_loader import DataListLoader from .dense_data_loader import DenseDataLoader from .temporal_dataloader import TemporalDataLoader from .neighbor_sampler import NeighborSampler from .imbalanced_sampler import ImbalancedSampler from .dynamic_batch_sampler import DynamicBatchSampler from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin __all__ = classes = [ 'DataLoader', 'NodeLoader', 'LinkLoader', 'NeighborLoader', 'LinkNeighborLoader', 'HGTLoader', 'ClusterData', 'ClusterLoader', 'GraphSAINTSampler', 'GraphSAINTNodeSampler', 'GraphSAINTEdgeSampler', 'GraphSAINTRandomWalkSampler', 'ShaDowKHopSampler', 'RandomNodeLoader', # 'IBMBBatchLoader', # 'IBMBNodeLoader', 'ZipLoader', 'DataListLoader', 'DenseDataLoader', 'TemporalDataLoader', 'NeighborSampler', 'ImbalancedSampler', 'DynamicBatchSampler', 'PrefetchLoader', 'CachedLoader', 'AffinityMixin', ] RandomNodeSampler = deprecated( details="use 'loader.RandomNodeLoader' instead", func_name='loader.RandomNodeSampler', )(RandomNodeLoader) ================================================ FILE: torch_geometric/loader/base.py ================================================ from typing import Any, Callable from torch.utils.data.dataloader import ( _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, ) class DataLoaderIterator: r"""A data loader iterator extended by a simple post transformation function :meth:`transform_fn`. While the iterator may request items from different sub-processes, :meth:`transform_fn` will always be executed in the main process. This iterator is used in PyG's sampler classes, and is responsible for feature fetching and filtering data objects after sampling has taken place in a sub-process. This has the following advantages: * We do not need to share feature matrices across processes which may prevent any errors due to too many open file handles. * We can execute any expensive post-processing commands on the main thread with full parallelization power (which usually executes faster). * It lets us naturally support data already being present on the GPU. """ def __init__(self, iterator: _BaseDataLoaderIter, transform_fn: Callable): self.iterator = iterator self.transform_fn = transform_fn def __iter__(self) -> 'DataLoaderIterator': return self def _reset(self, loader: Any, first_iter: bool = False): self.iterator._reset(loader, first_iter) def __len__(self) -> int: return len(self.iterator) def __next__(self) -> Any: return self.transform_fn(next(self.iterator)) def __del__(self) -> Any: if isinstance(self.iterator, _MultiProcessingDataLoaderIter): self.iterator.__del__() ================================================ FILE: torch_geometric/loader/cache.py ================================================ from collections.abc import Mapping from typing import Any, Callable, List, Optional, Sequence import torch from torch.utils.data import DataLoader def to_device(inputs: Any, device: Optional[torch.device] = None) -> Any: if hasattr(inputs, 'to'): return inputs.to(device) elif isinstance(inputs, Mapping): return {key: to_device(value, device) for key, value in inputs.items()} elif isinstance(inputs, tuple) and hasattr(inputs, '_fields'): return type(inputs)(*(to_device(s, device) for s in zip(*inputs))) elif isinstance(inputs, Sequence) and not isinstance(inputs, str): return [to_device(s, device) for s in zip(*inputs)] return inputs class CachedLoader: r"""A loader to cache mini-batch outputs, e.g., obtained during :class:`NeighborLoader` iterations. Args: loader (torch.utils.data.DataLoader): The data loader. device (torch.device, optional): The device to load the data to. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) """ def __init__( self, loader: DataLoader, device: Optional[torch.device] = None, transform: Optional[Callable] = None, ): self.loader = loader self.device = device self.transform = transform self._cache: List[Any] = [] def clear(self): r"""Clears the cache.""" self._cache = [] def __iter__(self) -> Any: if len(self._cache): for batch in self._cache: yield batch return for batch in self.loader: if self.transform is not None: batch = self.transform(batch) batch = to_device(batch, self.device) self._cache.append(batch) yield batch def __len__(self) -> int: return len(self.loader) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.loader})' ================================================ FILE: torch_geometric/loader/cluster.py ================================================ import copy import os import os.path as osp import sys from dataclasses import dataclass from typing import List, Literal, Optional import torch import torch.utils.data from torch import Tensor import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.index import index2ptr, ptr2index from torch_geometric.io import fs from torch_geometric.typing import pyg_lib from torch_geometric.utils import index_sort, narrow, select, sort_edge_index from torch_geometric.utils.map import map_index @dataclass class Partition: indptr: Tensor index: Tensor partptr: Tensor node_perm: Tensor edge_perm: Tensor sparse_format: Literal['csr', 'csc'] class ClusterData(torch.utils.data.Dataset): r"""Clusters/partitions a graph data object into multiple subgraphs, as motivated by the `"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" `_ paper. .. note:: The underlying METIS algorithm requires undirected graphs as input. Args: data (torch_geometric.data.Data): The graph data object. num_parts (int): The number of partitions. recursive (bool, optional): If set to :obj:`True`, will use multilevel recursive bisection instead of multilevel k-way partitioning. (default: :obj:`False`) save_dir (str, optional): If set, will save the partitioned data to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) filename (str, optional): Name of the stored partitioned file. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any progress. (default: :obj:`True`) keep_inter_cluster_edges (bool, optional): If set to :obj:`True`, will keep inter-cluster edge connections. (default: :obj:`False`) sparse_format (str, optional): The sparse format to use for computing partitions. (default: :obj:`"csr"`) """ def __init__( self, data, num_parts: int, recursive: bool = False, save_dir: Optional[str] = None, filename: Optional[str] = None, log: bool = True, keep_inter_cluster_edges: bool = False, sparse_format: Literal['csr', 'csc'] = 'csr', ): assert data.edge_index is not None assert sparse_format in ['csr', 'csc'] self.num_parts = num_parts self.recursive = recursive self.keep_inter_cluster_edges = keep_inter_cluster_edges self.sparse_format = sparse_format recursive_str = '_recursive' if recursive else '' root_dir = osp.join(save_dir or '', f'part_{num_parts}{recursive_str}') path = osp.join(root_dir, filename or 'metis.pt') if save_dir is not None and osp.exists(path): self.partition = fs.torch_load(path) else: if log: # pragma: no cover print('Computing METIS partitioning...', file=sys.stderr) cluster = self._metis(data.edge_index, data.num_nodes) self.partition = self._partition(data.edge_index, cluster) if save_dir is not None: os.makedirs(root_dir, exist_ok=True) torch.save(self.partition, path) if log: # pragma: no cover print('Done!', file=sys.stderr) self.data = self._permute_data(data, self.partition) def _metis(self, edge_index: Tensor, num_nodes: int) -> Tensor: # Computes a node-level partition assignment vector via METIS. if self.sparse_format == 'csr': # Calculate CSR representation: row, index = sort_edge_index(edge_index, num_nodes=num_nodes) indptr = index2ptr(row, size=num_nodes) else: # Calculate CSC representation: index, col = sort_edge_index(edge_index, num_nodes=num_nodes, sort_by_row=False) indptr = index2ptr(col, size=num_nodes) # Compute METIS partitioning: cluster: Optional[Tensor] = None if torch_geometric.typing.WITH_TORCH_SPARSE: try: cluster = torch.ops.torch_sparse.partition( indptr.cpu(), index.cpu(), None, self.num_parts, self.recursive, ).to(edge_index.device) except (AttributeError, RuntimeError): pass if cluster is None and torch_geometric.typing.WITH_METIS: cluster = pyg_lib.partition.metis( indptr.cpu(), index.cpu(), self.num_parts, recursive=self.recursive, ).to(edge_index.device) if cluster is None: raise ImportError(f"'{self.__class__.__name__}' requires either " f"'pyg-lib' or 'torch-sparse'") return cluster def _partition(self, edge_index: Tensor, cluster: Tensor) -> Partition: # Computes node-level and edge-level permutations and permutes the edge # connectivity accordingly: # Sort `cluster` and compute boundaries `partptr`: cluster, node_perm = index_sort(cluster, max_value=self.num_parts) partptr = index2ptr(cluster, size=self.num_parts) # Permute `edge_index` based on node permutation: edge_perm = torch.arange(edge_index.size(1), device=edge_index.device) arange = torch.empty_like(node_perm) arange[node_perm] = torch.arange(cluster.numel(), device=cluster.device) edge_index = arange[edge_index] # Compute final CSR representation: (row, col), edge_perm = sort_edge_index( edge_index, edge_attr=edge_perm, num_nodes=cluster.numel(), sort_by_row=self.sparse_format == 'csr', ) if self.sparse_format == 'csr': indptr, index = index2ptr(row, size=cluster.numel()), col else: indptr, index = index2ptr(col, size=cluster.numel()), row return Partition(indptr, index, partptr, node_perm, edge_perm, self.sparse_format) def _permute_data(self, data: Data, partition: Partition) -> Data: # Permute node-level and edge-level attributes according to the # calculated permutations in `Partition`: out = copy.copy(data) for key, value in data.items(): if key == 'edge_index': continue elif data.is_node_attr(key): cat_dim = data.__cat_dim__(key, value) out[key] = select(value, partition.node_perm, dim=cat_dim) elif data.is_edge_attr(key): cat_dim = data.__cat_dim__(key, value) out[key] = select(value, partition.edge_perm, dim=cat_dim) out.edge_index = None return out def __len__(self) -> int: return self.partition.partptr.numel() - 1 def __getitem__(self, idx: int) -> Data: node_start = int(self.partition.partptr[idx]) node_end = int(self.partition.partptr[idx + 1]) node_length = node_end - node_start indptr = self.partition.indptr[node_start:node_end + 1] edge_start = int(indptr[0]) edge_end = int(indptr[-1]) edge_length = edge_end - edge_start indptr = indptr - edge_start if self.sparse_format == 'csr': row = ptr2index(indptr) col = self.partition.index[edge_start:edge_end] if not self.keep_inter_cluster_edges: edge_mask = (col >= node_start) & (col < node_end) row = row[edge_mask] col = col[edge_mask] - node_start else: col = ptr2index(indptr) row = self.partition.index[edge_start:edge_end] if not self.keep_inter_cluster_edges: edge_mask = (row >= node_start) & (row < node_end) col = col[edge_mask] row = row[edge_mask] - node_start out = copy.copy(self.data) for key, value in self.data.items(): if key == 'num_nodes': out.num_nodes = node_length elif self.data.is_node_attr(key): cat_dim = self.data.__cat_dim__(key, value) out[key] = narrow(value, cat_dim, node_start, node_length) elif self.data.is_edge_attr(key): cat_dim = self.data.__cat_dim__(key, value) out[key] = narrow(value, cat_dim, edge_start, edge_length) if not self.keep_inter_cluster_edges: out[key] = out[key][edge_mask] out.edge_index = torch.stack([row, col], dim=0) return out def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num_parts})' class ClusterLoader(torch.utils.data.DataLoader): r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" `_ paper which merges partitioned subgraphs and their between-cluster links from a large-scale graph data object to form a mini-batch. .. note:: Use :class:`~torch_geometric.loader.ClusterData` and :class:`~torch_geometric.loader.ClusterLoader` in conjunction to form mini-batches of clusters. For an example of using Cluster-GCN, see `examples/cluster_gcn_reddit.py `_ or `examples/cluster_gcn_ppi.py `_. Args: cluster_data (torch_geometric.loader.ClusterData): The already partitioned data object. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__(self, cluster_data, **kwargs): self.cluster_data = cluster_data iterator = range(len(cluster_data)) super().__init__(iterator, collate_fn=self._collate, **kwargs) def _collate(self, batch: List[int]) -> Data: if not isinstance(batch, torch.Tensor): batch = torch.tensor(batch) global_indptr = self.cluster_data.partition.indptr global_index = self.cluster_data.partition.index # Get all node-level and edge-level start and end indices for the # current mini-batch: node_start = self.cluster_data.partition.partptr[batch] node_end = self.cluster_data.partition.partptr[batch + 1] edge_start = global_indptr[node_start] edge_end = global_indptr[node_end] # Iterate over each partition in the batch and calculate new edge # connectivity. This is done by slicing the corresponding source and # destination indices for each partition and adjusting their indices to # start from zero: rows, cols, nodes, cumsum = [], [], [], 0 for i in range(batch.numel()): nodes.append(torch.arange(node_start[i], node_end[i])) indptr = global_indptr[node_start[i]:node_end[i] + 1] indptr = indptr - edge_start[i] if self.cluster_data.partition.sparse_format == 'csr': row = ptr2index(indptr) + cumsum col = global_index[edge_start[i]:edge_end[i]] else: col = ptr2index(indptr) + cumsum row = global_index[edge_start[i]:edge_end[i]] rows.append(row) cols.append(col) cumsum += indptr.numel() - 1 node = torch.cat(nodes, dim=0) row = torch.cat(rows, dim=0) col = torch.cat(cols, dim=0) # Map `col` vector to valid entries and remove any entries that do not # connect two nodes within the same mini-batch: if self.cluster_data.partition.sparse_format == 'csr': col, edge_mask = map_index(col, node) row = row[edge_mask] else: row, edge_mask = map_index(row, node) col = col[edge_mask] out = copy.copy(self.cluster_data.data) # Slice node-level and edge-level attributes according to its offsets: for key, value in self.cluster_data.data.items(): if key == 'num_nodes': out.num_nodes = cumsum elif self.cluster_data.data.is_node_attr(key): cat_dim = self.cluster_data.data.__cat_dim__(key, value) out[key] = torch.cat([ narrow(out[key], cat_dim, s, e - s) for s, e in zip(node_start, node_end) ], dim=cat_dim) elif self.cluster_data.data.is_edge_attr(key): cat_dim = self.cluster_data.data.__cat_dim__(key, value) value = torch.cat([ narrow(out[key], cat_dim, s, e - s) for s, e in zip(edge_start, edge_end) ], dim=cat_dim) out[key] = select(value, edge_mask, dim=cat_dim) out.edge_index = torch.stack([row, col], dim=0) return out ================================================ FILE: torch_geometric/loader/data_list_loader.py ================================================ from typing import List, Union import torch from torch_geometric.data import Dataset from torch_geometric.data.data import BaseData def collate_fn(data_list): return data_list class DataListLoader(torch.utils.data.DataLoader): r"""A data loader which batches data objects from a :class:`torch_geometric.data.dataset` to a :python:`Python` list. Data objects can be either of type :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData`. .. note:: This data loader should be used for multi-GPU support via :class:`torch_geometric.nn.DataParallel`. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) shuffle (bool, optional): If set to :obj:`True`, the data will be reshuffled at every epoch. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`drop_last` or :obj:`num_workers`. """ def __init__(self, dataset: Union[Dataset, List[BaseData]], batch_size: int = 1, shuffle: bool = False, **kwargs): # Remove for PyTorch Lightning: kwargs.pop('collate_fn', None) super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, **kwargs) ================================================ FILE: torch_geometric/loader/dataloader.py ================================================ from collections.abc import Mapping from typing import Any, List, Optional, Sequence, Union import torch.utils.data from torch.utils.data.dataloader import default_collate from torch_geometric.data import Batch, Dataset from torch_geometric.data.data import BaseData from torch_geometric.data.datapipes import DatasetAdapter from torch_geometric.typing import TensorFrame, torch_frame class Collater: def __init__( self, dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, ): self.dataset = dataset self.follow_batch = follow_batch self.exclude_keys = exclude_keys def __call__(self, batch: List[Any]) -> Any: elem = batch[0] if isinstance(elem, BaseData): return Batch.from_data_list( batch, follow_batch=self.follow_batch, exclude_keys=self.exclude_keys, ) elif isinstance(elem, torch.Tensor): return default_collate(batch) elif isinstance(elem, TensorFrame): return torch_frame.cat(batch, dim=0) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, str): return batch elif isinstance(elem, Mapping): return {key: self([data[key] for data in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): return type(elem)(*(self(s) for s in zip(*batch))) elif isinstance(elem, Sequence) and not isinstance(elem, str): return [self(s) for s in zip(*batch)] raise TypeError(f"DataLoader found invalid type: '{type(elem)}'") class DataLoader(torch.utils.data.DataLoader): r"""A data loader which merges data objects from a :class:`torch_geometric.data.Dataset` to a mini-batch. Data objects can be either of type :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData`. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) shuffle (bool, optional): If set to :obj:`True`, the data will be reshuffled at every epoch. (default: :obj:`False`) follow_batch (List[str], optional): Creates assignment batch vectors for each key in the list. (default: :obj:`None`) exclude_keys (List[str], optional): Will exclude each key in the list. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`. """ def __init__( self, dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], batch_size: int = 1, shuffle: bool = False, follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, **kwargs, ): # Remove for PyTorch Lightning: kwargs.pop('collate_fn', None) # Save for PyTorch Lightning < 1.6: self.follow_batch = follow_batch self.exclude_keys = exclude_keys super().__init__( dataset, batch_size, shuffle, collate_fn=Collater(dataset, follow_batch, exclude_keys), **kwargs, ) ================================================ FILE: torch_geometric/loader/dense_data_loader.py ================================================ from typing import List, Union import torch from torch.utils.data.dataloader import default_collate from torch_geometric.data import Batch, Data, Dataset def collate_fn(data_list: List[Data]) -> Batch: batch = Batch() for key in data_list[0].keys(): batch[key] = default_collate([data[key] for data in data_list]) return batch class DenseDataLoader(torch.utils.data.DataLoader): r"""A data loader which batches data objects from a :class:`torch_geometric.data.dataset` to a :class:`torch_geometric.data.Batch` object by stacking all attributes in a new dimension. .. note:: To make use of this data loader, all graph attributes in the dataset need to have the same shape. In particular, this data loader should only be used when working with *dense* adjacency matrices. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) shuffle (bool, optional): If set to :obj:`True`, the data will be reshuffled at every epoch. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`drop_last` or :obj:`num_workers`. """ def __init__(self, dataset: Union[Dataset, List[Data]], batch_size: int = 1, shuffle: bool = False, **kwargs): # Remove for PyTorch Lightning: kwargs.pop('collate_fn', None) super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, **kwargs) ================================================ FILE: torch_geometric/loader/dynamic_batch_sampler.py ================================================ from typing import Iterator, List, Optional import torch from torch_geometric.data import Dataset class DynamicBatchSampler(torch.utils.data.sampler.Sampler): r"""Dynamically adds samples to a mini-batch up to a maximum size (either based on number of nodes or number of edges). When data samples have a wide range in sizes, specifying a mini-batch size in terms of number of samples is not ideal and can cause CUDA OOM errors. Within the :class:`DynamicBatchSampler`, the number of steps per epoch is ambiguous, depending on the order of the samples. By default the :meth:`__len__` will be undefined. This is fine for most cases but progress bars will be infinite. Alternatively, :obj:`num_steps` can be supplied to cap the number of mini-batches produced by the sampler. .. code-block:: python from torch_geometric.loader import DataLoader, DynamicBatchSampler sampler = DynamicBatchSampler(dataset, max_num=10000, mode="node") loader = DataLoader(dataset, batch_sampler=sampler, ...) Args: dataset (Dataset): Dataset to sample from. max_num (int): Size of mini-batch to aim for in number of nodes or edges. mode (str, optional): :obj:`"node"` or :obj:`"edge"` to measure batch size. (default: :obj:`"node"`) shuffle (bool, optional): If set to :obj:`True`, will have the data reshuffled at every epoch. (default: :obj:`False`) skip_too_big (bool, optional): If set to :obj:`True`, skip samples which cannot fit in a batch by itself. (default: :obj:`False`) num_steps (int, optional): The number of mini-batches to draw for a single epoch. If set to :obj:`None`, will iterate through all the underlying examples, but :meth:`__len__` will be :obj:`None` since it is ambiguous. (default: :obj:`None`) """ def __init__( self, dataset: Dataset, max_num: int, mode: str = 'node', shuffle: bool = False, skip_too_big: bool = False, num_steps: Optional[int] = None, ): if max_num <= 0: raise ValueError(f"`max_num` should be a positive integer value " f"(got {max_num})") if mode not in ['node', 'edge']: raise ValueError(f"`mode` choice should be either " f"'node' or 'edge' (got '{mode}')") self.dataset = dataset self.max_num = max_num self.mode = mode self.shuffle = shuffle self.skip_too_big = skip_too_big self.num_steps = num_steps self.max_steps = num_steps or len(dataset) def __iter__(self) -> Iterator[List[int]]: if self.shuffle: indices = torch.randperm(len(self.dataset)).tolist() else: indices = range(len(self.dataset)) samples: List[int] = [] current_num: int = 0 num_steps: int = 0 num_processed: int = 0 while (num_processed < len(self.dataset) and num_steps < self.max_steps): for i in indices[num_processed:]: data = self.dataset[i] num = data.num_nodes if self.mode == 'node' else data.num_edges if current_num + num > self.max_num: if current_num == 0: if self.skip_too_big: continue else: # Mini-batch filled: break samples.append(i) num_processed += 1 current_num += num yield samples samples: List[int] = [] current_num = 0 num_steps += 1 def __len__(self) -> int: if self.num_steps is None: raise ValueError(f"The length of '{self.__class__.__name__}' is " f"undefined since the number of steps per epoch " f"is ambiguous. Either specify `num_steps` or " f"use a static batch sampler.") return self.num_steps ================================================ FILE: torch_geometric/loader/graph_saint.py ================================================ import os.path as osp from typing import Optional import torch from tqdm import tqdm from torch_geometric.io import fs from torch_geometric.typing import SparseTensor class GraphSAINTSampler(torch.utils.data.DataLoader): r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" `_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:`~torch_geometric.loader.GraphSAINTNodeSampler`, :class:`~torch_geometric.loader.GraphSAINTEdgeSampler` and :class:`~torch_geometric.loader.GraphSAINTRandomWalkSampler` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/graph_saint.py `_. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The approximate number of samples per batch. num_steps (int, optional): The number of iterations per epoch. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`0`) save_dir (str, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any pre-processing progress. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or :obj:`num_workers`. """ def __init__(self, data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): # Remove for PyTorch Lightning: kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data assert not data.edge_index.is_cuda self.num_steps = num_steps self._batch_size = batch_size self.sample_coverage = sample_coverage self.save_dir = save_dir self.log = log self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor( row=data.edge_index[0], col=data.edge_index[1], value=torch.arange(self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = data super().__init__(self, batch_size=1, collate_fn=self._collate, **kwargs) if self.sample_coverage > 0: path = osp.join(save_dir or '', self._filename) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = fs.torch_load(path) else: self.node_norm, self.edge_norm = self._compute_norm() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) @property def _filename(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __len__(self): return self.num_steps def _sample_nodes(self, batch_size): raise NotImplementedError def __getitem__(self, idx): node_idx = self._sample_nodes(self._batch_size).unique() adj, _ = self.adj.saint_subgraph(node_idx) return node_idx, adj def _collate(self, data_list): assert len(data_list) == 1 node_idx, adj = data_list[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if key in ['edge_index', 'num_nodes']: continue if isinstance(item, torch.Tensor) and item.size(0) == self.N: data[key] = item[node_idx] elif isinstance(item, torch.Tensor) and item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item if self.sample_coverage > 0: data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def _compute_norm(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) loader = torch.utils.data.DataLoader(self, batch_size=200, collate_fn=lambda x: x, num_workers=self.num_workers) if self.log: # pragma: no cover pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') num_samples = total_sampled_nodes = 0 while total_sampled_nodes < self.N * self.sample_coverage: for data in loader: for node_idx, adj in data: edge_idx = adj.storage.value() node_count[node_idx] += 1 edge_count[edge_idx] += 1 total_sampled_nodes += node_idx.size(0) if self.log: # pragma: no cover pbar.update(node_idx.size(0)) num_samples += self.num_steps if self.log: # pragma: no cover pbar.close() row, _, edge_idx = self.adj.coo() t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row]) edge_norm = (t / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm class GraphSAINTNodeSampler(GraphSAINTSampler): r"""The GraphSAINT node sampler class (see :class:`~torch_geometric.loader.GraphSAINTSampler`). """ def _sample_nodes(self, batch_size): edge_sample = torch.randint(0, self.E, (batch_size, self.batch_size), dtype=torch.long) return self.adj.storage.row()[edge_sample] class GraphSAINTEdgeSampler(GraphSAINTSampler): r"""The GraphSAINT edge sampler class (see :class:`~torch_geometric.loader.GraphSAINTSampler`). """ def _sample_nodes(self, batch_size): row, col, _ = self.adj.coo() deg_in = 1. / self.adj.storage.colcount() deg_out = 1. / self.adj.storage.rowcount() prob = (1. / deg_in[row]) + (1. / deg_out[col]) # Parallel multinomial sampling (without replacement) # https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 rand = torch.rand(batch_size, self.E).log() / (prob + 1e-10) edge_sample = rand.topk(self.batch_size, dim=-1).indices source_node_sample = col[edge_sample] target_node_sample = row[edge_sample] return torch.cat([source_node_sample, target_node_sample], -1) class GraphSAINTRandomWalkSampler(GraphSAINTSampler): r"""The GraphSAINT random walk sampler class (see :class:`~torch_geometric.loader.GraphSAINTSampler`). Args: walk_length (int): The length of each random walk. """ def __init__(self, data, batch_size: int, walk_length: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): self.walk_length = walk_length super().__init__(data, batch_size, num_steps, sample_coverage, save_dir, log, **kwargs) @property def _filename(self): return (f'{self.__class__.__name__.lower()}_{self.walk_length}_' f'{self.sample_coverage}.pt') def _sample_nodes(self, batch_size): start = torch.randint(0, self.N, (batch_size, ), dtype=torch.long) node_idx = self.adj.random_walk(start.flatten(), self.walk_length) return node_idx.view(-1) ================================================ FILE: torch_geometric/loader/hgt_loader.py ================================================ from typing import Callable, Dict, List, Optional, Tuple, Union from torch import Tensor from torch_geometric.data import FeatureStore, GraphStore, HeteroData from torch_geometric.loader import NodeLoader from torch_geometric.sampler import HGTSampler from torch_geometric.typing import NodeType class HGTLoader(NodeLoader): r"""The Heterogeneous Graph Sampler from the `"Heterogeneous Graph Transformer" `_ paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. :class:`~torch_geometric.data.HGTLoader` tries to (1) keep a similar number of nodes and edges for each type and (2) keep the sampled sub-graph dense to minimize the information loss and reduce the sample variance. Methodically, :class:`~torch_geometric.data.HGTLoader` keeps track of a node budget for each node type, which is then used to determine the sampling probability of a node. In particular, the probability of sampling a node is determined by the number of connections to already sampled nodes and their node degrees. With this, :class:`~torch_geometric.data.HGTLoader` will sample a fixed amount of neighbors for each node type in each iteration, as given by the :obj:`num_samples` argument. Sampled nodes are sorted based on the order in which they were sampled. In particular, the first :obj:`batch_size` nodes represent the set of original mini-batch nodes. .. note:: For an example of using :class:`~torch_geometric.data.HGTLoader`, see `examples/hetero/to_hetero_mag.py `_. .. code-block:: python from torch_geometric.loader import HGTLoader from torch_geometric.datasets import OGB_MAG hetero_data = OGB_MAG(path)[0] loader = HGTLoader( hetero_data, # Sample 512 nodes per type and per iteration for 4 iterations num_samples={key: [512] * 4 for key in hetero_data.node_types}, # Use a batch size of 128 for sampling training nodes of type paper batch_size=128, input_nodes=('paper', hetero_data['paper'].train_mask), ) sampled_hetero_data = next(iter(loader)) print(sampled_data.batch_size) >>> 128 Args: data (Any): A :class:`~torch_geometric.data.Data`, :class:`~torch_geometric.data.HeteroData`, or (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. num_samples (List[int] or Dict[str, List[int]]): The number of nodes to sample in each iteration and for each node type. If given as a list, will sample the same amount of nodes for each node type. input_nodes (str or Tuple[str, torch.Tensor]): The indices of nodes for which neighbors are sampled to create mini-batches. Needs to be passed as a tuple that holds the node type and corresponding node indices. Node indices need to be either given as a :obj:`torch.LongTensor` or :obj:`torch.BoolTensor`. If node indices are set to :obj:`None`, all nodes of this specific type will be considered. transform (callable, optional): A function/transform that takes in an a sampled mini-batch and returns a transformed version. (default: :obj:`None`) transform_sampler_output (callable, optional): A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and returns a transformed version. (default: :obj:`None`) is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[HeteroData, Tuple[FeatureStore, GraphStore]], num_samples: Union[List[int], Dict[NodeType, List[int]]], input_nodes: Union[NodeType, Tuple[NodeType, Optional[Tensor]]], is_sorted: bool = False, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: Optional[bool] = None, **kwargs, ): hgt_sampler = HGTSampler( data, num_samples=num_samples, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, ) super().__init__( data=data, node_sampler=hgt_sampler, input_nodes=input_nodes, transform=transform, transform_sampler_output=transform_sampler_output, filter_per_worker=filter_per_worker, **kwargs, ) ================================================ FILE: torch_geometric/loader/ibmb_loader.py ================================================ import logging import math from typing import ( Any, Callable, Iterator, List, NamedTuple, Optional, Tuple, Union, ) import numpy as np import torch from torch import Tensor from tqdm import tqdm from torch_geometric.data import Data from torch_geometric.typing import SparseTensor from torch_geometric.utils import get_ppr, is_undirected, subgraph try: import numba WITH_NUMBA = True except ImportError: # pragma: no cover WITH_NUMBA = False class OutputNodes(NamedTuple): seed_id: Tensor auxiliary_id: Tensor class _IBMBBaseLoader(torch.utils.data.DataLoader): def __init__(self, data: Data, **kwargs): kwargs.pop('collate_fn', None) batch_size = kwargs.get('batch_size', 1) output_nodes = self.get_output_nodes(self) if batch_size == 1: # Pre-process subgraphs: data_list = ... super().__init__(data_list, collate_fn=self._cache_fn, **kwargs) else: self.data = data super().__init__(output_nodes, collate_fn=self._collate_fn, **kwargs) def get_output_nodes(self) -> List[OutputNodes]: raise NotImplementedError def _cache_fn(self, data_list: List[Data]) -> Data: assert len(data_list) == 1 return data_list[0] def _collate_fn(self, output_nodes: List[OutputNodes]) -> Data: raise NotImplementedError def __repr__(self) -> str: return f'{self.__class__.__name__}()' ############################################################################### def get_partitions( edge_index: Union[Tensor, SparseTensor], num_partitions: int, indices: Tensor, num_nodes: int, output_weight: Optional[float] = None, ) -> List[Tensor]: assert isinstance( edge_index, (torch.LongTensor, SparseTensor)), f'Unsupported edge_index type {type(edge_index)}' if isinstance(edge_index, torch.LongTensor): edge_index = SparseTensor.from_edge_index( edge_index, sparse_sizes=(num_nodes, num_nodes)) if output_weight is not None and output_weight != 1: node_weight = torch.ones(num_nodes) node_weight[indices] = output_weight else: node_weight = None _, partptr, perm = edge_index.partition(num_parts=num_partitions, recursive=False, weighted=False, node_weight=node_weight) partitions = [] for i in range(len(partptr) - 1): partitions.append(perm[partptr[i]:partptr[i + 1]]) return partitions def get_pair_wise_distance( ys: List, num_classes: int, dist_type: str = 'kl', ) -> np.ndarray: num_batches = len(ys) counts = np.zeros((num_batches, num_classes), dtype=np.int32) for i in range(num_batches): unique, count = np.unique(ys[i], return_counts=True) counts[i, unique] = count counts += 1 counts = counts / counts.sum(1).reshape(-1, 1) pairwise_dist = np.zeros((num_batches, num_batches), dtype=np.float64) for i in range(0, num_batches - 1): for j in range(i + 1, num_batches): if dist_type == 'l1': pairwise_dist[i, j] = np.sum(np.abs(counts[i] - counts[j])) elif dist_type == 'kl': def kl_divergence(p: np.ndarray, q: np.ndarray): return (p * np.log(p / q)).sum() pairwise_dist[i, j] = kl_divergence(counts[i], counts[j]) + kl_divergence( counts[j], counts[i]) else: raise ValueError pairwise_dist += pairwise_dist.T pairwise_dist += 1e-5 # for numerical stability np.fill_diagonal(pairwise_dist, 0.) return pairwise_dist def indices_complete_check( loader: List[Tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]], output_indices: Union[Tensor, np.ndarray], ): if isinstance(output_indices, Tensor): output_indices = output_indices.cpu().numpy() outs = [] for out, aux in loader: if isinstance(out, Tensor): out = out.cpu().numpy() if isinstance(aux, Tensor): aux = aux.cpu().numpy() assert np.all(np.isin(out, aux)), "Not all output nodes are in aux nodes!" outs.append(out) outs = np.sort(np.concatenate(outs)) assert np.all( outs == np.sort(output_indices)), "Output nodes missing or duplicate!" def get_subgraph( out_indices: Tensor, graph: Data, return_edge_index_type: str, adj: SparseTensor, **kwargs, ): if return_edge_index_type == 'adj': assert adj is not None if return_edge_index_type == 'adj': subg = Data(x=graph.x[out_indices], y=graph.y[out_indices], edge_index=adj[out_indices, :][:, out_indices]) elif return_edge_index_type == 'edge_index': edge_index, edge_attr = subgraph(out_indices, graph.edge_index, graph.edge_attr, relabel_nodes=True, num_nodes=graph.num_nodes, return_edge_mask=False) subg = Data(x=graph.x[out_indices], y=graph.y[out_indices], edge_index=edge_index, edge_attr=edge_attr) else: raise NotImplementedError for k, v in kwargs.items(): subg[k] = v return subg def define_sampler( batch_order: str, ys: List[Union[Tensor, np.ndarray, List]], num_classes: int, dist_type: str = 'kl', ): if batch_order == 'rand': logging.info("Running with random order") sampler = torch.utils.data.RandomSampler(ys) elif batch_order in ['order', 'sample']: kl_div = get_pair_wise_distance(ys, num_classes, dist_type=dist_type) if batch_order == 'order': from python_tsp.heuristics import solve_tsp_simulated_annealing best_perm, _ = solve_tsp_simulated_annealing(kl_div) logging.info(f"Running with given order: {best_perm}") sampler = IBMBOrderedSampler(best_perm) else: logging.info("Running with weighted sampling") sampler = IBMBWeightedSampler(kl_div) else: raise ValueError return sampler def create_batchwise_out_aux_pairs( adj: SparseTensor, partitions: List[Union[torch.LongTensor, np.ndarray]], prime_indices: Union[torch.LongTensor, np.ndarray], topk: int, num_outnodeset_per_batch: int = 50, alpha: float = 0.2, ppr_iterations: int = 50, ) -> List[Tuple[np.ndarray, np.ndarray]]: def ppr_power_method( adj: SparseTensor, batch: List[Union[np.ndarray, torch.LongTensor]], topk: int, num_iter: int, alpha: float, ) -> List[np.ndarray]: topk_neighbors = [] logits = torch.zeros( adj.size(0), len(batch), device=adj.device()) # each column contains a set of output nodes for i, tele_set in enumerate(batch): logits[tele_set, i] = 1. / len(tele_set) new_logits = logits.clone() for _ in range(num_iter): new_logits = adj @ new_logits * (1 - alpha) + alpha * logits inds = new_logits.argsort(0) nonzeros = (new_logits > 0).sum(0) nonzeros = torch.minimum( nonzeros, torch.tensor([topk], dtype=torch.int64, device=adj.device())) for i in range(new_logits.shape[1]): topk_neighbors.append(inds[-nonzeros[i]:, i].cpu().numpy()) return topk_neighbors device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if isinstance(prime_indices, Tensor): prime_indices = prime_indices.cpu().numpy() adj = adj.to(device) cur_output_nodes = [] loader = [] pbar = tqdm(range(len(partitions))) pbar.set_description("Processing topic-sensitive PPR batches") for n in pbar: part = partitions[n] if isinstance(part, Tensor): part = part.cpu().numpy() primes_in_part, *_ = np.intersect1d(part, prime_indices, assume_unique=True, return_indices=True) if len(primes_in_part): # no output nodes in this partition cur_output_nodes.append(primes_in_part) # accumulate enough output nodes to make good use of GPU memory if len(cur_output_nodes ) >= num_outnodeset_per_batch or n == len(partitions) - 1: topk_neighbors = ppr_power_method(adj, cur_output_nodes, topk, ppr_iterations, alpha) for i in range(len(cur_output_nodes)): # force output nodes to be aux nodes auxiliary_nodes = np.union1d(cur_output_nodes[i], topk_neighbors[i]) loader.append((cur_output_nodes[i], auxiliary_nodes)) cur_output_nodes = [] if torch.cuda.is_available(): torch.cuda.empty_cache() return loader def get_pairs(ppr_mat: Any) -> np.ndarray: ppr_mat = ppr_mat + ppr_mat.transpose() ppr_mat = ppr_mat.tocoo() row, col, data = ppr_mat.row, ppr_mat.col, ppr_mat.data mask = (row > col) # lu row, col, data = row[mask], col[mask], data[mask] sort_arg = np.argsort(data)[::-1] # sort_arg = parallel_sort.parallel_argsort(data)[::-1] # map prime_nodes to arange ppr_pairs = np.vstack((row[sort_arg], col[sort_arg])).T return ppr_pairs _prime_orient_merge_numba: Optional[Callable] = None def prime_orient_merge( ppr_pairs: np.ndarray, primes_per_batch: int, num_nodes: int, ): if not WITH_NUMBA: # pragma: no cover raise ImportError("'prime_orient_merge' requires the 'numba' package") global _prime_orient_merge_numba if _prime_orient_merge_numba is None: _prime_orient_merge_numba = numba.njit(cache=True)(_prime_orient_merge) return _prime_orient_merge_numba(ppr_pairs, primes_per_batch, num_nodes) def _prime_orient_merge( ppr_pairs: np.ndarray, primes_per_batch: int, num_nodes: int, ): id_primes_list = list(np.arange(num_nodes, dtype=np.int32).reshape(-1, 1)) node_id_list = np.arange(num_nodes, dtype=np.int32) placeholder = np.zeros(0, dtype=np.int32) for i, j in ppr_pairs: id1, id2 = node_id_list[i], node_id_list[j] if id1 > id2: id1, id2 = id2, id1 if id1 != id2 and len(id_primes_list[id1]) + len( id_primes_list[id2]) <= primes_per_batch: id_primes_list[id1] = np.concatenate( (id_primes_list[id1], id_primes_list[id2])) node_id_list[id_primes_list[id2]] = id1 id_primes_list[id2] = placeholder prime_lst = list() ids = np.unique(node_id_list) for _id in ids: prime_lst.append(list(id_primes_list[_id])) return list(prime_lst) def prime_post_process(loader, merge_max_size): from heapq import heapify, heappop, heappush h = [( len(p), p, ) for p in loader] heapify(h) while len(h) > 1: len1, p1 = heappop(h) len2, p2 = heappop(h) if len1 + len2 <= merge_max_size: heappush(h, (len1 + len2, p1 + p2)) else: heappush(h, ( len1, p1, )) heappush(h, ( len2, p2, )) break new_batch = [] while len(h): _, p = heappop(h) new_batch.append(p) return new_batch def topk_ppr_matrix( edge_index: Tensor, num_nodes: int, alpha: float, eps: float, output_node_indices: Union[np.ndarray, torch.LongTensor], topk: int, normalization='row', ) -> Tuple[Any, List[np.ndarray]]: neighbors, weights = get_ppr(edge_index, alpha, eps, output_node_indices, num_nodes) _, neighbor_counts = neighbors[0].unique(return_counts=True) ppr_matrix = SparseTensor( row=torch.arange( len(output_node_indices)).repeat_interleave(neighbor_counts), col=neighbors[1], value=weights, sparse_sizes=(len(output_node_indices), num_nodes)).to_scipy(layout='csr') neighbors = [ n.cpu().numpy() for n in torch.split(neighbors[1], neighbor_counts.cpu().tolist(), dim=0) ] weights = [ n.cpu().numpy() for n in torch.split(weights, neighbor_counts.cpu().tolist(), dim=0) ] def sparsify(neighbors: List[np.ndarray], weights: List[np.ndarray], topk: int): new_neighbors = [] for n, w in zip(neighbors, weights): idx_topk = np.argsort(w)[-topk:] new_neighbor = n[idx_topk] new_neighbors.append(new_neighbor) return new_neighbors neighbors = sparsify(neighbors, weights, topk) neighbors = [ np.union1d(nei, pr) for nei, pr in zip(neighbors, output_node_indices) ] _, out_degree = torch.unique(edge_index[0], sorted=True, return_counts=True) if normalization == 'sym': # Assume undirected (symmetric) adjacency matrix deg_sqrt = np.sqrt(np.maximum(out_degree, 1e-12)) deg_inv_sqrt = 1. / deg_sqrt row, col = ppr_matrix.nonzero() ppr_matrix.data = deg_sqrt[output_node_indices[row]] * \ ppr_matrix.data * \ deg_inv_sqrt[col] elif normalization == 'col': # Assume undirected (symmetric) adjacency matrix deg_inv = 1. / np.maximum(out_degree, 1e-12) row, col = ppr_matrix.nonzero() ppr_matrix.data = out_degree[output_node_indices[row]] * \ ppr_matrix.data * \ deg_inv[col] elif normalization == 'row': pass else: raise ValueError(f"Unknown PPR normalization: {normalization}") return ppr_matrix, neighbors class IBMBBaseLoader(torch.utils.data.DataLoader): def __init__( self, data_list: Union[List[Data], List[Tuple]], graph: Data, adj: SparseTensor, return_edge_index_type: str, **kwargs, ): self.graph = graph self.adj = adj self.return_edge_index_type = return_edge_index_type if 'collate_fn' in kwargs: del kwargs['collate_fn'] super().__init__(data_list, collate_fn=self.collate_fn, **kwargs) def create_loader(self, *args, **kwargs): raise NotImplementedError @classmethod def prepare_cache( cls, graph: Data, batch_wise_out_aux_pairs: List[Tuple[np.ndarray, np.ndarray]], adj: Optional[SparseTensor], return_edge_index_type: str, ): subgraphs = [] pbar = tqdm(batch_wise_out_aux_pairs) pbar.set_description( f"Caching data with type {return_edge_index_type}") if return_edge_index_type == 'adj': assert adj is not None for out, aux in pbar: mask = torch.from_numpy(np.isin(aux, out)) if isinstance(aux, np.ndarray): aux = torch.from_numpy(aux) subg = get_subgraph(aux, graph, return_edge_index_type, adj, output_node_mask=mask) subgraphs.append(subg) return subgraphs @classmethod def create_adj_from_edge_index( cls, edge_index: Tensor, num_nodes: int, normalization: str, ): assert normalization in ['sym', 'rw'] adj = SparseTensor.from_edge_index( edge_index, sparse_sizes=(num_nodes, num_nodes), ) adj = adj.fill_value(1.) degree = adj.sum(0) degree[degree == 0.] = 1e-12 deg_inv = 1 / degree if normalization == 'sym': deg_inv_sqrt = deg_inv**0.5 adj = adj * deg_inv_sqrt.reshape(1, -1) adj = adj * deg_inv_sqrt.reshape(-1, 1) elif normalization == 'rw': adj = adj * deg_inv.reshape(-1, 1) return adj def collate_fn(self, data_list: List[Union[Data, Tuple]]): if len(data_list) == 1 and isinstance(data_list[0], Data): return data_list[0] out, aux = zip(*data_list) out = np.concatenate(out) aux = np.unique(np.concatenate(aux)) mask = torch.from_numpy(np.isin(aux, out)) aux = torch.from_numpy(aux) subg = get_subgraph(aux, self.graph, self.return_edge_index_type, self.adj, output_node_mask=mask) return subg def __repr__(self) -> str: return f'{self.__class__.__name__}()' class IBMBBatchLoader(IBMBBaseLoader): r"""The batch-wise influence-based data loader from the `"Influence-Based Mini-Batching for Graph Neural Networks" `__ paper. First, the METIS graph partitioning algorithm separates the graph into :obj:`num_partitions` many partitions. Afterwards, input/seed nodes and their auxiliary nodes (found via topic-sensitive PageRank) are used to form a mini-batch. If :obj:`batch_size` is set to :obj:`1`, mini-batches are pre-calculated and cached in memory. Otherwise, only input nodes and their auxiliary nodes are pre-computed, and mini-batches are collated on-the-fly. Args: data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data` object. batch_order (str): A string indicating the batch order type (one of :obj:`"order"`, :obj:`"sample"` or :obj:`"rand"`). If :obj:`"order"`, calculates the pair-wise KL divergence between every two batches to organize an optimal order. If :obj:`"sample"`, samples the next batch w.r.t. the last one in which a batch with higher KL divergence score is more likely to be sampled. If :obj:`"rand"`, batches are generated randomly. num_partitions (int): The number of partitions. input_nodes (torch.Tensor): A vector containing the set of seed nodes. batch_expand_ratio (float, optional): The ratio between the returned batch size and the original partition size. For example, set it to :obj:`2.0` in case you would like the batch to have double the number of nodes as the size of its partition. (default: :obj:`1.0`) metis_input_node_weight (float, optional): The weights on the input nodes for METIS graph partitioning. (default: :obj:`None`) alpha (float, optional): The teleport probability of the PageRank calculation. (default: :obj:`0.2`) approximate_ppr_iterations (int, optional): The number of power iterations for PageRank calculation. (default: :obj:`50`) return_edge_index_type (str, optional): A string indicating the output type of edge indices (one of :obj:`"edge_index"` or :obj:`"adj"`). If set to :obj:`"adj"`, the :obj:`edge_index` of the batch will be a :class:`torch_sparse.SparseTensor`, otherwise a :class:`torch.Tensor`. (default: :obj:`"edge_index"`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Data, batch_order: str, num_partitions: int, input_nodes: Tensor, batch_expand_ratio: Optional[float] = 1.0, metis_input_node_weight: Optional[float] = None, alpha: Optional[float] = 0.2, approximate_ppr_iterations: Optional[int] = 50, return_edge_index_type: str = 'edge_index', **kwargs, ): self.subgraphs = [] self.batch_wise_out_aux_pairs = [] assert is_undirected( data.edge_index, num_nodes=data.num_nodes), "Assume the graph to be undirected" assert batch_order in ['rand', 'sample', 'order' ], f"Unsupported batch order: {batch_order}" adj = self.create_adj_from_edge_index( data.edge_index, data.num_nodes, normalization='rw', ) self.cache_data = kwargs['batch_size'] == 1 self.num_partitions = num_partitions self.output_indices = input_nodes assert return_edge_index_type in ['adj', 'edge_index'] self.return_edge_index_type = return_edge_index_type self.batch_expand_ratio = batch_expand_ratio self.metis_output_weight = metis_input_node_weight self.num_outnodeset_per_batch = 50 self.alpha = alpha self.approximate_ppr_iterations = approximate_ppr_iterations self.create_loader(data, adj) if len(self.batch_wise_out_aux_pairs) > 2: # <= 2 order makes no sense ys = [ data.y[out].numpy() for out, _ in self.batch_wise_out_aux_pairs ] sampler = define_sampler(batch_order, ys, data.y.max().item() + 1) else: sampler = None if not self.cache_data: cached_data = data # need to cache the original graph if return_edge_index_type == 'adj': cached_adj = adj else: cached_adj = None else: cached_data = None cached_adj = None super().__init__( self.subgraphs if self.cache_data else self.batch_wise_out_aux_pairs, cached_data, cached_adj, return_edge_index_type, sampler=sampler, **kwargs, ) def create_loader(self, graph: Data, adj: SparseTensor): partitions = get_partitions( adj, self.num_partitions, self.output_indices, graph.num_nodes, self.metis_output_weight, ) # get output - auxiliary node pairs topk = math.ceil(self.batch_expand_ratio * graph.num_nodes / self.num_partitions) batch_wise_out_aux_pairs = create_batchwise_out_aux_pairs( adj, partitions, self.output_indices, topk, self.num_outnodeset_per_batch, self.alpha, self.approximate_ppr_iterations) indices_complete_check(batch_wise_out_aux_pairs, self.output_indices) self.batch_wise_out_aux_pairs = batch_wise_out_aux_pairs if self.cache_data: self.subgraphs = self.prepare_cache( graph, batch_wise_out_aux_pairs, adj, self.return_edge_index_type, ) class IBMBNodeLoader(IBMBBaseLoader): r"""The node-wise influence-based data loader from the `"Influence-Based Mini-Batching for Graph Neural Networks" `__ paper. First, the Personalized PageRank (PPR) score for each input node is computed, for which the :obj:`k` nodes with the highest scores are taken auxiliary nodes. Afterwards, input nodes are merged according to their pair-wise PPR scores. Similar to :class:`~torch_geometric.loader.IBMBBatchLoader`, subgraphs are cached in memory for :obj:`batch_size = 1`, and collated on-the-fly otherwise. Args: data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data` object. batch_order (str): A string indicating the batch order type (one of :obj:`"order"`, :obj:`"sample"` or :obj:`"rand"`). If :obj:`"order"`, calculates the pair-wise KL divergence between every two batches to organize an optimal order. If :obj:`"sample"`, samples the next batch w.r.t. the last one in which a batch with higher KL divergence score is more likely to be sampled. If :obj:`"rand"`, batches are generated randomly. input_nodes (torch.Tensor): A vector containing the set of seed nodes. num_auxiliary_nodes (int): The number of auxiliary nodes per input node. num_nodes_per_batch (int): The number of seed nodes per batch. alpha (float, optional): The teleport probability of the PageRank calculation. (default: :obj:`0.2`) eps (float, optional): The threshold for stopping the PPR calculation The smaller :obj`eps` is, the more accurate are the results of PPR calculation, but it also takes longer. (default: :obj:`1e-5`) return_edge_index_type (str, optional): A string indicating the output type of edge indices (one of :obj:`"edge_index"` or :obj:`"adj"`). If set to :obj:`"adj"`, the :obj:`edge_index` of the batch will be a :class:`torch_sparse.SparseTensor`, otherwise a :class:`torch.Tensor`. (default: :obj:`"edge_index"`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Data, batch_order: str, input_nodes: torch.Tensor, num_auxiliary_nodes: int, num_nodes_per_batch: int, alpha: float = 0.2, eps: float = 1e-5, return_edge_index_type: str = 'edge_index', **kwargs, ): self.subgraphs = [] self.node_wise_out_aux_pairs = [] assert is_undirected( data.edge_index, num_nodes=data.num_nodes), "Assume the graph to be undirected" assert batch_order in ['rand', 'sample', 'order' ], f"Unsupported batch order: {batch_order}" if return_edge_index_type == 'adj': adj = self.create_adj_from_edge_index(data.edge_index, data.num_nodes, normalization='rw') else: adj = None self.cache_data = kwargs['batch_size'] == 1 self._batchsize = kwargs['batch_size'] self.output_indices = input_nodes.numpy() assert return_edge_index_type in ['adj', 'edge_index'] self.return_edge_index_type = return_edge_index_type self.num_auxiliary_node_per_output = num_auxiliary_nodes self.num_output_nodes_per_batch = num_nodes_per_batch self.alpha = alpha self.eps = eps self.create_loader(data, adj) if len(self.node_wise_out_aux_pairs) > 2: # <= 2 order makes no sense ys = [ data.y[out].numpy() for out, _ in self.node_wise_out_aux_pairs ] sampler = define_sampler(batch_order, ys, data.y.max().item() + 1) else: sampler = None if not self.cache_data: cached_graph = data # need to cache the original graph cached_adj = adj else: cached_graph = None cached_adj = None super().__init__( self.subgraphs if self.cache_data else self.node_wise_out_aux_pairs, cached_graph, cached_adj, return_edge_index_type, sampler=sampler, **kwargs, ) def create_loader(self, graph: Data, adj: SparseTensor): logging.info("Start PPR calculation") ppr_matrix, neighbors = topk_ppr_matrix( graph.edge_index, graph.num_nodes, self.alpha, self.eps, torch.from_numpy(self.output_indices), self.num_auxiliary_node_per_output) ppr_matrix = ppr_matrix[:, self.output_indices] logging.info("Getting PPR pairs") ppr_pairs = get_pairs(ppr_matrix) output_list = prime_orient_merge( ppr_pairs, self.num_output_nodes_per_batch, len(self.output_indices), ) output_list = prime_post_process( output_list, self.num_output_nodes_per_batch, ) node_wise_out_aux_pairs = [] if isinstance(neighbors, list): neighbors = np.array(neighbors, dtype=object) def _union(inputs): return np.unique(np.concatenate(inputs)) for p in output_list: node_wise_out_aux_pairs.append( (self.output_indices[p], _union(neighbors[p]).astype(np.int64))) indices_complete_check(node_wise_out_aux_pairs, self.output_indices) self.node_wise_out_aux_pairs = node_wise_out_aux_pairs if self.cache_data: self.subgraphs = self.prepare_cache( graph, node_wise_out_aux_pairs, adj, self.return_edge_index_type, ) class IBMBOrderedSampler(torch.utils.data.Sampler[int]): r"""A sampler with given order, specially for IBMB loaders. Args: data_source (np.ndarray, torch.Tensor, List): A :obj:`np.ndarray`, :obj:`torch.Tensor`, or :obj:`List` data object. Contains the order of the batches. """ def __init__(self, data_source: Union[np.ndarray, torch.Tensor, List]) -> None: self.data_source = data_source super().__init__(data_source) def __iter__(self) -> Iterator[int]: return iter(self.data_source) def __len__(self) -> int: return len(self.data_source) class IBMBWeightedSampler(torch.utils.data.Sampler[int]): r"""A weighted sampler wrt the pair wise KL divergence. The very first batch after initialization is sampled randomly, with the next ones being sampled according to the last batch, including the first batch in the next round. Args: batch_kl_div (np.ndarray, torch.Tensor): A :obj:`np.ndarray` or :obj:`torch.Tensor`, each element [i, j] contains the pair wise KL divergence between batch i and j. """ def __init__(self, batch_kl_div: Union[np.ndarray, torch.Tensor]) -> None: data_source = np.arange(batch_kl_div.shape[0]) self.data_source = data_source self.batch_kl_div = batch_kl_div self.last_train_batch_id = 0 super().__init__(data_source) def __iter__(self) -> Iterator[int]: probs = self.batch_kl_div.copy() last = self.last_train_batch_id num_batches = probs.shape[0] fetch_idx = [] next_id = 0 while np.any(probs): next_id = np.random.choice(num_batches, size=None, replace=False, p=probs[last] / probs[last].sum()) last = next_id fetch_idx.append(next_id) probs[:, next_id] = 0. self.last_train_batch_id = next_id return iter(fetch_idx) def __len__(self) -> int: return len(self.data_source) ================================================ FILE: torch_geometric/loader/imbalanced_sampler.py ================================================ from typing import List, Optional, Union import torch from torch import Tensor from torch_geometric.data import Data, Dataset, InMemoryDataset class ImbalancedSampler(torch.utils.data.WeightedRandomSampler): r"""A weighted random sampler that randomly samples elements according to class distribution. As such, it will either remove samples from the majority class (under-sampling) or add more examples from the minority class (over-sampling). **Graph-level sampling:** .. code-block:: python from torch_geometric.loader import DataLoader, ImbalancedSampler sampler = ImbalancedSampler(dataset) loader = DataLoader(dataset, batch_size=64, sampler=sampler, ...) **Node-level sampling:** .. code-block:: python from torch_geometric.loader import NeighborLoader, ImbalancedSampler sampler = ImbalancedSampler(data, input_nodes=data.train_mask) loader = NeighborLoader(data, input_nodes=data.train_mask, batch_size=64, num_neighbors=[-1, -1], sampler=sampler, ...) You can also pass in the class labels directly as a :class:`torch.Tensor`: .. code-block:: python from torch_geometric.loader import NeighborLoader, ImbalancedSampler sampler = ImbalancedSampler(data.y) loader = NeighborLoader(data, input_nodes=data.train_mask, batch_size=64, num_neighbors=[-1, -1], sampler=sampler, ...) Args: dataset (Dataset or Data or Tensor): The dataset or class distribution from which to sample the data, given either as a :class:`~torch_geometric.data.Dataset`, :class:`~torch_geometric.data.Data`, or :class:`torch.Tensor` object. input_nodes (Tensor, optional): The indices of nodes that are used by the corresponding loader, *e.g.*, by :class:`~torch_geometric.loader.NeighborLoader`. If set to :obj:`None`, all nodes will be considered. This argument should only be set for node-level loaders and does not have any effect when operating on a set of graphs as given by :class:`~torch_geometric.data.Dataset`. (default: :obj:`None`) num_samples (int, optional): The number of samples to draw for a single epoch. If set to :obj:`None`, will sample as much elements as there exists in the underlying data. (default: :obj:`None`) """ def __init__( self, dataset: Union[Dataset, Data, List[Data], Tensor], input_nodes: Optional[Tensor] = None, num_samples: Optional[int] = None, ): if isinstance(dataset, Data): y = dataset.y.view(-1) assert dataset.num_nodes == y.numel() y = y[input_nodes] if input_nodes is not None else y elif isinstance(dataset, Tensor): y = dataset.view(-1) y = y[input_nodes] if input_nodes is not None else y elif isinstance(dataset, InMemoryDataset): y = dataset.y.view(-1) assert len(dataset) == y.numel() else: ys = [data.y for data in dataset] if isinstance(ys[0], Tensor): y = torch.cat(ys, dim=0).view(-1) else: y = torch.tensor(ys).view(-1) assert len(dataset) == y.numel() assert y.dtype == torch.long # Require classification. num_samples = y.numel() if num_samples is None else num_samples class_weight = 1. / y.bincount() weight = class_weight[y] return super().__init__(weight, num_samples, replacement=True) ================================================ FILE: torch_geometric/loader/link_loader.py ================================================ from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.mixin import ( AffinityMixin, LogMemoryMixin, MultithreadingMixin, ) from torch_geometric.loader.utils import ( filter_custom_hetero_store, filter_custom_store, filter_data, filter_hetero_data, get_edge_label_index, infer_filter_per_worker, ) from torch_geometric.sampler import ( BaseSampler, EdgeSamplerInput, HeteroSamplerOutput, NegativeSampling, SamplerOutput, ) from torch_geometric.typing import InputEdges, OptTensor class LinkLoader( torch.utils.data.DataLoader, AffinityMixin, MultithreadingMixin, LogMemoryMixin, ): r"""A data loader that performs mini-batch sampling from link information, using a generic :class:`~torch_geometric.sampler.BaseSampler` implementation that defines a :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges` function and is supported on the provided input :obj:`data` object. .. note:: Negative sampling is currently implemented in an approximate way, *i.e.* negative edges may contain false negatives. Args: data (Any): A :class:`~torch_geometric.data.Data`, :class:`~torch_geometric.data.HeteroData`, or (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. link_sampler (torch_geometric.sampler.BaseSampler): The sampler implementation to be used with this loader. Needs to implement :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`. The sampler implementation must be compatible with the input :obj:`data` object. edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The edge indices, holding source and destination nodes to start sampling from. If set to :obj:`None`, all edges will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the edge type and corresponding edge indices. (default: :obj:`None`) edge_label (Tensor, optional): The labels of edge indices from which to start sampling from. Must be the same length as the :obj:`edge_label_index`. (default: :obj:`None`) edge_label_time (Tensor, optional): The timestamps of edge indices from which to start sampling from. Must be the same length as :obj:`edge_label_index`. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.*, neighbors have an earlier timestamp than the output edge. The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) neg_sampling (NegativeSampling, optional): The negative sampling configuration. For negative sampling mode :obj:`"binary"`, samples can be accessed via the attributes :obj:`edge_label_index` and :obj:`edge_label` in the respective edge type of the returned mini-batch. In case :obj:`edge_label` does not exist, it will be automatically created and represents a binary classification task (:obj:`0` = negative edge, :obj:`1` = positive edge). In case :obj:`edge_label` does exist, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. Note that returned labels are of type :obj:`torch.float` for binary classification (to facilitate the ease-of-use of :meth:`F.binary_cross_entropy`) and of type :obj:`torch.long` for multi-class classification (to facilitate the ease-of-use of :meth:`F.cross_entropy`). For negative sampling mode :obj:`"triplet"`, samples can be accessed via the attributes :obj:`src_index`, :obj:`dst_pos_index` and :obj:`dst_neg_index` in the respective node types of the returned mini-batch. :obj:`edge_label` needs to be :obj:`None` for :obj:`"triplet"` negative sampling mode. If set to :obj:`None`, no negative sampling strategy is applied. (default: :obj:`None`) neg_sampling_ratio (int or float, optional): The ratio of sampled negative edges to the number of positive edges. Deprecated in favor of the :obj:`neg_sampling` argument. (default: :obj:`None`). transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) transform_sampler_output (callable, optional): A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and returns a transformed version. (default: :obj:`None`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) custom_cls (HeteroData, optional): A custom :class:`~torch_geometric.data.HeteroData` class to return for mini-batches in case of remote backends. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], link_sampler: BaseSampler, edge_label_index: InputEdges = None, edge_label: OptTensor = None, edge_label_time: OptTensor = None, neg_sampling: Optional[NegativeSampling] = None, neg_sampling_ratio: Optional[Union[int, float]] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: Optional[bool] = None, custom_cls: Optional[HeteroData] = None, input_id: OptTensor = None, **kwargs, ): if filter_per_worker is None: filter_per_worker = infer_filter_per_worker(data) # Remove for PyTorch Lightning: kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) # Save for PyTorch Lightning: self.edge_label_index = edge_label_index if neg_sampling_ratio is not None and neg_sampling_ratio != 0.0: # TODO: Deprecation warning. neg_sampling = NegativeSampling("binary", neg_sampling_ratio) # Get edge type (or `None` for homogeneous graphs): input_type, edge_label_index = get_edge_label_index( data, edge_label_index) self.data = data self.link_sampler = link_sampler self.neg_sampling = NegativeSampling.cast(neg_sampling) self.transform = transform self.transform_sampler_output = transform_sampler_output self.filter_per_worker = filter_per_worker self.custom_cls = custom_cls if (self.neg_sampling is not None and self.neg_sampling.is_binary() and edge_label is not None and edge_label.min() == 0): # Increment labels such that `zero` now denotes "negative". edge_label = edge_label + 1 if (self.neg_sampling is not None and self.neg_sampling.is_triplet() and edge_label is not None): raise ValueError("'edge_label' needs to be undefined for " "'triplet'-based negative sampling. Please use " "`src_index`, `dst_pos_index` and " "`neg_pos_index` of the returned mini-batch " "instead to differentiate between positive and " "negative samples.") self.input_data = EdgeSamplerInput( input_id=input_id, row=edge_label_index[0], col=edge_label_index[1], label=edge_label, time=edge_label_time, input_type=input_type, ) iterator = range(edge_label_index.size(1)) super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) def __call__( self, index: Union[Tensor, List[int]], ) -> Union[Data, HeteroData]: r"""Samples a subgraph from a batch of input edges.""" out = self.collate_fn(index) if not self.filter_per_worker: out = self.filter_fn(out) return out def collate_fn(self, index: Union[Tensor, List[int]]) -> Any: r"""Samples a subgraph from a batch of input edges.""" input_data: EdgeSamplerInput = self.input_data[index] out = self.link_sampler.sample_from_edges( input_data, neg_sampling=self.neg_sampling) if self.filter_per_worker: # Execute `filter_fn` in the worker process out = self.filter_fn(out) return out def filter_fn( self, out: Union[SamplerOutput, HeteroSamplerOutput], ) -> Union[Data, HeteroData]: r"""Joins the sampled nodes with their corresponding features, returning the resulting :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object to be used downstream. """ if self.transform_sampler_output: out = self.transform_sampler_output(out) if isinstance(out, SamplerOutput): if isinstance(self.data, Data): data = filter_data( # self.data, out.node, out.row, out.col, out.edge, self.link_sampler.edge_permutation) else: # Tuple[FeatureStore, GraphStore] # Hack to detect whether we are in a distributed setting. if (self.link_sampler.__class__.__name__ == 'DistNeighborSampler'): edge_index = torch.stack([out.row, out.col]) data = Data(edge_index=edge_index) # Metadata entries are populated in # `DistributedNeighborSampler._collate_fn()` data.x = out.metadata[-3] data.y = out.metadata[-2] data.edge_attr = out.metadata[-1] else: data = filter_custom_store( # *self.data, out.node, out.row, out.col, out.edge, self.custom_cls) if 'n_id' not in data: data.n_id = out.node if out.edge is not None and 'e_id' not in data: edge = out.edge.to(torch.long) perm = self.link_sampler.edge_permutation data.e_id = perm[out.edge] if perm is not None else out.edge data.batch = out.batch data.num_sampled_nodes = out.num_sampled_nodes data.num_sampled_edges = out.num_sampled_edges data.input_id = out.metadata[0] if self.neg_sampling is None or self.neg_sampling.is_binary(): data.edge_label_index = out.metadata[1] data.edge_label = out.metadata[2] data.edge_label_time = out.metadata[3] elif self.neg_sampling.is_triplet(): data.src_index = out.metadata[1] data.dst_pos_index = out.metadata[2] data.dst_neg_index = out.metadata[3] data.seed_time = out.metadata[4] # Sanity removals in case `edge_label_index` and # `edge_label_time` are attributes of the base `data` object: del data.edge_label_index # Sanity removals. del data.edge_label_time elif isinstance(out, HeteroSamplerOutput): if isinstance(self.data, HeteroData): data = filter_hetero_data( # self.data, out.node, out.row, out.col, out.edge, self.link_sampler.edge_permutation) else: # Tuple[FeatureStore, GraphStore] # Hack to detect whether we are in a distributed setting. if (self.link_sampler.__class__.__name__ == 'DistNeighborSampler'): import torch_geometric.distributed as dist data = dist.utils.filter_dist_store( *self.data, out.node, out.row, out.col, out.edge, self.custom_cls, out.metadata, self.input_data.input_type) else: data = filter_custom_hetero_store( # *self.data, out.node, out.row, out.col, out.edge, self.custom_cls) for key, node in out.node.items(): if 'n_id' not in data[key]: data[key].n_id = node for key, edge in (out.edge or {}).items(): if edge is not None and 'e_id' not in data[key]: edge = edge.to(torch.long) perm = self.link_sampler.edge_permutation if perm is not None and perm.get(key, None) is not None: edge = perm[key][edge] data[key].e_id = edge data.set_value_dict('batch', out.batch) data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes) data.set_value_dict('num_sampled_edges', out.num_sampled_edges) input_type = self.input_data.input_type data[input_type].input_id = out.metadata[0] if self.neg_sampling is None or self.neg_sampling.is_binary(): data[input_type].edge_label_index = out.metadata[1] data[input_type].edge_label = out.metadata[2] data[input_type].edge_label_time = out.metadata[3] elif self.neg_sampling.is_triplet(): data[input_type[0]].src_index = out.metadata[1] data[input_type[-1]].dst_pos_index = out.metadata[2] data[input_type[-1]].dst_neg_index = out.metadata[3] data[input_type[0]].seed_time = out.metadata[4] data[input_type[-1]].seed_time = out.metadata[4] # Sanity removals in case `edge_label_index` and # `edge_label_time` are attributes of the base `data` object: if input_type in data.edge_types: del data[input_type].edge_label_index del data[input_type].edge_label_time else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " f"type: '{type(out)}'") return data if self.transform is None else self.transform(data) def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/loader/link_neighbor_loader.py ================================================ from typing import Callable, Dict, List, Optional, Tuple, Union from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData from torch_geometric.loader.link_loader import LinkLoader from torch_geometric.sampler import NegativeSampling, NeighborSampler from torch_geometric.sampler.base import SubgraphType from torch_geometric.typing import EdgeType, InputEdges, OptTensor class LinkNeighborLoader(LinkLoader): r"""A link-based data loader derived as an extension of the node-based :class:`torch_geometric.loader.NeighborLoader`. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. More specifically, this loader first selects a sample of edges from the set of input edges :obj:`edge_label_index` (which may or not be edges in the original graph) and then constructs a subgraph from all the nodes present in this list by sampling :obj:`num_neighbors` neighbors in each iteration. .. code-block:: python from torch_geometric.datasets import Planetoid from torch_geometric.loader import LinkNeighborLoader data = Planetoid(path, name='Cora')[0] loader = LinkNeighborLoader( data, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, edge_label_index=data.edge_index, ) sampled_data = next(iter(loader)) print(sampled_data) >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128]) It is additionally possible to provide edge labels for sampled edges, which are then added to the batch: .. code-block:: python loader = LinkNeighborLoader( data, num_neighbors=[30] * 2, batch_size=128, edge_label_index=data.edge_index, edge_label=torch.ones(data.edge_index.size(1)) ) sampled_data = next(iter(loader)) print(sampled_data) >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128], edge_label=[128]) The rest of the functionality mirrors that of :class:`~torch_geometric.loader.NeighborLoader`, including support for heterogeneous graphs. In particular, the data loader will add the following attributes to the returned mini-batch: * :obj:`n_id` The global node index for every sampled node * :obj:`e_id` The global edge index for every sampled edge * :obj:`input_id`: The global index of the :obj:`edge_label_index` * :obj:`num_sampled_nodes`: The number of sampled nodes in each hop * :obj:`num_sampled_edges`: The number of sampled edges in each hop .. note:: Negative sampling is currently implemented in an approximate way, *i.e.* negative edges may contain false negatives. .. warning:: Note that the sampling scheme is independent from the edge we are making a prediction for. That is, by default supervision edges in :obj:`edge_label_index` **will not** get masked out during sampling. In case there exists an overlap between message passing edges in :obj:`data.edge_index` and supervision edges in :obj:`edge_label_index`, you might end up sampling an edge you are making a prediction for. You can generally avoid this behavior (if desired) by making :obj:`data.edge_index` and :obj:`edge_label_index` two disjoint sets of edges, *e.g.*, via the :class:`~torch_geometric.transforms.RandomLinkSplit` transformation and its :obj:`disjoint_train_ratio` argument. Args: data (Any): A :class:`~torch_geometric.data.Data`, :class:`~torch_geometric.data.HeteroData`, or (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to :obj:`-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The edge indices for which neighbors are sampled to create mini-batches. If set to :obj:`None`, all edges will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the edge type and corresponding edge indices. (default: :obj:`None`) edge_label (Tensor, optional): The labels of edge indices for which neighbors are sampled. Must be the same length as the :obj:`edge_label_index`. If set to :obj:`None` its set to `torch.zeros(...)` internally. (default: :obj:`None`) edge_label_time (Tensor, optional): The timestamps for edge indices for which neighbors are sampled. Must be the same length as :obj:`edge_label_index`. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.*, neighbors have an earlier timestamp than the output edge. The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) replace (bool, optional): If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) subgraph_type (SubgraphType or str, optional): The type of the returned subgraph. If set to :obj:`"directional"`, the returned subgraph only holds the sampled (directed) edges which are necessary to compute representations for the sampled seed nodes. If set to :obj:`"bidirectional"`, sampled edges are converted to bidirectional edges. If set to :obj:`"induced"`, the returned subgraph contains the induced subgraph of all sampled nodes. (default: :obj:`"directional"`) disjoint (bool, optional): If set to :obj: `True`, each seed node will create its own disjoint subgraph. If set to :obj:`True`, mini-batch outputs will have a :obj:`batch` vector holding the mapping of nodes to their respective subgraph. Will get automatically set to :obj:`True` in case of temporal sampling. (default: :obj:`False`) temporal_strategy (str, optional): The sampling strategy when using temporal sampling (:obj:`"uniform"`, :obj:`"last"`). If set to :obj:`"uniform"`, will sample uniformly across neighbors that fulfill temporal constraints. If set to :obj:`"last"`, will sample the last `num_neighbors` that fulfill temporal constraints. (default: :obj:`"uniform"`) neg_sampling (NegativeSampling, optional): The negative sampling configuration. For negative sampling mode :obj:`"binary"`, samples can be accessed via the attributes :obj:`edge_label_index` and :obj:`edge_label` in the respective edge type of the returned mini-batch. In case :obj:`edge_label` does not exist, it will be automatically created and represents a binary classification task (:obj:`0` = negative edge, :obj:`1` = positive edge). In case :obj:`edge_label` does exist, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. Note that returned labels are of type :obj:`torch.float` for binary classification (to facilitate the ease-of-use of :meth:`F.binary_cross_entropy`) and of type :obj:`torch.long` for multi-class classification (to facilitate the ease-of-use of :meth:`F.cross_entropy`). For negative sampling mode :obj:`"triplet"`, samples can be accessed via the attributes :obj:`src_index`, :obj:`dst_pos_index` and :obj:`dst_neg_index` in the respective node types of the returned mini-batch. :obj:`edge_label` needs to be :obj:`None` for :obj:`"triplet"` negative sampling mode. If set to :obj:`None`, no negative sampling strategy is applied. (default: :obj:`None`) For example use obj:`neg_sampling=dict(mode= 'binary', amount=0.5)` neg_sampling_ratio (int or float, optional): The ratio of sampled negative edges to the number of positive edges. Deprecated in favor of the :obj:`neg_sampling` argument. (default: :obj:`None`) time_attr (str, optional): The name of the attribute that denotes timestamps for either the nodes or edges in the graph. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.* neighbors have an earlier or equal timestamp than the center node. Only used if :obj:`edge_label_time` is set. (default: :obj:`None`) weight_attr (str, optional): The name of the attribute that denotes edge weights in the graph. If set, weighted/biased sampling will be used such that neighbors are more likely to get sampled the higher their edge weights are. Edge weights do not need to sum to one, but must be non-negative, finite and have a non-zero sum within local neighborhoods. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) transform_sampler_output (callable, optional): A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and returns a transformed version. (default: :obj:`None`) is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. If :obj:`time_attr` is set, additionally requires that rows are sorted according to time within individual neighborhoods. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], edge_label_index: InputEdges = None, edge_label: OptTensor = None, edge_label_time: OptTensor = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', neg_sampling: Optional[NegativeSampling] = None, neg_sampling_ratio: Optional[Union[int, float]] = None, time_attr: Optional[str] = None, weight_attr: Optional[str] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, is_sorted: bool = False, filter_per_worker: Optional[bool] = None, neighbor_sampler: Optional[NeighborSampler] = None, directed: bool = True, # Deprecated. **kwargs, ): if (edge_label_time is not None) != (time_attr is not None): raise ValueError( f"Received conflicting 'edge_label_time' and 'time_attr' " f"arguments: 'edge_label_time' is " f"{'set' if edge_label_time is not None else 'not set'} " f"while 'time_attr' is " f"{'set' if time_attr is not None else 'not set'}. " f"Both arguments must be provided for temporal sampling.") if neighbor_sampler is None: neighbor_sampler = NeighborSampler( data, num_neighbors=num_neighbors, replace=replace, subgraph_type=subgraph_type, disjoint=disjoint, temporal_strategy=temporal_strategy, time_attr=time_attr, weight_attr=weight_attr, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, directed=directed, ) super().__init__( data=data, link_sampler=neighbor_sampler, edge_label_index=edge_label_index, edge_label=edge_label, edge_label_time=edge_label_time, neg_sampling=neg_sampling, neg_sampling_ratio=neg_sampling_ratio, transform=transform, transform_sampler_output=transform_sampler_output, filter_per_worker=filter_per_worker, **kwargs, ) ================================================ FILE: torch_geometric/loader/mixin.py ================================================ import glob import logging import os import os.path as osp import warnings from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Union import psutil import torch from torch_geometric.data import HeteroData def get_numa_nodes_cores() -> Dict[str, Any]: """Parses numa nodes information into a dictionary. ..code-block:: {: [(, [, ...]), ...], ...} # For example: {0: [(0, [0, 4]), (1, [1, 5])], 1: [(2, [2, 6]), (3, [3, 7])]} If not available, returns an empty dictionary. """ numa_node_paths = glob.glob('/sys/devices/system/node/node[0-9]*') if not numa_node_paths: return {} nodes = {} try: for node_path in numa_node_paths: numa_node_id = int(osp.basename(node_path)[4:]) thread_siblings = {} for cpu_dir in glob.glob(osp.join(node_path, 'cpu[0-9]*')): cpu_id = int(osp.basename(cpu_dir)[3:]) if cpu_id > 0: with open(osp.join(cpu_dir, 'online')) as core_online_file: core_online = int( core_online_file.read().splitlines()[0]) else: core_online = 1 # cpu0 is always online (special case) if core_online == 1: with open(osp.join(cpu_dir, 'topology', 'core_id')) as core_id_file: core_id = int(core_id_file.read().strip()) if core_id in thread_siblings: thread_siblings[core_id].append(cpu_id) else: thread_siblings[core_id] = [cpu_id] nodes[numa_node_id] = sorted([(k, sorted(v)) for k, v in thread_siblings.items()]) except (OSError, ValueError, IndexError): Warning('Failed to read NUMA info') return {} return nodes class WorkerInitWrapper: r"""Wraps the :attr:`worker_init_fn` argument for :class:`torch.utils.data.DataLoader` workers. """ def __init__(self, func: Callable) -> None: self.func = func def __call__(self, worker_id: int) -> None: if self.func is not None: self.func(worker_id) class LogMemoryMixin: r"""A context manager to enable logging of memory consumption in :class:`~torch.utils.data.DataLoader` workers. """ def _mem_init_fn(self, worker_id: int) -> None: proc = psutil.Process(os.getpid()) memory = proc.memory_info().rss / (1024 * 1024) logging.debug(f"Worker {worker_id} @ PID {proc.pid}: {memory:.2f} MB") # Chain worker init functions: self._old_worker_init_fn(worker_id) @contextmanager def enable_memory_log(self) -> None: self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn) try: self.worker_init_fn = self._mem_init_fn yield finally: self.worker_init_fn = self._old_worker_init_fn class MultithreadingMixin: r"""A context manager to enable multi-threading in :class:`~torch.utils.data.DataLoader` workers. It changes the default value of threads used in the loader from :obj:`1` to :obj:`worker_threads`. """ def _mt_init_fn(self, worker_id: int) -> None: try: torch.set_num_threads(int(self._worker_threads)) except IndexError as e: raise ValueError(f"Cannot set {self.worker_threads} threads " f"in worker {worker_id}") from e # Chain worker init functions: self._old_worker_init_fn(worker_id) @contextmanager def enable_multithreading( self, worker_threads: Optional[int] = None, ) -> None: r"""Enables multithreading in worker subprocesses. This option requires to change the start method from :obj:`"fork"` to :obj:`"spawn"`. .. code-block:: python def run(): loader = NeigborLoader(data, num_workers=3) with loader.enable_multithreading(10): for batch in loader: pass if __name__ == '__main__': torch.set_start_method('spawn') run() Args: worker_threads (int, optional): The number of threads to use in each worker process. By default, it uses half of all available CPU cores. (default: :obj:`torch.get_num_threads() // num_workers`) """ if worker_threads is None: worker_threads = torch.get_num_threads() // self.num_workers self._worker_threads = worker_threads if not self.num_workers > 0: raise ValueError(f"'enable_multithreading' needs to be performed " f"with at least one worker " f"(got {self.num_workers})") if worker_threads > torch.get_num_threads(): raise ValueError(f"'worker_threads' should be smaller than the " f"total available number of threads " f"{torch.get_num_threads()} " f"(got {worker_threads})") context = torch.multiprocessing.get_context()._name if context != 'spawn': raise ValueError(f"'enable_multithreading' can only be used with " f"the 'spawn' multiprocessing context " f"(got {context})") self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn) try: logging.debug(f"Using {worker_threads} threads in each worker") self.worker_init_fn = self._mt_init_fn yield finally: self.worker_init_fn = self._old_worker_init_fn class AffinityMixin: r"""A context manager to enable CPU affinity for data loader workers (only used when running on CPU devices). Affinitization places data loader workers threads on specific CPU cores. In effect, it allows for more efficient local memory allocation and reduces remote memory calls. Every time a process or thread moves from one core to another, registers and caches need to be flushed and reloaded. This can become very costly if it happens often, and our threads may also no longer be close to their data, or be able to share data in a cache. See `here `__ for the accompanying tutorial. .. warning:: To correctly affinitize compute threads (*i.e.* with :obj:`KMP_AFFINITY`), please make sure that you exclude :obj:`loader_cores` from the list of cores available for the main process. This will cause core oversubsription and exacerbate performance. .. code-block:: python loader = NeigborLoader(data, num_workers=3) with loader.enable_cpu_affinity(loader_cores=[0, 1, 2]): for batch in loader: pass """ def _aff_init_fn(self, worker_id: int) -> None: try: worker_cores = self.loader_cores[worker_id] if not isinstance(worker_cores, List): worker_cores = [worker_cores] if torch.multiprocessing.get_context()._name == 'spawn': torch.set_num_threads(len(worker_cores)) psutil.Process().cpu_affinity(worker_cores) except IndexError as e: raise ValueError(f"Cannot use CPU affinity for worker ID " f"{worker_id} on CPU {self.loader_cores}") from e # Chain worker init functions: self._old_worker_init_fn(worker_id) @contextmanager def enable_cpu_affinity( self, loader_cores: Optional[Union[List[List[int]], List[int]]] = None, ) -> None: r"""Enables CPU affinity. Args: loader_cores ([int], optional): List of CPU cores to which data loader workers should affinitize to. By default, it will affinitize to :obj:`numa0` cores. If used with :obj:`"spawn"` multiprocessing context, it will automatically enable multithreading and use multiple cores per each worker. """ if not self.num_workers > 0: raise ValueError( f"'enable_cpu_affinity' should be used with at least one " f"worker (got {self.num_workers})") if loader_cores and len(loader_cores) != self.num_workers: raise ValueError( f"The number of loader cores (got {len(loader_cores)}) " f"in 'enable_cpu_affinity' should match with the number " f"of workers (got {self.num_workers})") if isinstance(self.data, HeteroData): warnings.warn( "Due to conflicting parallelization methods it is not advised " "to use affinitization with 'HeteroData' datasets. " "Use `enable_multithreading` for better performance.", stacklevel=2) self.loader_cores = loader_cores[:] if loader_cores else None if self.loader_cores is None: numa_info = get_numa_nodes_cores() if numa_info and len(numa_info[0]) > self.num_workers: # Take one thread per each node 0 core: node0_cores = [cpus[0] for core_id, cpus in numa_info[0]] node0_cores.sort() else: node0_cores = list(range(psutil.cpu_count(logical=False))) if len(node0_cores) < self.num_workers: raise ValueError( f"More workers (got {self.num_workers}) than available " f"cores (got {len(node0_cores)})") # Set default loader core IDs: if torch.multiprocessing.get_context()._name == 'spawn': work_thread_pool = int(len(node0_cores) / self.num_workers) self.loader_cores = [ list( range( work_thread_pool * i, work_thread_pool * (i + 1), )) for i in range(self.num_workers) ] else: self.loader_cores = node0_cores[:self.num_workers] self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn) try: self.worker_init_fn = self._aff_init_fn logging.debug(f"{self.num_workers} data loader workers are " f"assigned to CPUs {self.loader_cores}") yield finally: self.worker_init_fn = self._old_worker_init_fn ================================================ FILE: torch_geometric/loader/neighbor_loader.py ================================================ from typing import Callable, Dict, List, Optional, Tuple, Union from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData from torch_geometric.loader.node_loader import NodeLoader from torch_geometric.sampler import NeighborSampler from torch_geometric.sampler.base import SubgraphType from torch_geometric.typing import EdgeType, InputNodes, OptTensor class NeighborLoader(NodeLoader): r"""A data loader that performs neighbor sampling as introduced in the `"Inductive Representation Learning on Large Graphs" `_ paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. More specifically, :obj:`num_neighbors` denotes how many neighbors are sampled for each node in each iteration. :class:`~torch_geometric.loader.NeighborLoader` takes in this list of :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for each node involved in iteration :obj:`i - 1`. Sampled nodes are sorted based on the order in which they were sampled. In particular, the first :obj:`batch_size` nodes represent the set of original mini-batch nodes. .. code-block:: python from torch_geometric.datasets import Planetoid from torch_geometric.loader import NeighborLoader data = Planetoid(path, name='Cora')[0] loader = NeighborLoader( data, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, input_nodes=data.train_mask, ) sampled_data = next(iter(loader)) print(sampled_data.batch_size) >>> 128 By default, the data loader will only include the edges that were originally sampled (:obj:`directed = True`). This option should only be used in case the number of hops is equivalent to the number of GNN layers. In case the number of GNN layers is greater than the number of hops, consider setting :obj:`directed = False`, which will include all edges between all sampled nodes (but is slightly slower as a result). Furthermore, :class:`~torch_geometric.loader.NeighborLoader` works for both **homogeneous** graphs stored via :class:`~torch_geometric.data.Data` as well as **heterogeneous** graphs stored via :class:`~torch_geometric.data.HeteroData`. When operating in heterogeneous graphs, up to :obj:`num_neighbors` neighbors will be sampled for each :obj:`edge_type`. However, more fine-grained control over the amount of sampled neighbors of individual edge types is possible: .. code-block:: python from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import NeighborLoader hetero_data = OGB_MAG(path)[0] loader = NeighborLoader( hetero_data, # Sample 30 neighbors for each node and edge type for 2 iterations num_neighbors={key: [30] * 2 for key in hetero_data.edge_types}, # Use a batch size of 128 for sampling training nodes of type paper batch_size=128, input_nodes=('paper', hetero_data['paper'].train_mask), ) sampled_hetero_data = next(iter(loader)) print(sampled_hetero_data['paper'].batch_size) >>> 128 .. note:: For an example of using :class:`~torch_geometric.loader.NeighborLoader`, see `examples/hetero/to_hetero_mag.py `_. The :class:`~torch_geometric.loader.NeighborLoader` will return subgraphs where global node indices are mapped to local indices corresponding to this specific subgraph. However, often times it is desired to map the nodes of the current subgraph back to the global node indices. The :class:`~torch_geometric.loader.NeighborLoader` will include this mapping as part of the :obj:`data` object: .. code-block:: python loader = NeighborLoader(data, ...) sampled_data = next(iter(loader)) print(sampled_data.n_id) # Global node index of each node in batch. In particular, the data loader will add the following attributes to the returned mini-batch: * :obj:`batch_size` The number of seed nodes (first nodes in the batch) * :obj:`n_id` The global node index for every sampled node * :obj:`e_id` The global edge index for every sampled edge * :obj:`input_id`: The global index of the :obj:`input_nodes` * :obj:`num_sampled_nodes`: The number of sampled nodes in each hop * :obj:`num_sampled_edges`: The number of sampled edges in each hop Args: data (Any): A :class:`~torch_geometric.data.Data`, :class:`~torch_geometric.data.HeteroData`, or (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to :obj:`-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The indices of nodes for which neighbors are sampled to create mini-batches. Needs to be either given as a :obj:`torch.LongTensor` or :obj:`torch.BoolTensor`. If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) input_time (torch.Tensor, optional): Optional values to override the timestamp for the input nodes given in :obj:`input_nodes`. If not set, will use the timestamps in :obj:`time_attr` as default (if present). The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) replace (bool, optional): If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) subgraph_type (SubgraphType or str, optional): The type of the returned subgraph. If set to :obj:`"directional"`, the returned subgraph only holds the sampled (directed) edges which are necessary to compute representations for the sampled seed nodes. If set to :obj:`"bidirectional"`, sampled edges are converted to bidirectional edges. If set to :obj:`"induced"`, the returned subgraph contains the induced subgraph of all sampled nodes. (default: :obj:`"directional"`) disjoint (bool, optional): If set to :obj: `True`, each seed node will create its own disjoint subgraph. If set to :obj:`True`, mini-batch outputs will have a :obj:`batch` vector holding the mapping of nodes to their respective subgraph. Will get automatically set to :obj:`True` in case of temporal sampling. (default: :obj:`False`) temporal_strategy (str, optional): The sampling strategy when using temporal sampling (:obj:`"uniform"`, :obj:`"last"`). If set to :obj:`"uniform"`, will sample uniformly across neighbors that fulfill temporal constraints. If set to :obj:`"last"`, will sample the last `num_neighbors` that fulfill temporal constraints. (default: :obj:`"uniform"`) time_attr (str, optional): The name of the attribute that denotes timestamps for either the nodes or edges in the graph. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.* neighbors have an earlier or equal timestamp than the center node. (default: :obj:`None`) weight_attr (str, optional): The name of the attribute that denotes edge weights in the graph. If set, weighted/biased sampling will be used such that neighbors are more likely to get sampled the higher their edge weights are. Edge weights do not need to sum to one, but must be non-negative, finite and have a non-zero sum within local neighborhoods. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) transform_sampler_output (callable, optional): A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and returns a transformed version. (default: :obj:`None`) is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. If :obj:`time_attr` is set, additionally requires that rows are sorted according to time within individual neighborhoods. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], input_nodes: InputNodes = None, input_time: OptTensor = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, weight_attr: Optional[str] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, is_sorted: bool = False, filter_per_worker: Optional[bool] = None, neighbor_sampler: Optional[NeighborSampler] = None, directed: bool = True, # Deprecated. **kwargs, ): if input_time is not None and time_attr is None: raise ValueError("Received conflicting 'input_time' and " "'time_attr' arguments: 'input_time' is set " "while 'time_attr' is not set.") if neighbor_sampler is None: neighbor_sampler = NeighborSampler( data, num_neighbors=num_neighbors, replace=replace, subgraph_type=subgraph_type, disjoint=disjoint, temporal_strategy=temporal_strategy, time_attr=time_attr, weight_attr=weight_attr, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, directed=directed, ) super().__init__( data=data, node_sampler=neighbor_sampler, input_nodes=input_nodes, input_time=input_time, transform=transform, transform_sampler_output=transform_sampler_output, filter_per_worker=filter_per_worker, **kwargs, ) ================================================ FILE: torch_geometric/loader/neighbor_sampler.py ================================================ from typing import Callable, List, NamedTuple, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.typing import SparseTensor class EdgeIndex(NamedTuple): edge_index: Tensor e_id: Optional[Tensor] size: Tuple[int, int] def to(self, *args, **kwargs): edge_index = self.edge_index.to(*args, **kwargs) e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None return EdgeIndex(edge_index, e_id, self.size) class Adj(NamedTuple): adj_t: SparseTensor e_id: Optional[Tensor] size: Tuple[int, int] def to(self, *args, **kwargs): adj_t = self.adj_t.to(*args, **kwargs) e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None return Adj(adj_t, e_id, self.size) class NeighborSampler(torch.utils.data.DataLoader): r"""The neighbor sampler from the `"Inductive Representation Learning on Large Graphs" `_ paper, which allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. Given a GNN with :math:`L` layers and a specific mini-batch of nodes :obj:`node_idx` for which we want to compute embeddings, this module iteratively samples neighbors and constructs bipartite graphs that simulate the actual computation flow of GNNs. More specifically, :obj:`sizes` denotes how much neighbors we want to sample for each node in each layer. This module then takes in these :obj:`sizes` and iteratively samples :obj:`sizes[l]` for each node involved in layer :obj:`l`. In the next layer, sampling is repeated for the union of nodes that were already encountered. The actual computation graphs are then returned in reverse-mode, meaning that we pass messages from a larger set of nodes to a smaller one, until we reach the nodes for which we originally wanted to compute embeddings. Hence, an item returned by :class:`NeighborSampler` holds the current :obj:`batch_size`, the IDs :obj:`n_id` of all nodes involved in the computation, and a list of bipartite graph objects via the tuple :obj:`(edge_index, e_id, size)`, where :obj:`edge_index` represents the bipartite edges between source and target nodes, :obj:`e_id` denotes the IDs of original edges in the full graph, and :obj:`size` holds the shape of the bipartite graph. For each bipartite graph, target nodes are also included at the beginning of the list of source nodes so that one can easily apply skip-connections or add self-loops. .. warning:: :class:`~torch_geometric.loader.NeighborSampler` is deprecated and will be removed in a future release. Use :class:`torch_geometric.loader.NeighborLoader` instead. .. note:: For an example of using :obj:`NeighborSampler`, see `examples/reddit.py `_ or `examples/ogbn_train.py `_. Args: edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a :class:`torch_sparse.SparseTensor` that defines the underlying graph connectivity/message passing flow. :obj:`edge_index` holds the indices of a (sparse) symmetric adjacency matrix. If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its shape must be defined as :obj:`[2, num_edges]`, where messages from nodes :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]` (in case :obj:`flow="source_to_target"`). If :obj:`edge_index` is of type :class:`torch_sparse.SparseTensor`, its sparse indices :obj:`(row, col)` should relate to :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`. The major difference between both formats is that we need to input the *transposed* sparse adjacency matrix. sizes ([int]): The number of neighbors to sample for each node in each layer. If set to :obj:`sizes[l] = -1`, all neighbors are included in layer :obj:`l`. node_idx (LongTensor, optional): The nodes that should be considered for creating mini-batches. If set to :obj:`None`, all nodes will be considered. num_nodes (int, optional): The number of nodes in the graph. (default: :obj:`None`) return_e_id (bool, optional): If set to :obj:`False`, will not return original edge indices of sampled edges. This is only useful in case when operating on graphs without edge features to save memory. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__(self, edge_index: Union[Tensor, SparseTensor], sizes: List[int], node_idx: Optional[Tensor] = None, num_nodes: Optional[int] = None, return_e_id: bool = True, transform: Callable = None, **kwargs): edge_index = edge_index.to('cpu') # Remove for PyTorch Lightning: kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) # Save for Pytorch Lightning < 1.6: self.edge_index = edge_index self.node_idx = node_idx self.num_nodes = num_nodes self.sizes = sizes self.return_e_id = return_e_id self.transform = transform self.is_sparse_tensor = isinstance(edge_index, SparseTensor) self.__val__ = None # Obtain a *transposed* `SparseTensor` instance. if not self.is_sparse_tensor: if (num_nodes is None and node_idx is not None and node_idx.dtype == torch.bool): num_nodes = node_idx.size(0) if (num_nodes is None and node_idx is not None and node_idx.dtype == torch.long): num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1 if num_nodes is None: num_nodes = int(edge_index.max()) + 1 value = torch.arange(edge_index.size(1)) if return_e_id else None self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], value=value, sparse_sizes=(num_nodes, num_nodes)).t() else: adj_t = edge_index if return_e_id: self.__val__ = adj_t.storage.value() value = torch.arange(adj_t.nnz()) adj_t = adj_t.set_value(value, layout='coo') self.adj_t = adj_t self.adj_t.storage.rowptr() if node_idx is None: node_idx = torch.arange(self.adj_t.sparse_size(0)) elif node_idx.dtype == torch.bool: node_idx = node_idx.nonzero(as_tuple=False).view(-1) super().__init__( node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs) def sample(self, batch): if not isinstance(batch, Tensor): batch = torch.tensor(batch) batch_size: int = len(batch) adjs = [] n_id = batch for size in self.sizes: adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False) e_id = adj_t.storage.value() size = adj_t.sparse_sizes()[::-1] if self.__val__ is not None: adj_t.set_value_(self.__val__[e_id], layout='coo') if self.is_sparse_tensor: adjs.append(Adj(adj_t, e_id, size)) else: row, col, _ = adj_t.coo() edge_index = torch.stack([col, row], dim=0) adjs.append(EdgeIndex(edge_index, e_id, size)) adjs = adjs[0] if len(adjs) == 1 else adjs[::-1] out = (batch_size, n_id, adjs) out = self.transform(*out) if self.transform is not None else out return out def __repr__(self) -> str: return f'{self.__class__.__name__}(sizes={self.sizes})' ================================================ FILE: torch_geometric/loader/node_loader.py ================================================ from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.mixin import ( AffinityMixin, LogMemoryMixin, MultithreadingMixin, ) from torch_geometric.loader.utils import ( filter_custom_hetero_store, filter_custom_store, filter_data, filter_hetero_data, get_input_nodes, infer_filter_per_worker, ) from torch_geometric.sampler import ( BaseSampler, HeteroSamplerOutput, NodeSamplerInput, SamplerOutput, ) from torch_geometric.typing import InputNodes, OptTensor class NodeLoader( torch.utils.data.DataLoader, AffinityMixin, MultithreadingMixin, LogMemoryMixin, ): r"""A data loader that performs mini-batch sampling from node information, using a generic :class:`~torch_geometric.sampler.BaseSampler` implementation that defines a :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` function and is supported on the provided input :obj:`data` object. Args: data (Any): A :class:`~torch_geometric.data.Data`, :class:`~torch_geometric.data.HeteroData`, or (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. node_sampler (torch_geometric.sampler.BaseSampler): The sampler implementation to be used with this loader. Needs to implement :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes`. The sampler implementation must be compatible with the input :obj:`data` object. input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The indices of seed nodes to start sampling from. Needs to be either given as a :obj:`torch.LongTensor` or :obj:`torch.BoolTensor`. If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) input_time (torch.Tensor, optional): Optional values to override the timestamp for the input nodes given in :obj:`input_nodes`. If not set, will use the timestamps in :obj:`time_attr` as default (if present). The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) transform_sampler_output (callable, optional): A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and returns a transformed version. (default: :obj:`None`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) custom_cls (HeteroData, optional): A custom :class:`~torch_geometric.data.HeteroData` class to return for mini-batches in case of remote backends. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], node_sampler: BaseSampler, input_nodes: InputNodes = None, input_time: OptTensor = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: Optional[bool] = None, custom_cls: Optional[HeteroData] = None, input_id: OptTensor = None, **kwargs, ): if filter_per_worker is None: filter_per_worker = infer_filter_per_worker(data) self.data = data self.node_sampler = node_sampler self.input_nodes = input_nodes self.input_time = input_time self.transform = transform self.transform_sampler_output = transform_sampler_output self.filter_per_worker = filter_per_worker self.custom_cls = custom_cls self.input_id = input_id kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) # Get node type (or `None` for homogeneous graphs): input_type, input_nodes, input_id = get_input_nodes( data, input_nodes, input_id) self.input_data = NodeSamplerInput( input_id=input_id, node=input_nodes, time=input_time, input_type=input_type, ) iterator = range(input_nodes.size(0)) super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) def __call__( self, index: Union[Tensor, List[int]], ) -> Union[Data, HeteroData]: r"""Samples a subgraph from a batch of input nodes.""" out = self.collate_fn(index) if not self.filter_per_worker: out = self.filter_fn(out) return out def collate_fn(self, index: Union[Tensor, List[int]]) -> Any: r"""Samples a subgraph from a batch of input nodes.""" input_data: NodeSamplerInput = self.input_data[index] out = self.node_sampler.sample_from_nodes(input_data) if self.filter_per_worker: # Execute `filter_fn` in the worker process out = self.filter_fn(out) return out def filter_fn( self, out: Union[SamplerOutput, HeteroSamplerOutput], ) -> Union[Data, HeteroData]: r"""Joins the sampled nodes with their corresponding features, returning the resulting :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object to be used downstream. """ if self.transform_sampler_output: out = self.transform_sampler_output(out) if isinstance(out, SamplerOutput): if isinstance(self.data, Data): data = filter_data( # self.data, out.node, out.row, out.col, out.edge, self.node_sampler.edge_permutation) else: # Tuple[FeatureStore, GraphStore] # Hack to detect whether we are in a distributed setting. if (self.node_sampler.__class__.__name__ == 'DistNeighborSampler'): edge_index = torch.stack([out.row, out.col]) data = Data(edge_index=edge_index) # Metadata entries are populated in # `DistributedNeighborSampler._collate_fn()` data.x = out.metadata[-3] data.y = out.metadata[-2] data.edge_attr = out.metadata[-1] else: data = filter_custom_store( # *self.data, out.node, out.row, out.col, out.edge, self.custom_cls) if 'n_id' not in data: data.n_id = out.node if out.edge is not None and 'e_id' not in data: edge = out.edge.to(torch.long) perm = self.node_sampler.edge_permutation data.e_id = perm[edge] if perm is not None else edge data.batch = out.batch data.num_sampled_nodes = out.num_sampled_nodes data.num_sampled_edges = out.num_sampled_edges if out.orig_row is not None and out.orig_col is not None: data._orig_edge_index = torch.stack([ out.orig_row, out.orig_col, ], dim=0) data.input_id = out.metadata[0] data.seed_time = out.metadata[1] data.batch_size = out.metadata[0].size(0) elif isinstance(out, HeteroSamplerOutput): if isinstance(self.data, HeteroData): data = filter_hetero_data( # self.data, out.node, out.row, out.col, out.edge, self.node_sampler.edge_permutation) else: # Tuple[FeatureStore, GraphStore] # Hack to detect whether we are in a distributed setting. if (self.node_sampler.__class__.__name__ == 'DistNeighborSampler'): import torch_geometric.distributed as dist data = dist.utils.filter_dist_store( *self.data, out.node, out.row, out.col, out.edge, self.custom_cls, out.metadata, self.input_data.input_type) else: data = filter_custom_hetero_store( # *self.data, out.node, out.row, out.col, out.edge, self.custom_cls) for key, node in out.node.items(): if 'n_id' not in data[key]: data[key].n_id = node for key, edge in (out.edge or {}).items(): if edge is not None and 'e_id' not in data[key]: edge = edge.to(torch.long) perm = self.node_sampler.edge_permutation if perm is not None and perm.get(key, None) is not None: edge = perm[key][edge] data[key].e_id = edge data.set_value_dict('batch', out.batch) data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes) data.set_value_dict('num_sampled_edges', out.num_sampled_edges) if out.orig_row is not None and out.orig_col is not None: for key in out.orig_row.keys(): data[key]._orig_edge_index = torch.stack([ out.orig_row[key], out.orig_col[key], ], dim=0) input_type = self.input_data.input_type data[input_type].input_id = out.metadata[0] data[input_type].seed_time = out.metadata[1] data[input_type].batch_size = out.metadata[0].size(0) else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " f"type: '{type(out)}'") return data if self.transform is None else self.transform(data) def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() # if not self.is_cuda_available and not self.cpu_affinity_enabled: # TODO: Add manual page for best CPU practices # link = ... # Warning('Dataloader CPU affinity opt is not enabled, consider ' # 'switching it on with enable_cpu_affinity() or see CPU ' # f'best practices for PyG [{link}])') # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/loader/prefetch.py ================================================ import warnings from contextlib import nullcontext from functools import partial from typing import Any, Optional import torch from torch.utils.data import DataLoader from torch_geometric.typing import WITH_IPEX class DeviceHelper: def __init__(self, device: Optional[torch.device] = None): with_cuda = torch.cuda.is_available() with_xpu = torch.xpu.is_available() if WITH_IPEX else False if device is None: if with_cuda: device = 'cuda' elif with_xpu: device = 'xpu' else: device = 'cpu' self.device = torch.device(device) self.is_gpu = self.device.type in ['cuda', 'xpu'] if ((self.device.type == 'cuda' and not with_cuda) or (self.device.type == 'xpu' and not with_xpu)): warnings.warn( f"Requested device '{self.device.type}' is not " f"available, falling back to CPU", stacklevel=2) self.device = torch.device('cpu') self.stream = None self.stream_context = nullcontext self.module = getattr(torch, self.device.type) if self.is_gpu else None def maybe_init_stream(self) -> None: if self.is_gpu: self.stream = self.module.Stream() self.stream_context = partial( self.module.stream, stream=self.stream, ) def maybe_wait_stream(self) -> None: if self.stream is not None: self.module.current_stream().wait_stream(self.stream) class PrefetchLoader: r"""A GPU prefetcher class for asynchronously transferring data of a :class:`torch.utils.data.DataLoader` from host memory to device memory. Args: loader (torch.utils.data.DataLoader): The data loader. device (torch.device, optional): The device to load the data to. (default: :obj:`None`) """ def __init__( self, loader: DataLoader, device: Optional[torch.device] = None, ): self.loader = loader self.device_helper = DeviceHelper(device) def non_blocking_transfer(self, batch: Any) -> Any: if not self.device_helper.is_gpu: return batch if isinstance(batch, (list, tuple)): return [self.non_blocking_transfer(v) for v in batch] if isinstance(batch, dict): return {k: self.non_blocking_transfer(v) for k, v in batch.items()} batch = batch.pin_memory() return batch.to(self.device_helper.device, non_blocking=True) def __iter__(self) -> Any: first = True self.device_helper.maybe_init_stream() batch = None for next_batch in self.loader: with self.device_helper.stream_context(): next_batch = self.non_blocking_transfer(next_batch) if not first: yield batch else: first = False self.device_helper.maybe_wait_stream() batch = next_batch yield batch def __len__(self) -> int: return len(self.loader) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.loader})' ================================================ FILE: torch_geometric/loader/random_node_loader.py ================================================ import math from typing import Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.hetero_data import to_homogeneous_edge_index class RandomNodeLoader(torch.utils.data.DataLoader): r"""A data loader that randomly samples nodes within a graph and returns their induced subgraph. .. note:: For an example of using :class:`~torch_geometric.loader.RandomNodeLoader`, see `examples/ogbn_proteins_deepgcn.py `_. Args: data (torch_geometric.data.Data or torch_geometric.data.HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. num_parts (int): The number of partitions. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData], num_parts: int, **kwargs, ): self.data = data self.num_parts = num_parts if isinstance(data, HeteroData): edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data) self.node_dict, self.edge_dict = node_dict, edge_dict else: edge_index = data.edge_index self.edge_index = edge_index self.num_nodes = data.num_nodes super().__init__( range(self.num_nodes), batch_size=math.ceil(self.num_nodes / num_parts), collate_fn=self.collate_fn, **kwargs, ) def collate_fn(self, index): if not isinstance(index, Tensor): index = torch.tensor(index) if isinstance(self.data, Data): return self.data.subgraph(index) elif isinstance(self.data, HeteroData): node_dict = { key: index[(index >= start) & (index < end)] - start for key, (start, end) in self.node_dict.items() } return self.data.subgraph(node_dict) ================================================ FILE: torch_geometric/loader/shadow.py ================================================ import copy from typing import Optional import torch from torch import Tensor from torch_geometric.data import Batch, Data from torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor class ShaDowKHopSampler(torch.utils.data.DataLoader): r"""The ShaDow :math:`k`-hop sampler from the `"Decoupling the Depth and Scope of Graph Neural Networks" `_ paper. Given a graph in a :obj:`data` object, the sampler will create shallow, localized subgraphs. A deep GNN on this local graph then smooths the informative local signals. .. note:: For an example of using :class:`ShaDowKHopSampler`, see `examples/shadow.py `_. Args: data (torch_geometric.data.Data): The graph data object. depth (int): The depth/number of hops of the localized subgraph. num_neighbors (int): The number of neighbors to sample for each node in each hop. node_idx (LongTensor or BoolTensor, optional): The nodes that should be considered for creating mini-batches. If set to :obj:`None`, all nodes will be considered. replace (bool, optional): If set to :obj:`True`, will sample neighbors with replacement. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or :obj:`num_workers`. """ def __init__(self, data: Data, depth: int, num_neighbors: int, node_idx: Optional[Tensor] = None, replace: bool = False, **kwargs): if not WITH_TORCH_SPARSE: raise ImportError( f"'{self.__class__.__name__}' requires 'torch-sparse'") self.data = copy.copy(data) self.depth = depth self.num_neighbors = num_neighbors self.replace = replace if data.edge_index is not None: self.is_sparse_tensor = False row, col = data.edge_index.cpu() self.adj_t = SparseTensor( row=row, col=col, value=torch.arange(col.size(0)), sparse_sizes=(data.num_nodes, data.num_nodes)).t() else: self.is_sparse_tensor = True self.adj_t = data.adj_t.cpu() if node_idx is None: node_idx = torch.arange(self.adj_t.sparse_size(0)) elif node_idx.dtype == torch.bool: node_idx = node_idx.nonzero(as_tuple=False).view(-1) self.node_idx = node_idx super().__init__(node_idx.tolist(), collate_fn=self.__collate__, **kwargs) def __collate__(self, n_id): n_id = torch.tensor(n_id) rowptr, col, value = self.adj_t.csr() out = torch.ops.torch_sparse.ego_k_hop_sample_adj( rowptr, col, n_id, self.depth, self.num_neighbors, self.replace) rowptr, col, n_id, e_id, ptr, root_n_id = out adj_t = SparseTensor(rowptr=rowptr, col=col, value=value[e_id] if value is not None else None, sparse_sizes=(n_id.numel(), n_id.numel()), is_sorted=True, trust_data=True) batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()), ptr=ptr) batch.root_n_id = root_n_id if self.is_sparse_tensor: batch.adj_t = adj_t else: row, col, e_id = adj_t.t().coo() batch.edge_index = torch.stack([row, col], dim=0) for k, v in self.data: if k in ['edge_index', 'adj_t', 'num_nodes', 'batch', 'ptr']: continue if k == 'y' and v.size(0) == self.data.num_nodes: batch[k] = v[n_id][root_n_id] elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes: batch[k] = v[n_id] elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges: batch[k] = v[e_id] else: batch[k] = v return batch ================================================ FILE: torch_geometric/loader/temporal_dataloader.py ================================================ from typing import List import torch from torch_geometric.data import TemporalData class TemporalDataLoader(torch.utils.data.DataLoader): r"""A data loader which merges successive events of a :class:`torch_geometric.data.TemporalData` to a mini-batch. Args: data (TemporalData): The :obj:`~torch_geometric.data.TemporalData` from which to load the data. batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) neg_sampling_ratio (float, optional): The ratio of sampled negative destination nodes to the number of positive destination nodes. (default: :obj:`0.0`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`. """ def __init__( self, data: TemporalData, batch_size: int = 1, neg_sampling_ratio: float = 0.0, **kwargs, ): # Remove for PyTorch Lightning: kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) kwargs.pop('shuffle', None) self.data = data self.events_per_batch = batch_size self.neg_sampling_ratio = neg_sampling_ratio if neg_sampling_ratio > 0: self.min_dst = int(data.dst.min()) self.max_dst = int(data.dst.max()) if kwargs.get('drop_last', False) and len(data) % batch_size != 0: arange = range(0, len(data) - batch_size, batch_size) else: arange = range(0, len(data), batch_size) super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs) def __call__(self, arange: List[int]) -> TemporalData: batch = self.data[arange[0]:arange[0] + self.events_per_batch] n_ids = [batch.src, batch.dst] if self.neg_sampling_ratio > 0: batch.neg_dst = torch.randint( low=self.min_dst, high=self.max_dst + 1, size=(round(self.neg_sampling_ratio * batch.dst.size(0)), ), dtype=batch.dst.dtype, device=batch.dst.device, ) n_ids += [batch.neg_dst] batch.n_id = torch.cat(n_ids, dim=0).unique() return batch ================================================ FILE: torch_geometric/loader/utils.py ================================================ import copy import logging import math from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch from torch import Tensor import torch_geometric.typing from torch_geometric.data import ( Data, FeatureStore, GraphStore, HeteroData, TensorAttr, remote_backend_utils, ) from torch_geometric.data.storage import EdgeStorage, NodeStorage from torch_geometric.typing import ( EdgeType, FeatureTensorType, InputEdges, InputNodes, NodeType, OptTensor, SparseTensor, TensorFrame, ) def index_select( value: FeatureTensorType, index: Tensor, dim: int = 0, ) -> Tensor: r"""Indexes the :obj:`value` tensor along dimension :obj:`dim` using the entries in :obj:`index`. Args: value (torch.Tensor or np.ndarray): The input tensor. index (torch.Tensor): The 1-D tensor containing the indices to index. dim (int, optional): The dimension in which to index. (default: :obj:`0`) .. warning:: :obj:`index` is casted to a :obj:`torch.int64` tensor internally, as `PyTorch currently only supports indexing `_ via :obj:`torch.int64`. """ # PyTorch currently only supports indexing via `torch.int64`: # https://github.com/pytorch/pytorch/issues/61819 index = index.to(torch.int64) if isinstance(value, Tensor): out: Optional[Tensor] = None if torch.utils.data.get_worker_info() is not None: # If we are in a background process, we write directly into a # shared memory tensor to avoid an extra copy: size = list(value.shape) size[dim] = index.numel() numel = math.prod(size) if torch_geometric.typing.WITH_PT20: storage = value.untyped_storage()._new_shared( numel * value.element_size()) else: storage = value.storage()._new_shared(numel) out = value.new(storage).view(size) return torch.index_select(value, dim, index, out=out) if isinstance(value, TensorFrame): assert dim == 0 return value[index] elif isinstance(value, np.ndarray): return torch.from_numpy(np.take(value, index, axis=dim)) raise ValueError(f"Encountered invalid feature tensor type " f"(got '{type(value)}')") def filter_node_store_(store: NodeStorage, out_store: NodeStorage, index: Tensor): # Filters a node storage object to only hold the nodes in `index`: for key, value in store.items(): if key == 'num_nodes': out_store.num_nodes = index.numel() elif store.is_node_attr(key): if isinstance(value, (Tensor, TensorFrame)): index = index.to(value.device) elif isinstance(value, np.ndarray): index = index.cpu() dim = store._parent().__cat_dim__(key, value, store) out_store[key] = index_select(value, index, dim=dim) def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor, col: Tensor, index: OptTensor, perm: OptTensor = None): # Filters a edge storage object to only hold the edges in `index`, # which represents the new graph as denoted by `(row, col)`: for key, value in store.items(): if key == 'edge_index': edge_index = torch.stack([row, col], dim=0).to(value.device) # TODO Integrate `EdgeIndex` into `custom_store`. # edge_index = EdgeIndex( # torch.stack([row, col], dim=0).to(value.device), # sparse_size=out_store.size(), # sort_order='col', # # TODO Support `is_undirected`. # ) out_store.edge_index = edge_index elif key == 'adj_t': # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout). row = row.to(value.device()) col = col.to(value.device()) edge_attr = value.storage.value() if edge_attr is not None: if index is not None: index = index.to(edge_attr.device) edge_attr = index_select(edge_attr, index, dim=0) else: edge_attr = None sparse_sizes = out_store.size()[::-1] # TODO Currently, we set `is_sorted=False`, see: # https://github.com/pyg-team/pytorch_geometric/issues/4346 out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=False, trust_data=True) elif store.is_edge_attr(key): if index is None: out_store[key] = None continue dim = store._parent().__cat_dim__(key, value, store) if isinstance(value, (Tensor, TensorFrame)): index = index.to(value.device) elif isinstance(value, np.ndarray): index = index.cpu() if perm is None: out_store[key] = index_select(value, index, dim=dim) else: if isinstance(value, (Tensor, TensorFrame)): perm = perm.to(value.device) elif isinstance(value, np.ndarray): perm = perm.cpu() out_store[key] = index_select( value, perm[index.to(torch.int64)], dim=dim, ) def filter_data(data: Data, node: Tensor, row: Tensor, col: Tensor, edge: OptTensor, perm: OptTensor = None) -> Data: # Filters a data object to only hold nodes in `node` and edges in `edge`: out = copy.copy(data) filter_node_store_(data._store, out._store, node) filter_edge_store_(data._store, out._store, row, col, edge, perm) return out def filter_hetero_data( data: HeteroData, node_dict: Dict[NodeType, Tensor], row_dict: Dict[EdgeType, Tensor], col_dict: Dict[EdgeType, Tensor], edge_dict: Dict[EdgeType, OptTensor], perm_dict: Optional[Dict[EdgeType, OptTensor]] = None, ) -> HeteroData: # Filters a heterogeneous data object to only hold nodes in `node` and # edges in `edge` for each node and edge type, respectively: out = copy.copy(data) for node_type in out.node_types: # Handle the case of disconnected graph sampling: if node_type not in node_dict: node_dict[node_type] = torch.empty(0, dtype=torch.long) filter_node_store_(data[node_type], out[node_type], node_dict[node_type]) for edge_type in out.edge_types: # Handle the case of disconnected graph sampling: if edge_type not in row_dict: row_dict[edge_type] = torch.empty(0, dtype=torch.long) if edge_type not in col_dict: col_dict[edge_type] = torch.empty(0, dtype=torch.long) if edge_type not in edge_dict: edge_dict[edge_type] = torch.empty(0, dtype=torch.long) filter_edge_store_( data[edge_type], out[edge_type], row_dict[edge_type], col_dict[edge_type], edge_dict[edge_type], perm_dict.get(edge_type, None) if perm_dict else None, ) return out def filter_custom_store( feature_store: FeatureStore, graph_store: GraphStore, node: Tensor, row: Tensor, col: Tensor, edge: OptTensor, custom_cls: Optional[Data] = None, ) -> Data: r"""Constructs a :class:`~torch_geometric.data.Data` object from a feature store and graph store instance. """ # Construct a new `Data` object: data = custom_cls() if custom_cls is not None else Data() data.edge_index = torch.stack([row, col], dim=0) # Filter node storage: required_attrs = [] for attr in feature_store.get_all_tensor_attrs(): attr.index = node # TODO Support edge features. required_attrs.append(attr) data.num_nodes = attr.index.size(0) # NOTE Here, we utilize `feature_store.multi_get` to give the feature store # full control over optimizing how it returns features (since the call is # synchronous, this amounts to giving the feature store control over all # iteration). tensors = feature_store.multi_get_tensor(required_attrs) for i, attr in enumerate(required_attrs): data[attr.attr_name] = tensors[i] return data def filter_custom_hetero_store( feature_store: FeatureStore, graph_store: GraphStore, node_dict: Dict[str, Tensor], row_dict: Dict[str, Tensor], col_dict: Dict[str, Tensor], edge_dict: Dict[str, OptTensor], custom_cls: Optional[HeteroData] = None, ) -> HeteroData: r"""Constructs a :class:`~torch_geometric.data.HeteroData` object from a feature store and graph store instance. """ # Construct a new `HeteroData` object: data = custom_cls() if custom_cls is not None else HeteroData() # Filter node storage: required_attrs = [] for attr in feature_store.get_all_tensor_attrs(): if attr.group_name in node_dict: attr.index = node_dict[attr.group_name] required_attrs.append(attr) data[attr.group_name].num_nodes = attr.index.size(0) # NOTE Here, we utilize `feature_store.multi_get` to give the feature store # full control over optimizing how it returns features (since the call is # synchronous, this amounts to giving the feature store control over all # iteration). tensors = feature_store.multi_get_tensor(required_attrs) for i, attr in enumerate(required_attrs): data[attr.group_name][attr.attr_name] = tensors[i] # Filter edge storage: # TODO support edge attributes for attr in graph_store.get_all_edge_attrs(): key = attr.edge_type if key in row_dict and key in col_dict: edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0) data[attr.edge_type].edge_index = edge_index return data # Input Utilities ############################################################# def get_input_nodes( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], input_nodes: Union[InputNodes, TensorAttr], input_id: Optional[Tensor] = None, ) -> Tuple[Optional[str], Tensor, Optional[Tensor]]: def to_index(nodes, input_id) -> Tuple[Tensor, Optional[Tensor]]: if isinstance(nodes, Tensor) and nodes.dtype == torch.bool: nodes = nodes.nonzero(as_tuple=False).view(-1) if input_id is not None: assert input_id.numel() == nodes.numel() else: input_id = nodes return nodes, input_id if not isinstance(nodes, Tensor): nodes = torch.tensor(nodes, dtype=torch.long) if input_id is not None: assert input_id.numel() == nodes.numel() return nodes, input_id if isinstance(data, Data): if input_nodes is None: return None, torch.arange(data.num_nodes), None return None, *to_index(input_nodes, input_id) elif isinstance(data, HeteroData): assert input_nodes is not None if isinstance(input_nodes, str): return input_nodes, torch.arange(data[input_nodes].num_nodes), None assert isinstance(input_nodes, (list, tuple)) assert len(input_nodes) == 2 assert isinstance(input_nodes[0], str) node_type, input_nodes = input_nodes if input_nodes is None: return node_type, torch.arange(data[node_type].num_nodes), None return node_type, *to_index(input_nodes, input_id) else: # Tuple[FeatureStore, GraphStore] feature_store, graph_store = data assert input_nodes is not None if isinstance(input_nodes, Tensor): return None, *to_index(input_nodes, input_id) if isinstance(input_nodes, str): num_nodes = remote_backend_utils.num_nodes( # feature_store, graph_store, input_nodes) return input_nodes, torch.arange(num_nodes), None if isinstance(input_nodes, (list, tuple)): assert len(input_nodes) == 2 assert isinstance(input_nodes[0], str) node_type, input_nodes = input_nodes if input_nodes is None: num_nodes = remote_backend_utils.num_nodes( # feature_store, graph_store, input_nodes) return node_type, torch.arange(num_nodes), None return node_type, *to_index(input_nodes, input_id) def get_edge_label_index( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], edge_label_index: InputEdges, ) -> Tuple[Optional[str], Tensor]: edge_type = None if isinstance(data, Data): if edge_label_index is None: return None, data.edge_index return None, edge_label_index assert edge_label_index is not None assert isinstance(edge_label_index, (list, tuple)) if isinstance(data, HeteroData): if isinstance(edge_label_index[0], str): edge_type = edge_label_index edge_type = data._to_canonical(*edge_type) assert edge_type in data.edge_types return edge_type, data[edge_type].edge_index assert len(edge_label_index) == 2 edge_type, edge_label_index = edge_label_index edge_type = data._to_canonical(*edge_type) if edge_label_index is None: return edge_type, data[edge_type].edge_index return edge_type, edge_label_index else: # Tuple[FeatureStore, GraphStore] _, graph_store = data # Need the edge index in COO for LinkNeighborLoader: def _get_edge_index(edge_type): row_dict, col_dict, _ = graph_store.coo([edge_type]) row = list(row_dict.values())[0] col = list(col_dict.values())[0] return torch.stack((row, col), dim=0) if isinstance(edge_label_index[0], str): edge_type = edge_label_index return edge_type, _get_edge_index(edge_type) assert len(edge_label_index) == 2 edge_type, edge_label_index = edge_label_index if edge_label_index is None: return edge_type, _get_edge_index(edge_type) return edge_type, edge_label_index def infer_filter_per_worker(data: Any) -> bool: out = True if isinstance(data, (Data, HeteroData)) and data.is_cuda: out = False logging.debug(f"Inferred 'filter_per_worker={out}' option for feature " f"fetching routines of the data loader") return out ================================================ FILE: torch_geometric/loader/zip_loader.py ================================================ from typing import Any, Iterator, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.loader import LinkLoader, NodeLoader from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.utils import infer_filter_per_worker class ZipLoader(torch.utils.data.DataLoader): r"""A loader that returns a tuple of data objects by sampling from multiple :class:`NodeLoader` or :class:`LinkLoader` instances. Args: loaders (List[NodeLoader] or List[LinkLoader]): The loader instances. filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, loaders: Union[List[NodeLoader], List[LinkLoader]], filter_per_worker: Optional[bool] = None, **kwargs, ): if filter_per_worker is None: filter_per_worker = infer_filter_per_worker(loaders[0].data) # Remove for PyTorch Lightning: kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) for loader in loaders: if not callable(getattr(loader, 'collate_fn', None)): raise ValueError("'{loader.__class__.__name__}' does not have " "a 'collate_fn' method") if not callable(getattr(loader, 'filter_fn', None)): raise ValueError("'{loader.__class__.__name__}' does not have " "a 'filter_fn' method") loader.filter_per_worker = filter_per_worker iterator = range(min([len(loader.dataset) for loader in loaders])) super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) self.loaders = loaders self.filter_per_worker = filter_per_worker def __call__( self, index: Union[Tensor, List[int]], ) -> Union[Tuple[Data, ...], Tuple[HeteroData, ...]]: r"""Samples subgraphs from a batch of input IDs.""" out = self.collate_fn(index) if not self.filter_per_worker: out = self.filter_fn(out) return out def collate_fn(self, index: List[int]) -> Tuple[Any, ...]: if not isinstance(index, Tensor): index = torch.tensor(index, dtype=torch.long) return tuple(loader.collate_fn(index) for loader in self.loaders) def filter_fn( self, outs: Tuple[Any, ...], ) -> Tuple[Union[Data, HeteroData], ...]: loaders = self.loaders return tuple(loader.filter_fn(v) for loader, v in zip(loaders, outs)) def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}(loaders={self.loaders})' ================================================ FILE: torch_geometric/logging.py ================================================ import sys from typing import Any _wandb_initialized: bool = False def init_wandb(name: str, **kwargs: Any) -> None: if '--wandb' not in sys.argv: return from datetime import datetime import wandb wandb.init( project=name, entity='pytorch-geometric', name=datetime.now().strftime('%Y-%m-%d_%H:%M'), config=kwargs, ) global _wandb_initialized _wandb_initialized = True def log(**kwargs: Any) -> None: def _map(value: Any) -> str: if isinstance(value, int) and not isinstance(value, bool): return f'{value:03d}' if isinstance(value, float): return f'{value:.4f}' return value print(', '.join(f'{key}: {_map(value)}' for key, value in kwargs.items())) if _wandb_initialized: import wandb wandb.log(kwargs) ================================================ FILE: torch_geometric/metrics/__init__.py ================================================ # flake8: noqa from .link_pred import ( LinkPredMetric, LinkPredMetricCollection, LinkPredPrecision, LinkPredRecall, LinkPredF1, LinkPredMAP, LinkPredNDCG, LinkPredMRR, LinkPredHitRatio, LinkPredCoverage, LinkPredDiversity, LinkPredPersonalization, LinkPredAveragePopularity, ) link_pred_metrics = [ 'LinkPredMetric', 'LinkPredMetricCollection', 'LinkPredPrecision', 'LinkPredRecall', 'LinkPredF1', 'LinkPredMAP', 'LinkPredNDCG', 'LinkPredMRR', 'LinkPredHitRatio', 'LinkPredCoverage', 'LinkPredDiversity', 'LinkPredPersonalization', 'LinkPredAveragePopularity', ] __all__ = link_pred_metrics ================================================ FILE: torch_geometric/metrics/link_pred.py ================================================ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.utils import cumsum, scatter try: import torchmetrics # noqa WITH_TORCHMETRICS = True BaseMetric = torchmetrics.Metric except Exception: WITH_TORCHMETRICS = False BaseMetric = torch.nn.Module # type: ignore @dataclass(repr=False) class LinkPredMetricData: pred_index_mat: Tensor edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]] edge_label_weight: Optional[Tensor] = None def __post_init__(self) -> None: # Filter all negative weights - they should not be used as ground-truth if self.edge_label_weight is not None: pos_mask = self.edge_label_weight > 0 self.edge_label_weight = self.edge_label_weight[pos_mask] if isinstance(self.edge_label_index, Tensor): self.edge_label_index = self.edge_label_index[:, pos_mask] else: self.edge_label_index = ( self.edge_label_index[0][pos_mask], self.edge_label_index[1][pos_mask], ) @property def pred_rel_mat(self) -> Tensor: r"""Returns a matrix indicating the relevance of the `k`-th prediction. If :obj:`edge_label_weight` is not given, relevance will be denoted as binary. """ if hasattr(self, '_pred_rel_mat'): return self._pred_rel_mat # type: ignore if self.edge_label_index[1].numel() == 0: self._pred_rel_mat = torch.zeros_like( self.pred_index_mat, dtype=torch.bool if self.edge_label_weight is None else torch.get_default_dtype(), ) return self._pred_rel_mat # Flatten both prediction and ground-truth indices, and determine # overlaps afterwards via `torch.searchsorted`. max_index = max( self.pred_index_mat.max() if self.pred_index_mat.numel() > 0 else 0, self.edge_label_index[1].max() if self.edge_label_index[1].numel() > 0 else 0, ) + 1 arange = torch.arange( start=0, end=max_index * self.pred_index_mat.size(0), # type: ignore step=max_index, # type: ignore device=self.pred_index_mat.device, ).view(-1, 1) flat_pred_index = (self.pred_index_mat + arange).view(-1) flat_label_index = max_index * self.edge_label_index[0] flat_label_index = flat_label_index + self.edge_label_index[1] flat_label_index, perm = flat_label_index.sort() edge_label_weight = self.edge_label_weight if edge_label_weight is not None: assert edge_label_weight.size() == self.edge_label_index[0].size() edge_label_weight = edge_label_weight[perm] pos = torch.searchsorted(flat_label_index, flat_pred_index) pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds. pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches if edge_label_weight is not None: pred_rel_mat = edge_label_weight[pos].where( pred_rel_mat, pred_rel_mat.new_zeros(1), ) pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size()) self._pred_rel_mat = pred_rel_mat return pred_rel_mat @property def label_count(self) -> Tensor: r"""The number of ground-truth labels for every example.""" if hasattr(self, '_label_count'): return self._label_count # type: ignore label_count = scatter( torch.ones_like(self.edge_label_index[0]), self.edge_label_index[0], dim=0, dim_size=self.pred_index_mat.size(0), reduce='sum', ) self._label_count = label_count return label_count @property def label_weight_sum(self) -> Tensor: r"""The sum of edge label weights for every example.""" if self.edge_label_weight is None: return self.label_count if hasattr(self, '_label_weight_sum'): return self._label_weight_sum # type: ignore label_weight_sum = scatter( self.edge_label_weight, self.edge_label_index[0], dim=0, dim_size=self.pred_index_mat.size(0), reduce='sum', ) self._label_weight_sum = label_weight_sum return label_weight_sum @property def edge_label_weight_pos(self) -> Optional[Tensor]: r"""Returns the position of edge label weights in descending order within example-wise buckets. """ if self.edge_label_weight is None: return None if hasattr(self, '_edge_label_weight_pos'): return self._edge_label_weight_pos # type: ignore # Get the permutation via two sorts: One globally on the weights, # followed by a (stable) sort on the example indices. perm1 = self.edge_label_weight.argsort(descending=True) perm2 = self.edge_label_index[0][perm1].argsort(stable=True) perm = perm1[perm2] # Invert the permutation to get the final position: pos = torch.empty_like(perm) pos[perm] = torch.arange(perm.size(0), device=perm.device) # Normalize position to zero within all buckets: pos = pos - cumsum(self.label_count)[self.edge_label_index[0]] self._edge_label_weight_pos = pos return pos class _LinkPredMetric(BaseMetric): r"""An abstract class for computing link prediction retrieval metrics. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ is_differentiable: bool = False full_state_update: bool = False higher_is_better: Optional[bool] = None def __init__(self, k: int) -> None: super().__init__() if k <= 0: raise ValueError(f"'k' needs to be a positive integer in " f"'{self.__class__.__name__}' (got {k})") self.k = k def update( self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None, ) -> None: r"""Updates the state variables based on the current mini-batch prediction. :meth:`update` can be repeated multiple times to accumulate the results of successive predictions, *e.g.*, inside a mini-batch training or evaluation loop. Args: pred_index_mat (torch.Tensor): The top-:math:`k` predictions of every example in the mini-batch with shape :obj:`[batch_size, k]`. edge_label_index (torch.Tensor): The ground-truth indices for every example in the mini-batch, given in COO format of shape :obj:`[2, num_ground_truth_indices]`. edge_label_weight (torch.Tensor, optional): The weight of the ground-truth indices for every example in the mini-batch of shape :obj:`[num_ground_truth_indices]`. If given, needs to be a vector of positive values. Required for weighted metrics, ignored otherwise. (default: :obj:`None`) """ raise NotImplementedError def compute(self) -> Tensor: r"""Computes the final metric value.""" raise NotImplementedError def reset(self) -> None: r"""Resets metric state variables to their default value.""" if WITH_TORCHMETRICS: super().reset() else: self._reset() def _reset(self) -> None: raise NotImplementedError def __repr__(self) -> str: return f'{self.__class__.__name__}(k={self.k})' class LinkPredMetric(_LinkPredMetric): r"""An abstract class for computing link prediction retrieval metrics. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ weighted: bool def __init__(self, k: int) -> None: super().__init__(k) self.accum: Tensor self.total: Tensor if WITH_TORCHMETRICS: self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total', torch.tensor(0), dist_reduce_fx='sum') else: self.register_buffer('accum', torch.tensor(0.), persistent=False) self.register_buffer('total', torch.tensor(0), persistent=False) def update( self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None, ) -> None: if self.weighted and edge_label_weight is None: raise ValueError(f"'edge_label_weight' is a required argument for " f"weighted '{self.__class__.__name__}' metrics") if not self.weighted: edge_label_weight = None data = LinkPredMetricData( pred_index_mat=pred_index_mat, edge_label_index=edge_label_index, edge_label_weight=edge_label_weight, ) self._update(data) def _update(self, data: LinkPredMetricData) -> None: metric = self._compute(data) self.accum += metric.sum() self.total += (data.label_count > 0).sum() def compute(self) -> Tensor: if self.total == 0: return torch.zeros_like(self.accum) return self.accum / self.total def _compute(self, data: LinkPredMetricData) -> Tensor: r"""Computes the specific metric. To be implemented separately for each metric class. Args: data (LinkPredMetricData): The mini-batch data for computing a link prediction metric per example. """ raise NotImplementedError def _reset(self) -> None: self.accum.zero_() self.total.zero_() def __repr__(self) -> str: weighted_repr = ', weighted=True' if self.weighted else '' return f'{self.__class__.__name__}(k={self.k}{weighted_repr})' class LinkPredMetricCollection(torch.nn.ModuleDict): r"""A collection of metrics to reduce and speed-up computation of link prediction metrics. .. code-block:: python from torch_geometric.metrics import ( LinkPredMAP, LinkPredMetricCollection, LinkPredPrecision, LinkPredRecall, ) metrics = LinkPredMetricCollection([ LinkPredMAP(k=10), LinkPredPrecision(k=100), LinkPredRecall(k=50), ]) metrics.update(pred_index_mat, edge_label_index) out = metrics.compute() metrics.reset() print(out) >>> {'LinkPredMAP@10': tensor(0.375), ... 'LinkPredPrecision@100': tensor(0.127), ... 'LinkPredRecall@50': tensor(0.483)} Args: metrics: The link prediction metrics. """ def __init__( self, metrics: Union[ List[LinkPredMetric], Dict[str, LinkPredMetric], ], ) -> None: super().__init__() if isinstance(metrics, (list, tuple)): metrics = { (f'{"Weighted" if getattr(metric, "weighted", False) else ""}' f'{metric.__class__.__name__}@{metric.k}'): metric for metric in metrics } assert len(metrics) > 0 assert isinstance(metrics, dict) for name, metric in metrics.items(): assert isinstance(metric, _LinkPredMetric) self[name] = metric @property def max_k(self) -> int: r"""The maximum number of top-:math:`k` predictions to evaluate against. """ return max([ metric.k # type: ignore[return-value] for metric in self.values() ]) # type: ignore[type-var] @property def weighted(self) -> bool: r"""Returns :obj:`True` in case the collection holds at least one weighted link prediction metric. """ return any( [getattr(metric, 'weighted', False) for metric in self.values()]) def update( # type: ignore self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None, ) -> None: r"""Updates the state variables based on the current mini-batch prediction. :meth:`update` can be repeated multiple times to accumulate the results of successive predictions, *e.g.*, inside a mini-batch training or evaluation loop. Args: pred_index_mat (torch.Tensor): The top-:math:`k` predictions of every example in the mini-batch with shape :obj:`[batch_size, k]`. edge_label_index (torch.Tensor): The ground-truth indices for every example in the mini-batch, given in COO format of shape :obj:`[2, num_ground_truth_indices]`. edge_label_weight (torch.Tensor, optional): The weight of the ground-truth indices for every example in the mini-batch of shape :obj:`[num_ground_truth_indices]`. If given, needs to be a vector of positive values. Required for weighted metrics, ignored otherwise. (default: :obj:`None`) """ if self.weighted and edge_label_weight is None: raise ValueError(f"'edge_label_weight' is a required argument for " f"weighted '{self.__class__.__name__}' metrics") data = LinkPredMetricData( # Share metric data across metrics. pred_index_mat=pred_index_mat, edge_label_index=edge_label_index, edge_label_weight=edge_label_weight, ) for metric in self.values(): if isinstance(metric, LinkPredMetric) and metric.weighted: metric._update(data) if WITH_TORCHMETRICS: metric._update_count += 1 data.edge_label_weight = None if hasattr(data, '_pred_rel_mat'): data._pred_rel_mat = data._pred_rel_mat != 0.0 if hasattr(data, '_label_weight_sum'): del data._label_weight_sum if hasattr(data, '_edge_label_weight_pos'): del data._edge_label_weight_pos for metric in self.values(): if isinstance(metric, LinkPredMetric) and not metric.weighted: metric._update(data) if WITH_TORCHMETRICS: metric._update_count += 1 for metric in self.values(): if not isinstance(metric, LinkPredMetric): metric.update( # type: ignore[operator] pred_index_mat, edge_label_index, edge_label_weight, ) def compute(self) -> Dict[str, Tensor]: r"""Computes the final metric values.""" return { name: metric.compute() # type: ignore[operator] for name, metric in self.items() } def reset(self) -> None: r"""Reset metric state variables to their default value.""" for metric in self.values(): metric.reset() # type: ignore[operator] def __repr__(self) -> str: names = [f' {name}: {metric},\n' for name, metric in self.items()] return f'{self.__class__.__name__}([\n{"".join(names)}])' class LinkPredPrecision(LinkPredMetric): r"""A link prediction metric to compute Precision @ :math:`k`, *i.e.* the proportion of recommendations within the top-:math:`k` that are actually relevant. A higher precision indicates the model's ability to surface relevant items early in the ranking. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True weighted: bool = False def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] return pred_rel_mat.sum(dim=-1) / self.k class LinkPredRecall(LinkPredMetric): r"""A link prediction metric to compute Recall @ :math:`k`, *i.e.* the proportion of relevant items that appear within the top-:math:`k`. A higher recall indicates the model's ability to retrieve a larger proportion of relevant items. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True def __init__(self, k: int, weighted: bool = False): super().__init__(k=k) self.weighted = weighted def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7) class LinkPredF1(LinkPredMetric): r"""A link prediction metric to compute F1 @ :math:`k`. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True weighted: bool = False def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] isin_count = pred_rel_mat.sum(dim=-1) precision = isin_count / self.k recall = isin_count / data.label_count.clamp(min=1e-7) return 2 * precision * recall / (precision + recall).clamp(min=1e-7) class LinkPredMAP(LinkPredMetric): r"""A link prediction metric to compute MAP @ :math:`k` (Mean Average Precision), considering the order of relevant items within the top-:math:`k`. MAP @ :math:`k` can provide a more comprehensive view of ranking quality than precision alone. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True weighted: bool = False def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] device = pred_rel_mat.device arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device) cum_precision = pred_rel_mat.cumsum(dim=1) / arange return ((cum_precision * pred_rel_mat).sum(dim=-1) / data.label_count.clamp(min=1e-7, max=self.k)) class LinkPredNDCG(LinkPredMetric): r"""A link prediction metric to compute the NDCG @ :math:`k` (Normalized Discounted Cumulative Gain). In particular, can account for the position of relevant items by considering relevance scores, giving higher weight to more relevant items appearing at the top. Args: k (int): The number of top-:math:`k` predictions to evaluate against. weighted (bool, optional): If set to :obj:`True`, assumes sorted lists of ground-truth items according to a relevance score as given by :obj:`edge_label_weight`. (default: :obj:`False`) """ higher_is_better: bool = True def __init__(self, k: int, weighted: bool = False): super().__init__(k=k) self.weighted = weighted dtype = torch.get_default_dtype() discount = torch.arange(2, k + 2, dtype=dtype).log2() self.discount: Tensor self.register_buffer('discount', discount, persistent=False) if not weighted: self.register_buffer('idcg', cumsum(1.0 / discount), persistent=False) else: self.idcg = None def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] discount = self.discount[:pred_rel_mat.size(1)].view(1, -1) dcg = (pred_rel_mat / discount).sum(dim=-1) if not self.weighted: assert self.idcg is not None idcg = self.idcg[data.label_count.clamp(max=self.k)] else: assert data.edge_label_weight is not None pos = data.edge_label_weight_pos assert pos is not None discount = torch.cat([ self.discount, self.discount.new_full((1, ), fill_value=float('inf')), ]) discount = discount[pos.clamp(max=self.k)] idcg = scatter( # Apply discount and aggregate: data.edge_label_weight / discount, data.edge_label_index[0], dim_size=data.pred_index_mat.size(0), reduce='sum', ) out = dcg / idcg out[out.isnan() | out.isinf()] = 0.0 return out class LinkPredMRR(LinkPredMetric): r"""A link prediction metric to compute the MRR @ :math:`k` (Mean Reciprocal Rank), *i.e.* the mean reciprocal rank of the first correct prediction (or zero otherwise). Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True weighted: bool = False def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] device = pred_rel_mat.device arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device) return (pred_rel_mat / arange).max(dim=-1)[0] class LinkPredHitRatio(LinkPredMetric): r"""A link prediction metric to compute the hit ratio @ :math:`k`, *i.e.* the percentage of users for whom at least one relevant item is present within the top-:math:`k` recommendations. A high ratio signifies the model's effectiveness in satisfying a broad range of user preferences. """ higher_is_better: bool = True weighted: bool = False def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] return pred_rel_mat.max(dim=-1)[0].to(torch.get_default_dtype()) class LinkPredCoverage(_LinkPredMetric): r"""A link prediction metric to compute the Coverage @ :math:`k` of predictions, *i.e.* the percentage of unique items recommended across all users within the top-:math:`k`. Higher coverage indicates a wider exploration of the item catalog. Args: k (int): The number of top-:math:`k` predictions to evaluate against. num_dst_nodes (int): The total number of destination nodes. """ higher_is_better: bool = True def __init__(self, k: int, num_dst_nodes: int) -> None: super().__init__(k) self.num_dst_nodes = num_dst_nodes self.mask: Tensor mask = torch.zeros(num_dst_nodes, dtype=torch.bool) if WITH_TORCHMETRICS: self.add_state('mask', mask, dist_reduce_fx='max') else: self.register_buffer('mask', mask, persistent=False) def update( self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None, ) -> None: self.mask[pred_index_mat[:, :self.k].flatten()] = True def compute(self) -> Tensor: return self.mask.to(torch.get_default_dtype()).mean() def _reset(self) -> None: self.mask.zero_() def __repr__(self) -> str: return (f'{self.__class__.__name__}(k={self.k}, ' f'num_dst_nodes={self.num_dst_nodes})') class LinkPredDiversity(_LinkPredMetric): r"""A link prediction metric to compute the Diversity @ :math:`k` of predictions according to item categories. Diversity is computed as .. math:: div_{u@k} = 1 - \left( \frac{1}{k \cdot (k-1)} \right) \sum_{i \neq j} sim(i, j) where .. math:: sim(i,j) = \begin{cases} 1 & \quad \text{if } i,j \text{ share category,}\\ 0 & \quad \text{otherwise.} \end{cases} which measures the pair-wise inequality of recommendations according to item categories. Args: k (int): The number of top-:math:`k` predictions to evaluate against. category (torch.Tensor): A vector that assigns each destination node to a specific category. """ higher_is_better: bool = True def __init__(self, k: int, category: Tensor) -> None: super().__init__(k) self.accum: Tensor self.total: Tensor if WITH_TORCHMETRICS: self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total', torch.tensor(0), dist_reduce_fx='sum') else: self.register_buffer('accum', torch.tensor(0.), persistent=False) self.register_buffer('total', torch.tensor(0), persistent=False) self.category: Tensor self.register_buffer('category', category, persistent=False) def update( self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None, ) -> None: category = self.category[pred_index_mat[:, :self.k]] sim = (category.unsqueeze(-2) == category.unsqueeze(-1)).sum(dim=-1) div = 1 - 1 / (self.k * (self.k - 1)) * (sim - 1).sum(dim=-1) self.accum += div.sum() self.total += pred_index_mat.size(0) def compute(self) -> Tensor: if self.total == 0: return torch.zeros_like(self.accum) return self.accum / self.total def _reset(self) -> None: self.accum.zero_() self.total.zero_() class LinkPredPersonalization(_LinkPredMetric): r"""A link prediction metric to compute the Personalization @ :math:`k`, *i.e.* the dissimilarity of recommendations across different users. Higher personalization suggests that the model tailors recommendations to individual user preferences rather than providing generic results. Dissimilarity is defined by the average inverse cosine similarity between users' lists of recommendations. Args: k (int): The number of top-:math:`k` predictions to evaluate against. max_src_nodes (int, optional): The maximum source nodes to consider to compute pair-wise dissimilarity. If specified, Personalization @ :math:`k` is approximated to avoid computation blowup due to quadratic complexity. (default: :obj:`2**12`) batch_size (int, optional): The batch size to determine how many pairs of user recommendations should be processed at once. (default: :obj:`2**16`) """ higher_is_better: bool = True def __init__( self, k: int, max_src_nodes: Optional[int] = 2**12, batch_size: int = 2**16, ) -> None: super().__init__(k) self.max_src_nodes = max_src_nodes self.batch_size = batch_size self.preds: List[Tensor] self.total: Tensor if WITH_TORCHMETRICS: self.add_state('preds', default=[], dist_reduce_fx='cat') self.add_state('total', torch.tensor(0), dist_reduce_fx='sum') else: self.preds = [] self.register_buffer('total', torch.tensor(0), persistent=False) def update( self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None, ) -> None: # NOTE Move to CPU to avoid memory blowup. pred_index_mat = pred_index_mat[:, :self.k].cpu() if self.max_src_nodes is None: self.preds.append(pred_index_mat) self.total += pred_index_mat.size(0) elif self.total < self.max_src_nodes: remaining = int(self.max_src_nodes - self.total) pred_index_mat = pred_index_mat[:remaining] self.preds.append(pred_index_mat) self.total += pred_index_mat.size(0) def compute(self) -> Tensor: device = self.total.device score = torch.tensor(0.0, device=device) total = torch.tensor(0, device=device) if len(self.preds) == 0: return score pred = torch.cat(self.preds, dim=0) if pred.size(0) == 0: return score # Calculate all pairs of nodes (e.g., triu_indices with offset=1). # NOTE We do this in chunks to avoid memory blow-up, which leads to a # more efficient but trickier implementation. num_pairs = (pred.size(0) * (pred.size(0) - 1)) // 2 offset = torch.arange(pred.size(0) - 1, 0, -1, device=device) rowptr = cumsum(offset) for start in range(0, num_pairs, self.batch_size): end = min(start + self.batch_size, num_pairs) idx = torch.arange(start, end, device=device) # Find the corresponding row: row = torch.searchsorted(rowptr, idx, right=True) - 1 # Find the corresponding column: col = idx - rowptr[row] + (pred.size(0) - offset[row]) left = pred[row.cpu()].to(device) right = pred[col.cpu()].to(device) # Use offset to work around applying `isin` along a specific dim: i = max(int(left.max()), int(right.max())) + 1 idx = torch.arange(0, i * row.size(0), i, device=device) idx = idx.view(-1, 1) isin = torch.isin(left + idx, right + idx) # Compute personalization via average inverse cosine similarity: cos = isin.sum(dim=-1) / pred.size(1) score += (1 - cos).sum() total += cos.numel() return score / total def _reset(self) -> None: self.preds = [] self.total.zero_() class LinkPredAveragePopularity(_LinkPredMetric): r"""A link prediction metric to compute the Average Recommendation Popularity (ARP) @ :math:`k`, which provides insights into the model's tendency to recommend popular items by averaging the popularity scores of items within the top-:math:`k` recommendations. Args: k (int): The number of top-:math:`k` predictions to evaluate against. popularity (torch.Tensor): The popularity of every item in the training set, *e.g.*, the number of times an item has been rated. """ higher_is_better: bool = False def __init__(self, k: int, popularity: Tensor) -> None: super().__init__(k) self.accum: Tensor self.total: Tensor if WITH_TORCHMETRICS: self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total', torch.tensor(0), dist_reduce_fx='sum') else: self.register_buffer('accum', torch.tensor(0.), persistent=False) self.register_buffer('total', torch.tensor(0), persistent=False) self.popularity: Tensor self.register_buffer('popularity', popularity, persistent=False) def update( self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None, ) -> None: pred_index_mat = pred_index_mat[:, :self.k] popularity = self.popularity[pred_index_mat] popularity = popularity.to(self.accum.dtype).mean(dim=-1) self.accum += popularity.sum() self.total += popularity.numel() def compute(self) -> Tensor: if self.total == 0: return torch.zeros_like(self.accum) return self.accum / self.total def _reset(self) -> None: self.accum.zero_() self.total.zero_() ================================================ FILE: torch_geometric/nn/__init__.py ================================================ from .reshape import Reshape from .sequential import Sequential from .data_parallel import DataParallel from .to_hetero_transformer import to_hetero from .to_hetero_with_bases_transformer import to_hetero_with_bases from .to_fixed_size_transformer import to_fixed_size from .encoding import PositionalEncoding, TemporalEncoding from .summary import summary from .aggr import * # noqa from .attention import * # noqa from .conv import * # noqa from .pool import * # noqa from .glob import * # noqa from .norm import * # noqa from .unpool import * # noqa from .dense import * # noqa from .kge import * # noqa from .models import * # noqa from .functional import * # noqa __all__ = [ 'Reshape', 'Sequential', 'DataParallel', 'to_hetero', 'to_hetero_with_bases', 'to_fixed_size', 'PositionalEncoding', 'TemporalEncoding', 'summary', ] ================================================ FILE: torch_geometric/nn/aggr/__init__.py ================================================ from .base import Aggregation from .multi import MultiAggregation from .basic import ( MeanAggregation, SumAggregation, MaxAggregation, MinAggregation, MulAggregation, VarAggregation, StdAggregation, SoftmaxAggregation, PowerMeanAggregation, ) from .quantile import MedianAggregation, QuantileAggregation from .lstm import LSTMAggregation from .gru import GRUAggregation from .set2set import Set2Set from .scaler import DegreeScalerAggregation from .equilibrium import EquilibriumAggregation from .sort import SortAggregation from .gmt import GraphMultisetTransformer from .attention import AttentionalAggregation from .mlp import MLPAggregation from .deep_sets import DeepSetsAggregation from .set_transformer import SetTransformerAggregation from .lcm import LCMAggregation from .variance_preserving import VariancePreservingAggregation from .patch_transformer import PatchTransformerAggregation __all__ = classes = [ 'Aggregation', 'MultiAggregation', 'SumAggregation', 'MeanAggregation', 'MaxAggregation', 'MinAggregation', 'MulAggregation', 'VarAggregation', 'StdAggregation', 'SoftmaxAggregation', 'PowerMeanAggregation', 'MedianAggregation', 'QuantileAggregation', 'LSTMAggregation', 'GRUAggregation', 'Set2Set', 'DegreeScalerAggregation', 'SortAggregation', 'GraphMultisetTransformer', 'AttentionalAggregation', 'EquilibriumAggregation', 'MLPAggregation', 'DeepSetsAggregation', 'SetTransformerAggregation', 'LCMAggregation', 'VariancePreservingAggregation', 'PatchTransformerAggregation', ] ================================================ FILE: torch_geometric/nn/aggr/attention.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.inits import reset from torch_geometric.utils import softmax class AttentionalAggregation(Aggregation): r"""The soft attention aggregation layer from the `"Graph Matching Networks for Learning the Similarity of Graph Structured Objects" `_ paper. .. math:: \mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left( h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \cdot h_{\mathbf{\Theta}} ( \mathbf{x}_n ), where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to \mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.* MLPs. Args: gate_nn (torch.nn.Module): A neural network :math:`h_{\mathrm{gate}}` that computes attention scores by mapping node features :obj:`x` of shape :obj:`[-1, in_channels]` to shape :obj:`[-1, 1]` (for node-level gating) or :obj:`[1, out_channels]` (for feature-level gating), *e.g.*, defined by :class:`torch.nn.Sequential`. nn (torch.nn.Module, optional): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]` before combining them with the attention scores, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) """ def __init__( self, gate_nn: torch.nn.Module, nn: Optional[torch.nn.Module] = None, ): super().__init__() from torch_geometric.nn import MLP self.gate_nn = self.gate_mlp = None if isinstance(gate_nn, MLP): self.gate_mlp = gate_nn else: self.gate_nn = gate_nn self.nn = self.mlp = None if isinstance(nn, MLP): self.mlp = nn else: self.nn = nn def reset_parameters(self): reset(self.gate_nn) reset(self.gate_mlp) reset(self.nn) reset(self.mlp) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: if self.gate_mlp is not None: gate = self.gate_mlp(x, batch=index, batch_size=dim_size) else: gate = self.gate_nn(x) if self.mlp is not None: x = self.mlp(x, batch=index, batch_size=dim_size) elif self.nn is not None: x = self.nn(x) gate = softmax(gate, index, ptr, dim_size, dim) return self.reduce(gate * x, index, ptr, dim_size, dim) def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'gate_nn={self.gate_mlp or self.gate_nn}, ' f'nn={self.mlp or self.nn})') ================================================ FILE: torch_geometric/nn/aggr/base.py ================================================ from typing import Final, Optional, Tuple import torch from torch import Tensor from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.utils import scatter, segment, to_dense_batch class Aggregation(torch.nn.Module): r"""An abstract base class for implementing custom aggregations. Aggregation can be either performed via an :obj:`index` vector, which defines the mapping from input elements to their location in the output: | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ master/docs/source/_figures/add.svg?sanitize=true :align: center :width: 400px | Notably, :obj:`index` does not have to be sorted (for most aggregation operators): .. code-block:: python # Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Assign each element to one of three sets: index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2]) output = aggr(x, index) # Output shape: [3, 64] Alternatively, aggregation can be achieved via a "compressed" index vector called :obj:`ptr`. Here, elements within the same set need to be grouped together in the input, and :obj:`ptr` defines their boundaries: .. code-block:: python # Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Define the boundary indices for three sets: ptr = torch.tensor([0, 4, 7, 10]) output = aggr(x, ptr=ptr) # Output shape: [3, 64] Note that at least one of :obj:`index` or :obj:`ptr` must be defined. Shapes: - **input:** node features :math:`(*, |\mathcal{V}|, F_{in})` or edge features :math:`(*, |\mathcal{E}|, F_{in})`, index vector :math:`(|\mathcal{V}|)` or :math:`(|\mathcal{E}|)`, - **output:** graph features :math:`(*, |\mathcal{G}|, F_{out})` or node features :math:`(*, |\mathcal{V}|, F_{out})` """ def __init__(self) -> None: super().__init__() self._deterministic: Final[bool] = ( torch.are_deterministic_algorithms_enabled() or torch.is_deterministic_algorithms_warn_only_enabled()) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. index (torch.Tensor, optional): The indices of elements for applying the aggregation. One of :obj:`index` or :obj:`ptr` must be defined. (default: :obj:`None`) ptr (torch.Tensor, optional): If given, computes the aggregation based on sorted inputs in CSR representation. One of :obj:`index` or :obj:`ptr` must be defined. (default: :obj:`None`) dim_size (int, optional): The size of the output tensor at dimension :obj:`dim` after aggregation. (default: :obj:`None`) dim (int, optional): The dimension in which to aggregate. (default: :obj:`-2`) max_num_elements: (int, optional): The maximum number of elements within a single aggregation group. (default: :obj:`None`) """ def reset_parameters(self): r"""Resets all learnable parameters of the module.""" @disable_dynamic_shapes(required_args=['dim_size']) def __call__( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, **kwargs, ) -> Tensor: if dim >= x.dim() or dim < -x.dim(): raise ValueError(f"Encountered invalid dimension '{dim}' of " f"source tensor with {x.dim()} dimensions") if index is None and ptr is None: index = x.new_zeros(x.size(dim), dtype=torch.long) if ptr is not None: if dim_size is None: dim_size = ptr.numel() - 1 elif dim_size != ptr.numel() - 1: raise ValueError(f"Encountered invalid 'dim_size' (got " f"'{dim_size}' but expected " f"'{ptr.numel() - 1}')") if index is not None and dim_size is None: dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 try: return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size, dim=dim, **kwargs) except (IndexError, RuntimeError) as e: if index is not None: if index.numel() > 0 and dim_size <= int(index.max()): raise ValueError(f"Encountered invalid 'dim_size' (got " f"'{dim_size}' but expected " f">= '{int(index.max()) + 1}')") from e raise e def __repr__(self) -> str: return f'{self.__class__.__name__}()' # Assertions ############################################################## def assert_index_present(self, index: Optional[Tensor]): # TODO Currently, not all aggregators support `ptr`. This assert helps # to ensure that we require `index` to be passed to the computation: if index is None: raise NotImplementedError( "Aggregation requires 'index' to be specified") def assert_sorted_index(self, index: Optional[Tensor]): if index is not None and not torch.all(index[:-1] <= index[1:]): raise ValueError("Can not perform aggregation since the 'index' " "tensor is not sorted. Specifically, if you use " "this aggregation as part of 'MessagePassing`, " "ensure that 'edge_index' is sorted by " "destination nodes, e.g., by calling " "`data.sort(sort_by_row=False)`") def assert_two_dimensional_input(self, x: Tensor, dim: int): if x.dim() != 2: raise ValueError(f"Aggregation requires two-dimensional inputs " f"(got '{x.dim()}')") if dim not in [-2, 0]: raise ValueError(f"Aggregation needs to perform aggregation in " f"first dimension (got '{dim}')") # Helper methods ########################################################## def reduce(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, reduce: str = 'sum') -> Tensor: if ptr is not None: if index is None or self._deterministic: ptr = expand_left(ptr, dim, dims=x.dim()) return segment(x, ptr, reduce=reduce) if index is None: raise RuntimeError("Aggregation requires 'index' to be specified") return scatter(x, index, dim, dim_size, reduce) def to_dense_batch( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, fill_value: float = 0.0, max_num_elements: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: # TODO Currently, `to_dense_batch` can only operate on `index`: self.assert_index_present(index) self.assert_sorted_index(index) self.assert_two_dimensional_input(x, dim) return to_dense_batch( x, index, batch_size=dim_size, fill_value=fill_value, max_num_nodes=max_num_elements, ) ############################################################################### def expand_left(ptr: Tensor, dim: int, dims: int) -> Tensor: for _ in range(dims + dim if dim < 0 else dim): ptr = ptr.unsqueeze(0) return ptr ================================================ FILE: torch_geometric/nn/aggr/basic.py ================================================ import math from typing import Optional import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.aggr import Aggregation from torch_geometric.utils import softmax class SumAggregation(Aggregation): r"""An aggregation operator that sums up features across a set of elements. .. math:: \mathrm{sum}(\mathcal{X}) = \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i. """ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='sum') class MeanAggregation(Aggregation): r"""An aggregation operator that averages features across a set of elements. .. math:: \mathrm{mean}(\mathcal{X}) = \frac{1}{|\mathcal{X}|} \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i. """ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='mean') class MaxAggregation(Aggregation): r"""An aggregation operator that takes the feature-wise maximum across a set of elements. .. math:: \mathrm{max}(\mathcal{X}) = \max_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i. """ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='max') class MinAggregation(Aggregation): r"""An aggregation operator that takes the feature-wise minimum across a set of elements. .. math:: \mathrm{min}(\mathcal{X}) = \min_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i. """ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='min') class MulAggregation(Aggregation): r"""An aggregation operator that multiples features across a set of elements. .. math:: \mathrm{mul}(\mathcal{X}) = \prod_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i. """ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: # TODO Currently, `mul` reduction can only operate on `index`: self.assert_index_present(index) return self.reduce(x, index, None, dim_size, dim, reduce='mul') class VarAggregation(Aggregation): r"""An aggregation operator that takes the feature-wise variance across a set of elements. .. math:: \mathrm{var}(\mathcal{X}) = \mathrm{mean}(\{ \mathbf{x}_i^2 : x \in \mathcal{X} \}) - \mathrm{mean}(\mathcal{X})^2. Args: semi_grad (bool, optional): If set to :obj:`True`, will turn off gradient calculation during :math:`E[X^2]` computation. Therefore, only semi-gradients are used during backpropagation. Useful for saving memory and accelerating backward computation. (default: :obj:`False`) """ def __init__(self, semi_grad: bool = False): super().__init__() self.semi_grad = semi_grad def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') if self.semi_grad: with torch.no_grad(): mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean') else: mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean') return mean2 - mean * mean class StdAggregation(Aggregation): r"""An aggregation operator that takes the feature-wise standard deviation across a set of elements. .. math:: \mathrm{std}(\mathcal{X}) = \sqrt{\mathrm{var}(\mathcal{X})}. Args: semi_grad (bool, optional): If set to :obj:`True`, will turn off gradient calculation during :math:`E[X^2]` computation. Therefore, only semi-gradients are used during backpropagation. Useful for saving memory and accelerating backward computation. (default: :obj:`False`) """ def __init__(self, semi_grad: bool = False): super().__init__() self.var_aggr = VarAggregation(semi_grad) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: var = self.var_aggr(x, index, ptr, dim_size, dim) # Allow "undefined" gradient at `sqrt(0.0)`: out = var.clamp(min=1e-5).sqrt() out = out.masked_fill(out <= math.sqrt(1e-5), 0.0) return out class SoftmaxAggregation(Aggregation): r"""The softmax aggregation operator based on a temperature term, as described in the `"DeeperGCN: All You Need to Train Deeper GCNs" `_ paper. .. math:: \mathrm{softmax}(\mathcal{X}|t) = \sum_{\mathbf{x}_i\in\mathcal{X}} \frac{\exp(t\cdot\mathbf{x}_i)}{\sum_{\mathbf{x}_j\in\mathcal{X}} \exp(t\cdot\mathbf{x}_j)}\cdot\mathbf{x}_{i}, where :math:`t` controls the softness of the softmax when aggregating over a set of features :math:`\mathcal{X}`. Args: t (float, optional): Initial inverse temperature for softmax aggregation. (default: :obj:`1.0`) learn (bool, optional): If set to :obj:`True`, will learn the value :obj:`t` for softmax aggregation dynamically. (default: :obj:`False`) semi_grad (bool, optional): If set to :obj:`True`, will turn off gradient calculation during softmax computation. Therefore, only semi-gradients are used during backpropagation. Useful for saving memory and accelerating backward computation when :obj:`t` is not learnable. (default: :obj:`False`) channels (int, optional): Number of channels to learn from :math:`t`. If set to a value greater than :obj:`1`, :math:`t` will be learned per input feature channel. This requires compatible shapes for the input to the forward calculation. (default: :obj:`1`) """ def __init__(self, t: float = 1.0, learn: bool = False, semi_grad: bool = False, channels: int = 1): super().__init__() if learn and semi_grad: raise ValueError( f"Cannot enable 'semi_grad' in '{self.__class__.__name__}' in " f"case the temperature term 't' is learnable") if not learn and channels != 1: raise ValueError(f"Cannot set 'channels' greater than '1' in case " f"'{self.__class__.__name__}' is not trainable") self._init_t = t self.learn = learn self.semi_grad = semi_grad self.channels = channels self.t = Parameter(torch.empty(channels)) if learn else t self.reset_parameters() def reset_parameters(self): if isinstance(self.t, Tensor): self.t.data.fill_(self._init_t) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: t = self.t if self.channels != 1: self.assert_two_dimensional_input(x, dim) assert isinstance(t, Tensor) t = t.view(-1, self.channels) alpha = x if not isinstance(t, (int, float)) or t != 1: alpha = x * t if not self.learn and self.semi_grad: with torch.no_grad(): alpha = softmax(alpha, index, ptr, dim_size, dim) else: alpha = softmax(alpha, index, ptr, dim_size, dim) return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') def __repr__(self) -> str: return (f'{self.__class__.__name__}(learn={self.learn})') class PowerMeanAggregation(Aggregation): r"""The powermean aggregation operator based on a power term, as described in the `"DeeperGCN: All You Need to Train Deeper GCNs" `_ paper. .. math:: \mathrm{powermean}(\mathcal{X}|p) = \left(\frac{1}{|\mathcal{X}|} \sum_{\mathbf{x}_i\in\mathcal{X}}\mathbf{x}_i^{p}\right)^{1/p}, where :math:`p` controls the power of the powermean when aggregating over a set of features :math:`\mathcal{X}`. Args: p (float, optional): Initial power for powermean aggregation. (default: :obj:`1.0`) learn (bool, optional): If set to :obj:`True`, will learn the value :obj:`p` for powermean aggregation dynamically. (default: :obj:`False`) channels (int, optional): Number of channels to learn from :math:`p`. If set to a value greater than :obj:`1`, :math:`p` will be learned per input feature channel. This requires compatible shapes for the input to the forward calculation. (default: :obj:`1`) clamp_min (float, optional): Lower-bound of the range to be clamped to. There is no lower bound if set to :obj:`None`. clamp_max (float, optional): Upper-bound of the range to be clamped to. There is no upper bound if set to :obj:`None`. """ def __init__( self, p: float = 1.0, learn: bool = False, channels: int = 1, clamp_min: Optional[float] = 1e-4, clamp_max: Optional[float] = 100., ) -> None: super().__init__() if not learn and channels != 1: raise ValueError(f"Cannot set 'channels' greater than '1' in case " f"'{self.__class__.__name__}' is not trainable") self._init_p = p self.learn = learn self.channels = channels self.p = Parameter(torch.empty(channels)) if learn else p self.reset_parameters() self.min_value = clamp_min self.max_value = clamp_max def reset_parameters(self): if isinstance(self.p, Tensor): self.p.data.fill_(self._init_p) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: p = self.p if self.channels != 1: assert isinstance(p, Tensor) self.assert_two_dimensional_input(x, dim) p = p.view(-1, self.channels) if not isinstance(p, (int, float)) or p != 1: x = x.clamp(min=self.min_value, max=self.max_value).pow(p) out = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') if not isinstance(p, (int, float)) or p != 1: out = out.clamp(min=self.min_value, max=self.max_value).pow(1. / p) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(learn={self.learn})') ================================================ FILE: torch_geometric/nn/aggr/deep_sets.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.inits import reset class DeepSetsAggregation(Aggregation): r"""Performs Deep Sets aggregation in which the elements to aggregate are first transformed by a Multi-Layer Perceptron (MLP) :math:`\phi_{\mathbf{\Theta}}`, summed, and then transformed by another MLP :math:`\rho_{\mathbf{\Theta}}`, as suggested in the `"Graph Neural Networks with Adaptive Readouts" `_ paper. Args: local_nn (torch.nn.Module, optional): The neural network :math:`\phi_{\mathbf{\Theta}}`, *e.g.*, defined by :class:`torch.nn.Sequential` or :class:`torch_geometric.nn.models.MLP`. (default: :obj:`None`) global_nn (torch.nn.Module, optional): The neural network :math:`\rho_{\mathbf{\Theta}}`, *e.g.*, defined by :class:`torch.nn.Sequential` or :class:`torch_geometric.nn.models.MLP`. (default: :obj:`None`) """ def __init__( self, local_nn: Optional[torch.nn.Module] = None, global_nn: Optional[torch.nn.Module] = None, ): super().__init__() from torch_geometric.nn import MLP self.local_nn = self.local_mlp = None if isinstance(local_nn, MLP): self.local_mlp = local_nn else: self.local_nn = local_nn self.global_nn = self.global_mlp = None if isinstance(global_nn, MLP): self.global_mlp = global_nn else: self.global_nn = global_nn def reset_parameters(self): reset(self.local_nn) reset(self.local_mlp) reset(self.global_nn) reset(self.global_mlp) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: if self.local_mlp is not None: x = self.local_mlp(x, batch=index, batch_size=dim_size) if self.local_nn is not None: x = self.local_nn(x) x = self.reduce(x, index, ptr, dim_size, dim, reduce='sum') if self.global_mlp is not None: x = self.global_mlp(x, batch=index, batch_size=dim_size) elif self.global_nn is not None: x = self.global_nn(x) return x def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'local_nn={self.local_mlp or self.local_nn}, ' f'global_nn={self.global_mlp or self.global_nn})') ================================================ FILE: torch_geometric/nn/aggr/equilibrium.py ================================================ from typing import Callable, List, Optional, Tuple import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.inits import reset from torch_geometric.utils import scatter class ResNetPotential(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, num_layers: List[int]): super().__init__() sizes = [in_channels] + num_layers + [out_channels] self.layers = torch.nn.ModuleList([ torch.nn.Sequential(torch.nn.Linear(in_size, out_size), torch.nn.LayerNorm(out_size), torch.nn.Tanh()) for in_size, out_size in zip(sizes[:-2], sizes[1:-1]) ]) self.layers.append(torch.nn.Linear(sizes[-2], sizes[-1])) self.res_trans = torch.nn.ModuleList([ torch.nn.Linear(in_channels, layer_size) for layer_size in num_layers + [out_channels] ]) def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor], dim_size: Optional[int] = None) -> Tensor: if index is None: inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1) else: inp = torch.cat([x, y[index]], dim=1) h = inp for layer, res in zip(self.layers, self.res_trans): h = layer(h) h = res(inp) + h if index is None: return h.mean() if dim_size is None: dim_size = int(index.max().item() + 1) return scatter(h, index, 0, dim_size, reduce='mean').sum() class MomentumOptimizer(torch.nn.Module): r"""Provides an inner loop optimizer for the implicitly defined output layer. It is based on an unrolled Nesterov momentum algorithm. Args: learning_rate (float): learning rate for optimizer. momentum (float): momentum for optimizer. learnable (bool): If :obj:`True` then the :obj:`learning_rate` and :obj:`momentum` will be learnable parameters. If False they are fixed. (default: :obj:`True`) """ def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9, learnable: bool = True): super().__init__() self._initial_lr = learning_rate self._initial_mom = momentum self._lr = torch.nn.Parameter(Tensor([learning_rate]), requires_grad=learnable) self._mom = torch.nn.Parameter(Tensor([momentum]), requires_grad=learnable) self.softplus = torch.nn.Softplus() self.sigmoid = torch.nn.Sigmoid() def reset_parameters(self): self._lr.data.fill_(self._initial_lr) self._mom.data.fill_(self._initial_mom) @property def learning_rate(self): return self.softplus(self._lr) @property def momentum(self): return self.sigmoid(self._mom) def forward( self, x: Tensor, y: Tensor, index: Optional[Tensor], dim_size: Optional[int], func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], iterations: int = 5, ) -> Tuple[Tensor, float]: momentum_buffer = torch.zeros_like(y) for _ in range(iterations): val = func(x, y, index, dim_size) grad = torch.autograd.grad(val, y, create_graph=True, retain_graph=True)[0] delta = self.learning_rate * grad momentum_buffer = self.momentum * momentum_buffer - delta y = y + momentum_buffer return y class EquilibriumAggregation(Aggregation): r"""The equilibrium aggregation layer from the `"Equilibrium Aggregation: Encoding Sets via Optimization" `_ paper. The output of this layer :math:`\mathbf{y}` is defined implicitly via a potential function :math:`F(\mathbf{x}, \mathbf{y})`, a regularization term :math:`R(\mathbf{y})`, and the condition .. math:: \mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) + \sum_{i} F(\mathbf{x}_i, \mathbf{y}). The given implementation uses a ResNet-like model for the potential function and a simple :math:`L_2` norm :math:`R(\mathbf{y}) = \textrm{softplus}(\lambda) \cdot {\| \mathbf{y} \|}^2_2` for the regularizer with learnable weight :math:`\lambda`. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. num_layers (List[int): List of hidden channels in the potential function. grad_iter (int): The number of steps to take in the internal gradient descent. (default: :obj:`5`) lamb (float): The initial regularization constant. (default: :obj:`0.1`) """ def __init__(self, in_channels: int, out_channels: int, num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1): super().__init__() self.potential = ResNetPotential(in_channels + out_channels, 1, num_layers) self.optimizer = MomentumOptimizer() self.initial_lamb = lamb self.lamb = torch.nn.Parameter(Tensor(1), requires_grad=True) self.softplus = torch.nn.Softplus() self.grad_iter = grad_iter self.output_dim = out_channels self.reset_parameters() def reset_parameters(self): self.lamb.data.fill_(self.initial_lamb) reset(self.optimizer) reset(self.potential) def init_output(self, dim_size: int) -> Tensor: return torch.zeros(dim_size, self.output_dim, requires_grad=True, device=self.lamb.device).float() def reg(self, y: Tensor) -> Tensor: return self.softplus(self.lamb) * y.square().sum(dim=-1).mean() def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor], dim_size: Optional[int] = None): return self.potential(x, y, index, dim_size) + self.reg(y) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: self.assert_index_present(index) dim_size = int(index.max()) + 1 if dim_size is None else dim_size with torch.enable_grad(): y = self.optimizer(x, self.init_output(dim_size), index, dim_size, self.energy, iterations=self.grad_iter) return y def __repr__(self) -> str: return (f'{self.__class__.__name__}()') ================================================ FILE: torch_geometric/nn/aggr/fused.py ================================================ import math from typing import Dict, List, Optional, Tuple, Union from torch import Tensor from torch_geometric.nn.aggr.base import Aggregation from torch_geometric.nn.aggr.basic import ( MaxAggregation, MeanAggregation, MinAggregation, MulAggregation, StdAggregation, SumAggregation, VarAggregation, ) from torch_geometric.nn.resolver import aggregation_resolver from torch_geometric.utils import scatter class FusedAggregation(Aggregation): r"""Helper class to fuse computation of multiple aggregations together. Used internally in :class:`~torch_geometric.nn.aggr.MultiAggregation` to speed-up computation. Currently, the following optimizations are performed: * :class:`MeanAggregation` will share the output with :class:`SumAggregation` in case it is present as well. * :class:`VarAggregation` will share the output with either :class:`MeanAggregation` or :class:`SumAggregation` in case one of them is present as well. * :class:`StdAggregation` will share the output with either :class:`VarAggregation`, :class:`MeanAggregation` or :class:`SumAggregation` in case one of them is present as well. In addition, temporary values such as the count per group index are shared as well. Benchmarking results on PyTorch 1.12 (summed over 1000 runs): +------------------------------+---------+---------+ | Aggregators | Vanilla | Fusion | +==============================+=========+=========+ | :obj:`[sum, mean]` | 0.3325s | 0.1996s | +------------------------------+---------+---------+ | :obj:`[sum, mean, min, max]` | 0.7139s | 0.5037s | +------------------------------+---------+---------+ | :obj:`[sum, mean, var]` | 0.6849s | 0.3871s | +------------------------------+---------+---------+ | :obj:`[sum, mean, var, std]` | 1.0955s | 0.3973s | +------------------------------+---------+---------+ Args: aggrs (list): The list of aggregation schemes to use. """ # We can fuse all aggregations together that rely on `scatter` directives. FUSABLE_AGGRS = { SumAggregation, MeanAggregation, MinAggregation, MaxAggregation, MulAggregation, VarAggregation, StdAggregation, } # All aggregations that rely on computing the degree of indices. DEGREE_BASED_AGGRS = { MeanAggregation, VarAggregation, StdAggregation, } # Map aggregations to `reduce` options in `scatter` directives. REDUCE = { 'SumAggregation': 'sum', 'MeanAggregation': 'sum', 'MinAggregation': 'min', 'MaxAggregation': 'max', 'MulAggregation': 'mul', 'VarAggregation': 'pow_sum', 'StdAggregation': 'pow_sum', } def __init__(self, aggrs: List[Union[Aggregation, str]]): super().__init__() if not isinstance(aggrs, (list, tuple)): raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " f"be a list or tuple (got '{type(aggrs)}').") if len(aggrs) == 0: raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " f"not be empty.") aggrs = [aggregation_resolver(aggr) for aggr in aggrs] aggr_classes = [aggr.__class__ for aggr in aggrs] self.aggr_names = [cls.__name__ for cls in aggr_classes] self.aggr_index: Dict[str, int] = { name: i for i, name in enumerate(self.aggr_names) } for cls in aggr_classes: if cls not in self.FUSABLE_AGGRS: raise ValueError(f"Received aggregation '{cls.__name__}' in " f"'{self.__class__.__name__}' which is not " f"fusable") self.semi_grad = False for aggr in aggrs: if hasattr(aggr, 'semi_grad'): self.semi_grad = self.semi_grad or aggr.semi_grad # Check whether we need to compute degree information: self.need_degree = False for cls in aggr_classes: if cls in self.DEGREE_BASED_AGGRS: self.need_degree = True # Determine which reduction to use for each aggregator: # An entry of `None` means that this operator re-uses intermediate # outputs from other aggregators. reduce_ops: List[Optional[str]] = [] # Determine which `(Aggregator, index)` to use as intermediate output: lookup_ops: List[Optional[Tuple[str, int]]] = [] for name in self.aggr_names: if name == 'MeanAggregation': # Directly use output of `SumAggregation`: if 'SumAggregation' in self.aggr_index: reduce_ops.append(None) lookup_ops.append(( 'SumAggregation', self.aggr_index['SumAggregation'], )) else: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(None) elif name == 'VarAggregation': if 'MeanAggregation' in self.aggr_index: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(( 'MeanAggregation', self.aggr_index['MeanAggregation'], )) elif 'SumAggregation' in self.aggr_index: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(( 'SumAggregation', self.aggr_index['SumAggregation'], )) else: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(None) elif name == 'StdAggregation': # Directly use output of `VarAggregation`: if 'VarAggregation' in self.aggr_index: reduce_ops.append(None) lookup_ops.append(( 'VarAggregation', self.aggr_index['VarAggregation'], )) elif 'MeanAggregation' in self.aggr_index: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(( 'MeanAggregation', self.aggr_index['MeanAggregation'], )) elif 'SumAggregation' in self.aggr_index: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(( 'SumAggregation', self.aggr_index['SumAggregation'], )) else: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(None) else: reduce_ops.append(self.REDUCE[name]) lookup_ops.append(None) self.reduce_ops: List[Optional[str]] = reduce_ops self.lookup_ops: List[Optional[Tuple[str, int]]] = lookup_ops def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> List[Tensor]: # Assert two-dimensional input for now to simplify computation: # TODO refactor this to support any dimension. self.assert_index_present(index) self.assert_two_dimensional_input(x, dim) assert index is not None if dim_size is None: if ptr is not None: dim_size = ptr.numel() - 1 else: dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 count: Optional[Tensor] = None if self.need_degree: count = x.new_zeros(dim_size) count.scatter_add_(0, index, x.new_ones(x.size(0))) count = count.clamp_(min=1).view(-1, 1) ####################################################################### outs: List[Optional[Tensor]] = [] # Iterate over all reduction ops to compute first results: for reduce in self.reduce_ops: if reduce is None: outs.append(None) continue assert isinstance(reduce, str) if reduce == 'pow_sum': if self.semi_grad: out = scatter(x.detach() * x.detach(), index, 0, dim_size, reduce='sum') else: out = scatter(x * x, index, 0, dim_size, reduce='sum') else: out = scatter(x, index, 0, dim_size, reduce=reduce) outs.append(out) ####################################################################### # Compute `MeanAggregation` first to be able to re-use it: i = self.aggr_index.get('MeanAggregation') if i is not None: assert count is not None if self.lookup_ops[i] is None: sum_ = outs[i] else: lookup_op = self.lookup_ops[i] assert lookup_op is not None tmp_aggr, j = lookup_op assert tmp_aggr == 'SumAggregation' sum_ = outs[j] assert sum_ is not None outs[i] = sum_ / count # Compute `VarAggregation` second to be able to re-use it: if 'VarAggregation' in self.aggr_index: i = self.aggr_index['VarAggregation'] assert count is not None if self.lookup_ops[i] is None: sum_ = scatter(x, index, 0, dim_size, reduce='sum') mean = sum_ / count else: lookup_op = self.lookup_ops[i] assert lookup_op is not None tmp_aggr, j = lookup_op if tmp_aggr == 'SumAggregation': sum_ = outs[j] assert sum_ is not None mean = sum_ / count elif tmp_aggr == 'MeanAggregation': mean = outs[j] else: raise NotImplementedError pow_sum = outs[i] assert pow_sum is not None assert mean is not None outs[i] = (pow_sum / count) - (mean * mean) # Compute `StdAggregation` last: if 'StdAggregation' in self.aggr_index: i = self.aggr_index['StdAggregation'] var: Optional[Tensor] = None pow_sum: Optional[Tensor] = None mean: Optional[Tensor] = None if self.lookup_ops[i] is None: pow_sum = outs[i] sum_ = scatter(x, index, 0, dim_size, reduce='sum') assert count is not None mean = sum_ / count else: lookup_op = self.lookup_ops[i] assert lookup_op is not None tmp_aggr, j = lookup_op if tmp_aggr == 'VarAggregation': var = outs[j] elif tmp_aggr == 'SumAggregation': pow_sum = outs[i] sum_ = outs[j] assert sum_ is not None assert count is not None mean = sum_ / count elif tmp_aggr == 'MeanAggregation': pow_sum = outs[i] mean = outs[j] else: raise NotImplementedError if var is None: assert pow_sum is not None assert count is not None assert mean is not None var = (pow_sum / count) - (mean * mean) # Allow "undefined" gradient at `sqrt(0.0)`: out = var.clamp(min=1e-5).sqrt() out = out.masked_fill(out <= math.sqrt(1e-5), 0.0) outs[i] = out ####################################################################### vals: List[Tensor] = [] for out in outs: assert out is not None vals.append(out) return vals ================================================ FILE: torch_geometric/nn/aggr/gmt.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.aggr.utils import ( PoolingByMultiheadAttention, SetAttentionBlock, ) class GraphMultisetTransformer(Aggregation): r"""The Graph Multiset Transformer pooling operator from the `"Accurate Learning of Graph Representations with Graph Multiset Pooling" `_ paper. The :class:`GraphMultisetTransformer` aggregates elements into :math:`k` representative elements via attention-based pooling, computes the interaction among them via :obj:`num_encoder_blocks` self-attention blocks, and finally pools the representative elements via attention-based pooling into a single cluster. .. note:: :class:`GraphMultisetTransformer` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. Args: channels (int): Size of each input sample. k (int): Number of :math:`k` representative nodes after pooling. num_encoder_blocks (int, optional): Number of Set Attention Blocks (SABs) between the two pooling blocks. (default: :obj:`1`) heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) norm (str, optional): If set to :obj:`True`, will apply layer normalization. (default: :obj:`False`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) """ def __init__( self, channels: int, k: int, num_encoder_blocks: int = 1, heads: int = 1, layer_norm: bool = False, dropout: float = 0.0, ): super().__init__() self.channels = channels self.k = k self.heads = heads self.layer_norm = layer_norm self.dropout = dropout self.pma1 = PoolingByMultiheadAttention(channels, k, heads, layer_norm, dropout) self.encoders = torch.nn.ModuleList([ SetAttentionBlock(channels, heads, layer_norm, dropout) for _ in range(num_encoder_blocks) ]) self.pma2 = PoolingByMultiheadAttention(channels, 1, heads, layer_norm, dropout) def reset_parameters(self): self.pma1.reset_parameters() for encoder in self.encoders: encoder.reset_parameters() self.pma2.reset_parameters() @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) x = self.pma1(x, mask) for encoder in self.encoders: x = encoder(x) x = self.pma2(x) return x.squeeze(1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'k={self.k}, heads={self.heads}, ' f'layer_norm={self.layer_norm}, ' f'dropout={self.dropout})') ================================================ FILE: torch_geometric/nn/aggr/gru.py ================================================ from typing import Optional from torch import Tensor from torch.nn import GRU from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.aggr import Aggregation class GRUAggregation(Aggregation): r"""Performs GRU aggregation in which the elements to aggregate are interpreted as a sequence, as described in the `"Graph Neural Networks with Adaptive Readouts" `_ paper. .. note:: :class:`GRUAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. .. warning:: :class:`GRUAggregation` is not a permutation-invariant operator. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. **kwargs (optional): Additional arguments of :class:`torch.nn.GRU`. """ def __init__(self, in_channels: int, out_channels: int, **kwargs): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.gru = GRU(in_channels, out_channels, batch_first=True, **kwargs) self.reset_parameters() def reset_parameters(self): self.gru.reset_parameters() @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) return self.gru(x)[0][:, -1] def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/aggr/lcm.py ================================================ from math import ceil, log2 from typing import Optional import torch from torch import Tensor from torch.nn import GRUCell, Linear from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.aggr import Aggregation class LCMAggregation(Aggregation): r"""The Learnable Commutative Monoid aggregation from the `"Learnable Commutative Monoids for Graph Neural Networks" `_ paper, in which the elements are aggregated using a binary tree reduction with :math:`\mathcal{O}(\log |\mathcal{V}|)` depth. .. note:: :class:`LCMAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. .. warning:: :class:`LCMAggregation` is not a permutation-invariant operator. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. project (bool, optional): If set to :obj:`True`, the layer will apply a linear transformation followed by an activation function before aggregation. (default: :obj:`True`) """ def __init__( self, in_channels: int, out_channels: int, project: bool = True, ): super().__init__() if in_channels != out_channels and not project: raise ValueError(f"Inputs of '{self.__class__.__name__}' must be " f"projected if `in_channels != out_channels`") self.in_channels = in_channels self.out_channels = out_channels self.project = project if self.project: self.lin = Linear(in_channels, out_channels) else: self.lin = None self.gru_cell = GRUCell(out_channels, out_channels) def reset_parameters(self): if self.project: self.lin.reset_parameters() self.gru_cell.reset_parameters() @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: if self.project: x = self.lin(x).relu() x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) x = x.permute(1, 0, 2) # [num_neighbors, num_nodes, num_features] _, num_nodes, num_features = x.size() depth = ceil(log2(x.size(0))) for _ in range(depth): half_size = ceil(x.size(0) / 2) if x.size(0) % 2 == 1: # This level of the tree has an odd number of nodes, so the # remaining unmatched node gets moved to the next level. x, remainder = x[:-1], x[-1:] else: remainder = None left_right = x.view(-1, 2, num_nodes, num_features) right_left = left_right.flip(dims=[1]) left_right = left_right.reshape(-1, num_features) right_left = right_left.reshape(-1, num_features) # Execute the GRUCell for all (left, right) pairs in the current # level of the tree in parallel: out = self.gru_cell(left_right, right_left) out = out.view(-1, 2, num_nodes, num_features) out = out.mean(dim=1) if remainder is not None: out = torch.cat([out, remainder], dim=0) x = out.view(half_size, num_nodes, num_features) assert x.size(0) == 1 return x.squeeze(0) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, project={self.project})') ================================================ FILE: torch_geometric/nn/aggr/lstm.py ================================================ from typing import Optional from torch import Tensor from torch.nn import LSTM from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.aggr import Aggregation class LSTMAggregation(Aggregation): r"""Performs LSTM-style aggregation in which the elements to aggregate are interpreted as a sequence, as described in the `"Inductive Representation Learning on Large Graphs" `_ paper. .. note:: :class:`LSTMAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. .. warning:: :class:`LSTMAggregation` is not a permutation-invariant operator. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`. """ def __init__(self, in_channels: int, out_channels: int, **kwargs): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs) self.reset_parameters() def reset_parameters(self): self.lstm.reset_parameters() @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) return self.lstm(x)[0][:, -1] def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/aggr/mlp.py ================================================ from typing import Optional from torch import Tensor from torch_geometric.nn.aggr import Aggregation class MLPAggregation(Aggregation): r"""Performs MLP aggregation in which the elements to aggregate are flattened into a single vectorial representation, and are then processed by a Multi-Layer Perceptron (MLP), as described in the `"Graph Neural Networks with Adaptive Readouts" `_ paper. .. note:: :class:`MLPAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. .. warning:: :class:`MLPAggregation` is not a permutation-invariant operator. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. max_num_elements (int): The maximum number of elements to aggregate per group. **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.models.MLP`. """ def __init__( self, in_channels: int, out_channels: int, max_num_elements: int, **kwargs, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.max_num_elements = max_num_elements from torch_geometric.nn import MLP self.mlp = MLP( in_channels=in_channels * max_num_elements, out_channels=out_channels, **kwargs, ) self.reset_parameters() def reset_parameters(self): self.mlp.reset_parameters() def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=self.max_num_elements) return self.mlp(x.view(-1, x.size(1) * x.size(2)), index, dim_size) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, ' f'max_num_elements={self.max_num_elements})') ================================================ FILE: torch_geometric/nn/aggr/multi.py ================================================ import copy from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor from torch.nn import Linear, MultiheadAttention from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.aggr.fused import FusedAggregation from torch_geometric.nn.dense import HeteroDictLinear from torch_geometric.nn.resolver import aggregation_resolver class MultiAggregation(Aggregation): r"""Performs aggregations with one or more aggregators and combines aggregated results, as described in the `"Principal Neighbourhood Aggregation for Graph Nets" `_ and `"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" `_ papers. Args: aggrs (list): The list of aggregation schemes to use. aggrs_kwargs (dict, optional): Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: :obj:`None`) mode (str, optional): The combine mode to use for combining aggregated results from multiple aggregations (:obj:`"cat"`, :obj:`"proj"`, :obj:`"sum"`, :obj:`"mean"`, :obj:`"max"`, :obj:`"min"`, :obj:`"logsumexp"`, :obj:`"std"`, :obj:`"var"`, :obj:`"attn"`). (default: :obj:`"cat"`) mode_kwargs (dict, optional): Arguments passed for the combine :obj:`mode`. When :obj:`"proj"` or :obj:`"attn"` is used as the combine :obj:`mode`, :obj:`in_channels` (int or tuple) and :obj:`out_channels` (int) are needed to be specified respectively for the size of each input sample to combine from the respective aggregation outputs and the size of each output sample after combination. When :obj:`"attn"` mode is used, :obj:`num_heads` (int) is needed to be specified for the number of parallel attention heads. (default: :obj:`None`) """ fused_out_index: List[int] is_fused_aggr: List[bool] def __init__( self, aggrs: List[Union[Aggregation, str]], aggrs_kwargs: Optional[List[Dict[str, Any]]] = None, mode: Optional[str] = 'cat', mode_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__() if not isinstance(aggrs, (list, tuple)): raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " f"be a list or tuple (got '{type(aggrs)}').") if len(aggrs) == 0: raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " f"not be empty.") if aggrs_kwargs is None: aggrs_kwargs = [{}] * len(aggrs) elif len(aggrs) != len(aggrs_kwargs): raise ValueError(f"'aggrs_kwargs' with invalid length passed to " f"'{self.__class__.__name__}' " f"(got '{len(aggrs_kwargs)}', " f"expected '{len(aggrs)}'). Ensure that both " f"'aggrs' and 'aggrs_kwargs' are consistent.") self.aggrs = torch.nn.ModuleList([ aggregation_resolver(aggr, **aggr_kwargs) for aggr, aggr_kwargs in zip(aggrs, aggrs_kwargs) ]) # Divide the set into fusable and non-fusable aggregations: fused_aggrs: List[Aggregation] = [] self.fused_out_index: List[int] = [] self.is_fused_aggr: List[bool] = [] for i, aggr in enumerate(self.aggrs): if aggr.__class__ in FusedAggregation.FUSABLE_AGGRS: fused_aggrs.append(aggr) self.fused_out_index.append(i) self.is_fused_aggr.append(True) else: self.is_fused_aggr.append(False) if len(fused_aggrs) > 0: self.fused_aggr = FusedAggregation(fused_aggrs) else: self.fused_aggr = None self.mode = mode mode_kwargs = copy.copy(mode_kwargs) or {} self.in_channels = mode_kwargs.pop('in_channels', None) self.out_channels = mode_kwargs.pop('out_channels', None) if mode == 'proj' or mode == 'attn': if len(aggrs) == 1: raise ValueError("Multiple aggregations are required for " "'proj' or 'attn' combine mode.") if (self.in_channels and self.out_channels) is None: raise ValueError( f"Combine mode '{mode}' must have `in_channels` " f"and `out_channels` specified.") if isinstance(self.in_channels, int): self.in_channels = [self.in_channels] * len(aggrs) if mode == 'proj': self.lin = Linear( sum(self.in_channels), self.out_channels, **mode_kwargs, ) elif mode == 'attn': channels = {str(k): v for k, v, in enumerate(self.in_channels)} self.lin_heads = HeteroDictLinear(channels, self.out_channels) num_heads = mode_kwargs.pop('num_heads', 1) self.multihead_attn = MultiheadAttention( self.out_channels, num_heads, **mode_kwargs, ) dense_combine_modes = [ 'sum', 'mean', 'max', 'min', 'logsumexp', 'std', 'var' ] if mode in dense_combine_modes: self.dense_combine = getattr(torch, mode) def reset_parameters(self): for aggr in self.aggrs: aggr.reset_parameters() if self.mode == 'proj': self.lin.reset_parameters() if self.mode == 'attn': self.lin_heads.reset_parameters() self.multihead_attn._reset_parameters() def get_out_channels(self, in_channels: int) -> int: if self.out_channels is not None: return self.out_channels # TODO Support having customized `out_channels` in each aggregation. if self.mode == 'cat': return in_channels * len(self.aggrs) return in_channels def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: # `FusedAggregation` is currently limited to two-dimensional inputs: if index is None or x.dim() != 2 or self.fused_aggr is None: outs = [aggr(x, index, ptr, dim_size, dim) for aggr in self.aggrs] return self.combine(outs) outs: List[Tensor] = [x] * len(self.aggrs) # Fill with dummy tensors. fused_outs = self.fused_aggr(x, index, ptr, dim_size, dim) for i, out in zip(self.fused_out_index, fused_outs): outs[i] = out for i, aggr in enumerate(self.aggrs): if not self.is_fused_aggr[i]: outs[i] = aggr(x, index, ptr, dim_size, dim) return self.combine(outs) def combine(self, inputs: List[Tensor]) -> Tensor: if len(inputs) == 1: return inputs[0] if self.mode == 'cat': return torch.cat(inputs, dim=-1) if hasattr(self, 'lin'): return self.lin(torch.cat(inputs, dim=-1)) if hasattr(self, 'multihead_attn'): x_dict = {str(k): v for k, v, in enumerate(inputs)} x_dict = self.lin_heads(x_dict) xs = [x_dict[str(key)] for key in range(len(inputs))] x = torch.stack(xs, dim=0) attn_out, _ = self.multihead_attn(x, x, x) return torch.mean(attn_out, dim=0) if hasattr(self, 'dense_combine'): out = self.dense_combine(torch.stack(inputs, dim=0), dim=0) return out if isinstance(out, Tensor) else out[0] raise ValueError(f"Combine mode '{self.mode}' is not supported.") def __repr__(self) -> str: aggrs = ',\n'.join([f' {aggr}' for aggr in self.aggrs]) + ',\n' return f'{self.__class__.__name__}([\n{aggrs}], mode={self.mode})' ================================================ FILE: torch_geometric/nn/aggr/patch_transformer.py ================================================ import math from typing import List, Optional, Union import torch from torch import Tensor from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock from torch_geometric.nn.encoding import PositionalEncoding from torch_geometric.utils import scatter class PatchTransformerAggregation(Aggregation): r"""Performs patch transformer aggregation in which the elements to aggregate are processed by multi-head attention blocks across patches, as described in the `"Simplifying Temporal Heterogeneous Network for Continuous-Time Link Prediction" `_ paper. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. patch_size (int): Number of elements in a patch. hidden_channels (int): Intermediate size of each sample. num_transformer_blocks (int, optional): Number of transformer blocks (default: :obj:`1`). heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0.0`) aggr (str or list[str], optional): The aggregation module, *e.g.*, :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`) device (torch.device, optional): The device of the module. (default: :obj:`None`) """ def __init__( self, in_channels: int, out_channels: int, patch_size: int, hidden_channels: int, num_transformer_blocks: int = 1, heads: int = 1, dropout: float = 0.0, aggr: Union[str, List[str]] = 'mean', device: Optional[torch.device] = None, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.patch_size = patch_size self.aggrs = [aggr] if isinstance(aggr, str) else aggr assert len(self.aggrs) > 0 for aggr in self.aggrs: assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std'] self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device) self.pad_projector = torch.nn.Linear( patch_size * hidden_channels, hidden_channels, device=device, ) self.pe = PositionalEncoding(hidden_channels, device=device) self.blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( channels=hidden_channels, heads=heads, layer_norm=True, dropout=dropout, device=device, ) for _ in range(num_transformer_blocks) ]) self.fc = torch.nn.Linear( hidden_channels * len(self.aggrs), out_channels, device=device, ) def reset_parameters(self) -> None: self.lin.reset_parameters() self.pad_projector.reset_parameters() self.pe.reset_parameters() for block in self.blocks: block.reset_parameters() self.fc.reset_parameters() @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: if max_num_elements is None: if ptr is not None: count = ptr.diff() else: count = scatter(torch.ones_like(index), index, dim=0, dim_size=dim_size, reduce='sum') max_num_elements = int(count.max()) + 1 # Set `max_num_elements` to a multiple of `patch_size`: max_num_elements = (math.floor(max_num_elements / self.patch_size) * self.patch_size) x = self.lin(x) # TODO If groups are heavily unbalanced, this will create a lot of # "empty" patches. Try to figure out a way to fix this. # [batch_size, num_patches * patch_size, hidden_channels] x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) # [batch_size, num_patches, patch_size * hidden_channels] x = x.view(x.size(0), max_num_elements // self.patch_size, self.patch_size * x.size(-1)) # [batch_size, num_patches, hidden_channels] x = self.pad_projector(x) x = x + self.pe(torch.arange(x.size(1), device=x.device)) # [batch_size, num_patches, hidden_channels] for block in self.blocks: x = block(x, x) # [batch_size, hidden_channels] outs: List[Tensor] = [] for aggr in self.aggrs: out = getattr(torch, aggr)(x, dim=1) outs.append(out[0] if isinstance(out, tuple) else out) out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0] # [batch_size, out_channels] return self.fc(out) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, patch_size={self.patch_size})') ================================================ FILE: torch_geometric/nn/aggr/quantile.py ================================================ from typing import List, Optional, Union import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation from torch_geometric.utils import cumsum class QuantileAggregation(Aggregation): r"""An aggregation operator that returns the feature-wise :math:`q`-th quantile of a set :math:`\mathcal{X}`. That is, for every feature :math:`d`, it computes .. math:: {\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} x_{\pi_i,d} & i = q \cdot n, \\ f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ \end{cases} where :math:`x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le x_{\pi_n,d}` and :math:`f(a, b)` is an interpolation function defined by :obj:`interpolation`. Args: q (float or list): The quantile value(s) :math:`q`. Can be a scalar or a list of scalars in the range :math:`[0, 1]`. If more than a quantile is passed, the results are concatenated. interpolation (str): Interpolation method applied if the quantile point :math:`q\cdot n` lies between two values :math:`a \le b`. Can be one of the following: * :obj:`"lower"`: Returns the one with lowest value. * :obj:`"higher"`: Returns the one with highest value. * :obj:`"midpoint"`: Returns the average of the two values. * :obj:`"nearest"`: Returns the one whose index is nearest to the quantile point. * :obj:`"linear"`: Returns a linear combination of the two elements, defined as :math:`f(a, b) = a + (b - a)\cdot(q\cdot n - i)`. (default: :obj:`"linear"`) fill_value (float, optional): The default value in the case no entry is found for a given index (default: :obj:`0.0`). """ interpolations = {'linear', 'lower', 'higher', 'nearest', 'midpoint'} def __init__(self, q: Union[float, List[float]], interpolation: str = 'linear', fill_value: float = 0.0): super().__init__() qs = [q] if not isinstance(q, (list, tuple)) else q if len(qs) == 0: raise ValueError("Provide at least one quantile value for `q`.") if not all(0. <= quantile <= 1. for quantile in qs): raise ValueError("`q` must be in the range [0, 1].") if interpolation not in self.interpolations: raise ValueError(f"Invalid interpolation method " f"got ('{interpolation}')") self._q = q self.register_buffer('q', torch.tensor(qs).view(-1, 1)) self.interpolation = interpolation self.fill_value = fill_value def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: dim = x.dim() + dim if dim < 0 else dim self.assert_index_present(index) assert index is not None # Required for TorchScript. count = torch.bincount(index, minlength=dim_size or 0) ptr = cumsum(count)[:-1] # In case there exists dangling indices (`dim_size > index.max()`), we # need to clamp them to prevent out-of-bound issues: if dim_size is not None: ptr = ptr.clamp(max=x.size(dim) - 1) q_point = self.q * (count - 1) + ptr q_point = q_point.t().reshape(-1) shape = [1] * x.dim() shape[dim] = -1 index = index.view(shape).expand_as(x) # Two sorts: the first one on the value, # the second (stable) on the indices: x, x_perm = torch.sort(x, dim=dim) index = index.take_along_dim(x_perm, dim=dim) index, index_perm = torch.sort(index, dim=dim, stable=True) x = x.take_along_dim(index_perm, dim=dim) # Compute the quantile interpolations: if self.interpolation == 'lower': quantile = x.index_select(dim, q_point.floor().long()) elif self.interpolation == 'higher': quantile = x.index_select(dim, q_point.ceil().long()) elif self.interpolation == 'nearest': quantile = x.index_select(dim, q_point.round().long()) else: l_quant = x.index_select(dim, q_point.floor().long()) r_quant = x.index_select(dim, q_point.ceil().long()) if self.interpolation == 'linear': q_frac = q_point.frac().view(shape) quantile = l_quant + (r_quant - l_quant) * q_frac else: # 'midpoint' quantile = 0.5 * l_quant + 0.5 * r_quant # If the number of elements is zero, fill with pre-defined value: repeats = self.q.numel() mask = (count == 0).repeat_interleave( repeats, output_size=repeats * count.numel()).view(shape) out = quantile.masked_fill(mask, self.fill_value) if self.q.numel() > 1: shape = list(out.shape) shape = (shape[:dim] + [shape[dim] // self.q.numel(), -1] + shape[dim + 2:]) out = out.view(shape) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(q={self._q})') class MedianAggregation(QuantileAggregation): r"""An aggregation operator that returns the feature-wise median of a set. That is, for every feature :math:`d`, it computes .. math:: {\mathrm{median}(\mathcal{X})}_d = x_{\pi_i,d} where :math:`x_{\pi_1,d} \le x_{\pi_2,d} \le \dots \le x_{\pi_n,d}` and :math:`i = \lfloor \frac{n}{2} \rfloor`. .. note:: If the median lies between two values, the lowest one is returned. To compute the midpoint (or other kind of interpolation) of the two values, use :class:`QuantileAggregation` instead. Args: fill_value (float, optional): The default value in the case no entry is found for a given index (default: :obj:`0.0`). """ def __init__(self, fill_value: float = 0.0): super().__init__(0.5, 'lower', fill_value) def __repr__(self) -> str: return f"{self.__class__.__name__}()" ================================================ FILE: torch_geometric/nn/aggr/scaler.py ================================================ from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver from torch_geometric.utils import degree class DegreeScalerAggregation(Aggregation): r"""Combines one or more aggregators and transforms its output with one or more scalers as introduced in the `"Principal Neighbourhood Aggregation for Graph Nets" `_ paper. The scalers are normalised by the in-degree of the training set and so must be provided at time of construction. See :class:`torch_geometric.nn.conv.PNAConv` for more information. Args: aggr (str or [str] or Aggregation): The aggregation scheme to use. See :class:`~torch_geometric.nn.conv.MessagePassing` for more information. scaler (str or list): Set of scaling function identifiers, namely one or more of :obj:`"identity"`, :obj:`"amplification"`, :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. deg (Tensor): Histogram of in-degrees of nodes in the training set, used by scalers to normalize. train_norm (bool, optional): Whether normalization parameters are trainable. (default: :obj:`False`) aggr_kwargs (Dict[str, Any], optional): Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: :obj:`None`) """ def __init__( self, aggr: Union[str, List[str], Aggregation], scaler: Union[str, List[str]], deg: Tensor, train_norm: bool = False, aggr_kwargs: Optional[List[Dict[str, Any]]] = None, ): super().__init__() if isinstance(aggr, (str, Aggregation)): self.aggr = aggr_resolver(aggr, **(aggr_kwargs or {})) elif isinstance(aggr, (tuple, list)): self.aggr = MultiAggregation(aggr, aggr_kwargs) else: raise ValueError(f"Only strings, list, tuples and instances of" f"`torch_geometric.nn.aggr.Aggregation` are " f"valid aggregation schemes (got '{type(aggr)}')") self.scaler = [scaler] if isinstance(aggr, str) else scaler deg = deg.to(torch.float) N = int(deg.sum()) bin_degree = torch.arange(deg.numel(), device=deg.device) self.init_avg_deg_lin = float((bin_degree * deg).sum()) / N self.init_avg_deg_log = float(((bin_degree + 1).log() * deg).sum()) / N if train_norm: self.avg_deg_lin = torch.nn.Parameter(torch.empty(1)) self.avg_deg_log = torch.nn.Parameter(torch.empty(1)) else: self.register_buffer('avg_deg_lin', torch.empty(1)) self.register_buffer('avg_deg_log', torch.empty(1)) self.reset_parameters() def reset_parameters(self): self.avg_deg_lin.data.fill_(self.init_avg_deg_lin) self.avg_deg_log.data.fill_(self.init_avg_deg_log) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: # TODO Currently, `degree` can only operate on `index`: self.assert_index_present(index) out = self.aggr(x, index, ptr, dim_size, dim) assert index is not None deg = degree(index, num_nodes=dim_size, dtype=out.dtype) size = [1] * len(out.size()) size[dim] = -1 deg = deg.view(size) outs = [] for scaler in self.scaler: if scaler == 'identity': out_scaler = out elif scaler == 'amplification': out_scaler = out * (torch.log(deg + 1) / self.avg_deg_log) elif scaler == 'attenuation': # Clamp minimum degree to one to avoid dividing by zero: out_scaler = out * (self.avg_deg_log / torch.log(deg.clamp(min=1) + 1)) elif scaler == 'linear': out_scaler = out * (deg / self.avg_deg_lin) elif scaler == 'inverse_linear': # Clamp minimum degree to one to avoid dividing by zero: out_scaler = out * (self.avg_deg_lin / deg.clamp(min=1)) else: raise ValueError(f"Unknown scaler '{scaler}'") outs.append(out_scaler) return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0] ================================================ FILE: torch_geometric/nn/aggr/set2set.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation from torch_geometric.utils import softmax class Set2Set(Aggregation): r"""The Set2Set aggregation operator based on iterative content-based attention, as described in the `"Order Matters: Sequence to sequence for Sets" `_ paper. .. math:: \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice the dimensionality as the input. Args: in_channels (int): Size of each input sample. processing_steps (int): Number of iterations :math:`T`. **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`. """ def __init__(self, in_channels: int, processing_steps: int, **kwargs): super().__init__() self.in_channels = in_channels self.out_channels = 2 * in_channels self.processing_steps = processing_steps self.lstm = torch.nn.LSTM(self.out_channels, in_channels, **kwargs) self.reset_parameters() def reset_parameters(self): self.lstm.reset_parameters() def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: self.assert_index_present(index) self.assert_two_dimensional_input(x, dim) h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))), x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1)))) q_star = x.new_zeros(dim_size, self.out_channels) for _ in range(self.processing_steps): q, h = self.lstm(q_star.unsqueeze(0), h) q = q.view(dim_size, self.in_channels) e = (x * q[index]).sum(dim=-1, keepdim=True) a = softmax(e, index, ptr, dim_size, dim) r = self.reduce(a * x, index, ptr, dim_size, dim, reduce='sum') q_star = torch.cat([q, r], dim=-1) return q_star def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/aggr/set_transformer.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.aggr.utils import ( PoolingByMultiheadAttention, SetAttentionBlock, ) class SetTransformerAggregation(Aggregation): r"""Performs "Set Transformer" aggregation in which the elements to aggregate are processed by multi-head attention blocks, as described in the `"Graph Neural Networks with Adaptive Readouts" `_ paper. .. note:: :class:`SetTransformerAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. Args: channels (int): Size of each input sample. num_seed_points (int, optional): Number of seed points. (default: :obj:`1`) num_encoder_blocks (int, optional): Number of Set Attention Blocks (SABs) in the encoder. (default: :obj:`1`). num_decoder_blocks (int, optional): Number of Set Attention Blocks (SABs) in the decoder. (default: :obj:`1`). heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the seed embeddings are averaged instead of concatenated. (default: :obj:`True`) layer_norm (str, optional): If set to :obj:`True`, will apply layer normalization. (default: :obj:`False`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) """ def __init__( self, channels: int, num_seed_points: int = 1, num_encoder_blocks: int = 1, num_decoder_blocks: int = 1, heads: int = 1, concat: bool = True, layer_norm: bool = False, dropout: float = 0.0, ): super().__init__() self.channels = channels self.num_seed_points = num_seed_points self.heads = heads self.concat = concat self.layer_norm = layer_norm self.dropout = dropout self.encoders = torch.nn.ModuleList([ SetAttentionBlock(channels, heads, layer_norm, dropout) for _ in range(num_encoder_blocks) ]) self.pma = PoolingByMultiheadAttention(channels, num_seed_points, heads, layer_norm, dropout) self.decoders = torch.nn.ModuleList([ SetAttentionBlock(channels, heads, layer_norm, dropout) for _ in range(num_decoder_blocks) ]) def reset_parameters(self): for encoder in self.encoders: encoder.reset_parameters() self.pma.reset_parameters() for decoder in self.decoders: decoder.reset_parameters() @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) for encoder in self.encoders: x = encoder(x, mask) x = self.pma(x, mask) for decoder in self.decoders: x = decoder(x) x = x.nan_to_num() return x.flatten(1, 2) if self.concat else x.mean(dim=1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'num_seed_points={self.num_seed_points}, ' f'heads={self.heads}, ' f'layer_norm={self.layer_norm}, ' f'dropout={self.dropout})') ================================================ FILE: torch_geometric/nn/aggr/sort.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.aggr import Aggregation class SortAggregation(Aggregation): r"""The pooling operator from the `"An End-to-End Deep Learning Architecture for Graph Classification" `_ paper, where node features are sorted in descending order based on their last feature channel. The first :math:`k` nodes form the output of the layer. .. note:: :class:`SortAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. Args: k (int): The number of nodes to hold for each graph. """ def __init__(self, k: int): super().__init__() self.k = k @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: fill_value = x.detach().min() - 1 batch_x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, fill_value=fill_value, max_num_elements=max_num_elements) B, N, D = batch_x.size() _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True) arange = torch.arange(B, dtype=torch.long, device=perm.device) * N perm = perm + arange.view(-1, 1) batch_x = batch_x.view(B * N, D) batch_x = batch_x[perm] batch_x = batch_x.view(B, N, D) if N >= self.k: batch_x = batch_x[:, :self.k].contiguous() else: expand_batch_x = batch_x.new_full((B, self.k - N, D), fill_value) batch_x = torch.cat([batch_x, expand_batch_x], dim=1) batch_x[batch_x == fill_value] = 0 x = batch_x.view(B, self.k * D) return x def __repr__(self) -> str: return (f'{self.__class__.__name__}(k={self.k})') ================================================ FILE: torch_geometric/nn/aggr/utils.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import LayerNorm, Linear, MultiheadAttention, Parameter class MultiheadAttentionBlock(torch.nn.Module): r"""The Multihead Attention Block (MAB) from the `"Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks" `_ paper. .. math:: \mathrm{MAB}(\mathbf{x}, \mathbf{y}) &= \mathrm{LayerNorm}(\mathbf{h} + \mathbf{W} \mathbf{h}) \mathbf{h} &= \mathrm{LayerNorm}(\mathbf{x} + \mathrm{Multihead}(\mathbf{x}, \mathbf{y}, \mathbf{y})) Args: channels (int): Size of each input sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) norm (str, optional): If set to :obj:`False`, will not apply layer normalization. (default: :obj:`True`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) device (torch.device, optional): The device of the module. (default: :obj:`None`) """ def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True, dropout: float = 0.0, device: Optional[torch.device] = None): super().__init__() self.channels = channels self.heads = heads self.dropout = dropout self.attn = MultiheadAttention( channels, heads, batch_first=True, dropout=dropout, device=device, ) self.lin = Linear(channels, channels, device=device) self.layer_norm1 = LayerNorm(channels, device=device) if layer_norm else None self.layer_norm2 = LayerNorm(channels, device=device) if layer_norm else None def reset_parameters(self): self.attn._reset_parameters() self.lin.reset_parameters() if self.layer_norm1 is not None: self.layer_norm1.reset_parameters() if self.layer_norm2 is not None: self.layer_norm2.reset_parameters() def forward(self, x: Tensor, y: Tensor, x_mask: Optional[Tensor] = None, y_mask: Optional[Tensor] = None) -> Tensor: """""" # noqa: D419 if y_mask is not None: y_mask = ~y_mask out, _ = self.attn(x, y, y, y_mask, need_weights=False) if x_mask is not None: out[~x_mask] = 0. out = out + x if self.layer_norm1 is not None: out = self.layer_norm1(out) out = out + self.lin(out).relu() if self.layer_norm2 is not None: out = self.layer_norm2(out) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'heads={self.heads}, ' f'layer_norm={self.layer_norm1 is not None}, ' f'dropout={self.dropout})') class SetAttentionBlock(torch.nn.Module): r"""The Set Attention Block (SAB) from the `"Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks" `_ paper. .. math:: \mathrm{SAB}(\mathbf{X}) = \mathrm{MAB}(\mathbf{x}, \mathbf{y}) Args: channels (int): Size of each input sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) norm (str, optional): If set to :obj:`False`, will not apply layer normalization. (default: :obj:`True`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) """ def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True, dropout: float = 0.0): super().__init__() self.mab = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) def reset_parameters(self): self.mab.reset_parameters() def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: return self.mab(x, x, mask, mask) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.mab.channels}, ' f'heads={self.mab.heads}, ' f'layer_norm={self.mab.layer_norm1 is not None}, ' f'dropout={self.mab.dropout})') class InducedSetAttentionBlock(torch.nn.Module): r"""The Induced Set Attention Block (SAB) from the `"Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks" `_ paper. .. math:: \mathrm{ISAB}(\mathbf{X}) &= \mathrm{MAB}(\mathbf{x}, \mathbf{h}) \mathbf{h} &= \mathrm{MAB}(\mathbf{I}, \mathbf{x}) where :math:`\mathbf{I}` denotes :obj:`num_induced_points` learnable vectors. Args: channels (int): Size of each input sample. num_induced_points (int): Number of induced points. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) norm (str, optional): If set to :obj:`False`, will not apply layer normalization. (default: :obj:`True`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) """ def __init__(self, channels: int, num_induced_points: int, heads: int = 1, layer_norm: bool = True, dropout: float = 0.0): super().__init__() self.ind = Parameter(torch.empty(1, num_induced_points, channels)) self.mab1 = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) self.mab2 = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.ind) self.mab1.reset_parameters() self.mab2.reset_parameters() def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: h = self.mab1(self.ind.expand(x.size(0), -1, -1), x, y_mask=mask) return self.mab2(x, h, x_mask=mask) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.ind.size(2)}, ' f'num_induced_points={self.ind.size(1)}, ' f'heads={self.mab1.heads}, ' f'layer_norm={self.mab1.layer_norm1 is not None}, ' f'dropout={self.mab1.dropout})') class PoolingByMultiheadAttention(torch.nn.Module): r"""The Pooling by Multihead Attention (PMA) layer from the `"Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks" `_ paper. .. math:: \mathrm{PMA}(\mathbf{X}) = \mathrm{MAB}(\mathbf{S}, \mathbf{x}) where :math:`\mathbf{S}` denotes :obj:`num_seed_points` learnable vectors. Args: channels (int): Size of each input sample. num_seed_points (int, optional): Number of seed points. (default: :obj:`1`) heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) norm (str, optional): If set to :obj:`False`, will not apply layer normalization. (default: :obj:`True`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) """ def __init__(self, channels: int, num_seed_points: int = 1, heads: int = 1, layer_norm: bool = True, dropout: float = 0.0): super().__init__() self.lin = Linear(channels, channels) self.seed = Parameter(torch.empty(1, num_seed_points, channels)) self.mab = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) self.reset_parameters() def reset_parameters(self): self.lin.reset_parameters() torch.nn.init.xavier_uniform_(self.seed) self.mab.reset_parameters() def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: x = self.lin(x).relu() return self.mab(self.seed.expand(x.size(0), -1, -1), x, y_mask=mask) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.seed.size(2)}, ' f'num_seed_points={self.seed.size(1)}, ' f'heads={self.mab.heads}, ' f'layer_norm={self.mab.layer_norm1 is not None}, ' f'dropout={self.mab.dropout})') ================================================ FILE: torch_geometric/nn/aggr/variance_preserving.py ================================================ from typing import Optional from torch import Tensor from torch_geometric.nn.aggr import Aggregation from torch_geometric.utils import degree from torch_geometric.utils._scatter import broadcast class VariancePreservingAggregation(Aggregation): r"""Performs the Variance Preserving Aggregation (VPA) from the `"GNN-VPA: A Variance-Preserving Aggregation Strategy for Graph Neural Networks" `_ paper. .. math:: \mathrm{vpa}(\mathcal{X}) = \frac{1}{\sqrt{|\mathcal{X}|}} \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i """ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: out = self.reduce(x, index, ptr, dim_size, dim, reduce='sum') if ptr is not None: count = ptr.diff().to(out.dtype) else: count = degree(index, dim_size, dtype=out.dtype) count = count.sqrt().clamp(min=1.0) count = broadcast(count, ref=out, dim=dim) return out / count ================================================ FILE: torch_geometric/nn/attention/__init__.py ================================================ from .performer import PerformerAttention from .qformer import QFormer from .sgformer import SGFormerAttention from .polynormer import PolynormerAttention __all__ = classes = [ 'PerformerAttention', 'QFormer', 'SGFormerAttention', 'PolynormerAttention', ] ================================================ FILE: torch_geometric/nn/attention/performer.py ================================================ import math from typing import Callable, Optional import torch from torch import Tensor def _orthogonal_matrix(dim: int) -> Tensor: r"""Get an orthogonal matrix by applying QR decomposition.""" # Random matrix from normal distribution mat = torch.randn((dim, dim)) # QR decomposition to two orthogonal matrices q, _ = torch.linalg.qr(mat.cpu(), mode='reduced') return q.t() def orthogonal_matrix(num_rows: int, num_cols: int) -> Tensor: r"""Generate an orthogonal matrix with `num_rows` rows and `num_cols` columns. """ num_full_blocks = int(num_rows / num_cols) blocks = [] for _ in range(num_full_blocks): q = _orthogonal_matrix(num_cols) blocks.append(q) remain_rows = num_rows - num_full_blocks * num_cols if remain_rows > 0: q = _orthogonal_matrix(num_cols) blocks.append(q[:remain_rows]) mat = torch.cat(blocks) # multiplier = torch.randn((num_rows, num_cols)).norm(dim=1) # scaler = torch.diag(multiplier) # mat = scaler @ mat return mat def linear_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor: r"""Efficient attention mechanism from the `"Rethinking Attention with Performers" `_ paper. .. math:: \mathbf{\hat{D}}^{-1}(\mathbf{Q}'((\mathbf{K}')^{\top} \mathbf{V})) """ D_inv = 1.0 / (q @ k.sum(dim=-2).unsqueeze(-1)) kv = k.transpose(-2, -1) @ v qkv = q @ kv out = torch.einsum('...L,...Ld->...Ld', D_inv.squeeze(-1), qkv) return out def generalized_kernel( x: Tensor, mat: Tensor, kernel: Callable = torch.nn.ReLU(), epsilon: float = 0.001, ) -> Tensor: batch_size, num_heads = x.size()[:2] projection = mat.t().expand(batch_size, num_heads, -1, -1) x = x @ projection out = kernel(x) + epsilon return out class PerformerProjection(torch.nn.Module): r"""The fast attention that uses a projection matrix from the `"Rethinking Attention with Performers" `_ paper. This class projects :math:`\mathbf{Q}` and :math:`\mathbf{K}` matrices with specified kernel. Args: num_cols (int): Projection matrix number of columns. kernel (Callable, optional): Kernels for generalized attention. If not specified, `ReLU` kernel will be used. (default: :obj:`torch.nn.ReLU()`) """ def __init__(self, num_cols: int, kernel: Callable = torch.nn.ReLU()): super().__init__() num_rows = int(num_cols * math.log(num_cols)) self.num_rows = num_rows self.num_cols = num_cols # Generate an orthogonal projection matrix # with the shape (num_rows, num_cols) projection_matrix = orthogonal_matrix(self.num_rows, self.num_cols) self.register_buffer('projection_matrix', projection_matrix) assert kernel is not None self.kernel = kernel def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: q = generalized_kernel(q, self.projection_matrix, self.kernel) k = generalized_kernel(k, self.projection_matrix, self.kernel) out = linear_attention(q, k, v) return out class PerformerAttention(torch.nn.Module): r"""The linear scaled attention mechanism from the `"Rethinking Attention with Performers" `_ paper. Args: channels (int): Size of each input sample. heads (int, optional): Number of parallel attention heads. head_channels (int, optional): Size of each attention head. (default: :obj:`64.`) kernel (Callable, optional): Kernels for generalized attention. If not specified, `ReLU` kernel will be used. (default: :obj:`torch.nn.ReLU()`) qkv_bias (bool, optional): If specified, add bias to query, key and value in the self attention. (default: :obj:`False`) attn_out_bias (bool, optional): If specified, add bias to the attention output. (default: :obj:`True`) dropout (float, optional): Dropout probability of the final attention output. (default: :obj:`0.0`) """ def __init__( self, channels: int, heads: int, head_channels: int = 64, kernel: Callable = torch.nn.ReLU(), qkv_bias: bool = False, attn_out_bias: bool = True, dropout: float = 0.0, ): super().__init__() assert channels % heads == 0 if head_channels is None: head_channels = channels // heads self.heads = heads self.head_channels = head_channels self.kernel = kernel self.fast_attn = PerformerProjection(head_channels, kernel) inner_channels = head_channels * heads self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.attn_out = torch.nn.Linear(inner_channels, channels, bias=attn_out_bias) self.dropout = torch.nn.Dropout(dropout) def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) """ B, N, *_ = x.shape q, k, v = self.q(x), self.k(x), self.v(x) # Reshape and permute q, k and v to proper shape # (B, N, num_heads * head_channels) to (b, num_heads, n, head_channels) q, k, v = map( lambda t: t.reshape(B, N, self.heads, self.head_channels).permute( 0, 2, 1, 3), (q, k, v)) if mask is not None: mask = mask[:, None, :, None] v.masked_fill_(~mask, 0.) out = self.fast_attn(q, k, v) out = out.permute(0, 2, 1, 3).reshape(B, N, -1) out = self.attn_out(out) out = self.dropout(out) return out @torch.no_grad() def redraw_projection_matrix(self): r"""As described in the paper, periodically redraw examples to improve overall approximation of attention. """ num_rows = self.fast_attn.num_rows num_cols = self.fast_attn.num_cols projection_matrix = orthogonal_matrix(num_rows, num_cols) self.fast_attn.projection_matrix.copy_(projection_matrix) del projection_matrix def _reset_parameters(self): self.q.reset_parameters() self.k.reset_parameters() self.v.reset_parameters() self.attn_out.reset_parameters() self.redraw_projection_matrix() def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'heads={self.heads}, ' f'head_channels={self.head_channels} ' f'kernel={self.kernel})') ================================================ FILE: torch_geometric/nn/attention/polynormer.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor class PolynormerAttention(torch.nn.Module): r"""The polynomial-expressive attention mechanism from the `"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time" `_ paper. Args: channels (int): Size of each input sample. heads (int, optional): Number of parallel attention heads. head_channels (int, optional): Size of each attention head. (default: :obj:`64.`) beta (float, optional): Polynormer beta initialization. (default: :obj:`0.9`) qkv_bias (bool, optional): If specified, add bias to query, key and value in the self attention. (default: :obj:`False`) qk_shared (bool optional): Whether weight of query and key are shared. (default: :obj:`True`) dropout (float, optional): Dropout probability of the final attention output. (default: :obj:`0.0`) """ def __init__( self, channels: int, heads: int, head_channels: int = 64, beta: float = 0.9, qkv_bias: bool = False, qk_shared: bool = True, dropout: float = 0.0, ) -> None: super().__init__() self.head_channels = head_channels self.heads = heads self.beta = beta self.qk_shared = qk_shared inner_channels = heads * head_channels self.h_lins = torch.nn.Linear(channels, inner_channels) if not self.qk_shared: self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.lns = torch.nn.LayerNorm(inner_channels) self.lin_out = torch.nn.Linear(inner_channels, inner_channels) self.dropout = torch.nn.Dropout(dropout) def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) """ B, N, *_ = x.shape h = self.h_lins(x) k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads) if self.qk_shared: q = k else: q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads) v = self.v(x).view(B, N, self.head_channels, self.heads) if mask is not None: mask = mask[:, :, None, None] v.masked_fill_(~mask, 0.) # numerator kv = torch.einsum('bndh, bnmh -> bdmh', k, v) num = torch.einsum('bndh, bdmh -> bnmh', q, kv) # denominator k_sum = torch.einsum('bndh -> bdh', k) den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2) # linear global attention based on kernel trick x = (num / (den + 1e-6)).reshape(B, N, -1) x = self.lns(x) * (h + self.beta) x = F.relu(self.lin_out(x)) x = self.dropout(x) return x def reset_parameters(self) -> None: self.h_lins.reset_parameters() if not self.qk_shared: self.q.reset_parameters() self.k.reset_parameters() self.v.reset_parameters() self.lns.reset_parameters() self.lin_out.reset_parameters() def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'heads={self.heads}, ' f'head_channels={self.head_channels})') ================================================ FILE: torch_geometric/nn/attention/qformer.py ================================================ from typing import Callable import torch class QFormer(torch.nn.Module): r"""The Querying Transformer (Q-Former) from `"BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models" `_ paper. Args: input_dim (int): The number of features in the input. hidden_dim (int): The dimension of the fnn in the encoder layer. output_dim (int): The final output dimension. num_heads (int): The number of multi-attention-heads. num_layers (int): The number of sub-encoder-layers in the encoder. dropout (int): The dropout value in each encoder layer. .. note:: This is a simplified version of the original Q-Former implementation. """ def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_heads: int, num_layers: int, dropout: float = 0.0, activation: Callable = torch.nn.ReLU(), ) -> None: super().__init__() self.num_layers = num_layers self.num_heads = num_heads self.layer_norm = torch.nn.LayerNorm(input_dim) self.encoder_layer = torch.nn.TransformerEncoderLayer( d_model=input_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout, activation=activation, batch_first=True, ) self.encoder = torch.nn.TransformerEncoder( self.encoder_layer, num_layers=num_layers, ) self.project = torch.nn.Linear(input_dim, output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: r"""Forward pass. Args: x (torch.Tensor): Input sequence to the encoder layer. :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, sequence length :math:`N`, and feature dimension :math:`F`. """ x = self.layer_norm(x) x = self.encoder(x) out = self.project(x) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'num_heads={self.num_heads}, ' f'num_layers={self.num_layers})') ================================================ FILE: torch_geometric/nn/attention/sgformer.py ================================================ from typing import Optional import torch from torch import Tensor class SGFormerAttention(torch.nn.Module): r"""The simple global attention mechanism from the `"SGFormer: Simplifying and Empowering Transformers for Large-Graph Representations" `_ paper. Args: channels (int): Size of each input sample. heads (int, optional): Number of parallel attention heads. (default: :obj:`1.`) head_channels (int, optional): Size of each attention head. (default: :obj:`64.`) qkv_bias (bool, optional): If specified, add bias to query, key and value in the self attention. (default: :obj:`False`) """ def __init__( self, channels: int, heads: int = 1, head_channels: int = 64, qkv_bias: bool = False, ) -> None: super().__init__() assert channels % heads == 0 if head_channels is None: head_channels = channels // heads self.heads = heads self.head_channels = head_channels inner_channels = head_channels * heads self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) """ B, N, *_ = x.shape qs, ks, vs = self.q(x), self.k(x), self.v(x) # reshape and permute q, k and v to proper shape # (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels) qs, ks, vs = map( lambda t: t.reshape(B, N, self.heads, self.head_channels), (qs, ks, vs)) if mask is not None: mask = mask[:, :, None, None] vs.masked_fill_(~mask, 0.) # replace 0's with epsilon epsilon = 1e-6 qs[qs == 0] = epsilon ks[ks == 0] = epsilon # normalize input, shape not changed qs, ks = map( lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True), (qs, ks)) # numerator kvs = torch.einsum("blhm,blhd->bhmd", ks, vs) attention_num = torch.einsum("bnhm,bhmd->bnhd", qs, kvs) attention_num += N * vs # denominator all_ones = torch.ones([B, N]).to(ks.device) ks_sum = torch.einsum("blhm,bl->bhm", ks, all_ones) attention_normalizer = torch.einsum("bnhm,bhm->bnh", qs, ks_sum) # attentive aggregated results attention_normalizer = torch.unsqueeze(attention_normalizer, len(attention_normalizer.shape)) attention_normalizer += torch.ones_like(attention_normalizer) * N attn_output = attention_num / attention_normalizer return attn_output.mean(dim=2) def reset_parameters(self): self.q.reset_parameters() self.k.reset_parameters() self.v.reset_parameters() def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'heads={self.heads}, ' f'head_channels={self.head_channels})') ================================================ FILE: torch_geometric/nn/conv/__init__.py ================================================ from .message_passing import MessagePassing from .simple_conv import SimpleConv from .gcn_conv import GCNConv from .cheb_conv import ChebConv from .sage_conv import SAGEConv from .cugraph.sage_conv import CuGraphSAGEConv from .graph_conv import GraphConv from .gravnet_conv import GravNetConv from .gated_graph_conv import GatedGraphConv from .res_gated_graph_conv import ResGatedGraphConv from .gat_conv import GATConv from .cugraph.gat_conv import CuGraphGATConv from .fused_gat_conv import FusedGATConv from .gatv2_conv import GATv2Conv from .transformer_conv import TransformerConv from .agnn_conv import AGNNConv from .tag_conv import TAGConv from .gin_conv import GINConv, GINEConv from .arma_conv import ARMAConv from .sg_conv import SGConv from .appnp import APPNP from .mf_conv import MFConv from .rgcn_conv import RGCNConv, FastRGCNConv from .cugraph.rgcn_conv import CuGraphRGCNConv from .rgat_conv import RGATConv from .signed_conv import SignedConv from .dna_conv import DNAConv from .point_conv import PointNetConv from .gmm_conv import GMMConv from .spline_conv import SplineConv from .nn_conv import NNConv from .cg_conv import CGConv from .edge_conv import EdgeConv, DynamicEdgeConv from .x_conv import XConv from .ppf_conv import PPFConv from .feast_conv import FeaStConv from .point_transformer_conv import PointTransformerConv from .hypergraph_conv import HypergraphConv from .le_conv import LEConv from .pna_conv import PNAConv from .cluster_gcn_conv import ClusterGCNConv from .gen_conv import GENConv from .gcn2_conv import GCN2Conv from .pan_conv import PANConv from .wl_conv import WLConv from .wl_conv_continuous import WLConvContinuous from .film_conv import FiLMConv from .supergat_conv import SuperGATConv from .fa_conv import FAConv from .eg_conv import EGConv from .pdn_conv import PDNConv from .general_conv import GeneralConv from .hgt_conv import HGTConv from .heat_conv import HEATConv from .hetero_conv import HeteroConv from .han_conv import HANConv from .lg_conv import LGConv from .ssg_conv import SSGConv from .point_gnn_conv import PointGNNConv from .gps_conv import GPSConv from .antisymmetric_conv import AntiSymmetricConv from .dir_gnn_conv import DirGNNConv from .mixhop_conv import MixHopConv from .meshcnn_conv import MeshCNNConv import torch_geometric.nn.conv.utils # noqa __all__ = [ 'MessagePassing', 'SimpleConv', 'GCNConv', 'ChebConv', 'SAGEConv', 'CuGraphSAGEConv', 'GraphConv', 'GravNetConv', 'GatedGraphConv', 'ResGatedGraphConv', 'GATConv', 'CuGraphGATConv', 'FusedGATConv', 'GATv2Conv', 'TransformerConv', 'AGNNConv', 'TAGConv', 'GINConv', 'GINEConv', 'ARMAConv', 'SGConv', 'SSGConv', 'APPNP', 'MFConv', 'RGCNConv', 'FastRGCNConv', 'CuGraphRGCNConv', 'RGATConv', 'SignedConv', 'DNAConv', 'PointNetConv', 'GMMConv', 'SplineConv', 'NNConv', 'CGConv', 'EdgeConv', 'DynamicEdgeConv', 'XConv', 'PPFConv', 'FeaStConv', 'PointTransformerConv', 'HypergraphConv', 'LEConv', 'PNAConv', 'ClusterGCNConv', 'GENConv', 'GCN2Conv', 'PANConv', 'WLConv', 'WLConvContinuous', 'FiLMConv', 'SuperGATConv', 'FAConv', 'EGConv', 'PDNConv', 'GeneralConv', 'HGTConv', 'HEATConv', 'HeteroConv', 'HANConv', 'LGConv', 'PointGNNConv', 'GPSConv', 'AntiSymmetricConv', 'DirGNNConv', 'MixHopConv', 'MeshCNNConv', ] classes = __all__ ECConv = NNConv PointConv = PointNetConv ================================================ FILE: torch_geometric/nn/conv/agnn_conv.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse from torch_geometric.utils import add_self_loops, remove_self_loops, softmax class AGNNConv(MessagePassing): r"""The graph attentional propagation layer from the `"Attention-based Graph Neural Network for Semi-Supervised Learning" `_ paper. .. math:: \mathbf{X}^{\prime} = \mathbf{P} \mathbf{X}, where the propagation matrix :math:`\mathbf{P}` is computed as .. math:: P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_k))} with trainable parameter :math:`\beta`. Args: requires_grad (bool, optional): If set to :obj:`False`, :math:`\beta` will not be trainable. (default: :obj:`True`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)`, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F)` """ def __init__(self, requires_grad: bool = True, add_self_loops: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.requires_grad = requires_grad self.add_self_loops = add_self_loops if requires_grad: self.beta = Parameter(torch.empty(1)) else: self.register_buffer('beta', torch.ones(1)) self.reset_parameters() def reset_parameters(self): super().reset_parameters() if self.requires_grad: self.beta.data.fill_(1) def forward(self, x: Tensor, edge_index: Adj) -> Tensor: if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim)) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) x_norm = F.normalize(x, p=2., dim=-1) # propagate_type: (x: Tensor, x_norm: Tensor) return self.propagate(edge_index, x=x, x_norm=x_norm) def message(self, x_j: Tensor, x_norm_i: Tensor, x_norm_j: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: alpha = self.beta * (x_norm_i * x_norm_j).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) return x_j * alpha.view(-1, 1) ================================================ FILE: torch_geometric/nn/conv/antisymmetric_conv.py ================================================ import math from typing import Any, Callable, Dict, Optional, Union import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import GCNConv, MessagePassing from torch_geometric.nn.inits import zeros from torch_geometric.nn.resolver import activation_resolver from torch_geometric.typing import Adj class AntiSymmetricConv(torch.nn.Module): r"""The anti-symmetric graph convolutional operator from the `"Anti-Symmetric DGN: a stable architecture for Deep Graph Networks" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{x}_i + \epsilon \cdot \sigma \left( (\mathbf{W}-\mathbf{W}^T-\gamma \mathbf{I}) \mathbf{x}_i + \Phi(\mathbf{X}, \mathcal{N}_i) + \mathbf{b}\right), where :math:`\Phi(\mathbf{X}, \mathcal{N}_i)` denotes a :class:`~torch.nn.conv.MessagePassing` layer. Args: in_channels (int): Size of each input sample. phi (MessagePassing, optional): The message passing module :math:`\Phi`. If set to :obj:`None`, will use a :class:`~torch_geometric.nn.conv.GCNConv` layer as default. (default: :obj:`None`) num_iters (int, optional): The number of times the anti-symmetric deep graph network operator is called. (default: :obj:`1`) epsilon (float, optional): The discretization step size :math:`\epsilon`. (default: :obj:`0.1`) gamma (float, optional): The strength of the diffusion :math:`\gamma`. It regulates the stability of the method. (default: :obj:`0.1`) act (str, optional): The non-linear activation function :math:`\sigma`, *e.g.*, :obj:`"tanh"` or :obj:`"relu"`. (default: :class:`"tanh"`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{in})` """ def __init__( self, in_channels: int, phi: Optional[MessagePassing] = None, num_iters: int = 1, epsilon: float = 0.1, gamma: float = 0.1, act: Union[str, Callable, None] = 'tanh', act_kwargs: Optional[Dict[str, Any]] = None, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.num_iters = num_iters self.gamma = gamma self.epsilon = epsilon self.act = activation_resolver(act, **(act_kwargs or {})) if phi is None: phi = GCNConv(in_channels, in_channels, bias=False) self.W = Parameter(torch.empty(in_channels, in_channels)) self.register_buffer('eye', torch.eye(in_channels)) self.phi = phi if bias: self.bias = Parameter(torch.empty(in_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" torch.nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) self.phi.reset_parameters() zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor: r"""Runs the forward pass of the module.""" antisymmetric_W = self.W - self.W.t() - self.gamma * self.eye for _ in range(self.num_iters): h = self.phi(x, edge_index, *args, **kwargs) h = x @ antisymmetric_W.t() + h if self.bias is not None: h += self.bias if self.act is not None: h = self.act(h) x = x + self.epsilon * h return x def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'{self.in_channels}, ' f'phi={self.phi}, ' f'num_iters={self.num_iters}, ' f'epsilon={self.epsilon}, ' f'gamma={self.gamma})') ================================================ FILE: torch_geometric/nn/conv/appnp.py ================================================ from typing import Optional import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor from torch_geometric.utils import is_torch_sparse_tensor, spmm, to_edge_index from torch_geometric.utils.sparse import set_sparse_value class APPNP(MessagePassing): r"""The approximate personalized propagation of neural predictions layer from the `"Predict then Propagate: Graph Neural Networks meet Personalized PageRank" `_ paper. .. math:: \mathbf{X}^{(0)} &= \mathbf{X} \mathbf{X}^{(k)} &= (1 - \alpha) \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X}^{(k-1)} + \alpha \mathbf{X}^{(0)} \mathbf{X}^{\prime} &= \mathbf{X}^{(K)}, where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. The adjacency matrix can include other values than :obj:`1` representing edge weights via the optional :obj:`edge_weight` tensor. Args: K (int): Number of iterations :math:`K`. alpha (float): Teleport probability :math:`\alpha`. dropout (float, optional): Dropout probability of edges during training. (default: :obj:`0`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the cached version for further executions. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) normalize (bool, optional): Whether to add self-loops and apply symmetric normalization. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F)` """ _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] def __init__(self, K: int, alpha: float, dropout: float = 0., cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.K = K self.alpha = alpha self.dropout = dropout self.cached = cached self.add_self_loops = add_self_loops self.normalize = normalize self._cached_edge_index = None self._cached_adj_t = None def reset_parameters(self): super().reset_parameters() self._cached_edge_index = None self._cached_adj_t = None def forward( self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, ) -> Tensor: if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache h = x for _ in range(self.K): if self.dropout > 0 and self.training: if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): _, edge_weight = to_edge_index(edge_index) edge_weight = F.dropout(edge_weight, p=self.dropout) edge_index = set_sparse_value(edge_index, edge_weight) else: assert edge_weight is not None edge_weight = F.dropout(edge_weight, p=self.dropout) else: value = edge_index.storage.value() assert value is not None value = F.dropout(value, p=self.dropout) edge_index = edge_index.set_value(value, layout='coo') # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight) x = x * (1 - self.alpha) x = x + self.alpha * h return x def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return f'{self.__class__.__name__}(K={self.K}, alpha={self.alpha})' ================================================ FILE: torch_geometric/nn/conv/arma_conv.py ================================================ from typing import Callable, Optional import torch import torch.nn.functional as F from torch import Tensor, nn from torch.nn import Parameter, ReLU from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import spmm class ARMAConv(MessagePassing): r"""The ARMA graph convolutional operator from the `"Graph Neural Networks with Convolutional ARMA Filters" `_ paper. .. math:: \mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K \mathbf{X}_k^{(T)}, with :math:`\mathbf{X}_k^{(T)}` being recursively defined by .. math:: \mathbf{X}_k^{(t+1)} = \sigma \left( \mathbf{\hat{L}} \mathbf{X}_k^{(t)} \mathbf{W} + \mathbf{X}^{(0)} \mathbf{V} \right), where :math:`\mathbf{\hat{L}} = \mathbf{I} - \mathbf{L} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}` denotes the modified Laplacian :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}`. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample :math:`\mathbf{x}^{(t+1)}`. num_stacks (int, optional): Number of parallel stacks :math:`K`. (default: :obj:`1`). num_layers (int, optional): Number of layers :math:`T`. (default: :obj:`1`) act (callable, optional): Activation function :math:`\sigma`. (default: :meth:`torch.nn.ReLU()`) shared_weights (int, optional): If set to :obj:`True` the layers in each stack will share the same parameters. (default: :obj:`False`) dropout (float, optional): Dropout probability of the skip connection. (default: :obj:`0.`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, num_stacks: int = 1, num_layers: int = 1, shared_weights: bool = False, act: Optional[Callable] = ReLU(), dropout: float = 0., bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.num_stacks = num_stacks self.num_layers = num_layers self.act = act self.shared_weights = shared_weights self.dropout = dropout K, T, F_in, F_out = num_stacks, num_layers, in_channels, out_channels T = 1 if self.shared_weights else T self.weight = Parameter(torch.empty(max(1, T - 1), K, F_out, F_out)) if in_channels > 0: self.init_weight = Parameter(torch.empty(K, F_in, F_out)) self.root_weight = Parameter(torch.empty(T, K, F_in, F_out)) else: self.init_weight = torch.nn.parameter.UninitializedParameter() self.root_weight = torch.nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) if bias: self.bias = Parameter(torch.empty(T, K, 1, F_out)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() glorot(self.weight) if not isinstance(self.init_weight, torch.nn.UninitializedParameter): glorot(self.init_weight) glorot(self.root_weight) zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, flow=self.flow, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, flow=self.flow, dtype=x.dtype) x = x.unsqueeze(-3) out = x for t in range(self.num_layers): if t == 0: out = out @ self.init_weight else: out = out @ self.weight[0 if self.shared_weights else t - 1] # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=out, edge_weight=edge_weight) root = F.dropout(x, p=self.dropout, training=self.training) root = root @ self.root_weight[0 if self.shared_weights else t] out = out + root if self.bias is not None: out = out + self.bias[0 if self.shared_weights else t] if self.act is not None: out = self.act(out) return out.mean(dim=-3) def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) @torch.no_grad() def initialize_parameters(self, module, input): if isinstance(self.init_weight, nn.parameter.UninitializedParameter): F_in, F_out = input[0].size(-1), self.out_channels T, K = self.weight.size(0) + 1, self.weight.size(1) self.init_weight.materialize((K, F_in, F_out)) self.root_weight.materialize((T, K, F_in, F_out)) glorot(self.init_weight) glorot(self.root_weight) module._hook.remove() delattr(module, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_stacks={self.num_stacks}, ' f'num_layers={self.num_layers})') ================================================ FILE: torch_geometric/nn/conv/cg_conv.py ================================================ from typing import Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import BatchNorm1d, Linear from torch_geometric.nn.conv import MessagePassing from torch_geometric.typing import Adj, OptTensor, PairTensor class CGConv(MessagePassing): r"""The crystal graph convolutional operator from the `"Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right) where :math:`\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]` denotes the concatenation of central node features, neighboring node features and edge features. In addition, :math:`\sigma` and :math:`g` denote the sigmoid and softplus functions, respectively. Args: channels (int or tuple): Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities. dim (int, optional): Edge feature dimensionality. (default: :obj:`0`) aggr (str, optional): The aggregation operator to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) batch_norm (bool, optional): If set to :obj:`True`, will make use of batch normalization. (default: :obj:`False`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F)` or :math:`(|\mathcal{V_t}|, F_{t})` if bipartite """ def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0, aggr: str = 'add', batch_norm: bool = False, bias: bool = True, **kwargs): super().__init__(aggr=aggr, **kwargs) self.channels = channels self.dim = dim self.batch_norm = batch_norm if isinstance(channels, int): channels = (channels, channels) self.lin_f = Linear(sum(channels) + dim, channels[1], bias=bias) self.lin_s = Linear(sum(channels) + dim, channels[1], bias=bias) if batch_norm: self.bn = BatchNorm1d(channels[1]) else: self.bn = None self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_f.reset_parameters() self.lin_s.reset_parameters() if self.bn is not None: self.bn.reset_parameters() def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: PairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr) out = out if self.bn is None else self.bn(out) out = out + x[1] return out def message(self, x_i, x_j, edge_attr: OptTensor) -> Tensor: if edge_attr is None: z = torch.cat([x_i, x_j], dim=-1) else: z = torch.cat([x_i, x_j, edge_attr], dim=-1) return self.lin_f(z).sigmoid() * F.softplus(self.lin_s(z)) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.channels}, dim={self.dim})' ================================================ FILE: torch_geometric/nn/conv/cheb_conv.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import OptTensor from torch_geometric.utils import get_laplacian class ChebConv(MessagePassing): r"""The chebyshev spectral graph convolutional operator from the `"Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering" `_ paper. .. math:: \mathbf{X}^{\prime} = \sum_{k=1}^{K} \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)} where :math:`\mathbf{Z}^{(k)}` is computed recursively by .. math:: \mathbf{Z}^{(1)} &= \mathbf{X} \mathbf{Z}^{(2)} &= \mathbf{\hat{L}} \cdot \mathbf{X} \mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)} and :math:`\mathbf{\hat{L}}` denotes the scaled and normalized Laplacian :math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}`. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. K (int): Chebyshev filter size :math:`K`. normalization (str, optional): The normalization scheme for the graph Laplacian (default: :obj:`"sym"`): 1. :obj:`None`: No normalization :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 2. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}` 3. :obj:`"rw"`: Random-walk normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` :obj:`\lambda_max` should be a :class:`torch.Tensor` of size :obj:`[num_graphs]` in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute :obj:`lambda_max` via the :class:`torch_geometric.transforms.LaplacianLambdaMax` transform. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)*, batch vector :math:`(|\mathcal{V}|)` *(optional)*, maximum :obj:`lambda` value :math:`(|\mathcal{G}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__( self, in_channels: int, out_channels: int, K: int, normalization: Optional[str] = 'sym', bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) assert K > 0 assert normalization in [None, 'sym', 'rw'], 'Invalid normalization' self.in_channels = in_channels self.out_channels = out_channels self.normalization = normalization self.lins = torch.nn.ModuleList([ Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') for _ in range(K) ]) if bias: self.bias = Parameter(Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() for lin in self.lins: lin.reset_parameters() zeros(self.bias) def __norm__( self, edge_index: Tensor, num_nodes: Optional[int], edge_weight: OptTensor, normalization: Optional[str], lambda_max: OptTensor = None, dtype: Optional[int] = None, batch: OptTensor = None, ): edge_index, edge_weight = get_laplacian(edge_index, edge_weight, normalization, dtype, num_nodes) assert edge_weight is not None if lambda_max is None: lambda_max = 2.0 * edge_weight.max() elif not isinstance(lambda_max, Tensor): lambda_max = torch.tensor(lambda_max, dtype=dtype, device=edge_index.device) assert lambda_max is not None if batch is not None and lambda_max.numel() > 1: lambda_max = lambda_max[batch[edge_index[0]]] edge_weight = (2.0 * edge_weight) / lambda_max edge_weight.masked_fill_(edge_weight == float('inf'), 0) loop_mask = edge_index[0] == edge_index[1] edge_weight[loop_mask] -= 1 return edge_index, edge_weight def forward( self, x: Tensor, edge_index: Tensor, edge_weight: OptTensor = None, batch: OptTensor = None, lambda_max: OptTensor = None, ) -> Tensor: edge_index, norm = self.__norm__( edge_index, x.size(self.node_dim), edge_weight, self.normalization, lambda_max, dtype=x.dtype, batch=batch, ) Tx_0 = x Tx_1 = x # Dummy. out = self.lins[0](Tx_0) # propagate_type: (x: Tensor, norm: Tensor) if len(self.lins) > 1: Tx_1 = self.propagate(edge_index, x=x, norm=norm) out = out + self.lins[1](Tx_1) for lin in self.lins[2:]: Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm) Tx_2 = 2. * Tx_2 - Tx_0 out = out + lin.forward(Tx_2) Tx_0, Tx_1 = Tx_1, Tx_2 if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, norm: Tensor) -> Tensor: return norm.view(-1, 1) * x_j def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, K={len(self.lins)}, ' f'normalization={self.normalization})') ================================================ FILE: torch_geometric/nn/conv/cluster_gcn_conv.py ================================================ import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse from torch_geometric.utils import ( add_self_loops, degree, is_torch_sparse_tensor, remove_self_loops, spmm, to_edge_index, ) from torch_geometric.utils.sparse import set_sparse_value class ClusterGCNConv(MessagePassing): r"""The ClusterGCN graph convolutional operator from the `"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" `_ paper. .. math:: \mathbf{X}^{\prime} = \left( \mathbf{\hat{A}} + \lambda \cdot \textrm{diag}(\mathbf{\hat{A}}) \right) \mathbf{X} \mathbf{W}_1 + \mathbf{X} \mathbf{W}_2 where :math:`\mathbf{\hat{A}} = {(\mathbf{D} + \mathbf{I})}^{-1}(\mathbf{A} + \mathbf{I})`. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. diag_lambda (float, optional): Diagonal enhancement value :math:`\lambda`. (default: :obj:`0.`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, diag_lambda: float = 0., add_self_loops: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.diag_lambda = diag_lambda self.add_self_loops = add_self_loops self.lin_out = Linear(in_channels, out_channels, bias=bias, weight_initializer='glorot') self.lin_root = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_out.reset_parameters() self.lin_root.reset_parameters() def forward(self, x: Tensor, edge_index: Adj) -> Tensor: num_nodes = x.size(self.node_dim) edge_weight: OptTensor = None if isinstance(edge_index, SparseTensor): assert edge_index.size(0) == edge_index.size(1) if self.add_self_loops: edge_index = torch_sparse.set_diag(edge_index) col, row, _ = edge_index.coo() # Transposed. deg_inv = 1. / torch_sparse.sum(edge_index, dim=1).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv edge_index = edge_index.set_value(edge_weight, layout='coo') elif is_torch_sparse_tensor(edge_index): assert edge_index.size(0) == edge_index.size(1) if edge_index.layout == torch.sparse_csc: raise NotImplementedError("Sparse CSC matrices are not yet " "supported in 'gcn_norm'") if self.add_self_loops: edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) col_and_row, value = to_edge_index(edge_index) col, row = col_and_row[0], col_and_row[1] deg_inv = 1. / degree(col, num_nodes=edge_index.size(0)).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv edge_index = set_sparse_value(edge_index, edge_weight) else: if self.add_self_loops: edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) row, col = edge_index[0], edge_index[1] deg_inv = 1. / degree(col, num_nodes=num_nodes).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight) out = self.lin_out(out) + self.lin_root(x) return out def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, diag_lambda={self.diag_lambda})') ================================================ FILE: torch_geometric/nn/conv/collect.jinja ================================================ from typing import List, NamedTuple, Optional, Union import torch from torch import Tensor from torch_geometric import EdgeIndex from torch_geometric.index import ptr2index from torch_geometric.utils import is_torch_sparse_tensor from torch_geometric.typing import SparseTensor class CollectArgs(NamedTuple): {%- if collect_param_dict|length > 0 %} {%- for param in collect_param_dict.values() %} {{param.name}}: {{param.type_repr}} {%- endfor %} {%- else %} pass {%- endif %} def {{collect_name}}( self, edge_index: Union[Tensor, SparseTensor], {%- for param in signature.param_dict.values() %} {{param.name}}: {{param.type_repr}}, {%- endfor %} size: List[Optional[int]], ) -> CollectArgs: i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) # Collect special arguments: if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): {%- if 'edge_index' in collect_param_dict %} raise ValueError("Cannot collect 'edge_indices' for sparse matrices") {%- endif %} adj_t = edge_index if adj_t.layout == torch.sparse_coo: edge_index_i = adj_t.indices()[0] edge_index_j = adj_t.indices()[1] ptr = None elif adj_t.layout == torch.sparse_csr: ptr = adj_t.crow_indices() edge_index_j = adj_t.col_indices() edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel()) else: raise ValueError(f"Received invalid layout '{adj_t.layout}'") {%- if 'edge_weight' in collect_param_dict %} if edge_weight is None: edge_weight = adj_t.values() {%- elif 'edge_attr' in collect_param_dict %} if edge_attr is None: _value = adj_t.values() edge_attr = None if _value.dim() == 1 else _value {%- elif 'edge_type' in collect_param_dict %} if edge_type is None: edge_type = adj_t.values() {%- endif %} else: {%- if 'adj_t' in collect_param_dict %} raise ValueError("Cannot collect 'adj_t' for edge indices") {%- endif %} edge_index_i = edge_index[i] edge_index_j = edge_index[j] ptr = None if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): if i == 0 and edge_index.is_sorted_by_row: (ptr, _), _ = edge_index.get_csr() elif i == 1 and edge_index.is_sorted_by_col: (ptr, _), _ = edge_index.get_csc() elif isinstance(edge_index, SparseTensor): {%- if 'edge_index' in collect_param_dict %} raise ValueError("Cannot collect 'edge_indices' for sparse matrices") {%- endif %} adj_t = edge_index edge_index_i, edge_index_j, _value = adj_t.coo() ptr, _, _ = adj_t.csr() {%- if 'edge_weight' in collect_param_dict %} if edge_weight is None: edge_weight = _value {%- elif 'edge_attr' in collect_param_dict %} if edge_attr is None: edge_attr = None if _value is None or _value.dim() == 1 else _value {%- elif 'edge_type' in collect_param_dict %} if edge_type is None: edge_type = _value {%- endif %} else: raise NotImplementedError {%- if 'edge_weight' in collect_param_dict and collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %} if torch.jit.is_scripting(): assert edge_weight is not None {%- elif 'edge_attr' in collect_param_dict and collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %} if torch.jit.is_scripting(): assert edge_attr is not None {%- elif 'edge_type' in collect_param_dict and collect_param_dict['edge_type'].type_repr.endswith('Tensor') %} if torch.jit.is_scripting(): assert edge_type is not None {%- endif %} # Collect user-defined arguments: {%- for name in collect_param_dict %} {%- if (name.endswith('_i') or name.endswith('_j')) and name not in ['edge_index_i', 'edge_index_j', 'size_i', 'size_j'] %} # ({{loop.index}}) - Collect `{{name}}`: if isinstance({{name[:-2]}}, (tuple, list)): assert len({{name[:-2]}}) == 2 _{{name[:-2]}}_0, _{{name[:-2]}}_1 = {{name[:-2]}}[0], {{name[:-2]}}[1] if isinstance(_{{name[:-2]}}_0, Tensor): self._set_size(size, 0, _{{name[:-2]}}_0) {%- if name.endswith('_j') %} {{name}} = self._index_select(_{{name[:-2]}}_0, edge_index_{{name[-1]}}) else: {{name}} = None {%- endif %} if isinstance(_{{name[:-2]}}_1, Tensor): self._set_size(size, 1, _{{name[:-2]}}_1) {%- if name.endswith('_i') %} {{name}} = self._index_select(_{{name[:-2]}}_1, edge_index_{{name[-1]}}) else: {{name}} = None {%- endif %} elif isinstance({{name[:-2]}}, Tensor): self._set_size(size, {{name[-1]}}, {{name[:-2]}}) {{name}} = self._index_select({{name[:-2]}}, edge_index_{{name[-1]}}) else: {{name}} = None {%- endif %} {%- endfor %} # Collect default arguments: {%- for name, param in collect_param_dict.items() %} {%- if name not in signature.param_dict and not name.endswith('_i') and not name.endswith('_j') and name not in ['edge_index', 'adj_t', 'size', 'ptr', 'index', 'dim_size'] and '_empty' not in param.default.__name__ %} {{name}} = {{param.default}} {%- endif %} {%- endfor %} index = edge_index_i size_i = size[i] if size[i] is not None else size[j] size_j = size[j] if size[j] is not None else size[i] dim_size = size_i return CollectArgs( {%- for name in collect_param_dict %} {{name}}, {%- endfor %} ) ================================================ FILE: torch_geometric/nn/conv/cugraph/__init__.py ================================================ from .base import CuGraphModule from .sage_conv import CuGraphSAGEConv from .gat_conv import CuGraphGATConv from .rgcn_conv import CuGraphRGCNConv __all__ = [ 'CuGraphModule', 'CuGraphSAGEConv', 'CuGraphGATConv', 'CuGraphRGCNConv', ] ================================================ FILE: torch_geometric/nn/conv/cugraph/base.py ================================================ from typing import Any, Optional import torch from torch import Tensor from torch_geometric import EdgeIndex try: # pragma: no cover LEGACY_MODE = False from pylibcugraphops.pytorch import CSC, HeteroCSC HAS_PYLIBCUGRAPHOPS = True except ImportError: HAS_PYLIBCUGRAPHOPS = False try: # pragma: no cover from pylibcugraphops import ( make_fg_csr, make_fg_csr_hg, make_mfg_csr, make_mfg_csr_hg, ) LEGACY_MODE = True except ImportError: pass class CuGraphModule(torch.nn.Module): # pragma: no cover r"""An abstract base class for implementing :obj:`cugraph`-based message passing layers. """ def __init__(self): super().__init__() if not HAS_PYLIBCUGRAPHOPS and not LEGACY_MODE: raise ModuleNotFoundError(f"'{self.__class__.__name__}' requires " f"'pylibcugraphops>=23.02'") def reset_parameters(self): r"""Resets all learnable parameters of the module.""" def get_cugraph( self, edge_index: EdgeIndex, max_num_neighbors: Optional[int] = None, ) -> Any: r"""Constructs a :obj:`cugraph` graph object from CSC representation. Supports both bipartite and non-bipartite graphs. Args: edge_index (EdgeIndex): The edge indices. max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) """ if not isinstance(edge_index, EdgeIndex): raise ValueError(f"'edge_index' needs to be of type 'EdgeIndex' " f"(got {type(edge_index)})") edge_index = edge_index.sort_by('col')[0] num_src_nodes = edge_index.get_sparse_size(0) (colptr, row), _ = edge_index.get_csc() if not row.is_cuda: raise RuntimeError(f"'{self.__class__.__name__}' requires GPU-" f"based processing (got CPU tensor)") if num_src_nodes != colptr.numel() - 1: # Bipartite graph: if max_num_neighbors is None: max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) if LEGACY_MODE: dst_nodes = torch.arange(colptr.numel() - 1, device=row.device) return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors, num_src_nodes) return CSC(colptr, row, num_src_nodes, dst_max_in_degree=max_num_neighbors) if LEGACY_MODE: return make_fg_csr(colptr, row) return CSC(colptr, row, num_src_nodes=num_src_nodes) def get_typed_cugraph( self, edge_index: EdgeIndex, edge_type: Tensor, num_edge_types: Optional[int] = None, max_num_neighbors: Optional[int] = None, ) -> Any: r"""Constructs a typed :obj:`cugraph` graph object from a CSC representation where each edge corresponds to a given edge type. Supports both bipartite and non-bipartite graphs. Args: edge_index (EdgeIndex): The edge indices. edge_type (torch.Tensor): The edge type. num_edge_types (int, optional): The maximum number of edge types. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) """ if num_edge_types is None: num_edge_types = int(edge_type.max()) + 1 if not isinstance(edge_index, EdgeIndex): raise ValueError(f"'edge_index' needs to be of type 'EdgeIndex' " f"(got {type(edge_index)})") edge_index, perm = edge_index.sort_by('col') edge_type = edge_type[perm] num_src_nodes = edge_index.get_sparse_size(0) (colptr, row), _ = edge_index.get_csc() edge_type = edge_type.int() if num_src_nodes != colptr.numel() - 1: # Bipartite graph: if max_num_neighbors is None: max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) if LEGACY_MODE: dst_nodes = torch.arange(colptr.numel() - 1, device=row.device) return make_mfg_csr_hg(dst_nodes, colptr, row, max_num_neighbors, num_src_nodes, n_node_types=0, n_edge_types=num_edge_types, out_node_types=None, in_node_types=None, edge_types=edge_type) return HeteroCSC(colptr, row, edge_type, num_src_nodes, num_edge_types, dst_max_in_degree=max_num_neighbors) if LEGACY_MODE: return make_fg_csr_hg(colptr, row, n_node_types=0, n_edge_types=num_edge_types, node_types=None, edge_types=edge_type) return HeteroCSC(colptr, row, edge_type, num_src_nodes, num_edge_types) def forward( self, x: Tensor, edge_index: EdgeIndex, max_num_neighbors: Optional[int] = None, ) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): The node features. edge_index (EdgeIndex): The edge indices. max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, the value will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) """ raise NotImplementedError ================================================ FILE: torch_geometric/nn/conv/cugraph/gat_conv.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Linear, Parameter from torch_geometric import EdgeIndex from torch_geometric.nn.conv.cugraph import CuGraphModule from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE from torch_geometric.nn.inits import zeros try: if LEGACY_MODE: from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg else: from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg except ImportError: pass class CuGraphGATConv(CuGraphModule): # pragma: no cover r"""The graph attentional operator from the `"Graph Attention Networks" `_ paper. :class:`CuGraphGATConv` is an optimized version of :class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops` package that fuses message passing computation for accelerated execution and lower memory footprint. See :ref:`install-cugraph` for how to set up :obj:`cugraph-ops`. """ def __init__( self, in_channels: int, out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.lin = Linear(in_channels, heads * out_channels, bias=False) self.att = Parameter(torch.empty(2 * heads * out_channels)) if bias and concat: self.bias = Parameter(torch.empty(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): self.lin.reset_parameters() gain = torch.nn.init.calculate_gain('relu') torch.nn.init.xavier_normal_( self.att.view(2, self.heads, self.out_channels), gain=gain) zeros(self.bias) def forward( self, x: Tensor, edge_index: EdgeIndex, edge_attr: Tensor, max_num_neighbors: Optional[int] = None, ) -> Tensor: graph = self.get_cugraph(edge_index, max_num_neighbors) x = self.lin(x) if LEGACY_MODE: out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU', self.negative_slope, False, self.concat, edge_feat=edge_attr) else: out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU', self.negative_slope, self.concat, edge_feat=edge_attr) if self.bias is not None: out = out + self.bias return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/cugraph/rgcn_conv.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Parameter from torch_geometric import EdgeIndex from torch_geometric.nn.conv.cugraph import CuGraphModule from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE from torch_geometric.nn.inits import glorot, zeros try: if LEGACY_MODE: from pylibcugraphops.torch.autograd import \ agg_hg_basis_n2n_post as RGCNConvAgg else: from pylibcugraphops.pytorch.operators import \ agg_hg_basis_n2n_post as RGCNConvAgg except ImportError: pass class CuGraphRGCNConv(CuGraphModule): # pragma: no cover r"""The relational graph convolutional operator from the `"Modeling Relational Data with Graph Convolutional Networks" `_ paper. :class:`CuGraphRGCNConv` is an optimized version of :class:`~torch_geometric.nn.conv.RGCNConv` based on the :obj:`cugraph-ops` package that fuses message passing computation for accelerated execution and lower memory footprint. See :ref:`install-cugraph` for how to set up :obj:`cugraph-ops`. """ def __init__(self, in_channels: int, out_channels: int, num_relations: int, num_bases: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, bias: bool = True): super().__init__() if aggr not in ['sum', 'add', 'mean']: raise ValueError(f"Aggregation function must be either 'mean' " f"or 'sum' (got '{aggr}')") self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.aggr = aggr self.root_weight = root_weight dim_root_weight = 1 if root_weight else 0 if num_bases is not None: self.weight = Parameter( torch.empty(num_bases + dim_root_weight, in_channels, out_channels)) self.comp = Parameter(torch.empty(num_relations, num_bases)) else: self.weight = Parameter( torch.empty(num_relations + dim_root_weight, in_channels, out_channels)) self.register_parameter('comp', None) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): end = -1 if self.root_weight else None glorot(self.weight[:end]) glorot(self.comp) if self.root_weight: glorot(self.weight[-1]) zeros(self.bias) def forward( self, x: Tensor, edge_index: EdgeIndex, edge_type: Tensor, max_num_neighbors: Optional[int] = None, ) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): The node features. edge_index (EdgeIndex): The edge indices. edge_type (torch.Tensor): The edge type. max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph.. When not given, the value will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) """ graph = self.get_typed_cugraph(edge_index, edge_type, self.num_relations, max_num_neighbors) out = RGCNConvAgg(x, self.comp, graph, concat_own=self.root_weight, norm_by_out_degree=bool(self.aggr == 'mean')) out = out @ self.weight.view(-1, self.out_channels) if self.bias is not None: out = out + self.bias return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_relations={self.num_relations})') ================================================ FILE: torch_geometric/nn/conv/cugraph/sage_conv.py ================================================ from typing import Optional import torch.nn.functional as F from torch import Tensor from torch.nn import Linear from torch_geometric import EdgeIndex from torch_geometric.nn.conv.cugraph import CuGraphModule from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE try: if LEGACY_MODE: from pylibcugraphops.torch.autograd import \ agg_concat_n2n as SAGEConvAgg else: from pylibcugraphops.pytorch.operators import \ agg_concat_n2n as SAGEConvAgg except ImportError: pass class CuGraphSAGEConv(CuGraphModule): # pragma: no cover r"""The GraphSAGE operator from the `"Inductive Representation Learning on Large Graphs" `_ paper. :class:`CuGraphSAGEConv` is an optimized version of :class:`~torch_geometric.nn.conv.SAGEConv` based on the :obj:`cugraph-ops` package that fuses message passing computation for accelerated execution and lower memory footprint. See :ref:`install-cugraph` for how to set up :obj:`cugraph-ops`. """ def __init__( self, in_channels: int, out_channels: int, aggr: str = 'mean', normalize: bool = False, root_weight: bool = True, project: bool = False, bias: bool = True, ): super().__init__() if aggr not in ['mean', 'sum', 'min', 'max']: raise ValueError(f"Aggregation function must be either 'mean', " f"'sum', 'min' or 'max' (got '{aggr}')") self.in_channels = in_channels self.out_channels = out_channels self.aggr = aggr self.normalize = normalize self.root_weight = root_weight self.project = project if self.project: self.pre_lin = Linear(in_channels, in_channels, bias=True) if self.root_weight: self.lin = Linear(2 * in_channels, out_channels, bias=bias) else: self.lin = Linear(in_channels, out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): if self.project: self.pre_lin.reset_parameters() self.lin.reset_parameters() def forward( self, x: Tensor, edge_index: EdgeIndex, max_num_neighbors: Optional[int] = None, ) -> Tensor: graph = self.get_cugraph(edge_index, max_num_neighbors) if self.project: x = self.pre_lin(x).relu() out = SAGEConvAgg(x, graph, self.aggr) if self.root_weight: out = self.lin(out) else: out = self.lin(out[:, :self.in_channels]) if self.normalize: out = F.normalize(out, p=2., dim=-1) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, aggr={self.aggr})') ================================================ FILE: torch_geometric/nn/conv/dir_gnn_conv.py ================================================ import copy import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing class DirGNNConv(torch.nn.Module): r"""A generic wrapper for computing graph convolution on directed graphs as described in the `"Edge Directionality Improves Learning on Heterophilic Graphs" `_ paper. :class:`DirGNNConv` will pass messages both from source nodes to target nodes and from target nodes to source nodes. Args: conv (MessagePassing): The underlying :class:`~torch_geometric.nn.conv.MessagePassing` layer to use. alpha (float, optional): The alpha coefficient used to weight the aggregations of in- and out-edges as part of a convex combination. (default: :obj:`0.5`) root_weight (bool, optional): If set to :obj:`True`, the layer will add transformed root node features to the output. (default: :obj:`True`) """ def __init__( self, conv: MessagePassing, alpha: float = 0.5, root_weight: bool = True, ): super().__init__() self.alpha = alpha self.root_weight = root_weight self.conv_in = copy.deepcopy(conv) self.conv_out = copy.deepcopy(conv) if hasattr(conv, 'add_self_loops'): self.conv_in.add_self_loops = False self.conv_out.add_self_loops = False if hasattr(conv, 'root_weight'): self.conv_in.root_weight = False self.conv_out.root_weight = False if root_weight: self.lin = torch.nn.Linear(conv.in_channels, conv.out_channels) else: self.lin = None self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.conv_in.reset_parameters() self.conv_out.reset_parameters() if self.lin is not None: self.lin.reset_parameters() def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: """""" # noqa: D419 x_in = self.conv_in(x, edge_index) x_out = self.conv_out(x, edge_index.flip([0])) out = self.alpha * x_out + (1 - self.alpha) * x_in if self.root_weight: out = out + self.lin(x) return out def __repr__(self) -> str: return f'{self.__class__.__name__}({self.conv_in}, alpha={self.alpha})' ================================================ FILE: torch_geometric/nn/conv/dna_conv.py ================================================ import math from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.inits import kaiming_uniform, uniform from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor class Linear(torch.nn.Module): def __init__(self, in_channels, out_channels, groups=1, bias=True): super().__init__() assert in_channels % groups == 0 and out_channels % groups == 0 self.in_channels = in_channels self.out_channels = out_channels self.groups = groups self.weight = Parameter( torch.empty(groups, in_channels // groups, out_channels // groups)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): kaiming_uniform(self.weight, fan=self.weight.size(1), a=math.sqrt(5)) uniform(self.weight.size(1), self.bias) def forward(self, src): # Input: [*, in_channels] # Output: [*, out_channels] if self.groups > 1: size = src.size()[:-1] src = src.view(-1, self.groups, self.in_channels // self.groups) src = src.transpose(0, 1).contiguous() out = torch.matmul(src, self.weight) out = out.transpose(1, 0).contiguous() out = out.view(size + (self.out_channels, )) else: out = torch.matmul(src, self.weight.squeeze(0)) if self.bias is not None: out = out + self.bias return out def __repr__(self) -> str: # pragma: no cover return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, groups={self.groups})') def restricted_softmax(src, dim: int = -1, margin: float = 0.): src_max = torch.clamp(src.max(dim=dim, keepdim=True)[0], min=0.) out = (src - src_max).exp() out = out / (out.sum(dim=dim, keepdim=True) + (margin - src_max).exp()) return out class Attention(torch.nn.Module): def __init__(self, dropout=0): super().__init__() self.dropout = dropout def forward(self, query, key, value): return self.compute_attention(query, key, value) def compute_attention(self, query, key, value): # query: [*, query_entries, dim_k] # key: [*, key_entries, dim_k] # value: [*, key_entries, dim_v] # Output: [*, query_entries, dim_v] assert query.dim() == key.dim() == value.dim() >= 2 assert query.size(-1) == key.size(-1) assert key.size(-2) == value.size(-2) # Score: [*, query_entries, key_entries] score = torch.matmul(query, key.transpose(-2, -1)) score = score / math.sqrt(key.size(-1)) score = restricted_softmax(score, dim=-1) score = F.dropout(score, p=self.dropout, training=self.training) return torch.matmul(score, value) def __repr__(self) -> str: # pragma: no cover return f'{self.__class__.__name__}(dropout={self.dropout})' class MultiHead(Attention): def __init__(self, in_channels, out_channels, heads=1, groups=1, dropout=0, bias=True): super().__init__(dropout) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.groups = groups self.bias = bias assert in_channels % heads == 0 and out_channels % heads == 0 assert in_channels % groups == 0 and out_channels % groups == 0 assert max(groups, self.heads) % min(groups, self.heads) == 0 self.lin_q = Linear(in_channels, out_channels, groups, bias) self.lin_k = Linear(in_channels, out_channels, groups, bias) self.lin_v = Linear(in_channels, out_channels, groups, bias) self.reset_parameters() def reset_parameters(self): self.lin_q.reset_parameters() self.lin_k.reset_parameters() self.lin_v.reset_parameters() def forward(self, query, key, value): # query: [*, query_entries, in_channels] # key: [*, key_entries, in_channels] # value: [*, key_entries, in_channels] # Output: [*, query_entries, out_channels] assert query.dim() == key.dim() == value.dim() >= 2 assert query.size(-1) == key.size(-1) == value.size(-1) assert key.size(-2) == value.size(-2) query = self.lin_q(query) key = self.lin_k(key) value = self.lin_v(value) # query: [*, heads, query_entries, out_channels // heads] # key: [*, heads, key_entries, out_channels // heads] # value: [*, heads, key_entries, out_channels // heads] size = query.size()[:-2] out_channels_per_head = self.out_channels // self.heads query_size = size + (query.size(-2), self.heads, out_channels_per_head) query = query.view(query_size).transpose(-2, -3) key_size = size + (key.size(-2), self.heads, out_channels_per_head) key = key.view(key_size).transpose(-2, -3) value_size = size + (value.size(-2), self.heads, out_channels_per_head) value = value.view(value_size).transpose(-2, -3) # Output: [*, heads, query_entries, out_channels // heads] out = self.compute_attention(query, key, value) # Output: [*, query_entries, heads, out_channels // heads] out = out.transpose(-3, -2).contiguous() # Output: [*, query_entries, out_channels] out = out.view(size + (query.size(-2), self.out_channels)) return out def __repr__(self) -> str: # pragma: no cover return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads}, ' f'groups={self.groups}, dropout={self.dropout}, ' f'bias={self.bias})') class DNAConv(MessagePassing): r"""The dynamic neighborhood aggregation operator from the `"Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks" `_ paper. .. math:: \mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v \leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in \mathcal{N}(v) \right\} \right) based on (multi-head) dot-product attention .. math:: \mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left( \mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_V^{(t)} \right) with :math:`\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)}, \mathbf{\Theta}_V^{(t)}` denoting (grouped) projection matrices for query, key and value information, respectively. :math:`h^{(t)}_{\mathbf{\Theta}}` is implemented as a non-trainable version of :class:`torch_geometric.nn.conv.GCNConv`. .. note:: In contrast to other layers, this operator expects node features as shape :obj:`[num_nodes, num_layers, channels]`. Args: channels (int): Size of each input/output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) groups (int, optional): Number of groups to use for all linear projections. (default: :obj:`1`) dropout (float, optional): Dropout probability of attention coefficients. (default: :obj:`0.`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the cached version for further executions. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) normalize (bool, optional): Whether to add self-loops and apply symmetric normalization. (default: :obj:`True`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, L, F)` where :math:`L` is the number of layers, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F)` """ _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] def __init__(self, channels: int, heads: int = 1, groups: int = 1, dropout: float = 0., cached: bool = False, normalize: bool = True, add_self_loops: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.bias = bias self.cached = cached self.normalize = normalize self.add_self_loops = add_self_loops self._cached_edge_index = None self._cached_adj_t = None self.multi_head = MultiHead(channels, channels, heads, groups, dropout, bias) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.multi_head.reset_parameters() self._cached_edge_index = None self._cached_adj_t = None def forward( self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, ) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): The input node features of shape :obj:`[num_nodes, num_layers, channels]`. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_weight (torch.Tensor, optional): The edge weights. (default: :obj:`None`) """ if x.dim() != 3: raise ValueError('Feature shape must be [num_nodes, num_layers, ' 'channels].') if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache # propagate_type: (x: Tensor, edge_weight: OptTensor) return self.propagate(edge_index, x=x, edge_weight=edge_weight) def message(self, x_i: Tensor, x_j: Tensor, edge_weight: Tensor) -> Tensor: x_i = x_i[:, -1:] # [num_edges, 1, channels] out = self.multi_head(x_i, x_j, x_j) # [num_edges, 1, channels] return edge_weight.view(-1, 1) * out.squeeze(1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.multi_head.in_channels}, ' f'heads={self.multi_head.heads}, ' f'groups={self.multi_head.groups})') ================================================ FILE: torch_geometric/nn/conv/edge_conv.py ================================================ from typing import Callable, Optional, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import knn else: knn = None class EdgeConv(MessagePassing): r"""The edge convolutional operator from the `"Dynamic Graph CNN for Learning on Point Clouds" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, \mathbf{x}_j - \mathbf{x}_i), where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* a MLP. Args: nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps pair-wise concatenated node features :obj:`x` of shape :obj:`[-1, 2 * in_channels]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"max"`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, nn: Callable, aggr: str = 'max', **kwargs): super().__init__(aggr=aggr, **kwargs) self.nn = nn self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.nn) def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: PairTensor) return self.propagate(edge_index, x=x) def message(self, x_i: Tensor, x_j: Tensor) -> Tensor: return self.nn(torch.cat([x_i, x_j - x_i], dim=-1)) def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn})' class DynamicEdgeConv(MessagePassing): r"""The dynamic edge convolutional operator from the `"Dynamic Graph CNN for Learning on Point Clouds" `_ paper (see :class:`torch_geometric.nn.conv.EdgeConv`), where the graph is dynamically constructed using nearest neighbors in the feature space. Args: nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps pair-wise concatenated node features :obj:`x` of shape `:obj:`[-1, 2 * in_channels]` to shape :obj:`[-1, out_channels]`, *e.g.* defined by :class:`torch.nn.Sequential`. k (int): Number of nearest neighbors. aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"max"`) num_workers (int): Number of workers to use for k-NN computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))` if bipartite, batch vector :math:`(|\mathcal{V}|)` or :math:`((|\mathcal{V}|), (|\mathcal{V}|))` if bipartite *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, nn: Callable, k: int, aggr: str = 'max', num_workers: int = 1, **kwargs): super().__init__(aggr=aggr, flow='source_to_target', **kwargs) if knn is None: raise ImportError('`DynamicEdgeConv` requires `torch-cluster`.') self.nn = nn self.k = k self.num_workers = num_workers self.reset_parameters() def reset_parameters(self): reset(self.nn) def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) if x[0].dim() != 2: raise ValueError("Static graphs not supported in DynamicEdgeConv") b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) edge_index = knn(x[0], x[1], self.k, b[0], b[1]).flip([0]) # propagate_type: (x: PairTensor) return self.propagate(edge_index, x=x) def message(self, x_i: Tensor, x_j: Tensor) -> Tensor: return self.nn(torch.cat([x_i, x_j - x_i], dim=-1)) def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn}, k={self.k})' ================================================ FILE: torch_geometric/nn/conv/edge_updater.jinja ================================================ import typing from typing import Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.utils import is_sparse from torch_geometric.typing import Size, SparseTensor {% for module in modules %} from {{module}} import * {%- endfor %} {% include "collect.jinja" %} def edge_updater( self, edge_index: Union[Tensor, SparseTensor], {%- for param in signature.param_dict.values() %} {{param.name}}: {{param.type_repr}}, {%- endfor %} size: Size = None, ) -> {{signature.return_type_repr}}: mutable_size = self._check_input(edge_index, size) kwargs = self.{{collect_name}}( edge_index, {%- for name in signature.param_dict %} {{name}}, {%- endfor %} mutable_size, ) # Begin Edge Update Forward Pre Hook ####################################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._edge_update_forward_pre_hooks.values(): hook_kwargs = dict( {%- for name in collect_param_dict %} {{name}}=kwargs.{{name}}, {%- endfor %} ) res = hook(self, (edge_index, size, hook_kwargs)) if res is not None: edge_index, size, hook_kwargs = res kwargs = CollectArgs( {%- for name in collect_param_dict %} {{name}}=hook_kwargs['{{name}}'], {%- endfor %} ) # End Edge Update Forward Pre Hook ######################################### out = self.edge_update( {%- for name in collect_param_dict %} {{name}}=kwargs.{{name}}, {%- endfor %} ) # Begin Edge Update Forward Hook ########################################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._edge_update_forward_hooks.values(): hook_kwargs = dict( {%- for name in collect_param_dict %} {{name}}=kwargs.{{name}}, {%- endfor %} ) res = hook(self, (edge_index, size, hook_kwargs), out) out = res if res is not None else out # End Edge Update Forward Hook ############################################# return out ================================================ FILE: torch_geometric/nn/conv/eg_conv.py ================================================ from typing import List, Optional, Tuple import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse from torch_geometric.utils import add_remaining_self_loops, scatter, spmm class EGConv(MessagePassing): r"""The Efficient Graph Convolution from the `"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" `_ paper. Its node-wise formulation is given by: .. math:: \mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in \mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \; \underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus} \mathbf{W}_b \mathbf{x}_{j} with :math:`\mathbf{W}_b` denoting a basis weight, :math:`\oplus` denoting an aggregator, and :math:`w` denoting per-vertex weighting coefficients across different heads, bases and aggregators. EGC retains :math:`\mathcal{O}(|\mathcal{V}|)` memory usage, making it a sensible alternative to :class:`~torch_geometric.nn.conv.GCNConv`, :class:`~torch_geometric.nn.conv.SAGEConv` or :class:`~torch_geometric.nn.conv.GINConv`. .. note:: For an example of using :obj:`EGConv`, see `examples/egc.py `_. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. aggregators (List[str], optional): Aggregators to be used. Supported aggregators are :obj:`"sum"`, :obj:`"mean"`, :obj:`"symnorm"`, :obj:`"max"`, :obj:`"min"`, :obj:`"std"`, :obj:`"var"`. Multiple aggregators can be used to improve the performance. (default: :obj:`["symnorm"]`) num_heads (int, optional): Number of heads :math:`H` to use. Must have :obj:`out_channels % num_heads == 0`. It is recommended to set :obj:`num_heads >= num_bases`. (default: :obj:`8`) num_bases (int, optional): Number of basis weights :math:`B` to use. (default: :obj:`4`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of the edge index with added self loops on first execution, along with caching the calculation of the symmetric normalized edge weights if the :obj:`"symnorm"` aggregator is being used. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ _cached_edge_index: Optional[Tuple[Tensor, OptTensor]] _cached_adj_t: Optional[SparseTensor] def __init__( self, in_channels: int, out_channels: int, aggregators: Optional[List[str]] = None, num_heads: int = 8, num_bases: int = 4, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs, ): super().__init__(node_dim=0, **kwargs) if out_channels % num_heads != 0: raise ValueError(f"'out_channels' (got {out_channels}) must be " f"divisible by the number of heads " f"(got {num_heads})") self.in_channels = in_channels self.out_channels = out_channels self.num_heads = num_heads self.num_bases = num_bases self.cached = cached self.add_self_loops = add_self_loops self.aggregators = aggregators or ['symnorm'] for a in self.aggregators: if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']: raise ValueError(f"Unsupported aggregator: '{a}'") self.bases_lin = Linear(in_channels, (out_channels // num_heads) * num_bases, bias=False, weight_initializer='glorot') self.comb_lin = Linear(in_channels, num_heads * num_bases * len(self.aggregators)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.bases_lin.reset_parameters() self.comb_lin.reset_parameters() zeros(self.bias) self._cached_adj_t = None self._cached_edge_index = None def forward(self, x: Tensor, edge_index: Adj) -> Tensor: symnorm_weight: OptTensor = None if "symnorm" in self.aggregators: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, symnorm_weight = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops, flow=self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, symnorm_weight) else: edge_index, symnorm_weight = cache elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops, flow=self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache elif self.add_self_loops: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if self.cached and cache is not None: edge_index = cache[0] else: edge_index, _ = add_remaining_self_loops(edge_index) if self.cached: self._cached_edge_index = (edge_index, None) elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if self.cached and cache is not None: edge_index = cache else: edge_index = torch_sparse.fill_diag(edge_index, 1.0) if self.cached: self._cached_adj_t = edge_index # [num_nodes, (out_channels // num_heads) * num_bases] bases = self.bases_lin(x) # [num_nodes, num_heads * num_bases * num_aggrs] weightings = self.comb_lin(x) # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases] # propagate_type: (x: Tensor, symnorm_weight: OptTensor) aggregated = self.propagate(edge_index, x=bases, symnorm_weight=symnorm_weight) weightings = weightings.view(-1, self.num_heads, self.num_bases * len(self.aggregators)) aggregated = aggregated.view( -1, len(self.aggregators) * self.num_bases, self.out_channels // self.num_heads, ) # [num_nodes, num_heads, out_channels // num_heads] out = torch.matmul(weightings, aggregated) out = out.view(-1, self.out_channels) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor) -> Tensor: return x_j def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None, symnorm_weight: OptTensor = None) -> Tensor: outs = [] for aggr in self.aggregators: if aggr == 'symnorm': assert symnorm_weight is not None out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0, dim_size, reduce='sum') elif aggr == 'var' or aggr == 'std': mean = scatter(inputs, index, 0, dim_size, reduce='mean') mean_squares = scatter(inputs * inputs, index, 0, dim_size, reduce='mean') out = mean_squares - mean * mean if aggr == 'std': out = out.clamp(min=1e-5).sqrt() else: out = scatter(inputs, index, 0, dim_size, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0] def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: adj_t_2 = adj_t if len(self.aggregators) > 1 and 'symnorm' in self.aggregators: if isinstance(adj_t, SparseTensor): adj_t_2 = adj_t.set_value(None) else: adj_t_2 = adj_t.clone() adj_t_2.values().fill_(1.0) outs = [] for aggr in self.aggregators: if aggr == 'symnorm': out = spmm(adj_t, x, reduce='sum') elif aggr in ['var', 'std']: mean = spmm(adj_t_2, x, reduce='mean') mean_sq = spmm(adj_t_2, x * x, reduce='mean') out = mean_sq - mean * mean if aggr == 'std': out = torch.sqrt(out.relu_() + 1e-5) else: out = spmm(adj_t_2, x, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0] def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, aggregators={self.aggregators})') ================================================ FILE: torch_geometric/nn/conv/fa_conv.py ================================================ import typing from typing import Optional, Tuple, Union import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import PairTensor # noqa from torch_geometric.typing import ( Adj, NoneType, OptPairTensor, OptTensor, SparseTensor, ) from torch_geometric.utils import is_torch_sparse_tensor from torch_geometric.utils.sparse import set_sparse_value if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload_method as overload class FAConv(MessagePassing): r"""The Frequency Adaptive Graph Convolution operator from the `"Beyond Low-Frequency Information in Graph Convolutional Networks" `_ paper. .. math:: \mathbf{x}^{\prime}_i= \epsilon \cdot \mathbf{x}^{(0)}_i + \sum_{j \in \mathcal{N}(i)} \frac{\alpha_{i,j}}{\sqrt{d_i d_j}} \mathbf{x}_{j} where :math:`\mathbf{x}^{(0)}_i` and :math:`d_i` denote the initial feature representation and node degree of node :math:`i`, respectively. The attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \mathbf{\alpha}_{i,j} = \textrm{tanh}(\mathbf{a}^{\top}[\mathbf{x}_i, \mathbf{x}_j]) based on the trainable parameter vector :math:`\mathbf{a}`. Args: channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. eps (float, optional): :math:`\epsilon`-value. (default: :obj:`0.1`) dropout (float, optional): Dropout probability of the normalized coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`). cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of :math:`\sqrt{d_i d_j}` on first execution, and will use the cached version for further executions. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) normalize (bool, optional): Whether to add self-loops (if :obj:`add_self_loops` is :obj:`True`) and compute symmetric normalization coefficients on the fly. If set to :obj:`False`, :obj:`edge_weight` needs to be provided in the layer's :meth:`forward` method. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)`, initial node features :math:`(|\mathcal{V}|, F)`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F)` or :math:`((|\mathcal{V}|, F), ((2, |\mathcal{E}|), (|\mathcal{E}|)))` if :obj:`return_attention_weights=True` """ _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] _alpha: OptTensor def __init__(self, channels: int, eps: float = 0.1, dropout: float = 0.0, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.channels = channels self.eps = eps self.dropout = dropout self.cached = cached self.add_self_loops = add_self_loops self.normalize = normalize self._cached_edge_index = None self._cached_adj_t = None self._alpha = None self.att_l = Linear(channels, 1, bias=False) self.att_r = Linear(channels, 1, bias=False) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.att_l.reset_parameters() self.att_r.reset_parameters() self._cached_edge_index = None self._cached_adj_t = None @overload def forward( self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None, return_attention_weights: NoneType = None, ) -> Tensor: pass @overload def forward( # noqa: F811 self, x: Tensor, x_0: Tensor, edge_index: Tensor, edge_weight: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: pass @overload def forward( # noqa: F811 self, x: Tensor, x_0: Tensor, edge_index: SparseTensor, edge_weight: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, SparseTensor]: pass def forward( # noqa: F811 self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None, return_attention_weights: Optional[bool] = None, ) -> Union[ Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, SparseTensor], ]: r"""Runs the forward pass of the module. Args: x (torch.Tensor): The node features. x_0 (torch.Tensor): The initial input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_weight (torch.Tensor, optional): The edge weights. (default: :obj:`None`) return_attention_weights (bool, optional): Will additionally return the tuple :obj:`(edge_index, attention_weights)` whenever it is set to a value, regardless of its actual value (might be `True` or `False`), holding the computed attention weights for each edge. (default: :obj:`None`) """ if self.normalize: if isinstance(edge_index, Tensor): assert edge_weight is None cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, None, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): assert not edge_index.has_value() cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache else: if isinstance(edge_index, Tensor) and not is_torch_sparse_tensor(edge_index): assert edge_weight is not None elif isinstance(edge_index, SparseTensor): assert edge_index.has_value() alpha_l = self.att_l(x) alpha_r = self.att_r(x) # propagate_type: (x: Tensor, alpha: PairTensor, # edge_weight: OptTensor) out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r), edge_weight=edge_weight) alpha = self._alpha self._alpha = None if self.eps != 0.0: out = out + self.eps * x_0 if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): # TODO TorchScript requires to return a tuple adj = set_sparse_value(edge_index, alpha) return out, (adj, alpha) else: return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor, edge_weight: OptTensor) -> Tensor: assert edge_weight is not None alpha = (alpha_j + alpha_i).tanh().squeeze(-1) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * (alpha * edge_weight).view(-1, 1) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.channels}, eps={self.eps})' ================================================ FILE: torch_geometric/nn/conv/feast_conv.py ================================================ from typing import Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import normal from torch_geometric.typing import Adj, PairTensor, SparseTensor, torch_sparse from torch_geometric.utils import add_self_loops, remove_self_loops class FeaStConv(MessagePassing): r"""The (translation-invariant) feature-steered convolutional operator from the `"FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \sum_{h=1}^H q_h(\mathbf{x}_i, \mathbf{x}_j) \mathbf{W}_h \mathbf{x}_j with :math:`q_h(\mathbf{x}_i, \mathbf{x}_j) = \mathrm{softmax}_j (\mathbf{u}_h^{\top} (\mathbf{x}_j - \mathbf{x}_i) + c_h)`, where :math:`H` denotes the number of attention heads, and :math:`\mathbf{W}_h`, :math:`\mathbf{u}_h` and :math:`c_h` are trainable parameters. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. heads (int, optional): Number of attention heads :math:`H`. (default: :obj:`1`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V_t}|, F_{out})` if bipartite """ def __init__(self, in_channels: int, out_channels: int, heads: int = 1, add_self_loops: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'mean') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.add_self_loops = add_self_loops self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='uniform') self.u = Linear(in_channels, heads, bias=False, weight_initializer='uniform') self.c = Parameter(torch.empty(heads)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() self.u.reset_parameters() normal(self.c, mean=0, std=0.1) normal(self.bias, mean=0, std=0.1) def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor: if isinstance(x, Tensor): x = (x, x) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x[1].size(0)) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) # propagate_type: (x: PairTensor) out = self.propagate(edge_index, x=x) if self.bias is not None: out = out + self.bias return out def message(self, x_i: Tensor, x_j: Tensor) -> Tensor: q = self.u(x_j - x_i) + self.c # Translation invariance. q = F.softmax(q, dim=1) x_j = self.lin(x_j).view(x_j.size(0), self.heads, -1) return (x_j * q.view(-1, self.heads, 1)).sum(dim=1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/film_conv.py ================================================ import copy from typing import Callable, Optional, Tuple, Union from torch import Tensor from torch.nn import ModuleList, ReLU from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import reset from torch_geometric.typing import ( Adj, OptTensor, PairTensor, SparseTensor, torch_sparse, ) class FiLMConv(MessagePassing): r"""The FiLM graph convolutional operator from the `"GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}(i)} \sigma \left( \boldsymbol{\gamma}_{r,i} \odot \mathbf{W}_r \mathbf{x}_j + \boldsymbol{\beta}_{r,i} \right) where :math:`\boldsymbol{\beta}_{r,i}, \boldsymbol{\gamma}_{r,i} = g(\mathbf{x}_i)` with :math:`g` being a single linear layer by default. Self-loops are automatically added to the input graph and represented as its own relation type. .. note:: For an example of using FiLM, see `examples/gcn.py `_. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. num_relations (int, optional): Number of relations. (default: :obj:`1`) nn (torch.nn.Module, optional): The neural network :math:`g` that maps node features :obj:`x_i` of shape :obj:`[-1, in_channels]` to shape :obj:`[-1, 2 * out_channels]`. If set to :obj:`None`, :math:`g` will be implemented as a single linear layer. (default: :obj:`None`) act (callable, optional): Activation function :math:`\sigma`. (default: :meth:`torch.nn.ReLU()`) aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"mean"`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge types :math:`(|\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V_t}|, F_{out})` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int = 1, nn: Optional[Callable] = None, act: Optional[Callable] = ReLU(), aggr: str = 'mean', **kwargs, ): super().__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.num_relations = max(num_relations, 1) self.act = act if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lins = ModuleList() self.films = ModuleList() for _ in range(num_relations): self.lins.append(Linear(in_channels[0], out_channels, bias=False)) if nn is None: film = Linear(in_channels[1], 2 * out_channels) else: film = copy.deepcopy(nn) self.films.append(film) self.lin_skip = Linear(in_channels[1], self.out_channels, bias=False) if nn is None: self.film_skip = Linear(in_channels[1], 2 * self.out_channels, bias=False) else: self.film_skip = copy.deepcopy(nn) self.reset_parameters() def reset_parameters(self): super().reset_parameters() for lin, film in zip(self.lins, self.films): lin.reset_parameters() reset(film) self.lin_skip.reset_parameters() reset(self.film_skip) def forward( self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_type: OptTensor = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) beta, gamma = self.film_skip(x[1]).split(self.out_channels, dim=-1) out = gamma * self.lin_skip(x[1]) + beta if self.act is not None: out = self.act(out) # propagate_type: (x: Tensor, beta: Tensor, gamma: Tensor) if self.num_relations <= 1: beta, gamma = self.films[0](x[1]).split(self.out_channels, dim=-1) out = out + self.propagate(edge_index, x=self.lins[0](x[0]), beta=beta, gamma=gamma) else: for i, (lin, film) in enumerate(zip(self.lins, self.films)): beta, gamma = film(x[1]).split(self.out_channels, dim=-1) if isinstance(edge_index, SparseTensor): _edge_type = edge_index.storage.value() assert _edge_type is not None mask = _edge_type == i adj_t = torch_sparse.masked_select_nnz( edge_index, mask, layout='coo') out = out + self.propagate(adj_t, x=lin(x[0]), beta=beta, gamma=gamma) else: assert edge_type is not None mask = edge_type == i out = out + self.propagate(edge_index[:, mask], x=lin( x[0]), beta=beta, gamma=gamma) return out def message(self, x_j: Tensor, beta_i: Tensor, gamma_i: Tensor) -> Tensor: out = gamma_i * x_j + beta_i if self.act is not None: out = self.act(out) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_relations={self.num_relations})') ================================================ FILE: torch_geometric/nn/conv/fused_gat_conv.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.index import index2ptr from torch_geometric.nn.conv import GATConv from torch_geometric.utils import sort_edge_index class FusedGATConv(GATConv): # pragma: no cover r"""The fused graph attention operator from the `"Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective" `_ paper. :class:`FusedGATConv` is an optimized version of :class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`dgNN` package that fuses message passing computation for accelerated execution and lower memory footprint. .. note:: This implementation is based on the :obj:`dgNN` package. See `here `__ for instructions on how to install. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.add_self_loops: raise ValueError(f"'{self.__class__.__name__}' does not support " f"adding self-loops. Please add them manually " f"in a pre-processing step and set " f"`add_self_loops=False`.") if self.edge_dim is not None: raise ValueError(f"'{self.__class__.__name__}' does not support " f"edge features. Set `edge_dim=None` in order " f"to proceed.") from dgNN.operators import GATConvFuse self.op = GATConvFuse @staticmethod def to_graph_format( edge_index: Tensor, size: Optional[Tuple[int, int]] = None, ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor]: r"""Converts an :obj:`edge_index` representation of a graph to the desired input format of :class:`FusedGATConv`. Args: edge_index (torch.Tensor): The edge indices. size ((int, int), optional): The shape of :obj:`edge_index` in each dimension. (default: :obj:`None`) """ edge_index = edge_index.to(torch.int) edge_index = sort_edge_index(edge_index, sort_by_row=True) rowptr = index2ptr(edge_index[0], size=size[0] if size else None) col = edge_index[1] device = edge_index.device perm = torch.arange(edge_index.size(1), dtype=torch.int, device=device) edge_index, perm = sort_edge_index(edge_index, perm, sort_by_row=False) row = edge_index[0] colptr = index2ptr(edge_index[1], size=size[1] if size else None) return (rowptr, col), (row, colptr), perm def forward( self, x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor, ) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): The node features. csr ((torch.Tensor, torch.Tensor)): A tuple containing the CSR representation of a graph, given as a tuple of :obj:`(rowptr, col)`. csc ((torch.Tensor, torch.Tensor)): A tuple containing the CSC representation of a graph, given as a tuple of :obj:`(row, colptr)`. perm (torch.Tensor): Permutation tensor to map the CSR representation to the CSC representation. .. note:: Use the :meth:`~torch_geometric.nn.conv.FusedGATConv.to_graph_format` method to obtain the :obj:`(csr, csc, perm)` graph format from an existing :obj:`edge_index` representation. """ H, C = self.heads, self.out_channels assert x.dim() == 2, "Static graphs not supported in 'GATConv'" x = self.lin_src(x).view(-1, H, C) alpha_src = (x * self.att_src).sum(dim=-1) alpha_dst = (x * self.att_dst).sum(dim=-1) dropout = self.dropout if self.training else 0.0 (rowptr, col), (row, colptr) = csr, csc out = self.op(alpha_dst, alpha_src, rowptr, col, colptr, row, perm, self.negative_slope, x, dropout) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias return out ================================================ FILE: torch_geometric/nn/conv/gat_conv.py ================================================ import typing from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import ( Adj, NoneType, OptPairTensor, OptTensor, Size, SparseTensor, torch_sparse, ) from torch_geometric.utils import ( add_self_loops, is_torch_sparse_tensor, remove_self_loops, softmax, ) from torch_geometric.utils.sparse import set_sparse_value if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload_method as overload class GATConv(MessagePassing): r"""The graph attentional operator from the `"Graph Attention Networks" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_t\mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t}\mathbf{\Theta}_{t}\mathbf{x}_k \right)\right)}. If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_k + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}. If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities in case of a bipartite graph. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) fill_value (float or torch.Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_dim != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. If :obj:`return_attention_weights=True`, then :math:`((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, residual: bool = False, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.add_self_loops = add_self_loops self.edge_dim = edge_dim self.fill_value = fill_value self.residual = residual # In case we are operating in bipartite graphs, we apply separate # transformations 'lin_src' and 'lin_dst' to source and target nodes: self.lin = self.lin_src = self.lin_dst = None if isinstance(in_channels, int): self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') else: self.lin_src = Linear(in_channels[0], heads * out_channels, False, weight_initializer='glorot') self.lin_dst = Linear(in_channels[1], heads * out_channels, False, weight_initializer='glorot') # The learnable parameters to compute attention coefficients: self.att_src = Parameter(torch.empty(1, heads, out_channels)) self.att_dst = Parameter(torch.empty(1, heads, out_channels)) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, weight_initializer='glorot') self.att_edge = Parameter(torch.empty(1, heads, out_channels)) else: self.lin_edge = None self.register_parameter('att_edge', None) # The number of output channels: total_out_channels = out_channels * (heads if concat else 1) if residual: self.res = Linear( in_channels if isinstance(in_channels, int) else in_channels[1], total_out_channels, bias=False, weight_initializer='glorot', ) else: self.register_parameter('res', None) if bias: self.bias = Parameter(torch.empty(total_out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() if self.lin is not None: self.lin.reset_parameters() if self.lin_src is not None: self.lin_src.reset_parameters() if self.lin_dst is not None: self.lin_dst.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() if self.res is not None: self.res.reset_parameters() glorot(self.att_src) glorot(self.att_dst) glorot(self.att_edge) zeros(self.bias) @overload def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: NoneType = None, ) -> Tensor: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, OptPairTensor], edge_index: Tensor, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, OptPairTensor], edge_index: SparseTensor, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, SparseTensor]: pass def forward( # noqa: F811 self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: Optional[bool] = None, ) -> Union[ Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, SparseTensor], ]: r"""Runs the forward pass of the module. Args: x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) size ((int, int), optional): The shape of the adjacency matrix. (default: :obj:`None`) return_attention_weights (bool, optional): Will additionally return the tuple :obj:`(edge_index, attention_weights)` whenever it is set to a value, regardless of its actual value (might be `True` or `False`), holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels res: Optional[Tensor] = None # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" if self.res is not None: res = self.res(x) if self.lin is not None: x_src = x_dst = self.lin(x).view(-1, H, C) else: # If the module is initialized as bipartite, transform source # and destination node features separately: assert self.lin_src is not None and self.lin_dst is not None x_src = self.lin_src(x).view(-1, H, C) x_dst = self.lin_dst(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" if x_dst is not None and self.res is not None: res = self.res(x_dst) if self.lin is not None: # If the module is initialized as non-bipartite, we expect that # source and destination node features have the same shape and # that they their transformations are shared: x_src = self.lin(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin(x_dst).view(-1, H, C) else: assert self.lin_src is not None and self.lin_dst is not None x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = torch_sparse.set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, size=size) # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if res is not None: out = out + res if self.bias is not None: out = out + self.bias if return_attention_weights: if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): # TODO TorchScript requires to return a tuple adj = set_sparse_value(edge_index, alpha) return out, (adj, alpha) else: return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') return out def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, dim_size: Optional[int]) -> Tensor: # Given edge-level attention coefficients for source and target nodes, # we simply need to sum them up to "emulate" concatenation: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i if index.numel() == 0: return alpha if edge_attr is not None and self.lin_edge is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) alpha_edge = (edge_attr * self.att_edge).sum(dim=-1) alpha = alpha + alpha_edge alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, dim_size) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: return alpha.unsqueeze(-1) * x_j def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/gated_graph_conv.py ================================================ import torch from torch import Tensor from torch.nn import Parameter as Param from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import uniform from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import spmm class GatedGraphConv(MessagePassing): r"""The gated graph convolution operator from the `"Gated Graph Sequence Neural Networks" `_ paper. .. math:: \mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0} \mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)} \mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)}) up to representation :math:`\mathbf{h}_i^{(L)}`. The number of input channels of :math:`\mathbf{x}_i` needs to be less or equal than :obj:`out_channels`. :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target node :obj:`i` (default: :obj:`1`) Args: out_channels (int): Size of each output sample. num_layers (int): The sequence length :math:`L`. aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs): super().__init__(aggr=aggr, **kwargs) self.out_channels = out_channels self.num_layers = num_layers self.weight = Param(Tensor(num_layers, out_channels, out_channels)) self.rnn = torch.nn.GRUCell(out_channels, out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): super().reset_parameters() uniform(self.out_channels, self.weight) self.rnn.reset_parameters() def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if x.size(-1) > self.out_channels: raise ValueError('The number of input channels is not allowed to ' 'be larger than the number of output channels') if x.size(-1) < self.out_channels: zero = x.new_zeros(x.size(0), self.out_channels - x.size(-1)) x = torch.cat([x, zero], dim=1) for i in range(self.num_layers): m = torch.matmul(x, self.weight[i]) # propagate_type: (x: Tensor, edge_weight: OptTensor) m = self.propagate(edge_index, x=m, edge_weight=edge_weight) x = self.rnn(m, x) return x def message(self, x_j: Tensor, edge_weight: OptTensor): return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.out_channels}, ' f'num_layers={self.num_layers})') ================================================ FILE: torch_geometric/nn/conv/gatv2_conv.py ================================================ import typing from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import ( Adj, NoneType, OptTensor, PairTensor, SparseTensor, torch_sparse, ) from torch_geometric.utils import ( add_self_loops, is_torch_sparse_tensor, remove_self_loops, softmax, ) from torch_geometric.utils.sparse import set_sparse_value if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload_method as overload class GATv2Conv(MessagePassing): r"""The GATv2 operator from the `"How Attentive are Graph Attention Networks?" `_ paper, which fixes the static attention problem of the standard :class:`~torch_geometric.conv.GATConv` layer. Since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, in :class:`GATv2`, every node can attend to any other node. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k \right)\right)}. If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j + \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k + \mathbf{\Theta}_{e} \mathbf{e}_{i,k}] \right)\right)}. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities in case of a bipartite graph. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) fill_value (float or torch.Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_dim != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) share_weights (bool, optional): If set to :obj:`True`, the same matrix will be applied to the source and the target node of every edge, *i.e.* :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. (default: :obj:`False`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. If :obj:`return_attention_weights=True`, then :math:`((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, share_weights: bool = False, residual: bool = False, **kwargs, ): super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.add_self_loops = add_self_loops self.edge_dim = edge_dim self.fill_value = fill_value self.residual = residual self.share_weights = share_weights if isinstance(in_channels, int): self.lin_l = Linear(in_channels, heads * out_channels, bias=bias, weight_initializer='glorot') if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear(in_channels, heads * out_channels, bias=bias, weight_initializer='glorot') else: self.lin_l = Linear(in_channels[0], heads * out_channels, bias=bias, weight_initializer='glorot') if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear(in_channels[1], heads * out_channels, bias=bias, weight_initializer='glorot') self.att = Parameter(torch.empty(1, heads, out_channels)) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, weight_initializer='glorot') else: self.lin_edge = None # The number of output channels: total_out_channels = out_channels * (heads if concat else 1) if residual: self.res = Linear( in_channels if isinstance(in_channels, int) else in_channels[1], total_out_channels, bias=False, weight_initializer='glorot', ) else: self.register_parameter('res', None) if bias: self.bias = Parameter(torch.empty(total_out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_l.reset_parameters() self.lin_r.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() if self.res is not None: self.res.reset_parameters() glorot(self.att) zeros(self.bias) @overload def forward( self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: NoneType = None, ) -> Tensor: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: Tensor, edge_attr: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: SparseTensor, edge_attr: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, SparseTensor]: pass def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: Optional[bool] = None, ) -> Union[ Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, SparseTensor], ]: r"""Runs the forward pass of the module. Args: x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) return_attention_weights (bool, optional): Will additionally return the tuple :obj:`(edge_index, attention_weights)` whenever it is set to a value, regardless of its actual value (might be `True` or `False`), holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels res: Optional[Tensor] = None x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 if self.res is not None: res = self.res(x) x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 if x_r is not None and self.res is not None: res = self.res(x_r) x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = torch_sparse.set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # edge_updater_type: (x: PairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, x=(x_l, x_r), edge_attr=edge_attr) # propagate_type: (x: PairTensor, alpha: Tensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if res is not None: out = out + res if self.bias is not None: out = out + self.bias if return_attention_weights: if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): # TODO TorchScript requires to return a tuple adj = set_sparse_value(edge_index, alpha) return out, (adj, alpha) else: return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') return out def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, dim_size: Optional[int]) -> Tensor: x = x_i + x_j if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) x = x + edge_attr x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, dim_size) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: return x_j * alpha.unsqueeze(-1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/gcn2_conv.py ================================================ from math import log from typing import Optional import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.inits import glorot from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor from torch_geometric.utils import spmm class GCN2Conv(MessagePassing): r"""The graph convolutional operator with initial residual connections and identity mapping (GCNII) from the `"Simple and Deep Graph Convolutional Networks" `_ paper. .. math:: \mathbf{X}^{\prime} = \left( (1 - \alpha) \mathbf{\hat{P}}\mathbf{X} + \alpha \mathbf{X^{(0)}}\right) \left( (1 - \beta) \mathbf{I} + \beta \mathbf{\Theta} \right) with :math:`\mathbf{\hat{P}} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}`, where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix, and :math:`\mathbf{X}^{(0)}` being the initial feature representation. Here, :math:`\alpha` models the strength of the initial residual connection, while :math:`\beta` models the strength of the identity mapping. The adjacency matrix can include other values than :obj:`1` representing edge weights via the optional :obj:`edge_weight` tensor. Args: channels (int): Size of each input and output sample. alpha (float): The strength of the initial residual connection :math:`\alpha`. theta (float, optional): The hyperparameter :math:`\theta` to compute the strength of the identity mapping :math:`\beta = \log \left( \frac{\theta}{\ell} + 1 \right)`. (default: :obj:`None`) layer (int, optional): The layer :math:`\ell` in which this module is executed. (default: :obj:`None`) shared_weights (bool, optional): If set to :obj:`False`, will use different weight matrices for the smoothed representation and the initial residual ("GCNII*"). (default: :obj:`True`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the cached version for further executions. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) normalize (bool, optional): Whether to add self-loops and apply symmetric normalization. (default: :obj:`True`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)`, initial node features :math:`(|\mathcal{V}|, F)`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F)` """ _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] def __init__(self, channels: int, alpha: float, theta: float = None, layer: int = None, shared_weights: bool = True, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.channels = channels self.alpha = alpha self.beta = 1. if theta is not None or layer is not None: assert theta is not None and layer is not None self.beta = log(theta / layer + 1) self.cached = cached self.normalize = normalize self.add_self_loops = add_self_loops self._cached_edge_index = None self._cached_adj_t = None self.weight1 = Parameter(torch.empty(channels, channels)) if shared_weights: self.register_parameter('weight2', None) else: self.weight2 = Parameter(torch.empty(channels, channels)) self.reset_parameters() def reset_parameters(self): super().reset_parameters() glorot(self.weight1) glorot(self.weight2) self._cached_edge_index = None self._cached_adj_t = None def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight) x.mul_(1 - self.alpha) x_0 = self.alpha * x_0[:x.size(0)] if self.weight2 is None: out = x.add_(x_0) out = torch.addmm(out, out, self.weight1, beta=1. - self.beta, alpha=self.beta) else: out = torch.addmm(x, x, self.weight1, beta=1. - self.beta, alpha=self.beta) out = out + torch.addmm(x_0, x_0, self.weight2, beta=1. - self.beta, alpha=self.beta) return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'alpha={self.alpha}, beta={self.beta})') ================================================ FILE: torch_geometric/nn/conv/gcn_conv.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import ( Adj, OptPairTensor, OptTensor, SparseTensor, torch_sparse, ) from torch_geometric.utils import add_remaining_self_loops from torch_geometric.utils import add_self_loops as add_self_loops_fn from torch_geometric.utils import ( is_torch_sparse_tensor, scatter, spmm, to_edge_index, ) from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.utils.sparse import set_sparse_value @torch.jit._overload def gcn_norm( # noqa: F811 edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype): # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor # noqa pass @torch.jit._overload def gcn_norm( # noqa: F811 edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype): # type: (SparseTensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> SparseTensor # noqa pass def gcn_norm( # noqa: F811 edge_index: Adj, edge_weight: OptTensor = None, num_nodes: Optional[int] = None, improved: bool = False, add_self_loops: bool = True, flow: str = "source_to_target", dtype: Optional[torch.dtype] = None, ): fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): assert edge_index.size(0) == edge_index.size(1) adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: adj_t = torch_sparse.fill_diag(adj_t, fill_value) deg = torch_sparse.sum(adj_t, dim=1) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(1, -1)) return adj_t if is_torch_sparse_tensor(edge_index): assert edge_index.size(0) == edge_index.size(1) if edge_index.layout == torch.sparse_csc: raise NotImplementedError("Sparse CSC matrices are not yet " "supported in 'gcn_norm'") adj_t = edge_index if add_self_loops: adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes) edge_index, value = to_edge_index(adj_t) col, row = edge_index[0], edge_index[1] deg = scatter(value, col, 0, dim_size=num_nodes, reduce='sum') deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) value = deg_inv_sqrt[row] * value * deg_inv_sqrt[col] return set_sparse_value(adj_t, value), None assert flow in ['source_to_target', 'target_to_source'] num_nodes = maybe_num_nodes(edge_index, num_nodes) if add_self_loops: edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) row, col = edge_index[0], edge_index[1] idx = col if flow == 'source_to_target' else row deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum') deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] return edge_index, edge_weight class GCNConv(MessagePassing): r"""The graph convolutional operator from the `"Semi-supervised Classification with Graph Convolutional Networks" `_ paper. .. math:: \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. The adjacency matrix can include other values than :obj:`1` representing edge weights via the optional :obj:`edge_weight` tensor. Its node-wise formulation is given by: .. math:: \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target node :obj:`i` (default: :obj:`1.0`) Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. improved (bool, optional): If set to :obj:`True`, the layer computes :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. (default: :obj:`False`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the cached version for further executions. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. By default, self-loops will be added in case :obj:`normalize` is set to :obj:`True`, and not added otherwise. (default: :obj:`None`) normalize (bool, optional): Whether to add self-loops and compute symmetric normalization coefficients on-the-fly. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)` or sparse matrix :math:`(|\mathcal{V}|, |\mathcal{V}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] def __init__( self, in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: Optional[bool] = None, normalize: bool = True, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) if add_self_loops is None: add_self_loops = normalize if add_self_loops and not normalize: raise ValueError(f"'{self.__class__.__name__}' does not support " f"adding self-loops to the graph when no " f"on-the-fly normalization is applied") self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.cached = cached self.add_self_loops = add_self_loops self.normalize = normalize self._cached_edge_index = None self._cached_adj_t = None self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() zeros(self.bias) self._cached_edge_index = None self._cached_adj_t = None def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if isinstance(x, (tuple, list)): raise ValueError(f"'{self.__class__.__name__}' received a tuple " f"of node features as input while this layer " f"does not support bipartite message passing. " f"Please try other layers such as 'SAGEConv' or " f"'GraphConv' instead") if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache x = self.lin(x) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) ================================================ FILE: torch_geometric/nn/conv/gen_conv.py ================================================ from typing import List, Optional, Tuple, Union from torch import Tensor from torch.nn import ( BatchNorm1d, Dropout, InstanceNorm1d, LayerNorm, ReLU, Sequential, ) from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import reset from torch_geometric.nn.norm import MessageNorm from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size class MLP(Sequential): def __init__(self, channels: List[int], norm: Optional[str] = None, bias: bool = True, dropout: float = 0.): m = [] for i in range(1, len(channels)): m.append(Linear(channels[i - 1], channels[i], bias=bias)) if i < len(channels) - 1: if norm and norm == 'batch': m.append(BatchNorm1d(channels[i], affine=True)) elif norm and norm == 'layer': m.append(LayerNorm(channels[i], elementwise_affine=True)) elif norm and norm == 'instance': m.append(InstanceNorm1d(channels[i], affine=False)) elif norm: raise NotImplementedError( f'Normalization layer "{norm}" not supported.') m.append(ReLU()) m.append(Dropout(dropout)) super().__init__(*m) class GENConv(MessagePassing): r"""The GENeralized Graph Convolution (GENConv) from the `"DeeperGCN: All You Need to Train Deeper GCNs" `_ paper. :class:`GENConv` supports both :math:`\textrm{softmax}` (see :class:`~torch_geometric.nn.aggr.SoftmaxAggregation`) and :math:`\textrm{powermean}` (see :class:`~torch_geometric.nn.aggr.PowerMeanAggregation`) aggregation. Its message construction is given by: .. math:: \mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_i + \mathrm{AGG} \left( \left\{ \mathrm{ReLU} \left( \mathbf{x}_j + \mathbf{e_{ji}} \right) +\epsilon : j \in \mathcal{N}(i) \right\} \right) \right) .. note:: For an example of using :obj:`GENConv`, see `examples/ogbn_proteins_deepgcn.py `_. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. aggr (str or Aggregation, optional): The aggregation scheme to use. Any aggregation of :obj:`torch_geometric.nn.aggr` can be used, (:obj:`"softmax"`, :obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`, :obj:`max`). (default: :obj:`"softmax"`) t (float, optional): Initial inverse temperature for softmax aggregation. (default: :obj:`1.0`) learn_t (bool, optional): If set to :obj:`True`, will learn the value :obj:`t` for softmax aggregation dynamically. (default: :obj:`False`) p (float, optional): Initial power for power mean aggregation. (default: :obj:`1.0`) learn_p (bool, optional): If set to :obj:`True`, will learn the value :obj:`p` for power mean aggregation dynamically. (default: :obj:`False`) msg_norm (bool, optional): If set to :obj:`True`, will use message normalization. (default: :obj:`False`) learn_msg_scale (bool, optional): If set to :obj:`True`, will learn the scaling factor of message normalization. (default: :obj:`False`) norm (str, optional): Norm layer of MLP layers (:obj:`"batch"`, :obj:`"layer"`, :obj:`"instance"`) (default: :obj:`batch`) num_layers (int, optional): The number of MLP layers. (default: :obj:`2`) expansion (int, optional): The expansion factor of hidden channels in MLP layers. (default: :obj:`2`) eps (float, optional): The epsilon value of the message construction function. (default: :obj:`1e-7`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) edge_dim (int, optional): Edge feature dimensionality. If set to :obj:`None`, Edge feature dimensionality is expected to match the `out_channels`. Other-wise, edge features are linearly transformed to match `out_channels` of node feature dimensionality. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GenMessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge attributes :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: Optional[Union[str, List[str], Aggregation]] = 'softmax', t: float = 1.0, learn_t: bool = False, p: float = 1.0, learn_p: bool = False, msg_norm: bool = False, learn_msg_scale: bool = False, norm: str = 'batch', num_layers: int = 2, expansion: int = 2, eps: float = 1e-7, bias: bool = False, edge_dim: Optional[int] = None, **kwargs, ): # Backward compatibility: semi_grad = True if aggr == 'softmax_sg' else False aggr = 'softmax' if aggr == 'softmax_sg' else aggr aggr = 'powermean' if aggr == 'power' else aggr # Override args of aggregator if `aggr_kwargs` is specified if 'aggr_kwargs' not in kwargs: if aggr == 'softmax': kwargs['aggr_kwargs'] = dict(t=t, learn=learn_t, semi_grad=semi_grad) elif aggr == 'powermean': kwargs['aggr_kwargs'] = dict(p=p, learn=learn_p) super().__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.eps = eps if isinstance(in_channels, int): in_channels = (in_channels, in_channels) if in_channels[0] != out_channels: self.lin_src = Linear(in_channels[0], out_channels, bias=bias) if edge_dim is not None and edge_dim != out_channels: self.lin_edge = Linear(edge_dim, out_channels, bias=bias) if isinstance(self.aggr_module, MultiAggregation): aggr_out_channels = self.aggr_module.get_out_channels(out_channels) else: aggr_out_channels = out_channels if aggr_out_channels != out_channels: self.lin_aggr_out = Linear(aggr_out_channels, out_channels, bias=bias) if in_channels[1] != out_channels: self.lin_dst = Linear(in_channels[1], out_channels, bias=bias) channels = [out_channels] for _ in range(num_layers - 1): channels.append(out_channels * expansion) channels.append(out_channels) self.mlp = MLP(channels, norm=norm, bias=bias) if msg_norm: self.msg_norm = MessageNorm(learn_msg_scale) def reset_parameters(self): super().reset_parameters() reset(self.mlp) if hasattr(self, 'msg_norm'): self.msg_norm.reset_parameters() if hasattr(self, 'lin_src'): self.lin_src.reset_parameters() if hasattr(self, 'lin_edge'): self.lin_edge.reset_parameters() if hasattr(self, 'lin_aggr_out'): self.lin_aggr_out.reset_parameters() if hasattr(self, 'lin_dst'): self.lin_dst.reset_parameters() def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None) -> Tensor: if isinstance(x, Tensor): x = (x, x) if hasattr(self, 'lin_src'): x = (self.lin_src(x[0]), x[1]) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) if hasattr(self, 'lin_aggr_out'): out = self.lin_aggr_out(out) if hasattr(self, 'msg_norm'): h = x[1] if x[1] is not None else x[0] assert h is not None out = self.msg_norm(h, out) x_dst = x[1] if x_dst is not None: if hasattr(self, 'lin_dst'): x_dst = self.lin_dst(x_dst) out = out + x_dst return self.mlp(out) def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor: if edge_attr is not None and hasattr(self, 'lin_edge'): edge_attr = self.lin_edge(edge_attr) if edge_attr is not None: assert x_j.size(-1) == edge_attr.size(-1) msg = x_j if edge_attr is None else x_j + edge_attr return msg.relu() + self.eps def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, aggr={self.aggr})') ================================================ FILE: torch_geometric/nn/conv/general_conv.py ================================================ from typing import Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot from torch_geometric.typing import ( Adj, Optional, OptPairTensor, OptTensor, Size, ) from torch_geometric.utils import softmax class GeneralConv(MessagePassing): r"""A general GNN layer adapted from the `"Design Space for Graph Neural Networks" `_ paper. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. in_edge_channels (int, optional): Size of each input edge. (default: :obj:`None`) aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"mean"`) skip_linear (bool, optional): Whether apply linear function in skip connection. (default: :obj:`False`) directed_msg (bool, optional): If message passing is directed; otherwise, message passing is bi-directed. (default: :obj:`True`) heads (int, optional): Number of message passing ensembles. If :obj:`heads > 1`, the GNN layer will output an ensemble of multiple messages. If attention is used (:obj:`attention=True`), this corresponds to multi-head attention. (default: :obj:`1`) attention (bool, optional): Whether to add attention to message computation. (default: :obj:`False`) attention_type (str, optional): Type of attention: :obj:`"additive"`, :obj:`"dot_product"`. (default: :obj:`"additive"`) l2_normalize (bool, optional): If set to :obj:`True`, output features will be :math:`\ell_2`-normalized, *i.e.*, :math:`\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}`. (default: :obj:`False`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge attributes :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: Optional[int], in_edge_channels: Optional[int] = None, aggr: str = "add", skip_linear: str = False, directed_msg: bool = True, heads: int = 1, attention: bool = False, attention_type: str = "additive", l2_normalize: bool = False, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', aggr) super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.in_edge_channels = in_edge_channels self.aggr = aggr self.skip_linear = skip_linear self.directed_msg = directed_msg self.heads = heads self.attention = attention self.attention_type = attention_type self.normalize_l2 = l2_normalize if isinstance(in_channels, int): in_channels = (in_channels, in_channels) if self.directed_msg: self.lin_msg = Linear(in_channels[0], out_channels * self.heads, bias=bias) else: self.lin_msg = Linear(in_channels[0], out_channels * self.heads, bias=bias) self.lin_msg_i = Linear(in_channels[0], out_channels * self.heads, bias=bias) if self.skip_linear or self.in_channels != self.out_channels: self.lin_self = Linear(in_channels[1], out_channels, bias=bias) else: self.lin_self = torch.nn.Identity() if self.in_edge_channels is not None: self.lin_edge = Linear(in_edge_channels, out_channels * self.heads, bias=bias) # TODO: A general torch_geometric.nn.AttentionLayer if self.attention: if self.attention_type == 'additive': self.att_msg = Parameter( torch.empty(1, self.heads, self.out_channels)) elif self.attention_type == 'dot_product': scaler = torch.tensor(out_channels, dtype=torch.float).sqrt() self.register_buffer('scaler', scaler) else: raise ValueError( f"Attention type '{self.attention_type}' not supported") self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_msg.reset_parameters() if hasattr(self.lin_self, 'reset_parameters'): self.lin_self.reset_parameters() if self.in_edge_channels is not None: self.lin_edge.reset_parameters() if self.attention and self.attention_type == 'additive': glorot(self.att_msg) def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x: OptPairTensor = (x, x) x_self = x[1] # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, size=size, edge_attr=edge_attr) out = out.mean(dim=1) # todo: other approach to aggregate heads out = out + self.lin_self(x_self) if self.normalize_l2: out = F.normalize(out, p=2, dim=-1) return out def message_basic(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor): if self.directed_msg: x_j = self.lin_msg(x_j) else: x_j = self.lin_msg(x_j) + self.lin_msg_i(x_i) if edge_attr is not None: x_j = x_j + self.lin_edge(edge_attr) return x_j def message(self, x_i: Tensor, x_j: Tensor, edge_index_i: Tensor, size_i: Tensor, edge_attr: Tensor) -> Tensor: x_j_out = self.message_basic(x_i, x_j, edge_attr) x_j_out = x_j_out.view(-1, self.heads, self.out_channels) if self.attention: if self.attention_type == 'dot_product': x_i_out = self.message_basic(x_j, x_i, edge_attr) x_i_out = x_i_out.view(-1, self.heads, self.out_channels) alpha = (x_i_out * x_j_out).sum(dim=-1) / self.scaler else: alpha = (x_j_out * self.att_msg).sum(dim=-1) alpha = F.leaky_relu(alpha, negative_slope=0.2) alpha = softmax(alpha, edge_index_i, num_nodes=size_i) alpha = alpha.view(-1, self.heads, 1) return x_j_out * alpha else: return x_j_out ================================================ FILE: torch_geometric/nn/conv/gin_conv.py ================================================ from typing import Callable, Optional, Union import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import reset from torch_geometric.typing import ( Adj, OptPairTensor, OptTensor, Size, SparseTensor, ) from torch_geometric.utils import spmm class GINConv(MessagePassing): r"""The graph isomorphism operator from the `"How Powerful are Graph Neural Networks?" `_ paper. .. math:: \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) or .. math:: \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right), here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP. Args: nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. eps (float, optional): (Initial) :math:`\epsilon`-value. (default: :obj:`0.`) train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` will be a trainable parameter. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.nn = nn self.initial_eps = eps if train_eps: self.eps = torch.nn.Parameter(torch.empty(1)) else: self.register_buffer('eps', torch.empty(1)) self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.nn) self.eps.data.fill_(self.initial_eps) def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor) out = self.propagate(edge_index, x=x, size=size) x_r = x[1] if x_r is not None: out = out + (1 + self.eps) * x_r return self.nn(out) def message(self, x_j: Tensor) -> Tensor: return x_j def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: if isinstance(adj_t, SparseTensor): adj_t = adj_t.set_value(None, layout=None) return spmm(adj_t, x[0], reduce=self.aggr) def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn})' class GINEConv(MessagePassing): r"""The modified :class:`GINConv` operator from the `"Strategies for Pre-training Graph Neural Networks" `_ paper. .. math:: \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU} ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right) that is able to incorporate edge features :math:`\mathbf{e}_{j,i}` into the aggregation procedure. Args: nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. eps (float, optional): (Initial) :math:`\epsilon`-value. (default: :obj:`0.`) train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` will be a trainable parameter. (default: :obj:`False`) edge_dim (int, optional): Edge feature dimensionality. If set to :obj:`None`, node and edge feature dimensionality is expected to match. Other-wise, edge features are linearly transformed to match node feature dimensionality. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, nn: torch.nn.Module, eps: float = 0., train_eps: bool = False, edge_dim: Optional[int] = None, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.nn = nn self.initial_eps = eps if train_eps: self.eps = torch.nn.Parameter(torch.empty(1)) else: self.register_buffer('eps', torch.empty(1)) if edge_dim is not None: if isinstance(self.nn, torch.nn.Sequential): nn = self.nn[0] if hasattr(nn, 'in_features'): in_channels = nn.in_features elif hasattr(nn, 'in_channels'): in_channels = nn.in_channels else: raise ValueError("Could not infer input channels from `nn`.") self.lin = Linear(edge_dim, in_channels) else: self.lin = None self.reset_parameters() def reset_parameters(self): reset(self.nn) self.eps.data.fill_(self.initial_eps) if self.lin is not None: self.lin.reset_parameters() def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None: out = out + (1 + self.eps) * x_r return self.nn(out) def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: if self.lin is None and x_j.size(-1) != edge_attr.size(-1): raise ValueError("Node and edge feature dimensionalities do not " "match. Consider setting the 'edge_dim' " "attribute of 'GINEConv'") if self.lin is not None: edge_attr = self.lin(edge_attr) return (x_j + edge_attr).relu() def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn})' ================================================ FILE: torch_geometric/nn/conv/gmm_conv.py ================================================ from typing import Tuple, Union import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size class GMMConv(MessagePassing): r"""The gaussian mixture model convolutional operator from the `"Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \frac{1}{K} \sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{i,j}) \odot \mathbf{\Theta}_k \mathbf{x}_j, where .. math:: \mathbf{w}_k(\mathbf{e}) = \exp \left( -\frac{1}{2} {\left( \mathbf{e} - \mathbf{\mu}_k \right)}^{\top} \Sigma_k^{-1} \left( \mathbf{e} - \mathbf{\mu}_k \right) \right) denotes a weighting function based on trainable mean vector :math:`\mathbf{\mu}_k` and diagonal covariance matrix :math:`\mathbf{\Sigma}_k`. .. note:: The edge attribute :math:`\mathbf{e}_{ij}` is usually given by :math:`\mathbf{e}_{ij} = \mathbf{p}_j - \mathbf{p}_i`, where :math:`\mathbf{p}_i` denotes the position of node :math:`i` (see :class:`torch_geometric.transform.Cartesian`). Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. dim (int): Pseudo-coordinate dimensionality. kernel_size (int): Number of kernels :math:`K`. separate_gaussians (bool, optional): If set to :obj:`True`, will learn separate GMMs for every pair of input and output channel, inspired by traditional CNNs. (default: :obj:`False`) aggr (str, optional): The aggregation operator to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"mean"`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: int, separate_gaussians: bool = False, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs): super().__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.dim = dim self.kernel_size = kernel_size self.separate_gaussians = separate_gaussians self.root_weight = root_weight if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.rel_in_channels = in_channels[0] if in_channels[0] > 0: self.g = Parameter( Tensor(in_channels[0], out_channels * kernel_size)) if not self.separate_gaussians: self.mu = Parameter(Tensor(kernel_size, dim)) self.sigma = Parameter(Tensor(kernel_size, dim)) if self.separate_gaussians: self.mu = Parameter( Tensor(in_channels[0], out_channels, kernel_size, dim)) self.sigma = Parameter( Tensor(in_channels[0], out_channels, kernel_size, dim)) else: self.g = torch.nn.parameter.UninitializedParameter() self.mu = torch.nn.parameter.UninitializedParameter() self.sigma = torch.nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) if root_weight: self.root = Linear(in_channels[1], out_channels, bias=False, weight_initializer='glorot') if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() if not isinstance(self.g, torch.nn.UninitializedParameter): glorot(self.g) glorot(self.mu) glorot(self.sigma) if self.root_weight: self.root.reset_parameters() zeros(self.bias) def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None): if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) if not self.separate_gaussians: out: OptPairTensor = (torch.matmul(x[0], self.g), x[1]) out = self.propagate(edge_index, x=out, edge_attr=edge_attr, size=size) else: out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None and self.root is not None: out = out + self.root(x_r) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: EPS = 1e-15 F, M = self.rel_in_channels, self.out_channels (E, D), K = edge_attr.size(), self.kernel_size if not self.separate_gaussians: gaussian = -0.5 * (edge_attr.view(E, 1, D) - self.mu.view(1, K, D)).pow(2) gaussian = gaussian / (EPS + self.sigma.view(1, K, D).pow(2)) gaussian = torch.exp(gaussian.sum(dim=-1)) # [E, K] return (x_j.view(E, K, M) * gaussian.view(E, K, 1)).sum(dim=-2) else: gaussian = -0.5 * (edge_attr.view(E, 1, 1, 1, D) - self.mu.view(1, F, M, K, D)).pow(2) gaussian = gaussian / (EPS + self.sigma.view(1, F, M, K, D).pow(2)) gaussian = torch.exp(gaussian.sum(dim=-1)) # [E, F, M, K] gaussian = gaussian * self.g.view(1, F, M, K) gaussian = gaussian.sum(dim=-1) # [E, F, M] return (x_j.view(E, F, 1) * gaussian).sum(dim=-2) # [E, M] @torch.no_grad() def initialize_parameters(self, module, input): if isinstance(self.g, torch.nn.parameter.UninitializedParameter): x = input[0][0] if isinstance(input, tuple) else input[0] in_channels = x.size(-1) out_channels, kernel_size = self.out_channels, self.kernel_size self.g.materialize((in_channels, out_channels * kernel_size)) if not self.separate_gaussians: self.mu.materialize((kernel_size, self.dim)) self.sigma.materialize((kernel_size, self.dim)) else: self.mu.materialize( (in_channels, out_channels, kernel_size, self.dim)) self.sigma.materialize( (in_channels, out_channels, kernel_size, self.dim)) glorot(self.g) glorot(self.mu) glorot(self.sigma) module._hook.remove() delattr(module, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, dim={self.dim})') ================================================ FILE: torch_geometric/nn/conv/gps_conv.py ================================================ import inspect from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Dropout, Linear, Sequential from torch_geometric.nn.attention import PerformerAttention from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset from torch_geometric.nn.resolver import ( activation_resolver, normalization_resolver, ) from torch_geometric.typing import Adj from torch_geometric.utils import to_dense_batch class GPSConv(torch.nn.Module): r"""The general, powerful, scalable (GPS) graph transformer layer from the `"Recipe for a General, Powerful, Scalable Graph Transformer" `_ paper. The GPS layer is based on a 3-part recipe: 1. Inclusion of positional (PE) and structural encodings (SE) to the input features (done in a pre-processing step via :class:`torch_geometric.transforms`). 2. A local message passing layer (MPNN) that operates on the input graph. 3. A global attention layer that operates on the entire graph. .. note:: For an example of using :class:`GPSConv`, see `examples/graph_gps.py `_. Args: channels (int): Size of each input sample. conv (MessagePassing, optional): The local message passing layer. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) dropout (float, optional): Dropout probability of intermediate embeddings. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`"batch_norm"`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) attn_type (str): Global attention type, :obj:`multihead` or :obj:`performer`. (default: :obj:`multihead`) attn_kwargs (Dict[str, Any], optional): Arguments passed to the attention layer. (default: :obj:`None`) """ def __init__( self, channels: int, conv: Optional[MessagePassing], heads: int = 1, dropout: float = 0.0, act: str = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[str] = 'batch_norm', norm_kwargs: Optional[Dict[str, Any]] = None, attn_type: str = 'multihead', attn_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__() self.channels = channels self.conv = conv self.heads = heads self.dropout = dropout self.attn_type = attn_type attn_kwargs = attn_kwargs or {} if attn_type == 'multihead': self.attn = torch.nn.MultiheadAttention( channels, heads, batch_first=True, **attn_kwargs, ) elif attn_type == 'performer': self.attn = PerformerAttention( channels=channels, heads=heads, **attn_kwargs, ) else: # TODO: Support BigBird raise ValueError(f'{attn_type} is not supported') self.mlp = Sequential( Linear(channels, channels * 2), activation_resolver(act, **(act_kwargs or {})), Dropout(dropout), Linear(channels * 2, channels), Dropout(dropout), ) norm_kwargs = norm_kwargs or {} self.norm1 = normalization_resolver(norm, channels, **norm_kwargs) self.norm2 = normalization_resolver(norm, channels, **norm_kwargs) self.norm3 = normalization_resolver(norm, channels, **norm_kwargs) self.norm_with_batch = False if self.norm1 is not None: signature = inspect.signature(self.norm1.forward) self.norm_with_batch = 'batch' in signature.parameters def reset_parameters(self): r"""Resets all learnable parameters of the module.""" if self.conv is not None: self.conv.reset_parameters() self.attn._reset_parameters() reset(self.mlp) if self.norm1 is not None: self.norm1.reset_parameters() if self.norm2 is not None: self.norm2.reset_parameters() if self.norm3 is not None: self.norm3.reset_parameters() def forward( self, x: Tensor, edge_index: Adj, batch: Optional[torch.Tensor] = None, **kwargs, ) -> Tensor: r"""Runs the forward pass of the module.""" hs = [] if self.conv is not None: # Local MPNN. h = self.conv(x, edge_index, **kwargs) h = F.dropout(h, p=self.dropout, training=self.training) h = h + x if self.norm1 is not None: if self.norm_with_batch: h = self.norm1(h, batch=batch) else: h = self.norm1(h) hs.append(h) # Global attention transformer-style model. h, mask = to_dense_batch(x, batch) if isinstance(self.attn, torch.nn.MultiheadAttention): h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False) elif isinstance(self.attn, PerformerAttention): h = self.attn(h, mask=mask) h = h[mask] h = F.dropout(h, p=self.dropout, training=self.training) h = h + x # Residual connection. if self.norm2 is not None: if self.norm_with_batch: h = self.norm2(h, batch=batch) else: h = self.norm2(h) hs.append(h) out = sum(hs) # Combine local and global outputs. out = out + self.mlp(out) if self.norm3 is not None: if self.norm_with_batch: out = self.norm3(out, batch=batch) else: out = self.norm3(out) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'conv={self.conv}, heads={self.heads}, ' f'attn_type={self.attn_type})') ================================================ FILE: torch_geometric/nn/conv/graph_conv.py ================================================ from typing import Final, Tuple, Union import torch from torch import Tensor from torch_geometric import EdgeIndex from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size from torch_geometric.utils import spmm class GraphConv(MessagePassing): r"""The graph neural network operator from the `"Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target node :obj:`i` (default: :obj:`1`) Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = True def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: str = 'add', bias: bool = True, **kwargs, ): super().__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin_rel = Linear(in_channels[0], out_channels, bias=bias) self.lin_root = Linear(in_channels[1], out_channels, bias=False) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_rel.reset_parameters() self.lin_root.reset_parameters() def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: OptTensor = None, size: Size = None) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) out = self.lin_rel(out) x_r = x[1] if x_r is not None: out = out + self.lin_root(x_r) return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate( self, edge_index: Adj, x: OptPairTensor, edge_weight: OptTensor, ) -> Tensor: if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): return edge_index.matmul( other=x[0], input_value=edge_weight, reduce=self.aggr, transpose=True, ) return spmm(edge_index, x[0], reduce=self.aggr) ================================================ FILE: torch_geometric/nn/conv/gravnet_conv.py ================================================ import warnings from typing import Optional, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import OptPairTensor # noqa from torch_geometric.typing import OptTensor, PairOptTensor, PairTensor if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import knn else: knn = None class GravNetConv(MessagePassing): r"""The GravNet operator from the `"Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks" `_ paper, where the graph is dynamically constructed using nearest neighbors. The neighbors are constructed in a learnable low-dimensional projection of the feature space. A second projection of the input feature space is then propagated from the neighbors to each vertex using distance weights that are derived by applying a Gaussian function to the distances. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): The number of output channels. space_dimensions (int): The dimensionality of the space used to construct the neighbors; referred to as :math:`S` in the paper. propagate_dimensions (int): The number of features to be propagated between the vertices; referred to as :math:`F_{\textrm{LR}}` in the paper. k (int): The number of nearest neighbors. **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` if bipartite, batch vector :math:`(|\mathcal{V}|)` or :math:`((|\mathcal{V}_s|), (|\mathcal{V}_t|))` if bipartite *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: Optional[int] = None, **kwargs): super().__init__(aggr=['mean', 'max'], flow='source_to_target', **kwargs) if knn is None: raise ImportError('`GravNetConv` requires `torch-cluster`.') if num_workers is not None: warnings.warn( "'num_workers' attribute in '{self.__class__.__name__}' is " "deprecated and will be removed in a future release", stacklevel=2) self.in_channels = in_channels self.out_channels = out_channels self.k = k self.lin_s = Linear(in_channels, space_dimensions) self.lin_h = Linear(in_channels, propagate_dimensions) self.lin_out1 = Linear(in_channels, out_channels, bias=False) self.lin_out2 = Linear(2 * propagate_dimensions, out_channels) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_s.reset_parameters() self.lin_h.reset_parameters() self.lin_out1.reset_parameters() self.lin_out2.reset_parameters() def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None, ) -> Tensor: is_bipartite: bool = True if isinstance(x, Tensor): x = (x, x) is_bipartite = False if x[0].dim() != 2: raise ValueError("Static graphs not supported in 'GravNetConv'") b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) h_l: Tensor = self.lin_h(x[0]) s_l: Tensor = self.lin_s(x[0]) s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0]) edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1) edge_weight = torch.exp(-10. * edge_weight) # 10 gives a better spread # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=(h_l, None), edge_weight=edge_weight, size=(s_l.size(0), s_r.size(0))) return self.lin_out1(x[1]) + self.lin_out2(out) def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return x_j * edge_weight.unsqueeze(1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, k={self.k})') ================================================ FILE: torch_geometric/nn/conv/han_conv.py ================================================ from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense import Linear from torch_geometric.nn.inits import glorot, reset from torch_geometric.typing import PairTensor # noqa from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType, OptTensor from torch_geometric.utils import softmax def group( xs: List[Tensor], q: nn.Parameter, k_lin: nn.Module, ) -> Tuple[OptTensor, OptTensor]: if len(xs) == 0: return None, None else: num_edge_types = len(xs) out = torch.stack(xs) if out.numel() == 0: return out.view(0, out.size(-1)), None attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1) attn = F.softmax(attn_score, dim=0) out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0) return out, attn class HANConv(MessagePassing): r"""The Heterogenous Graph Attention Operator from the `"Heterogenous Graph Attention Network" `_ paper. .. note:: For an example of using HANConv, see `examples/hetero/han_imdb.py `_. Args: in_channels (int or Dict[str, int]): Size of each input sample of every node type, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Metadata, heads: int = 1, negative_slope=0.2, dropout: float = 0.0, **kwargs, ): super().__init__(aggr='add', node_dim=0, **kwargs) if not isinstance(in_channels, dict): in_channels = {node_type: in_channels for node_type in metadata[0]} self.heads = heads self.in_channels = in_channels self.out_channels = out_channels self.negative_slope = negative_slope self.metadata = metadata self.dropout = dropout self.k_lin = nn.Linear(out_channels, out_channels) self.q = nn.Parameter(torch.empty(1, out_channels)) self.proj = nn.ModuleDict() for node_type, in_channels in self.in_channels.items(): self.proj[node_type] = Linear(in_channels, out_channels) self.lin_src = nn.ParameterDict() self.lin_dst = nn.ParameterDict() dim = out_channels // heads for edge_type in metadata[1]: edge_type = '__'.join(edge_type) self.lin_src[edge_type] = nn.Parameter(torch.empty(1, heads, dim)) self.lin_dst[edge_type] = nn.Parameter(torch.empty(1, heads, dim)) self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.proj) glorot(self.lin_src) glorot(self.lin_dst) self.k_lin.reset_parameters() glorot(self.q) def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj], return_semantic_attention_weights: bool = False, ) -> Union[Dict[NodeType, OptTensor], Tuple[Dict[NodeType, OptTensor], Dict[NodeType, OptTensor]]]: r"""Runs the forward pass of the module. Args: x_dict (Dict[str, torch.Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding graph connectivity information for each individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. return_semantic_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the semantic-level attention weights for each destination node type. (default: :obj:`False`) """ H, D = self.heads, self.out_channels // self.heads x_node_dict, out_dict = {}, {} # Iterate over node types: for node_type, x in x_dict.items(): x_node_dict[node_type] = self.proj[node_type](x).view(-1, H, D) out_dict[node_type] = [] # Iterate over edge types: for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type edge_type = '__'.join(edge_type) lin_src = self.lin_src[edge_type] lin_dst = self.lin_dst[edge_type] x_src = x_node_dict[src_type] x_dst = x_node_dict[dst_type] alpha_src = (x_src * lin_src).sum(dim=-1) alpha_dst = (x_dst * lin_dst).sum(dim=-1) # propagate_type: (x: PairTensor, alpha: PairTensor) out = self.propagate(edge_index, x=(x_src, x_dst), alpha=(alpha_src, alpha_dst)) out = F.relu(out) out_dict[dst_type].append(out) # iterate over node types: semantic_attn_dict = {} for node_type, outs in out_dict.items(): out, attn = group(outs, self.q, self.k_lin) out_dict[node_type] = out semantic_attn_dict[node_type] = attn if return_semantic_attention_weights: return out_dict, semantic_attn_dict return out_dict def message(self, x_j: Tensor, alpha_i: Tensor, alpha_j: Tensor, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor: alpha = alpha_j + alpha_i alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = x_j * alpha.view(-1, self.heads, 1) return out.view(-1, self.out_channels) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.out_channels}, ' f'heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/heat_conv.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import HeteroLinear, Linear from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import softmax class HEATConv(MessagePassing): r"""The heterogeneous edge-enhanced graph attentional operator from the `"Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction" `_ paper. :class:`HEATConv` enhances :class:`~torch_geometric.nn.conv.GATConv` by: 1. type-specific transformations of nodes of different types 2. edge type and edge feature incorporation, in which edges are assumed to have different types but contain the same kind of attributes Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. num_node_types (int): The number of node types. num_edge_types (int): The number of edge types. edge_type_emb_dim (int): The embedding size of edge types. edge_dim (int): Edge feature dimensionality. edge_attr_emb_dim (int): The embedding size of edge features. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, node types :math:`(|\mathcal{V}|)`, edge types :math:`(|\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, num_node_types: int, num_edge_types: int, edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, root_weight: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.root_weight = root_weight self.hetero_lin = HeteroLinear(in_channels, out_channels, num_node_types, bias=bias) self.edge_type_emb = Embedding(num_edge_types, edge_type_emb_dim) self.edge_attr_emb = Linear(edge_dim, edge_attr_emb_dim, bias=False) self.att = Linear( 2 * out_channels + edge_type_emb_dim + edge_attr_emb_dim, self.heads, bias=False) self.lin = Linear(out_channels + edge_attr_emb_dim, out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.hetero_lin.reset_parameters() self.edge_type_emb.reset_parameters() self.edge_attr_emb.reset_parameters() self.att.reset_parameters() self.lin.reset_parameters() def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor, edge_type: Tensor, edge_attr: OptTensor = None) -> Tensor: x = self.hetero_lin(x, node_type) edge_type_emb = F.leaky_relu(self.edge_type_emb(edge_type), self.negative_slope) # propagate_type: (x: Tensor, edge_type_emb: Tensor, # edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb, edge_attr=edge_attr) if self.concat: if self.root_weight: out = out + x.view(-1, 1, self.out_channels) out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.root_weight: out = out + x return out def message(self, x_i: Tensor, x_j: Tensor, edge_type_emb: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: edge_attr = F.leaky_relu(self.edge_attr_emb(edge_attr), self.negative_slope) alpha = torch.cat([x_i, x_j, edge_type_emb, edge_attr], dim=-1) alpha = F.leaky_relu(self.att(alpha), self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = self.lin(torch.cat([x_j, edge_attr], dim=-1)).unsqueeze(-2) return out * alpha.unsqueeze(-1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/hetero_conv.py ================================================ import warnings from typing import Dict, List, Optional import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.module_dict import ModuleDict from torch_geometric.typing import EdgeType, NodeType from torch_geometric.utils.hetero import check_add_self_loops def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]: if len(xs) == 0: return None elif aggr is None: return torch.stack(xs, dim=1) elif len(xs) == 1: return xs[0] elif aggr == "cat": return torch.cat(xs, dim=-1) else: out = torch.stack(xs, dim=0) out = getattr(torch, aggr)(out, dim=0) out = out[0] if isinstance(out, tuple) else out return out class HeteroConv(torch.nn.Module): r"""A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to :attr:`aggr`. In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is especially useful if you want to apply different message passing modules for different edge types. .. code-block:: python hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author'] Args: convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary holding a bipartite :class:`~torch_geometric.nn.conv.MessagePassing` layer for each individual edge type. aggr (str, optional): The aggregation scheme to use for grouping node embeddings generated by different relations (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"cat"`, :obj:`None`). (default: :obj:`"sum"`) """ def __init__( self, convs: Dict[EdgeType, MessagePassing], aggr: Optional[str] = "sum", ): super().__init__() for edge_type, module in convs.items(): check_add_self_loops(module, [edge_type]) src_node_types = {key[0] for key in convs.keys()} dst_node_types = {key[-1] for key in convs.keys()} if len(src_node_types - dst_node_types) > 0: warnings.warn( f"There exist node types ({src_node_types - dst_node_types}) " f"whose representations do not get updated during message " f"passing as they do not occur as destination type in any " f"edge type. This may lead to unexpected behavior.", stacklevel=2) self.convs = ModuleDict(convs) self.aggr = aggr def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.convs.values(): conv.reset_parameters() def forward( self, *args_dict, **kwargs_dict, ) -> Dict[NodeType, Tensor]: r"""Runs the forward pass of the module. Args: x_dict (Dict[str, torch.Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding graph connectivity information for each individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. *args_dict (optional): Additional forward arguments of individual :class:`torch_geometric.nn.conv.MessagePassing` layers. **kwargs_dict (optional): Additional forward arguments of individual :class:`torch_geometric.nn.conv.MessagePassing` layers. For example, if a specific GNN layer at edge type :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a forward argument, then you can pass them to :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via :obj:`edge_attr_dict = { edge_type: edge_attr }`. """ out_dict: Dict[str, List[Tensor]] = {} for edge_type, conv in self.convs.items(): src, rel, dst = edge_type has_edge_level_arg = False args = [] for value_dict in args_dict: if edge_type in value_dict: has_edge_level_arg = True args.append(value_dict[edge_type]) elif src == dst and src in value_dict: args.append(value_dict[src]) elif src in value_dict or dst in value_dict: args.append(( value_dict.get(src, None), value_dict.get(dst, None), )) kwargs = {} for arg, value_dict in kwargs_dict.items(): if not arg.endswith('_dict'): raise ValueError( f"Keyword arguments in '{self.__class__.__name__}' " f"need to end with '_dict' (got '{arg}')") arg = arg[:-5] # `{*}_dict` if edge_type in value_dict: has_edge_level_arg = True kwargs[arg] = value_dict[edge_type] elif src == dst and src in value_dict: kwargs[arg] = value_dict[src] elif src in value_dict or dst in value_dict: kwargs[arg] = ( value_dict.get(src, None), value_dict.get(dst, None), ) if not has_edge_level_arg: continue out = conv(*args, **kwargs) if dst not in out_dict: out_dict[dst] = [out] else: out_dict[dst].append(out) for key, value in out_dict.items(): out_dict[key] = group(value, self.aggr) return out_dict def __repr__(self) -> str: return f'{self.__class__.__name__}(num_relations={len(self.convs)})' ================================================ FILE: torch_geometric/nn/conv/hgt_conv.py ================================================ import math from typing import Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense import HeteroDictLinear, HeteroLinear from torch_geometric.nn.inits import ones from torch_geometric.nn.parameter_dict import ParameterDict from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType from torch_geometric.utils import softmax from torch_geometric.utils.hetero import construct_bipartite_edge_index class HGTConv(MessagePassing): r"""The Heterogeneous Graph Transformer (HGT) operator from the `"Heterogeneous Graph Transformer" `_ paper. .. note:: For an example of using HGT, see `examples/hetero/hgt_dblp.py `_. Args: in_channels (int or Dict[str, int]): Size of each input sample of every node type, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Metadata, heads: int = 1, **kwargs, ): super().__init__(aggr='add', node_dim=0, **kwargs) if out_channels % heads != 0: raise ValueError(f"'out_channels' (got {out_channels}) must be " f"divisible by the number of heads (got {heads})") if not isinstance(in_channels, dict): in_channels = {node_type: in_channels for node_type in metadata[0]} self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.node_types = metadata[0] self.edge_types = metadata[1] self.edge_types_map = { edge_type: i for i, edge_type in enumerate(metadata[1]) } self.dst_node_types = {key[-1] for key in self.edge_types} self.kqv_lin = HeteroDictLinear(self.in_channels, self.out_channels * 3) self.out_lin = HeteroDictLinear(self.out_channels, self.out_channels, types=self.node_types) dim = out_channels // heads num_types = heads * len(self.edge_types) self.k_rel = HeteroLinear(dim, dim, num_types, bias=False, is_sorted=True) self.v_rel = HeteroLinear(dim, dim, num_types, bias=False, is_sorted=True) self.skip = ParameterDict({ node_type: Parameter(torch.empty(1)) for node_type in self.node_types }) self.p_rel = ParameterDict() for edge_type in self.edge_types: edge_type = '__'.join(edge_type) self.p_rel[edge_type] = Parameter(torch.empty(1, heads)) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.kqv_lin.reset_parameters() self.out_lin.reset_parameters() self.k_rel.reset_parameters() self.v_rel.reset_parameters() ones(self.skip) ones(self.p_rel) def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]: """Concatenates a dictionary of features.""" cumsum = 0 outs: List[Tensor] = [] offset: Dict[str, int] = {} for key, x in x_dict.items(): outs.append(x) offset[key] = cumsum cumsum += x.size(0) return torch.cat(outs, dim=0), offset def _construct_src_node_feat( self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor], edge_index_dict: Dict[EdgeType, Adj] ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]: """Constructs the source node representations.""" cumsum = 0 num_edge_types = len(self.edge_types) H, D = self.heads, self.out_channels // self.heads # Flatten into a single tensor with shape [num_edge_types * heads, D]: ks: List[Tensor] = [] vs: List[Tensor] = [] type_list: List[Tensor] = [] offset: Dict[EdgeType] = {} for edge_type in edge_index_dict.keys(): src = edge_type[0] N = k_dict[src].size(0) offset[edge_type] = cumsum cumsum += N # construct type_vec for curr edge_type with shape [H, D] edge_type_offset = self.edge_types_map[edge_type] type_vec = torch.arange(H, dtype=torch.long).view(-1, 1).repeat( 1, N) * num_edge_types + edge_type_offset type_list.append(type_vec) ks.append(k_dict[src]) vs.append(v_dict[src]) ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) type_vec = torch.cat(type_list, dim=1).flatten() k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1) v = self.v_rel(vs, type_vec).view(H, -1, D).transpose(0, 1) return k, v, offset def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj] # Support both. ) -> Dict[NodeType, Optional[Tensor]]: r"""Runs the forward pass of the module. Args: x_dict (Dict[str, torch.Tensor]): A dictionary holding input node features for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding graph connectivity information for each individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. :rtype: :obj:`Dict[str, Optional[torch.Tensor]]` - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to :obj:`None`. """ F = self.out_channels H = self.heads D = F // H k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} # Compute K, Q, V over node types: kqv_dict = self.kqv_lin(x_dict) for key, val in kqv_dict.items(): k, q, v = torch.tensor_split(val, 3, dim=1) k_dict[key] = k.view(-1, H, D) q_dict[key] = q.view(-1, H, D) v_dict[key] = v.view(-1, H, D) q, dst_offset = self._cat(q_dict) k, v, src_offset = self._construct_src_node_feat( k_dict, v_dict, edge_index_dict) edge_index, edge_attr = construct_bipartite_edge_index( edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel, num_nodes=k.size(0)) out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr) # Reconstruct output node embeddings dict: for node_type, start_offset in dst_offset.items(): end_offset = start_offset + q_dict[node_type].size(0) if node_type in self.dst_node_types: out_dict[node_type] = out[start_offset:end_offset] # Transform output node embeddings: a_dict = self.out_lin({ k: torch.nn.functional.gelu(v) if v is not None else v for k, v in out_dict.items() }) # Iterate over node types: for node_type, out in out_dict.items(): out = a_dict[node_type] if out.size(-1) == x_dict[node_type].size(-1): alpha = self.skip[node_type].sigmoid() out = alpha * out + (1 - alpha) * x_dict[node_type] out_dict[node_type] = out return out_dict def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor: alpha = (q_i * k_j).sum(dim=-1) * edge_attr alpha = alpha / math.sqrt(q_i.size(-1)) alpha = softmax(alpha, index, ptr, size_i) out = v_j * alpha.view(-1, self.heads, 1) return out.view(-1, self.out_channels) def __repr__(self) -> str: return (f'{self.__class__.__name__}(-1, {self.out_channels}, ' f'heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/hypergraph_conv.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros from torch_geometric.utils import scatter, softmax class HypergraphConv(MessagePassing): r"""The hypergraph convolutional operator from the `"Hypergraph Convolution and Hypergraph Attention" `_ paper. .. math:: \mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W} \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta} where :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` is the incidence matrix, :math:`\mathbf{W} \in \mathbb{R}^M` is the diagonal hyperedge weight matrix, and :math:`\mathbf{D}` and :math:`\mathbf{B}` are the corresponding degree matrices. For example, in the hypergraph scenario :math:`\mathcal{G} = (\mathcal{V}, \mathcal{E})` with :math:`\mathcal{V} = \{ 0, 1, 2, 3 \}` and :math:`\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3 \} \}`, the :obj:`hyperedge_index` is represented as: .. code-block:: python hyperedge_index = torch.tensor([ [0, 1, 2, 1, 2, 3], [0, 0, 0, 1, 1, 1], ]) Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. use_attention (bool, optional): If set to :obj:`True`, attention will be added to this layer. (default: :obj:`False`) attention_mode (str, optional): The mode on how to compute attention. If set to :obj:`"node"`, will compute attention scores of nodes within all nodes belonging to the same hyperedge. If set to :obj:`"edge"`, will compute attention scores of nodes across all edges holding this node belongs to. (default: :obj:`"node"`) heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, hyperedge indices :math:`(|\mathcal{V}|, |\mathcal{E}|)`, hyperedge weights :math:`(|\mathcal{E}|)` *(optional)* hyperedge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__( self, in_channels: int, out_channels: int, use_attention: bool = False, attention_mode: str = 'node', heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(flow='source_to_target', node_dim=0, **kwargs) assert attention_mode in ['node', 'edge'] self.in_channels = in_channels self.out_channels = out_channels self.use_attention = use_attention self.attention_mode = attention_mode if self.use_attention: self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') self.att = Parameter(torch.empty(1, heads, 2 * out_channels)) else: self.heads = 1 self.concat = True self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') if bias and concat: self.bias = Parameter(torch.empty(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() if self.use_attention: glorot(self.att) zeros(self.bias) @disable_dynamic_shapes(required_args=['num_edges']) def forward(self, x: Tensor, hyperedge_index: Tensor, hyperedge_weight: Optional[Tensor] = None, hyperedge_attr: Optional[Tensor] = None, num_edges: Optional[int] = None) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. hyperedge_index (torch.Tensor): The hyperedge indices, *i.e.* the sparse incidence matrix :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` mapping from nodes to edges. hyperedge_weight (torch.Tensor, optional): Hyperedge weights :math:`\mathbf{W} \in \mathbb{R}^M`. (default: :obj:`None`) hyperedge_attr (torch.Tensor, optional): Hyperedge feature matrix in :math:`\mathbb{R}^{M \times F}`. These features only need to get passed in case :obj:`use_attention=True`. (default: :obj:`None`) num_edges (int, optional) : The number of edges :math:`M`. (default: :obj:`None`) """ num_nodes = x.size(0) if num_edges is None: num_edges = 0 if hyperedge_index.numel() > 0: num_edges = int(hyperedge_index[1].max()) + 1 if hyperedge_weight is None: hyperedge_weight = x.new_ones(num_edges) x = self.lin(x) alpha = None if self.use_attention: assert hyperedge_attr is not None x = x.view(-1, self.heads, self.out_channels) hyperedge_attr = self.lin(hyperedge_attr) hyperedge_attr = hyperedge_attr.view(-1, self.heads, self.out_channels) x_i = x[hyperedge_index[0]] x_j = hyperedge_attr[hyperedge_index[1]] alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) alpha = F.leaky_relu(alpha, self.negative_slope) if self.attention_mode == 'node': alpha = softmax(alpha, hyperedge_index[1], num_nodes=num_edges) else: alpha = softmax(alpha, hyperedge_index[0], num_nodes=num_nodes) alpha = F.dropout(alpha, p=self.dropout, training=self.training) D = scatter(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0], dim=0, dim_size=num_nodes, reduce='sum') D = 1.0 / D D[D == float("inf")] = 0 B = scatter(x.new_ones(hyperedge_index.size(1)), hyperedge_index[1], dim=0, dim_size=num_edges, reduce='sum') B = 1.0 / B B[B == float("inf")] = 0 out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha, size=(num_nodes, num_edges)) out = self.propagate(hyperedge_index.flip([0]), x=out, norm=D, alpha=alpha, size=(num_edges, num_nodes)) if self.concat is True: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, norm_i: Tensor, alpha: Tensor) -> Tensor: H, F = self.heads, self.out_channels out = norm_i.view(-1, 1, 1) * x_j.view(-1, H, F) if alpha is not None: out = alpha.view(-1, self.heads, 1) * out return out ================================================ FILE: torch_geometric/nn/conv/le_conv.py ================================================ from typing import Tuple, Union from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptTensor, PairTensor class LEConv(MessagePassing): r"""The local extremum graph neural network operator from the `"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations" `_ paper. :class:`LEConv` finds the importance of nodes with respect to their neighbors using the difference operator: .. math:: \mathbf{x}^{\prime}_i = \mathbf{x}_i \cdot \mathbf{\Theta}_1 + \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot (\mathbf{\Theta}_2 \mathbf{x}_i - \mathbf{\Theta}_3 \mathbf{x}_j) where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target node :obj:`i` (default: :obj:`1`) Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`). **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin1 = Linear(in_channels[0], out_channels, bias=bias) self.lin2 = Linear(in_channels[1], out_channels, bias=False) self.lin3 = Linear(in_channels[1], out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() self.lin3.reset_parameters() def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if isinstance(x, Tensor): x = (x, x) a = self.lin1(x[0]) b = self.lin2(x[1]) # propagate_type: (a: Tensor, b: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, a=a, b=b, edge_weight=edge_weight) return out + self.lin3(x[1]) def message(self, a_j: Tensor, b_i: Tensor, edge_weight: OptTensor) -> Tensor: out = a_j - b_i return out if edge_weight is None else out * edge_weight.view(-1, 1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/conv/lg_conv.py ================================================ from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import spmm class LGConv(MessagePassing): r"""The Light Graph Convolution (LGC) operator from the `"LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \frac{e_{j,i}}{\sqrt{\deg(i)\deg(j)}} \mathbf{x}_j Args: normalize (bool, optional): If set to :obj:`False`, output features will not be normalized via symmetric normalization. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F)` """ def __init__(self, normalize: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.normalize = normalize def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if self.normalize and isinstance(edge_index, Tensor): out = gcn_norm(edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, flow=self.flow, dtype=x.dtype) edge_index, edge_weight = out elif self.normalize and isinstance(edge_index, SparseTensor): edge_index = gcn_norm(edge_index, None, x.size(self.node_dim), add_self_loops=False, flow=self.flow, dtype=x.dtype) # propagate_type: (x: Tensor, edge_weight: OptTensor) return self.propagate(edge_index, x=x, edge_weight=edge_weight) def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) ================================================ FILE: torch_geometric/nn/conv/meshcnn_conv.py ================================================ # The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update # pyright: reportIncompatibleMethodOverride=false import warnings from typing import Optional import torch from torch.nn import Linear, Module, ModuleList from torch_geometric.nn.conv import MessagePassing from torch_geometric.typing import Tensor class MeshCNNConv(MessagePassing): r"""The convolutional layer introduced by the paper `"MeshCNN: A Network With An Edge" `_. Recall that, given a set of categories :math:`C`, MeshCNN is a function that takes as its input a triangular mesh :math:`\mathcal{m} = (V, F) \in \mathbb{R}^{|V| \times 3} \times \{0,...,|V|-1\}^{3 \times |F|}`, and returns as its output a :math:`|C|`-dimensional vector, whose :math:`i` th component denotes the probability of the input mesh belonging to category :math:`c_i \in C`. Let :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}` denote the output value of the prior (e.g. :math:`k` th ) layer of our neural network. The :math:`i` th row of :math:`X^{(k)}` is a :math:`\text{Dim-Out}(k)`-dimensional vector that represents the features computed by the :math:`k` th layer for edge :math:`e_i` of the input mesh :math:`\mathcal{m}`. Let :math:`A \in \{0, ..., |E|-1\}^{2 \times 4*|E|}` denote the *edge adjacency* matrix of our input mesh :math:`\mathcal{m}`. The :math:`j` th column of :math:`A` returns a pair of indices :math:`k,l \in \{0,...,|E|-1\}`, which means that edge :math:`e_k` is adjacent to edge :math:`e_l` in our input mesh :math:`\mathcal{m}`. The definition of edge adjacency in a triangular mesh is illustrated in Figure 1. In a triangular mesh, each edge :math:`e_i` is expected to be adjacent to exactly :math:`4` neighboring edges, hence the number of columns of :math:`A`: :math:`4*|E|`. We write *the neighborhood* of edge :math:`e_i` as :math:`\mathcal{N}(i) = (a(i), b(i), c(i), d(i))` where 1. :math:`a(i)` denotes the index of the *first* counter-clockwise edge of the face *above* :math:`e_i`. 2. :math:`b(i)` denotes the index of the *second* counter-clockwise edge of the face *above* :math:`e_i`. 3. :math:`c(i)` denotes the index of the *first* counter-clockwise edge of the face *below* :math:`e_i`. 4. :math:`d(i)` denotes the index of the *second* counter-clockwise edge of the face *below* :math:`e_i`. .. figure:: ../_figures/meshcnn_edge_adjacency.svg :align: center :width: 80% **Figure 1:** The neighbors of edge :math:`\mathbf{e_1}` are :math:`\mathbf{e_2}, \mathbf{e_3}, \mathbf{e_4}` and :math:`\mathbf{e_5}`, respectively. We write this as :math:`\mathcal{N}(1) = (a(1), b(1), c(1), d(1)) = (2, 3, 4, 5)` Because of this ordering constraint, :obj:`MeshCNNConv` **requires that the columns of** :math:`A` **be ordered in the following way**: .. math:: &A[:,0] = (0, \text{The index of the "a" edge for edge } 0) \\ &A[:,1] = (0, \text{The index of the "b" edge for edge } 0) \\ &A[:,2] = (0, \text{The index of the "c" edge for edge } 0) \\ &A[:,3] = (0, \text{The index of the "d" edge for edge } 0) \\ \vdots \\ &A[:,4*|E|-4] = \bigl(|E|-1, a\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-3] = \bigl(|E|-1, b\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-2] = \bigl(|E|-1, c\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-1] = \bigl(|E|-1, d\bigl(|E|-1\bigr)\bigr) Stated a bit more compactly, for every edge :math:`e_i` in the input mesh, :math:`A`, should have the following entries .. math:: A[:, 4*i] &= (i, a(i)) \\ A[:, 4*i + 1] &= (i, b(i)) \\ A[:, 4*i + 2] &= (i, c(i)) \\ A[:, 4*i + 3] &= (i, d(i)) To summarize so far, we have defined 3 things: 1. The activation of the prior (e.g. :math:`k` th) layer, :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}` 2. The edge adjacency matrix and the definition of edge adjacency. :math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}` 3. The ways the columns of :math:`A` must be ordered. We are now finally able to define the :obj:`MeshCNNConv` class/layer. In the following definition we assume :obj:`MeshCNNConv` is at the :math:`k+1` th layer of our neural network. The :obj:`MeshCNNConv` layer is a function, .. math:: \text{MeshCNNConv}^{(k+1)}(X^{(k)}, A) = X^{(k+1)}, that, given the prior layer's output :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}` and the edge adjacency matrix :math:`A` of the input mesh (graph) :math:`\mathcal{m}` , returns a new edge feature tensor :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k+1)}`, where the :math:`i` th row of :math:`X^{(k+1)}`, denoted by :math:`x^{(k+1)}_i`, represents the :math:`\text{Dim-Out}(k+1)`-dimensional feature vector of edge :math:`e_i`, **and is defined as follows**: .. math:: x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\ &+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\ &+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\ &+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\ &+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr). :math:`W_0^{(k+1)},W_1^{(k+1)},W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)} \in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}` are trainable linear functions (i.e. "the weights" of this layer). :math:`x_i` is the :math:`\text{Dim-Out}(k)`-dimensional feature of edge :math:`e_i` vector computed by the prior (e.g. :math:`k`) th layer. :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`, and :math:`x^{(k)}_{d(i)}` are the :math:`\text{Dim-Out}(k)`-feature vectors, computed in the :math:`k` th layer, that are associated with the :math:`4` neighboring edges of :math:`e_i`. Args: in_channels (int): Corresponds to :math:`\text{Dim-Out}(k)` in the above overview. This represents the output dimension of the prior layer. For the given input mesh :math:`\mathcal{m} = (V, F)`, the prior layer is expected to output a :math:`X \in \mathbb{R}^{|E| \times \textit{in_channels}}` feature matrix. Assuming the instance of this class is situated at layer :math:`k+1`, we write that :math:`X^{(k)} \in \mathbb{R}^{|E| \times \textit{in_channels}}`. out_channels (int): Corresponds to :math:`\text{Dim-Out}(k+1)` in the above overview. This represents the output dimension of this layer. Assuming the instance of this class is situated at layer :math:`k+1`, we write that :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \textit{out_channels}}`. kernels (torch.nn.ModuleList, optional): A list of length of 5, where each element is a :class:`torch.nn.module` (i.e a neural network), that each MUST take as input a vector of dimension :`obj:in_channels` and return a vector of dimension :obj:`out_channels`. In particular, `obj:kernels[0]` is :math:`W^{(k+1)}_0` in the above overview (see :obj:`MeshCNNConv`), `obj:kernels[1]` is :math:`W^{(k+1)}_1`, `obj:kernels[2]` is :math:`W^{(k+1)}_2`, `obj:kernels[3]` is :math:`W^{(k+1)}_3` `obj:kernels[4]` is :math:`W^{(k+1)}_4`. Note that this input is optional, in which case each of the 5 elements in the kernels will be a linear neural network :class:`torch.nn.modules.Linear` correctly configured to take as input :attr:`in_channels`-dimensional vectors and return a vector of dimensions :attr:`out_channels`. Discussion: The key difference that separates :obj:`MeshCNNConv` from a traditional message passing graph neural network is that :obj:`MeshCNNConv` requires the set of neighbors for a node :math:`\mathcal{N}(u) = (v_1, v_2, ...)` to *be an ordered set* (i.e. a tuple). In fact, :obj:`MeshCNNConv` goes further, requiring that :math:`\mathcal{N}(u)` always return a set of size :math:`4`. This is different to most message passing graph neural networks, which assume that :math:`\mathcal{N}(u) = \{v_1, v_2, ...\}` returns an ordered set. This lends :obj:`MeshCNNConv` more expressive power, at the cost of no longer being permutation invariant to :math:`\mathbb{S}_4`. Put more plainly, in tradition message passing GNNs, the network is *unable* to distinguish one neighboring node from another. In contrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a "role", either the "a", "b", "c", or "d" neighbor. We encode this fact by requiring that :math:`\mathcal{N}` return the 4-tuple, where the first component is the "a" neighbor, and so on. To summarize this comparison, it may re-define :obj:`MeshCNNConv` in terms of :math:`\text{UPDATE}` and :math:`\text{AGGREGATE}` functions, which is a general way to define a traditional GNN layer. If we let :math:`x_i^{(k+1)}` denote the output of a GNN layer for node :math:`i` at layer :math:`k+1`, and let :math:`\mathcal{N}(i)` denote the set of nodes adjacent to node :math:`i`, then we can describe the :math:`k+1` th layer as traditional GNN as .. math:: x_i^{(k+1)} = \text{UPDATE}^{(k+1)}\bigl(x^{(k)}_i, \text{AGGREGATE}^{(k+1)}\bigl(\mathcal{N}(i)\bigr)\bigr). Here, :math:`\text{UPDATE}^{(k+1)}` is a function of :math:`2` :math:`\text{Dim-Out}(k)`-dimensional vectors, and returns a :math:`\text{Dim-Out}(k+1)`-dimensional vector. :math:`\text{AGGREGATE}^{(k+1)}` function is a function of a *unordered set* of nodes that are neighbors of node :math:`i`, as defined by :math:`\mathcal{N}(i)`. Usually the size of this set varies across different nodes :math:`i`, and one of the most basic examples of such a function is the "sum aggregation", defined as :math:`\text{AGGREGATE}^{(k+1)}(\mathcal{N}(i)) = \sum_{j \in \mathcal{N}(i)} x^{(k)}_j`. See :class:`SumAggregation ` for more. In contrast, while :obj:`MeshCNNConv` 's :math:`\text{UPDATE}` function follows a tradition GNN, its :math:`\text{AGGREGATE}` is a function of a tuple (i.e. an ordered set) of neighbors rather than a unordered set of neighbors. In particular, while the :math:`\text{UPDATE}` function of :obj:`MeshCNNConv` for :math:`e_i` is .. math:: x_i^{(k+1)} = \text{UPDATE}^{(k+1)}(x_i^{(k)}, s_i^{(k+1)}) = W_0^{(k+1)}x_i^{(k)} + s_i^{(k+1)}, in contrast, :obj:`MeshCNNConv` 's :math:`\text{AGGREGATE}` function is .. math:: s_i^{(k+1)} = \text{AGGREGATE}^{(k+1)}(A, B, C, D) &= W_1^{(k+1)}\bigl|A - C \bigr| \\ &= W_2^{(k+1)}\bigl(A + C \bigr) \\ &= W_3^{(k+1)}\bigl|B - D \bigr| \\ &= W_4^{(k+1)}\bigl(B + D \bigr), where :math:`A=x_{a(i)}^{(k)}, B=x_{b(i)}^{(k)}, C=x_{c(i)}^{(k)},` and :math:`D=x_{d(i)}^{(k)}`. .. The :math:`i` th row of :math:`V \in \mathbb{R}^{|V| \times 3}` holds the cartesian :math:`xyz` coordinates for node :math:`v_i` in the mesh, and the :math:`j` th column in :math:`F \in \{1,...,|V|\}^{3 \times |V|}` holds the :math:`3` indices :math:`(k,l,m)` that correspond to the :math:`3` nodes :math:`(v_k, v_l, v_m)` that construct face :math:`j` of the mesh. """ def __init__(self, in_channels: int, out_channels: int, kernels: Optional[ModuleList] = None): super().__init__(aggr='add') self.in_channels = in_channels self.out_channels = out_channels if kernels is None: self.kernels = ModuleList( [Linear(in_channels, out_channels) for _ in range(5)]) else: # ensures kernels is properly formed, otherwise throws # the appropriate error. self._assert_kernels(kernels) self.kernels = kernels def forward(self, x: Tensor, edge_index: Tensor): r"""Forward pass. Args: x(torch.Tensor): :math:`X^{(k)} \in \mathbb{R}^{|E| \times \textit{in_channels}}`. The edge feature tensor returned by the prior layer (e.g. :math:`k`). The tensor is of shape :math:`|E| \times \text{Dim-Out}(k)`, or equivalently, :obj:`(|E|, self.in_channels)`. edge_index(torch.Tensor): :math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}`. The edge adjacency tensor of the networks input mesh :math:`\mathcal{m} = (V, F)`. The edge adjacency tensor **MUST** have the following form: .. math:: &A[:,0] = (0, \text{The index of the "a" edge for edge } 0) \\ &A[:,1] = (0, \text{The index of the "b" edge for edge } 0) \\ &A[:,2] = (0, \text{The index of the "c" edge for edge } 0) \\ &A[:,3] = (0, \text{The index of the "d" edge for edge } 0) \\ \vdots \\ &A[:,4*|E|-4] = \bigl(|E|-1, a\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-3] = \bigl(|E|-1, b\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-2] = \bigl(|E|-1, c\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-1] = \bigl(|E|-1, d\bigl(|E|-1\bigr)\bigr) See :obj:`MeshCNNConv` for what "index of the 'a'(b,c,d) edge for edge i" means, and also for the general definition of edge adjacency in MeshCNN. These definitions are also provided in the `paper `_ itself. Returns: torch.Tensor: :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \textit{out_channels}}`. The edge feature tensor for this (e.g. the :math:`k+1` th) layer. The :math:`i` th row of :math:`X^{(k+1)}` is computed according to the formula .. math:: x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\ &+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\ &+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\ &+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\ &+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr), where :math:`W_0^{(k+1)},W_1^{(k+1)}, W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)} \in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}` are the trainable linear functions (i.e. the trainable "weights") of this layer, and :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`, :math:`x^{(k)}_{d(i)}` are the :math:`\text{Dim-Out}(k)`-dimensional edge feature vectors computed by the prior (:math:`k` th) layer, that are associated with the :math:`4` neighboring edges of :math:`e_i`. """ return self.propagate(edge_index, x=x) def message(self, x_j: Tensor) -> Tensor: r"""The messaging passing step of :obj:`MeshCNNConv`. Args: x_j: A :obj:`[4*|E|, num_node_features]` tensor. Its ith row holds the value stored by the source node in the previous layer of edge i. Returns: A :obj:`[|E|, num_node_features]` tensor, whose ith row will be the value that the target node of edge i will receive. """ # The following variables names are taken from the paper # MeshCNN computes the features associated with edge # e by (|a - c|, a + c, |b - c|, b + c), where a, b, c, d are the # neighboring edges of e, a being the 1 edge of the upper face, # b being the second edge of the upper face, c being the first edge # of the lower face, # and d being the second edge of the lower face of the input Mesh # TODO: It is unclear if view is faster. If it is not, # then we should prefer the strided method commented out below E4, in_channels = x_j.size() # E4 = 4|E|, i.e. num edges in line graph # Option 1 n_a = x_j[0::4] # shape: |E| x in_channels n_b = x_j[1::4] # shape: |E| x in_channels n_c = x_j[2::4] # shape: |E| x in_channels n_d = x_j[3::4] # shape: |E| x in_channels m = torch.empty(E4, self.out_channels) m[0::4] = self.kernels[1].forward(torch.abs(n_a - n_c)) m[1::4] = self.kernels[2].forward(n_a + n_c) m[2::4] = self.kernels[3].forward(torch.abs(n_b - n_d)) m[3::4] = self.kernels[4].forward(n_b + n_d) return m # Option 2 # E4, in_channels = x_j.size() # E = E4 // 4 # x_j = x_j.view(E, 4, in_channels) # shape: (|E| x 4 x in_channels) # n_a, n_b, n_c, n_d = x_j.unbind( # dim=1) # shape: (4 x |E| x in_channels) # m = torch.stack( # [ # (n_a - n_c).abs(), # shape: |E| x in_channels # n_a + n_c, # (n_b - n_d).abs(), # n_b + n_d, # ], # dim=1) # shape: (|E| x 4 x in_channels) # m.view(E4, in_channels) # shape 4*|E| x in_channels # return m def update(self, inputs: Tensor, x: Tensor) -> Tensor: r"""The UPDATE step, in reference to the UPDATE and AGGREGATE formulation of message passing convolution. Args: inputs(torch.Tensor): The :attr:`in_channels`-dimensional vector returned by aggregate. x(torch.Tensor): :math:`X^{(k)}`. The original inputs to this layer. Returns: torch.Tensor: :math:`X^{(k+1)}`. The output of this layer, which has shape :obj:`(|E|, out_channels)`. """ return self.kernels[0].forward(x) + inputs def _assert_kernels(self, kernels: ModuleList): r"""Ensures that :obj:`kernels` is a list of 5 :obj:`torch.nn.Module` modules (i.e. networks). In addition, it also ensures that each network takes in input of dimension :attr:`in_channels`, and returns output of dimension :attr:`out_channels`. This method throws an error otherwise. .. warn:: This method throws an error if :obj:`kernels` is not valid. (Otherwise this method returns nothing) """ assert isinstance(kernels, ModuleList), \ f"Parameter 'kernels' must be a \ torch.nn.module.ModuleList with 5 members, but we got \ {type(kernels)}." assert len(kernels) == 5, "Parameter 'kernels' must be a \ torch.nn.module.ModuleList of with exactly 5 members" for i, network in enumerate(kernels): assert isinstance(network, Module), \ f"kernels[{i}] must be torch.nn.Module, got \ {type(network)}" if not hasattr(network, "in_channels") and \ not hasattr(network, "in_features"): warnings.warn( f"kernel[{i}] does not have attribute 'in_channels' nor " f"'out_features'. The network must take as input a " f"{self.in_channels}-dimensional tensor.", stacklevel=2) else: input_dimension = getattr(network, "in_channels", network.in_features) assert input_dimension == self.in_channels, f"The input \ dimension of the neural network in kernel[{i}] must \ be \ equal to 'in_channels', but input_dimension = \ {input_dimension}, and \ self.in_channels={self.in_channels}." if not hasattr(network, "out_channels") and \ not hasattr(network, "out_features"): warnings.warn( f"kernel[{i}] does not have attribute 'in_channels' nor " f"'out_features'. The network must take as input a " f"{self.in_channels}-dimensional tensor.", stacklevel=2) else: output_dimension = getattr(network, "out_channels", network.out_features) assert output_dimension == self.out_channels, f"The output \ dimension of the neural network in kernel[{i}] must \ be \ equal to 'out_channels', but out_dimension = \ {output_dimension}, and \ self.out_channels={self.out_channels}." ================================================ FILE: torch_geometric/nn/conv/message_passing.py ================================================ import os.path as osp import warnings from abc import abstractmethod from inspect import Parameter from typing import ( Any, Callable, Dict, Final, List, Optional, OrderedDict, Set, Tuple, Union, ) import torch from torch import Tensor from torch.utils.hooks import RemovableHandle from torch_geometric import EdgeIndex, is_compiling from torch_geometric.index import ptr2index from torch_geometric.inspector import Inspector, Signature from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver from torch_geometric.template import module_from_template from torch_geometric.typing import Adj, Size, SparseTensor from torch_geometric.utils import ( is_sparse, is_torch_sparse_tensor, to_edge_index, ) FUSE_AGGRS = {'add', 'sum', 'mean', 'min', 'max'} HookDict = OrderedDict[int, Callable] class MessagePassing(torch.nn.Module): r"""Base class for creating message passing layers. Message passing layers follow the form .. math:: \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right), where :math:`\bigoplus` denotes a differentiable, permutation invariant function, *e.g.*, sum, mean, min, max or mul, and :math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as MLPs. See `here `__ for the accompanying tutorial. Args: aggr (str or [str] or Aggregation, optional): The aggregation scheme to use, *e.g.*, :obj:`"sum"` :obj:`"mean"`, :obj:`"min"`, :obj:`"max"` or :obj:`"mul"`. In addition, can be any :class:`~torch_geometric.nn.aggr.Aggregation` module (or any string that automatically resolves to it). If given as a list, will make use of multiple aggregations in which different outputs will get concatenated in the last dimension. If set to :obj:`None`, the :class:`MessagePassing` instantiation is expected to implement its own aggregation logic via :meth:`aggregate`. (default: :obj:`"add"`) aggr_kwargs (Dict[str, Any], optional): Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: :obj:`None`) flow (str, optional): The flow direction of message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) node_dim (int, optional): The axis along which to propagate. (default: :obj:`-2`) decomposed_layers (int, optional): The number of feature decomposition layers, as introduced in the `"Optimizing Memory Efficiency of Graph Neural Networks on Edge Computing Platforms" `_ paper. Feature decomposition reduces the peak memory usage by slicing the feature dimensions into separated feature decomposition layers during GNN aggregation. This method can accelerate GNN execution on CPU-based platforms (*e.g.*, 2-3x speedup on the :class:`~torch_geometric.datasets.Reddit` dataset) for common GNN models such as :class:`~torch_geometric.nn.models.GCN`, :class:`~torch_geometric.nn.models.GraphSAGE`, :class:`~torch_geometric.nn.models.GIN`, etc. However, this method is not applicable to all GNN operators available, in particular for operators in which message computation can not easily be decomposed, *e.g.* in attention-based GNNs. The selection of the optimal value of :obj:`decomposed_layers` depends both on the specific graph dataset and available hardware resources. A value of :obj:`2` is suitable in most cases. Although the peak memory usage is directly associated with the granularity of feature decomposition, the same is not necessarily true for execution speedups. (default: :obj:`1`) """ special_args: Set[str] = { 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j', 'ptr', 'index', 'dim_size' } # Supports `message_and_aggregate` via `EdgeIndex`. # TODO Remove once migration is finished. SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = False def __init__( self, aggr: Optional[Union[str, List[str], Aggregation]] = 'sum', *, aggr_kwargs: Optional[Dict[str, Any]] = None, flow: str = "source_to_target", node_dim: int = -2, decomposed_layers: int = 1, ) -> None: super().__init__() if flow not in ['source_to_target', 'target_to_source']: raise ValueError(f"Expected 'flow' to be either 'source_to_target'" f" or 'target_to_source' (got '{flow}')") # Cast `aggr` into a string representation for backward compatibility: self.aggr: Optional[Union[str, List[str]]] if aggr is None: self.aggr = None elif isinstance(aggr, (str, Aggregation)): self.aggr = str(aggr) elif isinstance(aggr, (tuple, list)): self.aggr = [str(x) for x in aggr] self.aggr_module = aggr_resolver(aggr, **(aggr_kwargs or {})) self.flow = flow self.node_dim = node_dim # Collect attribute names requested in message passing hooks: self.inspector = Inspector(self.__class__) self.inspector.inspect_signature(self.message) self.inspector.inspect_signature(self.aggregate, exclude=[0, 'aggr']) self.inspector.inspect_signature(self.message_and_aggregate, [0]) self.inspector.inspect_signature(self.update, exclude=[0]) self.inspector.inspect_signature(self.edge_update) self._user_args: List[str] = self.inspector.get_flat_param_names( ['message', 'aggregate', 'update'], exclude=self.special_args) self._fused_user_args: List[str] = self.inspector.get_flat_param_names( ['message_and_aggregate', 'update'], exclude=self.special_args) self._edge_user_args: List[str] = self.inspector.get_param_names( 'edge_update', exclude=self.special_args) # Support for "fused" message passing: self.fuse = self.inspector.implements('message_and_aggregate') if self.aggr is not None: self.fuse &= isinstance(self.aggr, str) and self.aggr in FUSE_AGGRS # Hooks: self._propagate_forward_pre_hooks: HookDict = OrderedDict() self._propagate_forward_hooks: HookDict = OrderedDict() self._message_forward_pre_hooks: HookDict = OrderedDict() self._message_forward_hooks: HookDict = OrderedDict() self._aggregate_forward_pre_hooks: HookDict = OrderedDict() self._aggregate_forward_hooks: HookDict = OrderedDict() self._message_and_aggregate_forward_pre_hooks: HookDict = OrderedDict() self._message_and_aggregate_forward_hooks: HookDict = OrderedDict() self._edge_update_forward_pre_hooks: HookDict = OrderedDict() self._edge_update_forward_hooks: HookDict = OrderedDict() # Set jittable `propagate` and `edge_updater` function templates: self._set_jittable_templates() # Explainability: self._explain: Optional[bool] = None self._edge_mask: Optional[Tensor] = None self._loop_mask: Optional[Tensor] = None self._apply_sigmoid: bool = True # Inference Decomposition: self._decomposed_layers = 1 self.decomposed_layers = decomposed_layers def reset_parameters(self) -> None: r"""Resets all learnable parameters of the module.""" if self.aggr_module is not None: self.aggr_module.reset_parameters() def __setstate__(self, data: Dict[str, Any]) -> None: self.inspector = data['inspector'] self.fuse = data['fuse'] self._set_jittable_templates() super().__setstate__(data) def __repr__(self) -> str: channels_repr = '' if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'): channels_repr = f'{self.in_channels}, {self.out_channels}' elif hasattr(self, 'channels'): channels_repr = f'{self.channels}' return f'{self.__class__.__name__}({channels_repr})' # Utilities ############################################################### def _check_input( self, edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[Optional[int], Optional[int]]], ) -> List[Optional[int]]: if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): return [edge_index.num_rows, edge_index.num_cols] if is_sparse(edge_index): if self.flow == 'target_to_source': raise ValueError( 'Flow direction "target_to_source" is invalid for ' 'message propagation via `torch_sparse.SparseTensor` ' 'or `torch.sparse.Tensor`. If you really want to make ' 'use of a reverse message passing flow, pass in the ' 'transposed sparse tensor to the message passing module, ' 'e.g., `adj_t.t()`.') if isinstance(edge_index, SparseTensor): return [edge_index.size(1), edge_index.size(0)] return [edge_index.size(1), edge_index.size(0)] elif isinstance(edge_index, Tensor): int_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) if edge_index.dtype not in int_dtypes: raise ValueError(f"Expected 'edge_index' to be of integer " f"type (got '{edge_index.dtype}')") if edge_index.dim() != 2: raise ValueError(f"Expected 'edge_index' to be two-dimensional" f" (got {edge_index.dim()} dimensions)") if not torch.jit.is_tracing() and edge_index.size(0) != 2: raise ValueError(f"Expected 'edge_index' to have size '2' in " f"the first dimension (got " f"'{edge_index.size(0)}')") return list(size) if size is not None else [None, None] raise ValueError( '`MessagePassing.propagate` only supports integer tensors of ' 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or ' '`torch.sparse.Tensor` for argument `edge_index`.') def _set_size( self, size: List[Optional[int]], dim: int, src: Tensor, ) -> None: the_size = size[dim] if the_size is None: size[dim] = src.size(self.node_dim) elif the_size != src.size(self.node_dim): raise ValueError( f'Encountered tensor with size {src.size(self.node_dim)} in ' f'dimension {self.node_dim}, but expected size {the_size}.') def _index_select(self, src: Tensor, index) -> Tensor: if torch.jit.is_scripting() or is_compiling(): return src.index_select(self.node_dim, index) else: return self._index_select_safe(src, index) def _index_select_safe(self, src: Tensor, index: Tensor) -> Tensor: try: return src.index_select(self.node_dim, index) except (IndexError, RuntimeError) as e: if index.numel() > 0 and index.min() < 0: raise IndexError( f"Found negative indices in 'edge_index' (got " f"{index.min().item()}). Please ensure that all " f"indices in 'edge_index' point to valid indices " f"in the interval [0, {src.size(self.node_dim)}) in " f"your node feature matrix and try again.") from e if (index.numel() > 0 and index.max() >= src.size(self.node_dim)): raise IndexError( f"Found indices in 'edge_index' that are larger " f"than {src.size(self.node_dim) - 1} (got " f"{index.max().item()}). Please ensure that all " f"indices in 'edge_index' point to valid indices " f"in the interval [0, {src.size(self.node_dim)}) in " f"your node feature matrix and try again.") from e raise e def _lift( self, src: Tensor, edge_index: Union[Tensor, SparseTensor], dim: int, ) -> Tensor: if not torch.jit.is_scripting() and is_torch_sparse_tensor(edge_index): assert dim == 0 or dim == 1 if edge_index.layout == torch.sparse_coo: index = edge_index._indices()[1 - dim] elif edge_index.layout == torch.sparse_csr: if dim == 0: index = edge_index.col_indices() else: index = ptr2index(edge_index.crow_indices()) elif edge_index.layout == torch.sparse_csc: if dim == 0: index = ptr2index(edge_index.ccol_indices()) else: index = edge_index.row_indices() else: raise ValueError(f"Unsupported sparse tensor layout " f"(got '{edge_index.layout}')") return src.index_select(self.node_dim, index) elif isinstance(edge_index, Tensor): if torch.jit.is_scripting(): # Try/catch blocks are not supported. index = edge_index[dim] return src.index_select(self.node_dim, index) return self._index_select(src, edge_index[dim]) elif isinstance(edge_index, SparseTensor): row, col, _ = edge_index.coo() if dim == 0: return src.index_select(self.node_dim, col) elif dim == 1: return src.index_select(self.node_dim, row) raise ValueError( '`MessagePassing.propagate` only supports integer tensors of ' 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` ' 'or `torch.sparse.Tensor` for argument `edge_index`.') def _collect( self, args: Set[str], edge_index: Union[Tensor, SparseTensor], size: List[Optional[int]], kwargs: Dict[str, Any], ) -> Dict[str, Any]: i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) out = {} for arg in args: if arg[-2:] not in ['_i', '_j']: out[arg] = kwargs.get(arg, Parameter.empty) else: dim = j if arg[-2:] == '_j' else i data = kwargs.get(arg[:-2], Parameter.empty) if isinstance(data, (tuple, list)): assert len(data) == 2 if isinstance(data[1 - dim], Tensor): self._set_size(size, 1 - dim, data[1 - dim]) data = data[dim] if isinstance(data, Tensor): self._set_size(size, dim, data) data = self._lift(data, edge_index, dim) out[arg] = data if is_torch_sparse_tensor(edge_index): indices, values = to_edge_index(edge_index) out['adj_t'] = edge_index out['edge_index'] = None out['edge_index_i'] = indices[0] out['edge_index_j'] = indices[1] out['ptr'] = None # TODO Get `rowptr` from CSR representation. if out.get('edge_weight', None) is None: out['edge_weight'] = values if out.get('edge_attr', None) is None: out['edge_attr'] = None if values.dim() == 1 else values if out.get('edge_type', None) is None: out['edge_type'] = values elif isinstance(edge_index, Tensor): out['adj_t'] = None out['edge_index'] = edge_index out['edge_index_i'] = edge_index[i] out['edge_index_j'] = edge_index[j] out['ptr'] = None if isinstance(edge_index, EdgeIndex): if i == 0 and edge_index.is_sorted_by_row: (out['ptr'], _), _ = edge_index.get_csr() elif i == 1 and edge_index.is_sorted_by_col: (out['ptr'], _), _ = edge_index.get_csc() elif isinstance(edge_index, SparseTensor): row, col, value = edge_index.coo() rowptr, _, _ = edge_index.csr() out['adj_t'] = edge_index out['edge_index'] = None out['edge_index_i'] = row out['edge_index_j'] = col out['ptr'] = rowptr if out.get('edge_weight', None) is None: out['edge_weight'] = value if out.get('edge_attr', None) is None: out['edge_attr'] = value if out.get('edge_type', None) is None: out['edge_type'] = value out['index'] = out['edge_index_i'] out['size'] = size out['size_i'] = size[i] if size[i] is not None else size[j] out['size_j'] = size[j] if size[j] is not None else size[i] out['dim_size'] = out['size_i'] return out # Message Passing ######################################################### def forward(self, *args: Any, **kwargs: Any) -> Any: r"""Runs the forward pass of the module.""" def propagate( self, edge_index: Adj, size: Size = None, **kwargs: Any, ) -> Tensor: r"""The initial call to start propagating messages. Args: edge_index (torch.Tensor or SparseTensor): A :class:`torch.Tensor`, a :class:`torch_sparse.SparseTensor` or a :class:`torch.sparse.Tensor` that defines the underlying graph connectivity/message passing flow. :obj:`edge_index` holds the indices of a general (sparse) assignment matrix of shape :obj:`[N, M]`. If :obj:`edge_index` is a :obj:`torch.Tensor`, its :obj:`dtype` should be :obj:`torch.long` and its shape needs to be defined as :obj:`[2, num_messages]` where messages from nodes in :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]` (in case :obj:`flow="source_to_target"`). If :obj:`edge_index` is a :class:`torch_sparse.SparseTensor` or a :class:`torch.sparse.Tensor`, its sparse indices :obj:`(row, col)` should relate to :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`. The major difference between both formats is that we need to input the *transposed* sparse adjacency matrix into :meth:`propagate`. size ((int, int), optional): The size :obj:`(N, M)` of the assignment matrix in case :obj:`edge_index` is a :class:`torch.Tensor`. If set to :obj:`None`, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case :obj:`edge_index` is a :class:`torch_sparse.SparseTensor` or a :class:`torch.sparse.Tensor`. (default: :obj:`None`) **kwargs: Any additional data which is needed to construct and aggregate messages, and to update node embeddings. """ decomposed_layers = 1 if self.explain else self.decomposed_layers for hook in self._propagate_forward_pre_hooks.values(): res = hook(self, (edge_index, size, kwargs)) if res is not None: edge_index, size, kwargs = res mutable_size = self._check_input(edge_index, size) # Run "fused" message and aggregation (if applicable). fuse = False if self.fuse and not self.explain: if is_sparse(edge_index): fuse = True elif (not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex)): if (self.SUPPORTS_FUSED_EDGE_INDEX and edge_index.is_sorted_by_col): fuse = True if fuse: coll_dict = self._collect(self._fused_user_args, edge_index, mutable_size, kwargs) msg_aggr_kwargs = self.inspector.collect_param_data( 'message_and_aggregate', coll_dict) for hook in self._message_and_aggregate_forward_pre_hooks.values(): res = hook(self, (edge_index, msg_aggr_kwargs)) if res is not None: edge_index, msg_aggr_kwargs = res out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs) for hook in self._message_and_aggregate_forward_hooks.values(): res = hook(self, (edge_index, msg_aggr_kwargs), out) if res is not None: out = res update_kwargs = self.inspector.collect_param_data( 'update', coll_dict) out = self.update(out, **update_kwargs) else: # Otherwise, run both functions in separation. if decomposed_layers > 1: user_args = self._user_args decomp_args = {a[:-2] for a in user_args if a[-2:] == '_j'} decomp_kwargs = { a: kwargs[a].chunk(decomposed_layers, -1) for a in decomp_args } decomp_out = [] for i in range(decomposed_layers): if decomposed_layers > 1: for arg in decomp_args: kwargs[arg] = decomp_kwargs[arg][i] coll_dict = self._collect(self._user_args, edge_index, mutable_size, kwargs) msg_kwargs = self.inspector.collect_param_data( 'message', coll_dict) for hook in self._message_forward_pre_hooks.values(): res = hook(self, (msg_kwargs, )) if res is not None: msg_kwargs = res[0] if isinstance(res, tuple) else res out = self.message(**msg_kwargs) for hook in self._message_forward_hooks.values(): res = hook(self, (msg_kwargs, ), out) if res is not None: out = res if self.explain: explain_msg_kwargs = self.inspector.collect_param_data( 'explain_message', coll_dict) out = self.explain_message(out, **explain_msg_kwargs) aggr_kwargs = self.inspector.collect_param_data( 'aggregate', coll_dict) for hook in self._aggregate_forward_pre_hooks.values(): res = hook(self, (aggr_kwargs, )) if res is not None: aggr_kwargs = res[0] if isinstance(res, tuple) else res out = self.aggregate(out, **aggr_kwargs) for hook in self._aggregate_forward_hooks.values(): res = hook(self, (aggr_kwargs, ), out) if res is not None: out = res update_kwargs = self.inspector.collect_param_data( 'update', coll_dict) out = self.update(out, **update_kwargs) if decomposed_layers > 1: decomp_out.append(out) if decomposed_layers > 1: out = torch.cat(decomp_out, dim=-1) for hook in self._propagate_forward_hooks.values(): res = hook(self, (edge_index, mutable_size, kwargs), out) if res is not None: out = res return out def message(self, x_j: Tensor) -> Tensor: r"""Constructs messages from node :math:`j` to node :math:`i` in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in :obj:`edge_index`. This function can take any argument as input which was initially passed to :meth:`propagate`. Furthermore, tensors passed to :meth:`propagate` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. """ return x_j def aggregate( self, inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, ) -> Tensor: r"""Aggregates messages from neighbors as :math:`\bigoplus_{j \in \mathcal{N}(i)}`. Takes in the output of message computation as first argument and any argument which was initially passed to :meth:`propagate`. By default, this function will delegate its call to the underlying :class:`~torch_geometric.nn.aggr.Aggregation` module to reduce messages as specified in :meth:`__init__` by the :obj:`aggr` argument. """ return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size, dim=self.node_dim) @abstractmethod def message_and_aggregate(self, edge_index: Adj) -> Tensor: r"""Fuses computations of :func:`message` and :func:`aggregate` into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on a :obj:`torch_sparse.SparseTensor` or a :obj:`torch.sparse.Tensor`. """ raise NotImplementedError def update(self, inputs: Tensor) -> Tensor: r"""Updates node embeddings in analogy to :math:`\gamma_{\mathbf{\Theta}}` for each node :math:`i \in \mathcal{V}`. Takes in the output of aggregation as first argument and any argument which was initially passed to :meth:`propagate`. """ return inputs # Edge-level Updates ###################################################### def edge_updater( self, edge_index: Adj, size: Size = None, **kwargs: Any, ) -> Tensor: r"""The initial call to compute or update features for each edge in the graph. Args: edge_index (torch.Tensor or SparseTensor): A :obj:`torch.Tensor`, a :class:`torch_sparse.SparseTensor` or a :class:`torch.sparse.Tensor` that defines the underlying graph connectivity/message passing flow. See :meth:`propagate` for more information. size ((int, int), optional): The size :obj:`(N, M)` of the assignment matrix in case :obj:`edge_index` is a :class:`torch.Tensor`. If set to :obj:`None`, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case :obj:`edge_index` is a :class:`torch_sparse.SparseTensor` or a :class:`torch.sparse.Tensor`. (default: :obj:`None`) **kwargs: Any additional data which is needed to compute or update features for each edge in the graph. """ for hook in self._edge_update_forward_pre_hooks.values(): res = hook(self, (edge_index, size, kwargs)) if res is not None: edge_index, size, kwargs = res mutable_size = self._check_input(edge_index, size=None) coll_dict = self._collect(self._edge_user_args, edge_index, mutable_size, kwargs) edge_kwargs = self.inspector.collect_param_data( 'edge_update', coll_dict) out = self.edge_update(**edge_kwargs) for hook in self._edge_update_forward_hooks.values(): res = hook(self, (edge_index, size, kwargs), out) if res is not None: out = res return out @abstractmethod def edge_update(self) -> Tensor: r"""Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to :meth:`edge_updater`. Furthermore, tensors passed to :meth:`edge_updater` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. """ raise NotImplementedError # Inference Decomposition ################################################# @property def decomposed_layers(self) -> int: return self._decomposed_layers @decomposed_layers.setter def decomposed_layers(self, decomposed_layers: int) -> None: if torch.jit.is_scripting(): raise ValueError("Inference decomposition of message passing " "modules is only supported on the Python module") if decomposed_layers == self._decomposed_layers: return # Abort early if nothing to do. self._decomposed_layers = decomposed_layers if decomposed_layers != 1: if hasattr(self.__class__, '_orig_propagate'): self.propagate = self.__class__._orig_propagate.__get__( self, MessagePassing) elif self.explain is None or self.explain is False: if hasattr(self.__class__, '_jinja_propagate'): self.propagate = self.__class__._jinja_propagate.__get__( self, MessagePassing) # Explainability ########################################################## @property def explain(self) -> Optional[bool]: return self._explain @explain.setter def explain(self, explain: Optional[bool]) -> None: if torch.jit.is_scripting(): raise ValueError("Explainability of message passing modules " "is only supported on the Python module") if explain == self._explain: return # Abort early if nothing to do. self._explain = explain if explain is True: assert self.decomposed_layers == 1 self.inspector.remove_signature(self.explain_message) self.inspector.inspect_signature(self.explain_message, exclude=[0]) self._user_args = self.inspector.get_flat_param_names( funcs=['message', 'explain_message', 'aggregate', 'update'], exclude=self.special_args, ) if hasattr(self.__class__, '_orig_propagate'): self.propagate = self.__class__._orig_propagate.__get__( self, MessagePassing) else: self._user_args = self.inspector.get_flat_param_names( funcs=['message', 'aggregate', 'update'], exclude=self.special_args, ) if self.decomposed_layers == 1: if hasattr(self.__class__, '_jinja_propagate'): self.propagate = self.__class__._jinja_propagate.__get__( self, MessagePassing) def explain_message( self, inputs: Tensor, dim_size: Optional[int], ) -> Tensor: # NOTE Replace this method in custom explainers per message-passing # layer to customize how messages shall be explained, e.g., via: # conv.explain_message = explain_message.__get__(conv, MessagePassing) # see stackoverflow.com: 394770/override-a-method-at-instance-level edge_mask = self._edge_mask if edge_mask is None: raise ValueError("Could not find a pre-defined 'edge_mask' " "to explain. Did you forget to initialize it?") if self._apply_sigmoid: edge_mask = edge_mask.sigmoid() # Some ops add self-loops to `edge_index`. We need to do the same for # `edge_mask` (but do not train these entries). if inputs.size(self.node_dim) != edge_mask.size(0): assert dim_size is not None edge_mask = edge_mask[self._loop_mask] loop = edge_mask.new_ones(dim_size) edge_mask = torch.cat([edge_mask, loop], dim=0) assert inputs.size(self.node_dim) == edge_mask.size(0) size = [1] * inputs.dim() size[self.node_dim] = -1 return inputs * edge_mask.view(size) # Hooks ################################################################### def register_propagate_forward_pre_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward pre-hook on the module. The hook will be called every time before :meth:`propagate` is invoked. It should have the following signature: .. code-block:: python hook(module, inputs) -> None or modified input The hook can modify the input. Input keyword arguments are passed to the hook as a dictionary in :obj:`inputs[-1]`. Returns a :class:`torch.utils.hooks.RemovableHandle` that can be used to remove the added hook by calling :obj:`handle.remove()`. """ handle = RemovableHandle(self._propagate_forward_pre_hooks) self._propagate_forward_pre_hooks[handle.id] = hook return handle def register_propagate_forward_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward hook on the module. The hook will be called every time after :meth:`propagate` has computed an output. It should have the following signature: .. code-block:: python hook(module, inputs, output) -> None or modified output The hook can modify the output. Input keyword arguments are passed to the hook as a dictionary in :obj:`inputs[-1]`. Returns a :class:`torch.utils.hooks.RemovableHandle` that can be used to remove the added hook by calling :obj:`handle.remove()`. """ handle = RemovableHandle(self._propagate_forward_hooks) self._propagate_forward_hooks[handle.id] = hook return handle def register_message_forward_pre_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward pre-hook on the module. The hook will be called every time before :meth:`message` is invoked. See :meth:`register_propagate_forward_pre_hook` for more information. """ handle = RemovableHandle(self._message_forward_pre_hooks) self._message_forward_pre_hooks[handle.id] = hook return handle def register_message_forward_hook(self, hook: Callable) -> RemovableHandle: r"""Registers a forward hook on the module. The hook will be called every time after :meth:`message` has computed an output. See :meth:`register_propagate_forward_hook` for more information. """ handle = RemovableHandle(self._message_forward_hooks) self._message_forward_hooks[handle.id] = hook return handle def register_aggregate_forward_pre_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward pre-hook on the module. The hook will be called every time before :meth:`aggregate` is invoked. See :meth:`register_propagate_forward_pre_hook` for more information. """ handle = RemovableHandle(self._aggregate_forward_pre_hooks) self._aggregate_forward_pre_hooks[handle.id] = hook return handle def register_aggregate_forward_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward hook on the module. The hook will be called every time after :meth:`aggregate` has computed an output. See :meth:`register_propagate_forward_hook` for more information. """ handle = RemovableHandle(self._aggregate_forward_hooks) self._aggregate_forward_hooks[handle.id] = hook return handle def register_message_and_aggregate_forward_pre_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward pre-hook on the module. The hook will be called every time before :meth:`message_and_aggregate` is invoked. See :meth:`register_propagate_forward_pre_hook` for more information. """ handle = RemovableHandle(self._message_and_aggregate_forward_pre_hooks) self._message_and_aggregate_forward_pre_hooks[handle.id] = hook return handle def register_message_and_aggregate_forward_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward hook on the module. The hook will be called every time after :meth:`message_and_aggregate` has computed an output. See :meth:`register_propagate_forward_hook` for more information. """ handle = RemovableHandle(self._message_and_aggregate_forward_hooks) self._message_and_aggregate_forward_hooks[handle.id] = hook return handle def register_edge_update_forward_pre_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward pre-hook on the module. The hook will be called every time before :meth:`edge_update` is invoked. See :meth:`register_propagate_forward_pre_hook` for more information. """ handle = RemovableHandle(self._edge_update_forward_pre_hooks) self._edge_update_forward_pre_hooks[handle.id] = hook return handle def register_edge_update_forward_hook( self, hook: Callable, ) -> RemovableHandle: r"""Registers a forward hook on the module. The hook will be called every time after :meth:`edge_update` has computed an output. See :meth:`register_propagate_forward_hook` for more information. """ handle = RemovableHandle(self._edge_update_forward_hooks) self._edge_update_forward_hooks[handle.id] = hook return handle # TorchScript Support ##################################################### def _set_jittable_templates(self, raise_on_error: bool = False) -> None: root_dir = osp.dirname(osp.realpath(__file__)) jinja_prefix = f'{self.__module__}_{self.__class__.__name__}' # Optimize `propagate()` via `*.jinja` templates: if not self.propagate.__module__.startswith(jinja_prefix): try: if ('propagate' in self.__class__.__dict__ and self.__class__.__dict__['propagate'] != MessagePassing.propagate): raise ValueError("Cannot compile custom 'propagate' " "method") module = module_from_template( module_name=f'{jinja_prefix}_propagate', template_path=osp.join(root_dir, 'propagate.jinja'), tmp_dirname='message_passing', # Keyword arguments: modules=self.inspector._modules, collect_name='collect', signature=self._get_propagate_signature(), collect_param_dict=self.inspector.get_flat_param_dict( ['message', 'aggregate', 'update']), message_args=self.inspector.get_param_names('message'), aggregate_args=self.inspector.get_param_names('aggregate'), message_and_aggregate_args=self.inspector.get_param_names( 'message_and_aggregate'), update_args=self.inspector.get_param_names('update'), fuse=self.fuse, ) self.__class__._orig_propagate = self.__class__.propagate self.__class__._jinja_propagate = module.propagate self.__class__.propagate = module.propagate self.__class__.collect = module.collect except Exception as e: # pragma: no cover if raise_on_error: raise e self.__class__._orig_propagate = self.__class__.propagate self.__class__._jinja_propagate = self.__class__.propagate # Optimize `edge_updater()` via `*.jinja` templates (if implemented): if (self.inspector.implements('edge_update') and not self.edge_updater.__module__.startswith(jinja_prefix)): try: if ('edge_updater' in self.__class__.__dict__ and self.__class__.__dict__['edge_updater'] != MessagePassing.edge_updater): raise ValueError("Cannot compile custom 'edge_updater' " "method") module = module_from_template( module_name=f'{jinja_prefix}_edge_updater', template_path=osp.join(root_dir, 'edge_updater.jinja'), tmp_dirname='message_passing', # Keyword arguments: modules=self.inspector._modules, collect_name='edge_collect', signature=self._get_edge_updater_signature(), collect_param_dict=self.inspector.get_param_dict( 'edge_update'), ) self.__class__._orig_edge_updater = self.__class__.edge_updater self.__class__._jinja_edge_updater = module.edge_updater self.__class__.edge_updater = module.edge_updater self.__class__.edge_collect = module.edge_collect except Exception as e: # pragma: no cover if raise_on_error: raise e self.__class__._orig_edge_updater = self.__class__.edge_updater self.__class__._jinja_edge_updater = ( self.__class__.edge_updater) def _get_propagate_signature(self) -> Signature: param_dict = self.inspector.get_params_from_method_call( 'propagate', exclude=[0, 'edge_index', 'size']) update_signature = self.inspector.get_signature('update') return Signature( param_dict=param_dict, return_type=update_signature.return_type, return_type_repr=update_signature.return_type_repr, ) def _get_edge_updater_signature(self) -> Signature: param_dict = self.inspector.get_params_from_method_call( 'edge_updater', exclude=[0, 'edge_index', 'size']) edge_update_signature = self.inspector.get_signature('edge_update') return Signature( param_dict=param_dict, return_type=edge_update_signature.return_type, return_type_repr=edge_update_signature.return_type_repr, ) def jittable(self, typing: Optional[str] = None) -> 'MessagePassing': r"""Analyzes the :class:`MessagePassing` instance and produces a new jittable module that can be used in combination with :meth:`torch.jit.script`. .. note:: :meth:`jittable` is deprecated and a no-op from :pyg:`PyG` 2.5 onwards. """ warnings.warn( f"'{self.__class__.__name__}.jittable' is deprecated " f"and a no-op. Please remove its usage.", stacklevel=2) return self ================================================ FILE: torch_geometric/nn/conv/mf_conv.py ================================================ from typing import Tuple, Union import torch from torch import Tensor from torch.nn import ModuleList from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor from torch_geometric.utils import degree, spmm class MFConv(MessagePassing): r"""The graph neural network operator from the `"Convolutional Networks on Graphs for Learning Molecular Fingerprints" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}^{(\deg(i))}_1 \mathbf{x}_i + \mathbf{W}^{(\deg(i))}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j which trains a distinct weight matrix for each possible vertex degree. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. max_degree (int, optional): The maximum node degree to consider when updating weights (default: :obj:`10`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **inputs:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V_t}|, F_{out})` if bipartite """ def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, max_degree: int = 10, bias=True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.max_degree = max_degree if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lins_l = ModuleList([ Linear(in_channels[0], out_channels, bias=bias) for _ in range(max_degree + 1) ]) self.lins_r = ModuleList([ Linear(in_channels[1], out_channels, bias=False) for _ in range(max_degree + 1) ]) self.reset_parameters() def reset_parameters(self): super().reset_parameters() for lin in self.lins_l: lin.reset_parameters() for lin in self.lins_r: lin.reset_parameters() def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) x_r = x[1] deg = x[0] # Dummy. if isinstance(edge_index, SparseTensor): deg = edge_index.storage.rowcount() elif isinstance(edge_index, Tensor): i = 1 if self.flow == 'source_to_target' else 0 N = x[0].size(self.node_dim) N = size[1] if size is not None else N N = x_r.size(self.node_dim) if x_r is not None else N deg = degree(edge_index[i], N, dtype=torch.long) deg.clamp_(max=self.max_degree) # propagate_type: (x: OptPairTensor) h = self.propagate(edge_index, x=x, size=size) out = h.new_empty(list(h.size())[:-1] + [self.out_channels]) for i, (lin_l, lin_r) in enumerate(zip(self.lins_l, self.lins_r)): idx = (deg == i).nonzero().view(-1) r = lin_l(h.index_select(self.node_dim, idx)) if x_r is not None: r = r + lin_r(x_r.index_select(self.node_dim, idx)) out.index_copy_(self.node_dim, idx, r) return out def message(self, x_j: Tensor) -> Tensor: return x_j def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: if isinstance(adj_t, SparseTensor): adj_t = adj_t.set_value(None, layout=None) return spmm(adj_t, x[0], reduce=self.aggr) ================================================ FILE: torch_geometric/nn/conv/mixhop_conv.py ================================================ from typing import List, Optional import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import spmm class MixHopConv(MessagePassing): r"""The Mix-Hop graph convolutional operator from the `"MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" `_ paper. .. math:: \mathbf{X}^{\prime}={\Bigg\Vert}_{p\in P} {\left( \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^p \mathbf{X} \mathbf{\Theta}, where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. powers (List[int], optional): The powers of the adjacency matrix to use. (default: :obj:`[0, 1, 2]`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, |P| \cdot F_{out})` """ def __init__( self, in_channels: int, out_channels: int, powers: Optional[List[int]] = None, add_self_loops: bool = True, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) if powers is None: powers = [0, 1, 2] self.in_channels = in_channels self.out_channels = out_channels self.powers = powers self.add_self_loops = add_self_loops self.lins = torch.nn.ModuleList([ Linear(in_channels, out_channels, bias=False) if p in powers else torch.nn.Identity() for p in range(max(powers) + 1) ]) if bias: self.bias = Parameter(torch.empty(len(powers) * out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): for lin in self.lins: if hasattr(lin, 'reset_parameters'): lin.reset_parameters() zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, x.dtype) outs = [self.lins[0](x)] for lin in self.lins[1:]: # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight) outs.append(lin.forward(x)) out = torch.cat([outs[p] for p in self.powers], dim=-1) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, powers={self.powers})') ================================================ FILE: torch_geometric/nn/conv/nn_conv.py ================================================ from typing import Callable, Tuple, Union import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import reset, zeros from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size class NNConv(MessagePassing): r"""The continuous kernel-based convolutional operator from the `"Neural Message Passing for Quantum Chemistry" `_ paper. This convolution is also known as the edge-conditioned convolution from the `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs" `_ paper (see :class:`torch_geometric.nn.conv.ECConv` for an alias): .. math:: \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* a MLP. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps edge features :obj:`edge_attr` of shape :obj:`[-1, num_edge_features]` to shape :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add the transformed root node features to the output. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, nn: Callable, aggr: str = 'add', root_weight: bool = True, bias: bool = True, **kwargs): super().__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.nn = nn self.root_weight = root_weight if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.in_channels_l = in_channels[0] if root_weight: self.lin = Linear(in_channels[1], out_channels, bias=False, weight_initializer='uniform') if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.nn) if self.root_weight: self.lin.reset_parameters() zeros(self.bias) def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None and self.root_weight: out = out + self.lin(x_r) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: weight = self.nn(edge_attr) weight = weight.view(-1, self.in_channels_l, self.out_channels) return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, aggr={self.aggr}, nn={self.nn})') ================================================ FILE: torch_geometric/nn/conv/pan_conv.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, SparseTensor from torch_geometric.utils import is_torch_sparse_tensor, spmm class PANConv(MessagePassing): r"""The path integral based convolutional operator from the `"Path Integral Based Convolution and Pooling for Graph Neural Networks" `_ paper. .. math:: \mathbf{X}^{\prime} = \mathbf{M} \mathbf{X} \mathbf{W} where :math:`\mathbf{M}` denotes the normalized and learned maximal entropy transition (MET) matrix that includes neighbors up to :obj:`filter_size` hops: .. math:: \mathbf{M} = \mathbf{Z}^{-1/2} \sum_{n=0}^L e^{-\frac{E(n)}{T}} \mathbf{A}^n \mathbf{Z}^{-1/2} Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. filter_size (int): The filter size :math:`L`. **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, filter_size: int, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.filter_size = filter_size self.lin = Linear(in_channels, out_channels) self.weight = Parameter(torch.empty(filter_size + 1)) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() self.weight.data.fill_(0.5) def forward( self, x: Tensor, edge_index: Adj, ) -> Tuple[Tensor, SparseTensor]: adj_t: Optional[SparseTensor] = None if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): # TODO Handle PyTorch sparse tensor directly. if edge_index.layout == torch.sparse_coo: adj_t = SparseTensor.from_torch_sparse_coo_tensor( edge_index) elif edge_index.layout == torch.sparse_csr: adj_t = SparseTensor.from_torch_sparse_csr_tensor( edge_index) else: raise ValueError(f"Unexpected sparse tensor layout " f"(got '{edge_index.layout}')") else: adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], sparse_sizes=(x.size(0), x.size(0))) elif isinstance(edge_index, SparseTensor): adj_t = edge_index.set_value(None) adj_t = self.panentropy(adj_t, dtype=x.dtype) deg = adj_t.storage.rowcount().to(x.dtype) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0. M = deg_inv_sqrt.view(1, -1) * adj_t * deg_inv_sqrt.view(-1, 1) out = self.propagate(M, x=x, edge_weight=None) out = self.lin(out) return out, M def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def panentropy(self, adj_t: SparseTensor, dtype: Optional[int] = None) -> SparseTensor: if not adj_t.has_value(): adj_t = adj_t.fill_value(1.0) tmp = SparseTensor.eye(adj_t.size(0), adj_t.size(1), has_value=True, dtype=dtype, device=adj_t.device()) tmp = tmp.mul_nnz(self.weight[0], layout='coo') outs = [tmp] for i in range(1, self.filter_size + 1): tmp = tmp @ adj_t tmp = tmp.mul_nnz(self.weight[i], layout='coo') outs += [tmp] row = torch.cat([out.storage.row() for out in outs], dim=0) col = torch.cat([out.storage.col() for out in outs], dim=0) value = torch.cat([out.storage.value() for out in outs], dim=0) out = SparseTensor(row=row, col=col, value=value, sparse_sizes=adj_t.sparse_sizes()).coalesce() return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, filter_size={self.filter_size})') ================================================ FILE: torch_geometric/nn/conv/pdn_conv.py ================================================ import torch from torch import Tensor from torch.nn import Linear, Parameter, ReLU, Sequential, Sigmoid from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import spmm class PDNConv(MessagePassing): r"""The pathfinder discovery network convolutional operator from the `"Pathfinder Discovery Networks for Neural Message Passing" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{i\}}f_{\Theta}(\textbf{e}_{(j,i)}) \cdot f_{\Omega}(\mathbf{x}_{j}) where :math:`z_{i,j}` denotes the edge feature vector from source node :math:`j` to target node :math:`i`, and :math:`\mathbf{x}_{j}` denotes the node feature vector of node :math:`j`. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. edge_dim (int): Edge feature dimensionality. hidden_channels (int): Hidden edge feature dimensionality. add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) normalize (bool, optional): Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, edge_dim: int, hidden_channels: int, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs): kwargs.setdefault("aggr", "add") super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.edge_dim = edge_dim self.hidden_channels = hidden_channels self.add_self_loops = add_self_loops self.normalize = normalize self.lin = Linear(in_channels, out_channels, bias=False) self.mlp = Sequential( Linear(edge_dim, hidden_channels), ReLU(inplace=True), Linear(hidden_channels, 1), Sigmoid(), ) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() glorot(self.lin.weight) glorot(self.mlp[0].weight) glorot(self.mlp[2].weight) zeros(self.mlp[0].bias) zeros(self.mlp[2].bias) zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: if isinstance(edge_index, SparseTensor): edge_attr = edge_index.storage.value() if edge_attr is not None: edge_attr = self.mlp(edge_attr).squeeze(-1) if isinstance(edge_index, SparseTensor): edge_index = edge_index.set_value(edge_attr, layout='coo') if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_attr = gcn_norm(edge_index, edge_attr, x.size(self.node_dim), False, self.add_self_loops, self.flow, x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm(edge_index, None, x.size(self.node_dim), False, self.add_self_loops, self.flow, x.dtype) x = self.lin(x) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_attr) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/conv/pna_conv.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import Tensor from torch.nn import ModuleList, Sequential from torch.utils.data import DataLoader from torch_geometric.nn.aggr import DegreeScalerAggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import reset from torch_geometric.nn.resolver import activation_resolver from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import degree class PNAConv(MessagePassing): r"""The Principal Neighbourhood Aggregation graph convolution operator from the `"Principal Neighbourhood Aggregation for Graph Nets" `_ paper. .. math:: \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right) with .. math:: \bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}}, where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` denote MLPs. .. note:: For an example of using :obj:`PNAConv`, see `examples/pna.py `_. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. aggregators (List[str]): Set of aggregation function identifiers, namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"var"` and :obj:`"std"`. scalers (List[str]): Set of scaling function identifiers, namely :obj:`"identity"`, :obj:`"amplification"`, :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. deg (torch.Tensor): Histogram of in-degrees of nodes in the training set, used by scalers to normalize. edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default :obj:`None`) towers (int, optional): Number of towers (default: :obj:`1`). pre_layers (int, optional): Number of transformation layers before aggregation (default: :obj:`1`). post_layers (int, optional): Number of transformation layers after aggregation (default: :obj:`1`). divide_input (bool, optional): Whether the input features should be split between towers or not (default: :obj:`False`). act (str or callable, optional): Pre- and post-layer activation function to use. (default: :obj:`"relu"`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) train_norm (bool, optional): Whether normalization parameters are trainable. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__( self, in_channels: int, out_channels: int, aggregators: List[str], scalers: List[str], deg: Tensor, edge_dim: Optional[int] = None, towers: int = 1, pre_layers: int = 1, post_layers: int = 1, divide_input: bool = False, act: Union[str, Callable, None] = "relu", act_kwargs: Optional[Dict[str, Any]] = None, train_norm: bool = False, **kwargs, ): aggr = DegreeScalerAggregation(aggregators, scalers, deg, train_norm) super().__init__(aggr=aggr, node_dim=0, **kwargs) if divide_input: assert in_channels % towers == 0 assert out_channels % towers == 0 self.in_channels = in_channels self.out_channels = out_channels self.edge_dim = edge_dim self.towers = towers self.divide_input = divide_input self.F_in = in_channels // towers if divide_input else in_channels self.F_out = self.out_channels // towers if self.edge_dim is not None: self.edge_encoder = Linear(edge_dim, self.F_in) self.pre_nns = ModuleList() self.post_nns = ModuleList() for _ in range(towers): modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)] for _ in range(pre_layers - 1): modules += [activation_resolver(act, **(act_kwargs or {}))] modules += [Linear(self.F_in, self.F_in)] self.pre_nns.append(Sequential(*modules)) in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in modules = [Linear(in_channels, self.F_out)] for _ in range(post_layers - 1): modules += [activation_resolver(act, **(act_kwargs or {}))] modules += [Linear(self.F_out, self.F_out)] self.post_nns.append(Sequential(*modules)) self.lin = Linear(out_channels, out_channels) self.reset_parameters() def reset_parameters(self): super().reset_parameters() if self.edge_dim is not None: self.edge_encoder.reset_parameters() for nn in self.pre_nns: reset(nn) for nn in self.post_nns: reset(nn) self.lin.reset_parameters() def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: if self.divide_input: x = x.view(-1, self.towers, self.F_in) else: x = x.view(-1, 1, self.F_in).repeat(1, self.towers, 1) # propagate_type: (x: Tensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr) out = torch.cat([x, out], dim=-1) outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)] out = torch.cat(outs, dim=1) return self.lin(out) def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor: h: Tensor = x_i # Dummy. if edge_attr is not None: edge_attr = self.edge_encoder(edge_attr) edge_attr = edge_attr.view(-1, 1, self.F_in) edge_attr = edge_attr.repeat(1, self.towers, 1) h = torch.cat([x_i, x_j, edge_attr], dim=-1) else: h = torch.cat([x_i, x_j], dim=-1) hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] return torch.stack(hs, dim=1) def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, towers={self.towers}, ' f'edge_dim={self.edge_dim})') @staticmethod def get_degree_histogram(loader: DataLoader) -> Tensor: r"""Returns the degree histogram to be used as input for the :obj:`deg` argument in :class:`PNAConv`. """ deg_histogram = torch.zeros(1, dtype=torch.long) for data in loader: deg = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg_bincount = torch.bincount(deg, minlength=deg_histogram.numel()) deg_histogram = deg_histogram.to(deg_bincount.device) if deg_bincount.numel() > deg_histogram.numel(): deg_bincount[:deg_histogram.size(0)] += deg_histogram deg_histogram = deg_bincount else: assert deg_bincount.numel() == deg_histogram.numel() deg_histogram += deg_bincount return deg_histogram ================================================ FILE: torch_geometric/nn/conv/point_conv.py ================================================ from typing import Callable, Optional, Union import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset from torch_geometric.typing import ( Adj, OptTensor, PairOptTensor, PairTensor, SparseTensor, torch_sparse, ) from torch_geometric.utils import add_self_loops, remove_self_loops class PointNetConv(MessagePassing): r"""The PointNet set layer from the `"PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation" `_ and `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" `_ papers. .. math:: \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \mathbf{p}_j - \mathbf{p}_i) \right), where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.* MLPs, and :math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of each point. Args: local_nn (torch.nn.Module, optional): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and relative spatial coordinates :obj:`pos_j - pos_i` of shape :obj:`[-1, in_channels + num_dimensions]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) global_nn (torch.nn.Module, optional): A neural network :math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, final_out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, positions :math:`(|\mathcal{V}|, 3)` or :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs): kwargs.setdefault('aggr', 'max') super().__init__(**kwargs) self.local_nn = local_nn self.global_nn = global_nn self.add_self_loops = add_self_loops self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.local_nn) reset(self.global_nn) def forward( self, x: Union[OptTensor, PairOptTensor], pos: Union[Tensor, PairTensor], edge_index: Adj, ) -> Tensor: if not isinstance(x, tuple): x = (x, None) if isinstance(pos, Tensor): pos = (pos, pos) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops( edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0))) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) # propagate_type: (x: PairOptTensor, pos: PairTensor) out = self.propagate(edge_index, x=x, pos=pos) if self.global_nn is not None: out = self.global_nn(out) return out def message(self, x_j: Optional[Tensor], pos_i: Tensor, pos_j: Tensor) -> Tensor: msg = pos_j - pos_i if x_j is not None: msg = torch.cat([x_j, msg], dim=1) if self.local_nn is not None: msg = self.local_nn(msg) return msg def __repr__(self) -> str: return (f'{self.__class__.__name__}(local_nn={self.local_nn}, ' f'global_nn={self.global_nn})') ================================================ FILE: torch_geometric/nn/conv/point_gnn_conv.py ================================================ import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset from torch_geometric.typing import Adj class PointGNNConv(MessagePassing): r"""The PointGNN operator from the `"Point-GNN: Graph Neural Network for 3D Object Detection in a Point Cloud" `_ paper. .. math:: \Delta \textrm{pos}_i &= h_{\mathbf{\Theta}}(\mathbf{x}_i) \mathbf{e}_{j,i} &= f_{\mathbf{\Theta}}(\textrm{pos}_j - \textrm{pos}_i + \Delta \textrm{pos}_i, \mathbf{x}_j) \mathbf{x}^{\prime}_i &= g_{\mathbf{\Theta}}(\max_{j \in \mathcal{N}(i)} \mathbf{e}_{j,i}) + \mathbf{x}_i The relative position is used in the message passing step to introduce global translation invariance. To also counter shifts in the local neighborhood of the center node, the authors propose to utilize an alignment offset. The graph should be statically constructed using radius-based cutoff. Args: mlp_h (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features of size :math:`F_{in}` to three-dimensional coordination offsets :math:`\Delta \textrm{pos}_i`. mlp_f (torch.nn.Module): A neural network :math:`f_{\mathbf{\Theta}}` that computes :math:`\mathbf{e}_{j,i}` from the features of neighbors of size :math:`F_{in}` and the three-dimensional vector :math:`\textrm{pos_j} - \textrm{pos_i} + \Delta \textrm{pos}_i`. mlp_g (torch.nn.Module): A neural network :math:`g_{\mathbf{\Theta}}` that maps the aggregated edge features back to :math:`F_{in}`. **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, positions :math:`(|\mathcal{V}|, 3)`, edge indices :math:`(2, |\mathcal{E}|)`, - **output:** node features :math:`(|\mathcal{V}|, F_{in})` """ def __init__( self, mlp_h: torch.nn.Module, mlp_f: torch.nn.Module, mlp_g: torch.nn.Module, **kwargs, ): kwargs.setdefault('aggr', 'max') super().__init__(**kwargs) self.mlp_h = mlp_h self.mlp_f = mlp_f self.mlp_g = mlp_g self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.mlp_h) reset(self.mlp_f) reset(self.mlp_g) def forward(self, x: Tensor, pos: Tensor, edge_index: Adj) -> Tensor: # propagate_type: (x: Tensor, pos: Tensor) out = self.propagate(edge_index, x=x, pos=pos) out = self.mlp_g(out) return x + out def message(self, pos_j: Tensor, pos_i: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor: delta = self.mlp_h(x_i) e = torch.cat([pos_j - pos_i + delta, x_j], dim=-1) return self.mlp_f(e) def __repr__(self) -> str: return (f'{self.__class__.__name__}(\n' f' mlp_h={self.mlp_h},\n' f' mlp_f={self.mlp_f},\n' f' mlp_g={self.mlp_g},\n' f')') ================================================ FILE: torch_geometric/nn/conv/point_transformer_conv.py ================================================ from typing import Callable, Optional, Tuple, Union from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import reset from torch_geometric.typing import ( Adj, OptTensor, PairTensor, SparseTensor, torch_sparse, ) from torch_geometric.utils import add_self_loops, remove_self_loops, softmax class PointTransformerConv(MessagePassing): r"""The Point Transformer layer from the `"Point Transformer" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 \mathbf{x}_j + \delta_{ij} \right), where the attention coefficients :math:`\alpha_{i,j}` and positional embedding :math:`\delta_{ij}` are computed as .. math:: \alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + \delta_{i,j}) \right) and .. math:: \delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j), with :math:`\gamma_\mathbf{\Theta}` and :math:`h_\mathbf{\Theta}` denoting neural networks, *i.e.* MLPs, and :math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of each point. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. pos_nn (torch.nn.Module, optional): A neural network :math:`h_\mathbf{\Theta}` which maps relative spatial coordinates :obj:`pos_j - pos_i` of shape :obj:`[-1, 3]` to shape :obj:`[-1, out_channels]`. Will default to a :class:`torch.nn.Linear` transformation if not further specified. (default: :obj:`None`) attn_nn (torch.nn.Module, optional): A neural network :math:`\gamma_\mathbf{\Theta}` which maps transformed node features of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, out_channels]`. (default: :obj:`None`) add_self_loops (bool, optional) : If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, positions :math:`(|\mathcal{V}|, 3)` or :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, pos_nn: Optional[Callable] = None, attn_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.add_self_loops = add_self_loops if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.pos_nn = pos_nn if self.pos_nn is None: self.pos_nn = Linear(3, out_channels) self.attn_nn = attn_nn self.lin = Linear(in_channels[0], out_channels, bias=False) self.lin_src = Linear(in_channels[0], out_channels, bias=False) self.lin_dst = Linear(in_channels[1], out_channels, bias=False) self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.pos_nn) if self.attn_nn is not None: reset(self.attn_nn) self.lin.reset_parameters() self.lin_src.reset_parameters() self.lin_dst.reset_parameters() def forward( self, x: Union[Tensor, PairTensor], pos: Union[Tensor, PairTensor], edge_index: Adj, ) -> Tensor: if isinstance(x, Tensor): alpha = (self.lin_src(x), self.lin_dst(x)) x = (self.lin(x), x) else: alpha = (self.lin_src(x[0]), self.lin_dst(x[1])) x = (self.lin(x[0]), x[1]) if isinstance(pos, Tensor): pos = (pos, pos) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops( edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0))) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) # propagate_type: (x: PairTensor, pos: PairTensor, alpha: PairTensor) out = self.propagate(edge_index, x=x, pos=pos, alpha=alpha) return out def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor, alpha_i: Tensor, alpha_j: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: delta = self.pos_nn(pos_i - pos_j) alpha = alpha_i - alpha_j + delta if self.attn_nn is not None: alpha = self.attn_nn(alpha) alpha = softmax(alpha, index, ptr, size_i) return alpha * (x_j + delta) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/conv/ppf_conv.py ================================================ from typing import Callable, Optional, Union import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset from torch_geometric.typing import ( Adj, OptTensor, PairOptTensor, PairTensor, SparseTensor, torch_sparse, ) from torch_geometric.utils import add_self_loops, remove_self_loops def get_angle(v1: Tensor, v2: Tensor) -> Tensor: return torch.atan2( torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1)) def point_pair_features(pos_i: Tensor, pos_j: Tensor, normal_i: Tensor, normal_j: Tensor) -> Tensor: pseudo = pos_j - pos_i return torch.stack([ pseudo.norm(p=2, dim=1), get_angle(normal_i, pseudo), get_angle(normal_j, pseudo), get_angle(normal_i, normal_j) ], dim=1) class PPFConv(MessagePassing): r"""The PPFNet operator from the `"PPFNet: Global Context Aware Local Features for Robust 3D Point Matching" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right) where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *.i.e.* MLPs, which takes in node features and :class:`torch_geometric.transforms.PointPairFeatures`. Args: local_nn (torch.nn.Module, optional): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and relative spatial coordinates :obj:`pos_j - pos_i` of shape :obj:`[-1, in_channels + num_dimensions]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) global_nn (torch.nn.Module, optional): A neural network :math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, final_out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, positions :math:`(|\mathcal{V}|, 3)` or :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, point normals :math:`(|\mathcal{V}, 3)` or :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs): kwargs.setdefault('aggr', 'max') super().__init__(**kwargs) self.local_nn = local_nn self.global_nn = global_nn self.add_self_loops = add_self_loops self.reset_parameters() def reset_parameters(self): super().reset_parameters() reset(self.local_nn) reset(self.global_nn) def forward( self, x: Union[OptTensor, PairOptTensor], pos: Union[Tensor, PairTensor], normal: Union[Tensor, PairTensor], edge_index: Adj, ) -> Tensor: if not isinstance(x, tuple): x = (x, None) if isinstance(pos, Tensor): pos = (pos, pos) if isinstance(normal, Tensor): normal = (normal, normal) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=pos[1].size(0)) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) # propagate_type: (x: PairOptTensor, pos: PairTensor, # normal: PairTensor) out = self.propagate(edge_index, x=x, pos=pos, normal=normal) if self.global_nn is not None: out = self.global_nn(out) return out def message(self, x_j: OptTensor, pos_i: Tensor, pos_j: Tensor, normal_i: Tensor, normal_j: Tensor) -> Tensor: msg = point_pair_features(pos_i, pos_j, normal_i, normal_j) if x_j is not None: msg = torch.cat([x_j, msg], dim=1) if self.local_nn is not None: msg = self.local_nn(msg) return msg def __repr__(self) -> str: return (f'{self.__class__.__name__}(local_nn={self.local_nn}, ' f'global_nn={self.global_nn})') ================================================ FILE: torch_geometric/nn/conv/propagate.jinja ================================================ import typing from typing import Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.utils import is_sparse from torch_geometric.typing import Size, SparseTensor {% for module in modules %} from {{module}} import * {%- endfor %} {% include "collect.jinja" %} def propagate( self, edge_index: Union[Tensor, SparseTensor], {%- for param in signature.param_dict.values() %} {{param.name}}: {{param.type_repr}}, {%- endfor %} size: Size = None, ) -> {{signature.return_type_repr}}: # Begin Propagate Forward Pre Hook ######################################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._propagate_forward_pre_hooks.values(): hook_kwargs = dict( {%- for name in signature.param_dict %} {{name}}={{name}}, {%- endfor %} ) res = hook(self, (edge_index, size, hook_kwargs)) if res is not None: edge_index, size, hook_kwargs = res {%- for name in signature.param_dict %} {{name}} = hook_kwargs['{{name}}'] {%- endfor %} # End Propagate Forward Pre Hook ########################################### mutable_size = self._check_input(edge_index, size) # Run "fused" message and aggregation (if applicable). fuse = False if self.fuse: if is_sparse(edge_index): fuse = True elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): if self.SUPPORTS_FUSED_EDGE_INDEX and edge_index.is_sorted_by_col: fuse = True if fuse: {%- if fuse %} # Begin Message and Aggregate Forward Pre Hook ######################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._message_and_aggregate_forward_pre_hooks.values(): hook_kwargs = dict( {%- for name in message_and_aggregate_args %} {{name}}={{name}}, {%- endfor %} ) res = hook(self, (edge_index, hook_kwargs)) if res is not None: edge_index, hook_kwargs = res {%- for name in message_and_aggregate_args %} {{name}} = hook_kwargs['{{name}}'] {%- endfor %} # End Message and Aggregate Forward Pre Hook ########################## out = self.message_and_aggregate( edge_index, {%- for name in message_and_aggregate_args %} {{name}}, {%- endfor %} ) # Begin Message and Aggregate Forward Hook ############################# if not torch.jit.is_scripting() and not is_compiling(): for hook in self._message_and_aggregate_forward_hooks.values(): hook_kwargs = dict( {%- for name in message_and_aggregate_args %} {{name}}={{name}}, {%- endfor %} ) res = hook(self, (edge_index, hook_kwargs, ), out) out = res if res is not None else out # End Message and Aggregate Forward Hook ############################### out = self.update( out, {%- for name in update_args %} {{name}}={{name}}, {%- endfor %} ) {%- else %} raise NotImplementedError("'message_and_aggregate' not implemented") {%- endif %} else: kwargs = self.{{collect_name}}( edge_index, {%- for name in signature.param_dict %} {{name}}, {%- endfor %} mutable_size, ) # Begin Message Forward Pre Hook ####################################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._message_forward_pre_hooks.values(): hook_kwargs = dict( {%- for name in message_args %} {{name}}=kwargs.{{name}}, {%- endfor %} ) res = hook(self, (hook_kwargs, )) hook_kwargs = res[0] if isinstance(res, tuple) else res if res is not None: kwargs = CollectArgs( {%- for name in collect_param_dict %} {%- if name in message_args %} {{name}}=hook_kwargs['{{name}}'], {%- else %} {{name}}=kwargs.{{name}}, {%- endif %} {%- endfor %} ) # End Message Forward Pre Hook ######################################### out = self.message( {%- for name in message_args %} {{name}}=kwargs.{{name}}, {%- endfor %} ) # Begin Message Forward Hook ########################################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._message_forward_hooks.values(): hook_kwargs = dict( {%- for name in message_args %} {{name}}=kwargs.{{name}}, {%- endfor %} ) res = hook(self, (hook_kwargs, ), out) out = res if res is not None else out # End Message Forward Hook ############################################# # Begin Aggregate Forward Pre Hook ##################################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._aggregate_forward_pre_hooks.values(): hook_kwargs = dict( {%- for name in aggregate_args %} {{name}}=kwargs.{{name}}, {%- endfor %} ) res = hook(self, (hook_kwargs, )) hook_kwargs = res[0] if isinstance(res, tuple) else res if res is not None: kwargs = CollectArgs( {%- for name in collect_param_dict %} {%- if name in aggregate_args %} {{name}}=hook_kwargs['{{name}}'], {%- else %} {{name}}=kwargs.{{name}}, {%- endif %} {%- endfor %} ) # End Aggregate Forward Pre Hook ####################################### out = self.aggregate( out, {%- for name in aggregate_args %} {{name}}=kwargs.{{name}}, {%- endfor %} ) # Begin Aggregate Forward Hook ######################################### if not torch.jit.is_scripting() and not is_compiling(): for hook in self._aggregate_forward_hooks.values(): hook_kwargs = dict( {%- for name in aggregate_args %} {{name}}=kwargs.{{name}}, {%- endfor %} ) res = hook(self, (hook_kwargs, ), out) out = res if res is not None else out # End Aggregate Forward Hook ########################################### out = self.update( out, {%- for name in update_args %} {{name}}=kwargs.{{name}}, {%- endfor %} ) # Begin Propagate Forward Hook ############################################ if not torch.jit.is_scripting() and not is_compiling(): for hook in self._propagate_forward_hooks.values(): hook_kwargs = dict( {%- for name in signature.param_dict %} {{name}}={{name}}, {%- endfor %} ) res = hook(self, (edge_index, mutable_size, hook_kwargs), out) out = res if res is not None else out # End Propagate Forward Hook ############################################## return out ================================================ FILE: torch_geometric/nn/conv/res_gated_graph_conv.py ================================================ from typing import Callable, Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Parameter, Sigmoid from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import Adj, OptTensor, PairTensor class ResGatedGraphConv(MessagePassing): r"""The residual gated graph convolutional operator from the `"Residual Gated Graph ConvNets" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \eta_{i,j} \odot \mathbf{W}_2 \mathbf{x}_j where the gate :math:`\eta_{i,j}` is defined as .. math:: \eta_{i,j} = \sigma(\mathbf{W}_3 \mathbf{x}_i + \mathbf{W}_4 \mathbf{x}_j) with :math:`\sigma` denoting the sigmoid function. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. act (callable, optional): Gating function :math:`\sigma`. (default: :meth:`torch.nn.Sigmoid()`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **inputs:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V_t}|, F_{out})` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, act: Optional[Callable] = Sigmoid(), edge_dim: Optional[int] = None, root_weight: bool = True, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.act = act self.edge_dim = edge_dim self.root_weight = root_weight if isinstance(in_channels, int): in_channels = (in_channels, in_channels) edge_dim = edge_dim if edge_dim is not None else 0 self.lin_key = Linear(in_channels[1] + edge_dim, out_channels) self.lin_query = Linear(in_channels[0] + edge_dim, out_channels) self.lin_value = Linear(in_channels[0] + edge_dim, out_channels) if root_weight: self.lin_skip = Linear(in_channels[1], out_channels, bias=False) else: self.register_parameter('lin_skip', None) if bias: self.bias = Parameter(Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_key.reset_parameters() self.lin_query.reset_parameters() self.lin_value.reset_parameters() if self.lin_skip is not None: self.lin_skip.reset_parameters() if self.bias is not None: zeros(self.bias) def forward( self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # In case edge features are not given, we can compute key, query and # value tensors in node-level space, which is a bit more efficient: if self.edge_dim is None: k = self.lin_key(x[1]) q = self.lin_query(x[0]) v = self.lin_value(x[0]) else: k, q, v = x[1], x[0], x[0] # propagate_type: (k: Tensor, q: Tensor, v: Tensor, # edge_attr: OptTensor) out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr) if self.root_weight: out = out + self.lin_skip(x[1]) if self.bias is not None: out = out + self.bias return out def message(self, k_i: Tensor, q_j: Tensor, v_j: Tensor, edge_attr: OptTensor) -> Tensor: assert (edge_attr is not None) == (self.edge_dim is not None) if edge_attr is not None: k_i = self.lin_key(torch.cat([k_i, edge_attr], dim=-1)) q_j = self.lin_query(torch.cat([q_j, edge_attr], dim=-1)) v_j = self.lin_value(torch.cat([v_j, edge_attr], dim=-1)) return self.act(k_i + q_j) * v_j ================================================ FILE: torch_geometric/nn/conv/rgat_conv.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter, ReLU from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, ones, zeros from torch_geometric.typing import Adj, OptTensor, Size, SparseTensor from torch_geometric.utils import is_torch_sparse_tensor, scatter, softmax from torch_geometric.utils.sparse import set_sparse_value class RGATConv(MessagePassing): r"""The relational graph attentional operator from the `"Relational Graph Attention Networks" `_ paper. Here, attention logits :math:`\mathbf{a}^{(r)}_{i,j}` are computed for each relation type :math:`r` with the help of both query and key kernels, *i.e.* .. math:: \mathbf{q}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{Q}^{(r)} \quad \textrm{and} \quad \mathbf{k}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{K}^{(r)}. Two schemes have been proposed to compute attention logits :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r`: **Additive attention** .. math:: \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j) or **multiplicative attention** .. math:: \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j. If the graph has multi-dimensional edge features :math:`\mathbf{e}^{(r)}_{i,j}`, the attention logits :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r` are computed as .. math:: \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j + \mathbf{W}_2^{(r)}\mathbf{e}^{(r)}_{i,j}) or .. math:: \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j \cdot \mathbf{W}_2^{(r)} \mathbf{e}^{(r)}_{i,j}, respectively. The attention coefficients :math:`\alpha^{(r)}_{i,j}` for each relation type :math:`r` are then obtained via two different attention mechanisms: The **within-relation** attention mechanism .. math:: \alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{k \in \mathcal{N}_r(i)} \exp(\mathbf{a}^{(r)}_{i,k})} or the **across-relation** attention mechanism .. math:: \alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{r^{\prime} \in \mathcal{R}} \sum_{k \in \mathcal{N}_{r^{\prime}}(i)} \exp(\mathbf{a}^{(r^{\prime})}_{i,k})} where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. Edge type needs to be a one-dimensional :obj:`torch.long` tensor which stores a relation identifier :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge. To enhance the discriminative power of attention-based GNNs, this layer further implements four different cardinality preservation options as proposed in the `"Improving Attention Mechanism in Graph Neural Networks via Cardinality Preservation" `_ paper: .. math:: \text{additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + \mathcal{W} \odot \sum_{j \in \mathcal{N}_r(i)} \mathbf{x}^{(r)}_j \text{scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= \psi(|\mathcal{N}_r(i)|) \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j \text{f-additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} (\alpha^{(r)}_{i,j} + 1) \cdot \mathbf{x}^{(r)}_j \text{f-scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= |\mathcal{N}_r(i)| \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j * If :obj:`attention_mode="additive-self-attention"` and :obj:`concat=True`, the layer outputs :obj:`heads * out_channels` features for each node. * If :obj:`attention_mode="multiplicative-self-attention"` and :obj:`concat=True`, the layer outputs :obj:`heads * dim * out_channels` features for each node. * If :obj:`attention_mode="additive-self-attention"` and :obj:`concat=False`, the layer outputs :obj:`out_channels` features for each node. * If :obj:`attention_mode="multiplicative-self-attention"` and :obj:`concat=False`, the layer outputs :obj:`dim * out_channels` features for each node. Please make sure to set the :obj:`in_channels` argument of the next layer accordingly if more than one instance of this layer is used. .. note:: For an example of using :class:`RGATConv`, see `examples/rgat.py `_. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. num_relations (int): Number of relations. num_bases (int, optional): If set, this layer will use the basis-decomposition regularization scheme where :obj:`num_bases` denotes the number of bases to use. (default: :obj:`None`) num_blocks (int, optional): If set, this layer will use the block-diagonal-decomposition regularization scheme where :obj:`num_blocks` denotes the number of blocks to use. (default: :obj:`None`) mod (str, optional): The cardinality preservation option to use. (:obj:`"additive"`, :obj:`"scaled"`, :obj:`"f-additive"`, :obj:`"f-scaled"`, :obj:`None`). (default: :obj:`None`) attention_mechanism (str, optional): The attention mechanism to use (:obj:`"within-relation"`, :obj:`"across-relation"`). (default: :obj:`"across-relation"`) attention_mode (str, optional): The mode to calculate attention logits. (:obj:`"additive-self-attention"`, :obj:`"multiplicative-self-attention"`). (default: :obj:`"additive-self-attention"`) heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) dim (int): Number of dimensions for query and key kernels. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ _alpha: OptTensor def __init__( self, in_channels: int, out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, mod: Optional[str] = None, attention_mechanism: str = "across-relation", attention_mode: str = "additive-self-attention", heads: int = 1, dim: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_dim: Optional[int] = None, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.heads = heads self.negative_slope = negative_slope self.dropout = dropout self.mod = mod self.activation = ReLU() self.concat = concat self.attention_mode = attention_mode self.attention_mechanism = attention_mechanism self.dim = dim self.edge_dim = edge_dim self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.num_blocks = num_blocks mod_types = ['additive', 'scaled', 'f-additive', 'f-scaled'] if (self.attention_mechanism != "within-relation" and self.attention_mechanism != "across-relation"): raise ValueError('attention mechanism must either be ' '"within-relation" or "across-relation"') if (self.attention_mode != "additive-self-attention" and self.attention_mode != "multiplicative-self-attention"): raise ValueError('attention mode must either be ' '"additive-self-attention" or ' '"multiplicative-self-attention"') if self.attention_mode == "additive-self-attention" and self.dim > 1: raise ValueError('"additive-self-attention" mode cannot be ' 'applied when value of d is greater than 1. ' 'Use "multiplicative-self-attention" instead.') if self.dropout > 0.0 and self.mod in mod_types: raise ValueError('mod must be None with dropout value greater ' 'than 0 in order to sample attention ' 'coefficients stochastically') if num_bases is not None and num_blocks is not None: raise ValueError('Can not apply both basis-decomposition and ' 'block-diagonal-decomposition at the same time.') # The learnable parameters to compute both attention logits and # attention coefficients: self.q = Parameter( torch.empty(self.heads * self.out_channels, self.heads * self.dim)) self.k = Parameter( torch.empty(self.heads * self.out_channels, self.heads * self.dim)) if bias and concat: self.bias = Parameter( torch.empty(self.heads * self.dim * self.out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(self.dim * self.out_channels)) else: self.register_parameter('bias', None) if edge_dim is not None: self.lin_edge = Linear(self.edge_dim, self.heads * self.out_channels, bias=False, weight_initializer='glorot') self.e = Parameter( torch.empty(self.heads * self.out_channels, self.heads * self.dim)) else: self.lin_edge = None self.register_parameter('e', None) if num_bases is not None: self.att = Parameter( torch.empty(self.num_relations, self.num_bases)) self.basis = Parameter( torch.empty(self.num_bases, self.in_channels, self.heads * self.out_channels)) elif num_blocks is not None: assert ( self.in_channels % self.num_blocks == 0 and (self.heads * self.out_channels) % self.num_blocks == 0), ( "both 'in_channels' and 'heads * out_channels' must be " "multiple of 'num_blocks' used") self.weight = Parameter( torch.empty(self.num_relations, self.num_blocks, self.in_channels // self.num_blocks, (self.heads * self.out_channels) // self.num_blocks)) else: self.weight = Parameter( torch.empty(self.num_relations, self.in_channels, self.heads * self.out_channels)) self.w = Parameter(torch.ones(self.out_channels)) self.l1 = Parameter(torch.empty(1, self.out_channels)) self.b1 = Parameter(torch.empty(1, self.out_channels)) self.l2 = Parameter(torch.empty(self.out_channels, self.out_channels)) self.b2 = Parameter(torch.empty(1, self.out_channels)) self._alpha = None self.reset_parameters() def reset_parameters(self): super().reset_parameters() if self.num_bases is not None: glorot(self.basis) glorot(self.att) else: glorot(self.weight) glorot(self.q) glorot(self.k) zeros(self.bias) ones(self.l1) zeros(self.b1) torch.full(self.l2.size(), 1 / self.out_channels) zeros(self.b2) if self.lin_edge is not None: glorot(self.lin_edge) glorot(self.e) def forward( self, x: Tensor, edge_index: Adj, edge_type: OptTensor = None, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None, ): r"""Runs the forward pass of the module. Args: x (torch.Tensor): The input node features. Can be either a :obj:`[num_nodes, in_channels]` node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). edge_index (torch.Tensor or SparseTensor): The edge indices. edge_type (torch.Tensor, optional): The one-dimensional relation type/index for each edge in :obj:`edge_index`. Should be only :obj:`None` in case :obj:`edge_index` is of type :class:`torch_sparse.SparseTensor` or :class:`torch.sparse.Tensor`. (default: :obj:`None`) edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) size ((int, int), optional): The shape of the adjacency matrix. (default: :obj:`None`) return_attention_weights (bool, optional): Will additionally return the tuple :obj:`(edge_index, attention_weights)` whenever it is set to a value, regardless of its actual value (might be `True` or `False`), holding the computed attention weights for each edge. (default: :obj:`None`) """ # propagate_type: (x: Tensor, edge_type: OptTensor, # edge_attr: OptTensor) out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x, size=size, edge_attr=edge_attr) alpha = self._alpha assert alpha is not None self._alpha = None if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): # TODO TorchScript requires to return a tuple adj = set_sparse_value(edge_index, alpha) return out, (adj, alpha) else: return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: if self.num_bases is not None: # Basis-decomposition ================= w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) w = w.view(self.num_relations, self.in_channels, self.heads * self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ======= if (x_i.dtype == torch.long and x_j.dtype == torch.long and self.num_blocks is not None): raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') w = self.weight x_i = x_i.view(-1, 1, w.size(1), w.size(2)) x_j = x_j.view(-1, 1, w.size(1), w.size(2)) w = torch.index_select(w, 0, edge_type) outi = torch.einsum('abcd,acde->ace', x_i, w) outi = outi.contiguous().view(-1, self.heads * self.out_channels) outj = torch.einsum('abcd,acde->ace', x_j, w) outj = outj.contiguous().view(-1, self.heads * self.out_channels) else: # No regularization/Basis-decomposition ======================== if self.num_bases is None: w = self.weight w = torch.index_select(w, 0, edge_type) outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2) outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) qi = torch.matmul(outi, self.q) kj = torch.matmul(outj, self.k) alpha_edge, alpha = 0, torch.tensor([0]) if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None, ( "Please set 'edge_dim = edge_attr.size(-1)' while calling the " "RGATConv layer") edge_attributes = self.lin_edge(edge_attr).view( -1, self.heads * self.out_channels) if edge_attributes.size(0) != edge_attr.size(0): edge_attributes = torch.index_select(edge_attributes, 0, edge_type) alpha_edge = torch.matmul(edge_attributes, self.e) if self.attention_mode == "additive-self-attention": if edge_attr is not None: alpha = torch.add(qi, kj) + alpha_edge else: alpha = torch.add(qi, kj) alpha = F.leaky_relu(alpha, self.negative_slope) elif self.attention_mode == "multiplicative-self-attention": if edge_attr is not None: alpha = (qi * kj) * alpha_edge else: alpha = qi * kj if self.attention_mechanism == "within-relation": across_out = torch.zeros_like(alpha) for r in range(self.num_relations): mask = edge_type == r across_out[mask] = softmax(alpha[mask], index[mask]) alpha = across_out elif self.attention_mechanism == "across-relation": alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha if self.mod == "additive": if self.attention_mode == "additive-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, self.out_channels) * ones.view(-1, self.heads, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1) + h) elif self.attention_mode == "multiplicative-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, 1, self.out_channels) * ones.view(-1, self.heads, self.dim, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1) + h) elif self.mod == "scaled": if self.attention_mode == "additive-self-attention": ones = alpha.new_ones(index.size()) degree = scatter(ones, index, dim_size=size_i, reduce='sum')[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1), degree.view(-1, 1, self.out_channels)) elif self.attention_mode == "multiplicative-self-attention": ones = alpha.new_ones(index.size()) degree = scatter(ones, index, dim_size=size_i, reduce='sum')[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1), degree.view(-1, 1, 1, self.out_channels)) elif self.mod == "f-additive": alpha = torch.where(alpha > 0, alpha + 1, alpha) elif self.mod == "f-scaled": ones = alpha.new_ones(index.size()) degree = scatter(ones, index, dim_size=size_i, reduce='sum')[index].unsqueeze(-1) alpha = alpha * degree elif self.training and self.dropout > 0: alpha = F.dropout(alpha, p=self.dropout, training=True) else: alpha = alpha # original if self.attention_mode == "additive-self-attention": return alpha.view(-1, self.heads, 1) * outj.view( -1, self.heads, self.out_channels) else: return (alpha.view(-1, self.heads, self.dim, 1) * outj.view(-1, self.heads, 1, self.out_channels)) def update(self, aggr_out: Tensor) -> Tensor: if self.attention_mode == "additive-self-attention": if self.concat is True: aggr_out = aggr_out.view(-1, self.heads * self.out_channels) else: aggr_out = aggr_out.mean(dim=1) if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out else: if self.concat is True: aggr_out = aggr_out.view( -1, self.heads * self.dim * self.out_channels) else: aggr_out = aggr_out.mean(dim=1) aggr_out = aggr_out.view(-1, self.dim * self.out_channels) if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self) -> str: return '{}({}, {}, heads={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads) ================================================ FILE: torch_geometric/nn/conv/rgcn_conv.py ================================================ from typing import Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Parameter import torch_geometric.backend import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.index import index2ptr from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import ( Adj, OptTensor, SparseTensor, pyg_lib, torch_sparse, ) from torch_geometric.utils import index_sort, one_hot, scatter, spmm def masked_edge_index(edge_index: Adj, edge_mask: Tensor) -> Adj: if isinstance(edge_index, Tensor): return edge_index[:, edge_mask] return torch_sparse.masked_select_nnz(edge_index, edge_mask, layout='coo') class RGCNConv(MessagePassing): r"""The relational graph convolutional operator from the `"Modeling Relational Data with Graph Convolutional Networks" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j, where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. Edge type needs to be a one-dimensional :obj:`torch.long` tensor which stores a relation identifier :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge. .. note:: This implementation is as memory-efficient as possible by iterating over each individual relation type. Therefore, it may result in low GPU utilization in case the graph has a large number of relations. As an alternative approach, :class:`FastRGCNConv` does not iterate over each individual type, but may consume a large amount of memory to compensate. We advise to check out both implementations to see which one fits your needs. .. note:: :class:`RGCNConv` can use `dynamic shapes `_, which means that the shape of the interim tensors can be determined at runtime. If your device doesn't support dynamic shapes, use :class:`FastRGCNConv` instead. Args: in_channels (int or tuple): Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities. In case no input features are given, this argument should correspond to the number of nodes in your graph. out_channels (int): Size of each output sample. num_relations (int): Number of relations. num_bases (int, optional): If set, this layer will use the basis-decomposition regularization scheme where :obj:`num_bases` denotes the number of bases to use. (default: :obj:`None`) num_blocks (int, optional): If set, this layer will use the block-diagonal-decomposition regularization scheme where :obj:`num_blocks` denotes the number of blocks to use. (default: :obj:`None`) aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"mean"`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by :obj:`edge_type`. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, is_sorted: bool = False, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', aggr) super().__init__(node_dim=0, **kwargs) if num_bases is not None and num_blocks is not None: raise ValueError('Can not apply both basis-decomposition and ' 'block-diagonal-decomposition at the same time.') self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.num_blocks = num_blocks self.is_sorted = is_sorted if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.in_channels_l = in_channels[0] self._use_segment_matmul_heuristic_output: torch.jit.Attribute( None, Optional[float]) if num_bases is not None: self.weight = Parameter( torch.empty(num_bases, in_channels[0], out_channels)) self.comp = Parameter(torch.empty(num_relations, num_bases)) elif num_blocks is not None: assert (in_channels[0] % num_blocks == 0 and out_channels % num_blocks == 0) self.weight = Parameter( torch.empty(num_relations, num_blocks, in_channels[0] // num_blocks, out_channels // num_blocks)) self.register_parameter('comp', None) else: self.weight = Parameter( torch.empty(num_relations, in_channels[0], out_channels)) self.register_parameter('comp', None) if root_weight: self.root = Parameter(torch.empty(in_channels[1], out_channels)) else: self.register_parameter('root', None) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() glorot(self.weight) glorot(self.comp) glorot(self.root) zeros(self.bias) def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], edge_index: Adj, edge_type: OptTensor = None): r"""Runs the forward pass of the module. Args: x (torch.Tensor or tuple, optional): The input node features. Can be either a :obj:`[num_nodes, in_channels]` node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). Furthermore, :obj:`x` can be of type :obj:`tuple` denoting source and destination node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_type (torch.Tensor, optional): The one-dimensional relation type/index for each edge in :obj:`edge_index`. Should be only :obj:`None` in case :obj:`edge_index` is of type :class:`torch_sparse.SparseTensor`. (default: :obj:`None`) """ # Convert input features to a pair of node features or node indices. x_l: OptTensor = None if isinstance(x, tuple): x_l = x[0] else: x_l = x if x_l is None: x_l = torch.arange(self.in_channels_l, device=self.weight.device) x_r: Tensor = x_l if isinstance(x, tuple): x_r = x[1] size = (x_l.size(0), x_r.size(0)) if isinstance(edge_index, SparseTensor): edge_type = edge_index.storage.value() assert edge_type is not None # propagate_type: (x: Tensor, edge_type_ptr: OptTensor) out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) weight = self.weight if self.num_bases is not None: # Basis-decomposition ================= weight = (self.comp @ weight.view(self.num_bases, -1)).view( self.num_relations, self.in_channels_l, self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ===== if not torch.is_floating_point( x_r) and self.num_blocks is not None: raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') for i in range(self.num_relations): tmp = masked_edge_index(edge_index, edge_type == i) h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size) h = h.view(-1, weight.size(1), weight.size(2)) h = torch.einsum('abc,bcd->abd', h, weight[i]) out = out + h.contiguous().view(-1, self.out_channels) else: # No regularization/Basis-decomposition ======================== use_segment_matmul = torch_geometric.backend.use_segment_matmul # If `use_segment_matmul` is not specified, use a simple heuristic # to determine whether `segment_matmul` can speed up computation # given the observed input sizes: if use_segment_matmul is None: segment_count = scatter(torch.ones_like(edge_type), edge_type, dim_size=self.num_relations) self._use_segment_matmul_heuristic_output = ( torch_geometric.backend.use_segment_matmul_heuristic( num_segments=self.num_relations, max_segment_size=int(segment_count.max()), in_channels=self.weight.size(1), out_channels=self.weight.size(2), )) assert self._use_segment_matmul_heuristic_output is not None use_segment_matmul = self._use_segment_matmul_heuristic_output if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM and not is_compiling() and self.num_bases is None and x_l.is_floating_point() and isinstance(edge_index, Tensor)): if not self.is_sorted: if (edge_type[1:] < edge_type[:-1]).any(): edge_type, perm = index_sort( edge_type, max_value=self.num_relations) edge_index = edge_index[:, perm] edge_type_ptr = index2ptr(edge_type, self.num_relations) out = self.propagate(edge_index, x=x_l, edge_type_ptr=edge_type_ptr, size=size) else: for i in range(self.num_relations): tmp = masked_edge_index(edge_index, edge_type == i) if not torch.is_floating_point(x_r): out = out + self.propagate( tmp, x=weight[i, x_l], edge_type_ptr=None, size=size, ) else: h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size) out = out + (h @ weight[i]) root = self.root if root is not None: if not torch.is_floating_point(x_r): out = out + root[x_r] else: out = out + x_r @ root if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor: if (torch_geometric.typing.WITH_SEGMM and not is_compiling() and edge_type_ptr is not None): # TODO Re-weight according to edge type degree for `aggr=mean`. return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight) return x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: if isinstance(adj_t, SparseTensor): adj_t = adj_t.set_value(None) return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_relations={self.num_relations})') class FastRGCNConv(RGCNConv): r"""See :class:`RGCNConv`.""" def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], edge_index: Adj, edge_type: OptTensor = None): self.fuse = False assert self.aggr in ['add', 'sum', 'mean'] # Convert input features to a pair of node features or node indices. x_l: OptTensor = None if isinstance(x, tuple): x_l = x[0] else: x_l = x if x_l is None: x_l = torch.arange(self.in_channels_l, device=self.weight.device) x_r: Tensor = x_l if isinstance(x, tuple): x_r = x[1] size = (x_l.size(0), x_r.size(0)) # propagate_type: (x: Tensor, edge_type: OptTensor) out = self.propagate(edge_index, x=x_l, edge_type=edge_type, size=size) root = self.root if root is not None: if not torch.is_floating_point(x_r): out = out + root[x_r] else: out = out + x_r @ root if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_type: Tensor, edge_index_j: Tensor) -> Tensor: weight = self.weight if self.num_bases is not None: # Basis-decomposition ================= weight = (self.comp @ weight.view(self.num_bases, -1)).view( self.num_relations, self.in_channels_l, self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ======= if not torch.is_floating_point(x_j): raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') weight = weight[edge_type].view(-1, weight.size(2), weight.size(3)) x_j = x_j.view(-1, 1, weight.size(1)) return torch.bmm(x_j, weight).view(-1, self.out_channels) else: # No regularization/Basis-decomposition ======================== if not torch.is_floating_point(x_j): weight_index = edge_type * weight.size(1) + edge_index_j return weight.view(-1, self.out_channels)[weight_index] return torch.bmm(x_j.unsqueeze(-2), weight[edge_type]).squeeze(-2) def aggregate(self, inputs: Tensor, edge_type: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: # Compute normalization in separation for each `edge_type`. if self.aggr == 'mean': norm = one_hot(edge_type, self.num_relations, dtype=inputs.dtype) norm = scatter(norm, index, dim=0, dim_size=dim_size)[index] norm = torch.gather(norm, 1, edge_type.view(-1, 1)) norm = 1. / norm.clamp_(1.) inputs = norm * inputs return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size) ================================================ FILE: torch_geometric/nn/conv/sage_conv.py ================================================ from typing import List, Optional, Tuple, Union import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor from torch_geometric.utils import spmm class SAGEConv(MessagePassing): r"""The GraphSAGE operator from the `"Inductive Representation Learning on Large Graphs" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j If :obj:`project = True`, then :math:`\mathbf{x}_j` will first get projected via .. math:: \mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j + \mathbf{b}) as described in Eq. (3) of the paper. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. aggr (str or Aggregation, optional): The aggregation scheme to use. Any aggregation of :obj:`torch_geometric.nn.aggr` can be used, *e.g.*, :obj:`"mean"`, :obj:`"max"`, or :obj:`"lstm"`. (default: :obj:`"mean"`) normalize (bool, optional): If set to :obj:`True`, output features will be :math:`\ell_2`-normalized, *i.e.*, :math:`\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}`. (default: :obj:`False`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) project (bool, optional): If set to :obj:`True`, the layer will apply a linear transformation followed by an activation function before aggregation (as described in Eq. (3) of the paper). (default: :obj:`False`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **inputs:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V_t}|, F_{out})` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: Optional[Union[str, List[str], Aggregation]] = "mean", normalize: bool = False, root_weight: bool = True, project: bool = False, bias: bool = True, **kwargs, ): self.in_channels = in_channels self.out_channels = out_channels self.normalize = normalize self.root_weight = root_weight self.project = project if isinstance(in_channels, int): in_channels = (in_channels, in_channels) if aggr == 'lstm': kwargs.setdefault('aggr_kwargs', {}) kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0]) kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0]) super().__init__(aggr, **kwargs) if self.project: if in_channels[0] <= 0: raise ValueError(f"'{self.__class__.__name__}' does not " f"support lazy initialization with " f"`project=True`") self.lin = Linear(in_channels[0], in_channels[0], bias=True) if isinstance(self.aggr_module, MultiAggregation): aggr_out_channels = self.aggr_module.get_out_channels( in_channels[0]) else: aggr_out_channels = in_channels[0] self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias) if self.root_weight: self.lin_r = Linear(in_channels[1], out_channels, bias=False) self.reset_parameters() def reset_parameters(self): super().reset_parameters() if self.project: self.lin.reset_parameters() self.lin_l.reset_parameters() if self.root_weight: self.lin_r.reset_parameters() def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) if self.project and hasattr(self, 'lin'): x = (self.lin(x[0]).relu(), x[1]) # propagate_type: (x: OptPairTensor) out = self.propagate(edge_index, x=x, size=size) out = self.lin_l(out) x_r = x[1] if self.root_weight and x_r is not None: out = out + self.lin_r(x_r) if self.normalize: out = F.normalize(out, p=2., dim=-1) return out def message(self, x_j: Tensor) -> Tensor: return x_j def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: if isinstance(adj_t, SparseTensor): adj_t = adj_t.set_value(None, layout=None) return spmm(adj_t, x[0], reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, aggr={self.aggr})') ================================================ FILE: torch_geometric/nn/conv/sg_conv.py ================================================ from typing import Optional from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import spmm class SGConv(MessagePassing): r"""The simple graph convolutional operator from the `"Simplifying Graph Convolutional Networks" `_ paper. .. math:: \mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta}, where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. The adjacency matrix can include other values than :obj:`1` representing edge weights via the optional :obj:`edge_weight` tensor. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. K (int, optional): Number of hops :math:`K`. (default: :obj:`1`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X}` on first execution, and will use the cached version for further executions. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ _cached_x: Optional[Tensor] def __init__(self, in_channels: int, out_channels: int, K: int = 1, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.K = K self.cached = cached self.add_self_loops = add_self_loops self._cached_x = None self.lin = Linear(in_channels, out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() self._cached_x = None def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: cache = self._cached_x if cache is None: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) for _ in range(self.K): # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight) if self.cached: self._cached_x = x else: x = cache.detach() return self.lin(x) def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, K={self.K})') ================================================ FILE: torch_geometric/nn/conv/signed_conv.py ================================================ from typing import Union import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, PairTensor, SparseTensor from torch_geometric.utils import spmm class SignedConv(MessagePassing): r"""The signed graph convolutional operator from the `"Signed Graph Convolutional Network" `_ paper. .. math:: \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w , \mathbf{x}_v \right] \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w , \mathbf{x}_v \right] if :obj:`first_aggr` is set to :obj:`True`, and .. math:: \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})}, \mathbf{x}_v^{(\textrm{pos})} \right] \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})}, \mathbf{x}_v^{(\textrm{neg})} \right] otherwise. In case :obj:`first_aggr` is :obj:`False`, the layer expects :obj:`x` to be a tensor where :obj:`x[:, :in_channels]` denotes the positive node features :math:`\mathbf{X}^{(\textrm{pos})}` and :obj:`x[:, in_channels:]` denotes the negative node features :math:`\mathbf{X}^{(\textrm{neg})}`. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. first_aggr (bool): Denotes which aggregation formula to use. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` if bipartite, positive edge indices :math:`(2, |\mathcal{E}^{(+)}|)`, negative edge indices :math:`(2, |\mathcal{E}^{(-)}|)` - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V_t}|, F_{out})` if bipartite """ def __init__(self, in_channels: int, out_channels: int, first_aggr: bool, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'mean') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.first_aggr = first_aggr if first_aggr: self.lin_pos_l = Linear(in_channels, out_channels, False) self.lin_pos_r = Linear(in_channels, out_channels, bias) self.lin_neg_l = Linear(in_channels, out_channels, False) self.lin_neg_r = Linear(in_channels, out_channels, bias) else: self.lin_pos_l = Linear(2 * in_channels, out_channels, False) self.lin_pos_r = Linear(in_channels, out_channels, bias) self.lin_neg_l = Linear(2 * in_channels, out_channels, False) self.lin_neg_r = Linear(in_channels, out_channels, bias) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_pos_l.reset_parameters() self.lin_pos_r.reset_parameters() self.lin_neg_l.reset_parameters() self.lin_neg_r.reset_parameters() def forward( self, x: Union[Tensor, PairTensor], pos_edge_index: Adj, neg_edge_index: Adj, ): if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: PairTensor) if self.first_aggr: out_pos = self.propagate(pos_edge_index, x=x) out_pos = self.lin_pos_l(out_pos) out_pos = out_pos + self.lin_pos_r(x[1]) out_neg = self.propagate(neg_edge_index, x=x) out_neg = self.lin_neg_l(out_neg) out_neg = out_neg + self.lin_neg_r(x[1]) return torch.cat([out_pos, out_neg], dim=-1) else: F_in = self.in_channels out_pos1 = self.propagate(pos_edge_index, x=(x[0][..., :F_in], x[1][..., :F_in])) out_pos2 = self.propagate(neg_edge_index, x=(x[0][..., F_in:], x[1][..., F_in:])) out_pos = torch.cat([out_pos1, out_pos2], dim=-1) out_pos = self.lin_pos_l(out_pos) out_pos = out_pos + self.lin_pos_r(x[1][..., :F_in]) out_neg1 = self.propagate(pos_edge_index, x=(x[0][..., F_in:], x[1][..., F_in:])) out_neg2 = self.propagate(neg_edge_index, x=(x[0][..., :F_in], x[1][..., :F_in])) out_neg = torch.cat([out_neg1, out_neg2], dim=-1) out_neg = self.lin_neg_l(out_neg) out_neg = out_neg + self.lin_neg_r(x[1][..., F_in:]) return torch.cat([out_pos, out_neg], dim=-1) def message(self, x_j: Tensor) -> Tensor: return x_j def message_and_aggregate(self, adj_t: Adj, x: PairTensor) -> Tensor: if isinstance(adj_t, SparseTensor): adj_t = adj_t.set_value(None, layout=None) return spmm(adj_t, x[0], reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, first_aggr={self.first_aggr})') ================================================ FILE: torch_geometric/nn/conv/simple_conv.py ================================================ from typing import List, Optional, Union import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.typing import ( Adj, OptPairTensor, OptTensor, Size, SparseTensor, torch_sparse, ) from torch_geometric.utils import add_self_loops, spmm class SimpleConv(MessagePassing): r"""A simple message passing operator that performs (non-trainable) propagation. .. math:: \mathbf{x}^{\prime}_i = \bigoplus_{j \in \mathcal{N(i)}} e_{ji} \cdot \mathbf{x}_j where :math:`\bigoplus` defines a custom aggregation scheme. Args: aggr (str or [str] or Aggregation, optional): The aggregation scheme to use, *e.g.*, :obj:`"add"`, :obj:`"sum"` :obj:`"mean"`, :obj:`"min"`, :obj:`"max"` or :obj:`"mul"`. In addition, can be any :class:`~torch_geometric.nn.aggr.Aggregation` module (or any string that automatically resolves to it). (default: :obj:`"sum"`) combine_root (str, optional): Specifies whether or how to combine the central node representation (one of :obj:`"sum"`, :obj:`"cat"`, :obj:`"self_loop"`, :obj:`None`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **inputs:** node features :math:`(|\mathcal{V}|, F)` or :math:`((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, *))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **outputs:** node features :math:`(|\mathcal{V}|, F)` or :math:`(|\mathcal{V_t}|, F)` if bipartite """ def __init__( self, aggr: Optional[Union[str, List[str], Aggregation]] = "sum", combine_root: Optional[str] = None, **kwargs, ): if combine_root not in ['sum', 'cat', 'self_loop', None]: raise ValueError(f"Received invalid value for 'combine_root' " f"(got '{combine_root}')") super().__init__(aggr, **kwargs) self.combine_root = combine_root def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: OptTensor = None, size: Size = None) -> Tensor: if self.combine_root is not None: if self.combine_root == 'self_loop': if not isinstance(x, Tensor) or (size is not None and size[0] != size[1]): raise ValueError("Cannot use `combine_root='self_loop'` " "for bipartite message passing") if isinstance(edge_index, Tensor): edge_index, edge_weight = add_self_loops( edge_index, edge_weight, num_nodes=x.size(0)) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) x_dst = x[1] if x_dst is not None and self.combine_root is not None: if self.combine_root == 'sum': out = out + x_dst elif self.combine_root == 'cat': out = torch.cat([x_dst, out], dim=-1) return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: assert isinstance(self.aggr, str) return spmm(adj_t, x[0], reduce=self.aggr) ================================================ FILE: torch_geometric/nn/conv/spline_conv.py ================================================ import warnings from typing import List, Tuple, Union import torch from torch import Tensor, nn from torch.nn import Parameter import torch_geometric.typing from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import uniform, zeros from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size from torch_geometric.utils.repeat import repeat if torch_geometric.typing.WITH_SPLINE: from pyg_lib.ops import spline_basis, spline_weighting else: spline_basis = spline_weighting = None class SplineConv(MessagePassing): r"""The spline-based convolutional operator from the `"SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), where :math:`h_{\mathbf{\Theta}}` denotes a kernel function defined over the weighted B-Spline tensor product basis. .. note:: Pseudo-coordinates must lay in the fixed interval :math:`[0, 1]` for this method to work as intended. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. dim (int): Pseudo-coordinate dimensionality. kernel_size (int or [int]): Size of the convolving kernel. is_open_spline (bool or [bool], optional): If set to :obj:`False`, the operator will use a closed B-spline basis in this dimension. (default :obj:`True`) degree (int, optional): B-spline basis degrees. (default: :obj:`1`) aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"mean"`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: Union[int, List[int]], is_open_spline: bool = True, degree: int = 1, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs, ): super().__init__(aggr=aggr, **kwargs) if spline_basis is None: raise ImportError("'SplineConv' requires 'pyg-lib>=0.6.0'") self.in_channels = in_channels self.out_channels = out_channels self.dim = dim self.degree = degree self.root_weight = root_weight kernel_size = torch.tensor(repeat(kernel_size, dim), dtype=torch.long) self.register_buffer('kernel_size', kernel_size) is_open_spline = repeat(is_open_spline, dim) is_open_spline = torch.tensor(is_open_spline, dtype=torch.uint8) self.register_buffer('is_open_spline', is_open_spline) if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.K = kernel_size.prod().item() if in_channels[0] > 0: self.weight = Parameter( torch.empty(self.K, in_channels[0], out_channels)) else: self.weight = torch.nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) if root_weight: self.lin = Linear(in_channels[1], out_channels, bias=False, weight_initializer='uniform') if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() if not isinstance(self.weight, nn.UninitializedParameter): size = self.weight.size(0) * self.weight.size(1) uniform(size, self.weight) if self.root_weight: self.lin.reset_parameters() zeros(self.bias) def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None) -> Tensor: if isinstance(x, Tensor): x = (x, x) if not x[0].is_cuda: warnings.warn( 'We do not recommend using the non-optimized CPU version of ' '`SplineConv`. If possible, please move your data to GPU.', stacklevel=2) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None and self.root_weight: out = out + self.lin(x_r) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: data = spline_basis(edge_attr, self.kernel_size, self.is_open_spline, self.degree) return spline_weighting(x_j, self.weight, *data) @torch.no_grad() def initialize_parameters(self, module, input): if isinstance(self.weight, torch.nn.parameter.UninitializedParameter): x = input[0][0] if isinstance(input, tuple) else input[0] in_channels = x.size(-1) self.weight.materialize((self.K, in_channels, self.out_channels)) size = self.weight.size(0) * self.weight.size(1) uniform(size, self.weight) module._hook.remove() delattr(module, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, dim={self.dim})') ================================================ FILE: torch_geometric/nn/conv/ssg_conv.py ================================================ from typing import Optional from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import spmm class SSGConv(MessagePassing): r"""The simple spectral graph convolutional operator from the `"Simple Spectral Graph Convolution" `_ paper. .. math:: \mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K\left((1-\alpha) {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+\alpha \mathbf{X}\right) \mathbf{\Theta}, where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. The adjacency matrix can include other values than :obj:`1` representing edge weights via the optional :obj:`edge_weight` tensor. :class:`~torch_geometric.nn.conv.SSGConv` is an improved operator of :class:`~torch_geometric.nn.conv.SGConv` by introducing the :obj:`alpha` parameter to address the oversmoothing issue. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. alpha (float): Teleport probability :math:`\alpha \in [0, 1]`. K (int, optional): Number of hops :math:`K`. (default: :obj:`1`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of :math:`\frac{1}{K} \sum_{k=1}^K\left((1-\alpha) {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+ \alpha \mathbf{X}\right)` on first execution, and will use the cached version for further executions. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ _cached_h: Optional[Tensor] def __init__(self, in_channels: int, out_channels: int, alpha: float, K: int = 1, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.alpha = alpha self.K = K self.cached = cached self.add_self_loops = add_self_loops self._cached_h = None self.lin = Linear(in_channels, out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() self._cached_h = None def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: cache = self._cached_h if cache is None: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) h = x * self.alpha for _ in range(self.K): # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight) h = h + (1 - self.alpha) / self.K * x if self.cached: self._cached_h = h else: h = cache.detach() return self.lin(h) def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, K={self.K}, alpha={self.alpha})') ================================================ FILE: torch_geometric/nn/conv/supergat_conv.py ================================================ import math from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse from torch_geometric.utils import ( add_self_loops, batched_negative_sampling, dropout_edge, is_undirected, negative_sampling, remove_self_loops, softmax, to_undirected, ) class SuperGATConv(MessagePassing): r"""The self-supervised graph attentional operator from the `"How to Find Your Friendly Neighborhood: Graph Attention Design with Self-Supervision" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, where the two types of attention :math:`\alpha_{i,j}^{\mathrm{MX\ or\ SD}}` are computed as: .. math:: \alpha_{i,j}^{\mathrm{MX\ or\ SD}} &= \frac{ \exp\left(\mathrm{LeakyReLU}\left( e_{i,j}^{\mathrm{MX\ or\ SD}} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( e_{i,k}^{\mathrm{MX\ or\ SD}} \right)\right)} e_{i,j}^{\mathrm{MX}} &= \mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \cdot \sigma \left( \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j \right) e_{i,j}^{\mathrm{SD}} &= \frac{ \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j }{ \sqrt{d} } The self-supervised task is a link prediction using the attention values as input to predict the likelihood :math:`\phi_{i,j}^{\mathrm{MX\ or\ SD}}` that an edge exists between nodes: .. math:: \phi_{i,j}^{\mathrm{MX}} &= \sigma \left( \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j \right) \phi_{i,j}^{\mathrm{SD}} &= \sigma \left( \frac{ \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j }{ \sqrt{d} } \right) .. note:: For an example of using SuperGAT, see `examples/super_gat.py `_. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) attention_type (str, optional): Type of attention to use (:obj:`'MX'`, :obj:`'SD'`). (default: :obj:`'MX'`) neg_sample_ratio (float, optional): The ratio of the number of sampled negative edges to the number of positive edges. (default: :obj:`0.5`) edge_sample_ratio (float, optional): The ratio of samples to use for training among the number of training edges. (default: :obj:`1.0`) is_undirected (bool, optional): Whether the input graph is undirected. If not given, will be automatically computed with the input graph when negative sampling is performed. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, negative edge indices :math:`(2, |\mathcal{E}^{(-)}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` """ att_x: OptTensor att_y: OptTensor def __init__(self, in_channels: int, out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, bias: bool = True, attention_type: str = 'MX', neg_sample_ratio: float = 0.5, edge_sample_ratio: float = 1.0, is_undirected: bool = False, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.add_self_loops = add_self_loops self.attention_type = attention_type self.neg_sample_ratio = neg_sample_ratio self.edge_sample_ratio = edge_sample_ratio self.is_undirected = is_undirected assert attention_type in ['MX', 'SD'] assert 0.0 < neg_sample_ratio and 0.0 < edge_sample_ratio <= 1.0 self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') if self.attention_type == 'MX': self.att_l = Parameter(torch.empty(1, heads, out_channels)) self.att_r = Parameter(torch.empty(1, heads, out_channels)) else: # self.attention_type == 'SD' self.register_parameter('att_l', None) self.register_parameter('att_r', None) self.att_x = self.att_y = None # x/y for self-supervision if bias and concat: self.bias = Parameter(torch.empty(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() glorot(self.att_l) glorot(self.att_r) zeros(self.bias) def forward( self, x: Tensor, edge_index: Adj, neg_edge_index: OptTensor = None, batch: OptTensor = None, ) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. neg_edge_index (torch.Tensor, optional): The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. Used when sampling negatives on-the-fly in mini-batch scenarios. (default: :obj:`None`) """ N, H, C = x.size(0), self.heads, self.out_channels if self.add_self_loops: if isinstance(edge_index, SparseTensor): edge_index = torch_sparse.fill_diag(edge_index, 1.) else: edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=N) x = self.lin(x).view(-1, H, C) # propagate_type: (x: Tensor) out = self.propagate(edge_index, x=x) if self.training: if isinstance(edge_index, SparseTensor): col, row, _ = edge_index.coo() edge_index = torch.stack([row, col], dim=0) pos_edge_index = self.positive_sampling(edge_index) pos_att = self.get_attention( edge_index_i=pos_edge_index[1], x_i=x[pos_edge_index[1]], x_j=x[pos_edge_index[0]], num_nodes=x.size(0), return_logits=True, ) if neg_edge_index is None: neg_edge_index = self.negative_sampling(edge_index, N, batch) neg_att = self.get_attention( edge_index_i=neg_edge_index[1], x_i=x[neg_edge_index[1]], x_j=x[neg_edge_index[0]], num_nodes=x.size(0), return_logits=True, ) self.att_x = torch.cat([pos_att, neg_att], dim=0) self.att_y = self.att_x.new_zeros(self.att_x.size(0)) self.att_y[:pos_edge_index.size(1)] = 1. if self.concat is True: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias return out def message(self, edge_index_i: Tensor, x_i: Tensor, x_j: Tensor, size_i: Optional[int]) -> Tensor: alpha = self.get_attention(edge_index_i, x_i, x_j, num_nodes=size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * alpha.view(-1, self.heads, 1) def negative_sampling(self, edge_index: Tensor, num_nodes: int, batch: OptTensor = None) -> Tensor: num_neg_samples = int(self.neg_sample_ratio * self.edge_sample_ratio * edge_index.size(1)) if not self.is_undirected and not is_undirected( edge_index, num_nodes=num_nodes): edge_index = to_undirected(edge_index, num_nodes=num_nodes) if batch is None: neg_edge_index = negative_sampling(edge_index, num_nodes, num_neg_samples=num_neg_samples) else: neg_edge_index = batched_negative_sampling( edge_index, batch, num_neg_samples=num_neg_samples) return neg_edge_index def positive_sampling(self, edge_index: Tensor) -> Tensor: pos_edge_index, _ = dropout_edge(edge_index, p=1. - self.edge_sample_ratio, training=self.training) return pos_edge_index def get_attention(self, edge_index_i: Tensor, x_i: Tensor, x_j: Tensor, num_nodes: Optional[int], return_logits: bool = False) -> Tensor: if self.attention_type == 'MX': logits = (x_i * x_j).sum(dim=-1) if return_logits: return logits alpha = (x_j * self.att_l).sum(-1) + (x_i * self.att_r).sum(-1) alpha = alpha * logits.sigmoid() else: # self.attention_type == 'SD' alpha = (x_i * x_j).sum(dim=-1) / math.sqrt(self.out_channels) if return_logits: return alpha alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, edge_index_i, num_nodes=num_nodes) return alpha def get_attention_loss(self) -> Tensor: r"""Computes the self-supervised graph attention loss.""" if not self.training: return torch.tensor([0], device=self.lin.weight.device) return F.binary_cross_entropy_with_logits( self.att_x.mean(dim=-1), self.att_y, ) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads}, ' f'type={self.attention_type})') ================================================ FILE: torch_geometric/nn/conv/tag_conv.py ================================================ import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import spmm class TAGConv(MessagePassing): r"""The topology adaptive graph convolutional networks operator from the `"Topology Adaptive Graph Convolutional Networks" `_ paper. .. math:: \mathbf{X}^{\prime} = \sum_{k=0}^K \left( \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \right)^k \mathbf{X} \mathbf{W}_{k}, where :math:`\mathbf{A}` denotes the adjacency matrix and :math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix. The adjacency matrix can include other values than :obj:`1` representing edge weights via the optional :obj:`edge_weight` tensor. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. K (int, optional): Number of hops :math:`K`. (default: :obj:`3`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) normalize (bool, optional): Whether to apply symmetric normalization. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node_features :math:`(|\mathcal{V}|, F_{in})`, edge_index :math:`(2, |\mathcal{E}|)`, edge_weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, K: int = 3, bias: bool = True, normalize: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.K = K self.normalize = normalize self.lins = torch.nn.ModuleList([ Linear(in_channels, out_channels, bias=False) for _ in range(K + 1) ]) if bias: self.bias = torch.nn.Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() for lin in self.lins: lin.reset_parameters() zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), improved=False, add_self_loops=False, flow=self.flow, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, flow=self.flow, dtype=x.dtype) out = self.lins[0](x) for lin in self.lins[1:]: # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight) out = out + lin.forward(x) if self.bias is not None: out = out + self.bias return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, K={self.K})') ================================================ FILE: torch_geometric/nn/conv/transformer_conv.py ================================================ import math import typing from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import ( Adj, NoneType, OptTensor, PairTensor, SparseTensor, ) from torch_geometric.utils import softmax if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload_method as overload class TransformerConv(MessagePassing): r"""The graph transformer operator from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed via multi-head dot product attention: .. math:: \alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right) Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) beta (bool, optional): If set, will combine aggregation and skip information via .. math:: \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i} with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is: .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right), where the attention coefficients :math:`\alpha_{i,j}` are now computed via: .. math:: \alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right) (default :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add the transformed root node features to the output and the option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ _alpha: OptTensor def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, beta: bool = False, dropout: float = 0., edge_dim: Optional[int] = None, bias: bool = True, root_weight: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.beta = beta and root_weight self.root_weight = root_weight self.concat = concat self.dropout = dropout self.edge_dim = edge_dim self._alpha = None if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin_key = Linear(in_channels[0], heads * out_channels, bias=bias) self.lin_query = Linear(in_channels[1], heads * out_channels, bias=bias) self.lin_value = Linear(in_channels[0], heads * out_channels, bias=bias) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False) else: self.lin_edge = self.register_parameter('lin_edge', None) if concat: self.lin_skip = Linear(in_channels[1], heads * out_channels, bias=bias) if self.beta: self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter('lin_beta', None) else: self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) if self.beta: self.lin_beta = Linear(3 * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter('lin_beta', None) self.reset_parameters() def reset_parameters(self): super().reset_parameters() self.lin_key.reset_parameters() self.lin_query.reset_parameters() self.lin_value.reset_parameters() if self.edge_dim: self.lin_edge.reset_parameters() self.lin_skip.reset_parameters() if self.beta: self.lin_beta.reset_parameters() @overload def forward( self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: NoneType = None, ) -> Tensor: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: Tensor, edge_attr: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: SparseTensor, edge_attr: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, SparseTensor]: pass def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: Optional[bool] = None, ) -> Union[ Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, SparseTensor], ]: r"""Runs the forward pass of the module. Args: x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) return_attention_weights (bool, optional): Will additionally return the tuple :obj:`(edge_index, attention_weights)` whenever it is set to a value, regardless of its actual value (might be `True` or `False`), holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels if isinstance(x, Tensor): x = (x, x) query = self.lin_query(x[1]).view(-1, H, C) key = self.lin_key(x[0]).view(-1, H, C) value = self.lin_value(x[0]).view(-1, H, C) # propagate_type: (query: Tensor, key:Tensor, value: Tensor, # edge_attr: OptTensor) out = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.root_weight: x_r = self.lin_skip(x[1]) if self.lin_beta is not None: beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) beta = beta.sigmoid() out = beta * x_r + (1 - beta) * out else: out = out + x_r if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: if self.lin_edge is not None: assert edge_attr is not None edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels) key_j = key_j + edge_attr alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = value_j if edge_attr is not None: out = out + edge_attr out = out * alpha.view(-1, self.heads, 1) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') ================================================ FILE: torch_geometric/nn/conv/utils/__init__.py ================================================ r"""GNN utility package.""" from .cheatsheet import paper_title, paper_link from .cheatsheet import supports_sparse_tensor from .cheatsheet import supports_edge_weights from .cheatsheet import supports_edge_features from .cheatsheet import supports_bipartite_graphs from .cheatsheet import supports_static_graphs from .cheatsheet import supports_lazy_initialization from .cheatsheet import processes_heterogeneous_graphs from .cheatsheet import processes_hypergraphs from .cheatsheet import processes_point_clouds __all__ = [ 'paper_title', 'paper_link', 'supports_sparse_tensor', 'supports_edge_weights', 'supports_edge_features', 'supports_bipartite_graphs', 'supports_static_graphs', 'supports_lazy_initialization', 'processes_heterogeneous_graphs', 'processes_hypergraphs', 'processes_point_clouds', ] ================================================ FILE: torch_geometric/nn/conv/utils/cheatsheet.py ================================================ import importlib import inspect import re from typing import Optional def paper_title(cls: str) -> Optional[str]: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] match = re.search('`\".+?\"', inspect.getdoc(cls), flags=re.DOTALL) return None if match is None else match.group().replace('\n', ' ')[2:-1] def paper_link(cls: str) -> Optional[str]: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] match = re.search('<.+?>', inspect.getdoc(cls), flags=re.DOTALL) return None if match is None else match.group().replace('\n', ' ')[1:-1] def supports_sparse_tensor(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return 'SparseTensor' in str(signature) def supports_edge_weights(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return 'edge_weight' in str(signature) def supports_edge_features(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return 'edge_attr' in str(signature) def supports_bipartite_graphs(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return 'Union[torch.Tensor, Tuple[torch.Tensor' in str(signature) def supports_static_graphs(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] return 'node_dim=' not in inspect.getsource(cls.__init__) def supports_lazy_initialization(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] doc = re.sub(' +', ' ', inspect.getdoc(cls).replace('\n', ' ')) match = re.search('or :obj:`-1` to derive the size from the first', doc) return match is not None def processes_heterogeneous_graphs(cls: str) -> bool: if 'hetero' in cls.lower(): return True cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return 'edge_index_dict' in str(signature) or 'edge_type' in str(signature) def processes_hypergraphs(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return 'hyperedge_index' in str(signature) def processes_point_clouds(cls: str) -> bool: cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return (('edge_index' not in str(signature) and 'csc' not in str(signature)) or 'pos' in str(signature)) ================================================ FILE: torch_geometric/nn/conv/wl_conv.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.typing import Adj from torch_geometric.utils import ( degree, is_sparse, scatter, sort_edge_index, to_edge_index, ) class WLConv(torch.nn.Module): r"""The Weisfeiler Lehman (WL) operator from the `"A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction" `_ paper. :class:`WLConv` iteratively refines node colorings according to: .. math:: \mathbf{x}^{\prime}_i = \textrm{hash} \left( \mathbf{x}_i, \{ \mathbf{x}_j \colon j \in \mathcal{N}(i) \} \right) Shapes: - **input:** node coloring :math:`(|\mathcal{V}|, F_{in})` *(one-hot encodings)* or :math:`(|\mathcal{V}|)` *(integer-based)*, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node coloring :math:`(|\mathcal{V}|)` *(integer-based)* """ def __init__(self): super().__init__() self.hashmap = {} def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.hashmap = {} @torch.no_grad() def forward(self, x: Tensor, edge_index: Adj) -> Tensor: r"""Runs the forward pass of the module.""" if x.dim() > 1: assert (x.sum(dim=-1) == 1).sum() == x.size(0) x = x.argmax(dim=-1) # one-hot -> integer. assert x.dtype == torch.long if is_sparse(edge_index): col_and_row, _ = to_edge_index(edge_index) col = col_and_row[0] row = col_and_row[1] else: edge_index = sort_edge_index(edge_index, num_nodes=x.size(0), sort_by_row=False) row, col = edge_index[0], edge_index[1] # `col` is sorted, so we can use it to `split` neighbors to groups: deg = degree(col, x.size(0), dtype=torch.long).tolist() out = [] for node, neighbors in zip(x.tolist(), x[row].split(deg)): idx = hash(tuple([node] + neighbors.sort()[0].tolist())) if idx not in self.hashmap: self.hashmap[idx] = len(self.hashmap) out.append(self.hashmap[idx]) return torch.tensor(out, device=x.device) def histogram(self, x: Tensor, batch: Optional[Tensor] = None, norm: bool = False) -> Tensor: r"""Given a node coloring :obj:`x`, computes the color histograms of the respective graphs (separated by :obj:`batch`). """ if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) num_colors = len(self.hashmap) batch_size = int(batch.max()) + 1 index = batch * num_colors + x out = scatter(torch.ones_like(index), index, dim=0, dim_size=num_colors * batch_size, reduce='sum') out = out.view(batch_size, num_colors) if norm: out = out.to(torch.float) out /= out.norm(dim=-1, keepdim=True) return out ================================================ FILE: torch_geometric/nn/conv/wl_conv_continuous.py ================================================ from typing import Union from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.typing import ( Adj, OptPairTensor, OptTensor, Size, SparseTensor, ) from torch_geometric.utils import scatter, spmm class WLConvContinuous(MessagePassing): r"""The Weisfeiler Lehman operator from the `"Wasserstein Weisfeiler-Lehman Graph Kernels" `_ paper. Refinement is done though a degree-scaled mean aggregation and works on nodes with continuous attributes: .. math:: \mathbf{x}^{\prime}_i = \frac{1}{2}\big(\mathbf{x}_i + \frac{1}{\textrm{deg}(i)} \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j \big) where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target node :obj:`i` (default: :obj:`1`) Args: **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)` or :math:`((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, F))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F)` or :math:`(|\mathcal{V}_t|, F)` if bipartite """ def __init__(self, **kwargs): super().__init__(aggr='add', **kwargs) def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: OptTensor = None, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) if isinstance(edge_index, SparseTensor): assert edge_weight is None dst_index, _, edge_weight = edge_index.coo() else: dst_index = edge_index[1] if edge_weight is None: edge_weight = x[0].new_ones(dst_index.numel()) deg = scatter(edge_weight, dst_index, 0, out.size(0), reduce='sum') deg_inv = 1. / deg deg_inv.masked_fill_(deg_inv == float('inf'), 0) out = deg_inv.view(-1, 1) * out x_dst = x[1] if x_dst is not None: out = 0.5 * (x_dst + out) return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: return spmm(adj_t, x[0], reduce=self.aggr) ================================================ FILE: torch_geometric/nn/conv/x_conv.py ================================================ from math import ceil from typing import Optional import torch from torch import Tensor from torch.nn import ELU from torch.nn import BatchNorm1d as BN from torch.nn import Conv1d from torch.nn import Linear as L from torch.nn import Sequential as S import torch_geometric.typing from torch_geometric.nn import Reshape from torch_geometric.nn.inits import reset if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import knn_graph else: knn_graph = None class XConv(torch.nn.Module): r"""The convolutional operator on :math:`\mathcal{X}`-transformed points from the `"PointCNN: Convolution On X-Transformed Points" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathrm{Conv}\left(\mathbf{K}, \gamma_{\mathbf{\Theta}}(\mathbf{P}_i - \mathbf{p}_i) \times \left( h_\mathbf{\Theta}(\mathbf{P}_i - \mathbf{p}_i) \, \Vert \, \mathbf{x}_i \right) \right), where :math:`\mathbf{K}` and :math:`\mathbf{P}_i` denote the trainable filter and neighboring point positions of :math:`\mathbf{x}_i`, respectively. :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` describe neural networks, *i.e.* MLPs, where :math:`h_{\mathbf{\Theta}}` individually lifts each point into a higher-dimensional space, and :math:`\gamma_{\mathbf{\Theta}}` computes the :math:`\mathcal{X}`- transformation matrix based on *all* points in a neighborhood. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. dim (int): Point cloud dimensionality. kernel_size (int): Size of the convolving kernel, *i.e.* number of neighbors including self-loops. hidden_channels (int, optional): Output size of :math:`h_{\mathbf{\Theta}}`, *i.e.* dimensionality of lifted points. If set to :obj:`None`, will be automatically set to :obj:`in_channels / 4`. (default: :obj:`None`) dilation (int, optional): The factor by which the neighborhood is extended, from which :obj:`kernel_size` neighbors are then uniformly sampled. Can be interpreted as the dilation rate of classical convolutional operators. (default: :obj:`1`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) num_workers (int): Number of workers to use for k-NN computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, positions :math:`(|\mathcal{V}|, D)`, batch vector :math:`(|\mathcal{V}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, dim: int, kernel_size: int, hidden_channels: Optional[int] = None, dilation: int = 1, bias: bool = True, num_workers: int = 1): super().__init__() if knn_graph is None: raise ImportError('`XConv` requires `torch-cluster`.') self.in_channels = in_channels if hidden_channels is None: hidden_channels = in_channels // 4 assert hidden_channels > 0 self.hidden_channels = hidden_channels self.out_channels = out_channels self.dim = dim self.kernel_size = kernel_size self.dilation = dilation self.num_workers = num_workers C_in, C_delta, C_out = in_channels, hidden_channels, out_channels D, K = dim, kernel_size self.mlp1 = S( L(dim, C_delta), ELU(), BN(C_delta), L(C_delta, C_delta), ELU(), BN(C_delta), Reshape(-1, K, C_delta), ) self.mlp2 = S( L(D * K, K**2), ELU(), BN(K**2), Reshape(-1, K, K), Conv1d(K, K**2, K, groups=K), ELU(), BN(K**2), Reshape(-1, K, K), Conv1d(K, K**2, K, groups=K), BN(K**2), Reshape(-1, K, K), ) C_in = C_in + C_delta depth_multiplier = int(ceil(C_out / C_in)) self.conv = S( Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in), Reshape(-1, C_in * depth_multiplier), L(C_in * depth_multiplier, C_out, bias=bias), ) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.mlp1) reset(self.mlp2) reset(self.conv) def forward(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None): r"""Runs the forward pass of the module.""" pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos (N, D), K = pos.size(), self.kernel_size edge_index = knn_graph(pos, K * self.dilation, batch, loop=True, flow='target_to_source', num_workers=self.num_workers) if self.dilation > 1: edge_index = edge_index[:, ::self.dilation] row, col = edge_index[0], edge_index[1] pos = pos[col] - pos[row] x_star = self.mlp1(pos) if x is not None: x = x.unsqueeze(-1) if x.dim() == 1 else x x = x[col].view(N, K, self.in_channels) x_star = torch.cat([x_star, x], dim=-1) x_star = x_star.transpose(1, 2).contiguous() transform_matrix = self.mlp2(pos.view(N, K * D)) x_transformed = torch.matmul(x_star, transform_matrix) out = self.conv(x_transformed) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/data_parallel.py ================================================ import logging import warnings from itertools import chain import torch from torch_geometric.data import Batch from torch_geometric.utils import cumsum class DataParallel(torch.nn.DataParallel): r"""Implements data parallelism at the module level. This container parallelizes the application of the given :attr:`module` by splitting a list of :class:`torch_geometric.data.Data` objects and copying them as :class:`torch_geometric.data.Batch` objects to each device. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. The batch size should be larger than the number of GPUs used. The parallelized :attr:`module` must have its parameters and buffers on :obj:`device_ids[0]`. .. note:: You need to use the :class:`torch_geometric.loader.DataListLoader` for this module. .. warning:: It is recommended to use :class:`torch.nn.parallel.DistributedDataParallel` instead of :class:`DataParallel` for multi-GPU training. :class:`DataParallel` is usually much slower than :class:`~torch.nn.parallel.DistributedDataParallel` even on a single machine. Take a look `here `_ for an example on how to use :pyg:`PyG` in combination with :class:`~torch.nn.parallel.DistributedDataParallel`. Args: module (Module): Module to be parallelized. device_ids (list of int or torch.device): CUDA devices. (default: all devices) output_device (int or torch.device): Device location of output. (default: :obj:`device_ids[0]`) follow_batch (list or tuple, optional): Creates assignment batch vectors for each key in the list. (default: :obj:`None`) exclude_keys (list or tuple, optional): Will exclude each key in the list. (default: :obj:`None`) """ def __init__(self, module, device_ids=None, output_device=None, follow_batch=None, exclude_keys=None): super().__init__(module, device_ids, output_device) warnings.warn( "'DataParallel' is usually much slower than " "'DistributedDataParallel' even on a single machine. " "Please consider switching to 'DistributedDataParallel' " "for multi-GPU training.", stacklevel=2) self.src_device = torch.device(f'cuda:{self.device_ids[0]}') self.follow_batch = follow_batch or [] self.exclude_keys = exclude_keys or [] def forward(self, data_list): """""" # noqa: D419 if len(data_list) == 0: logging.warning('DataParallel received an empty data list, which ' 'may result in unexpected behavior.') return None if not self.device_ids or len(self.device_ids) == 1: # Fallback data = Batch.from_data_list( data_list, follow_batch=self.follow_batch, exclude_keys=self.exclude_keys).to(self.src_device) return self.module(data) for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device: raise RuntimeError( f"Module must have its parameters and buffers on device " f"'{self.src_device}' but found one of them on device " f"'{t.device}'") inputs = self.scatter(data_list, self.device_ids) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, None) return self.gather(outputs, self.output_device) def scatter(self, data_list, device_ids): num_devices = min(len(device_ids), len(data_list)) count = torch.tensor([data.num_nodes for data in data_list]) ptr = cumsum(count) device_id = num_devices * ptr.to(torch.float) / ptr[-1].item() device_id = (device_id[:-1] + device_id[1:]) / 2.0 device_id = device_id.to(torch.long) # round. split = cumsum(device_id.bincount()) split = torch.unique(split, sorted=True) split = split.tolist() return [ Batch.from_data_list(data_list[split[i]:split[i + 1]], follow_batch=self.follow_batch, exclude_keys=self.exclude_keys).to( torch.device(f'cuda:{device_ids[i]}')) for i in range(len(split) - 1) ] ================================================ FILE: torch_geometric/nn/dense/__init__.py ================================================ r"""Dense neural network module package. This package provides modules applicable for operating on dense tensor representations. """ from .linear import Linear, HeteroLinear, HeteroDictLinear from .dense_gat_conv import DenseGATConv from .dense_sage_conv import DenseSAGEConv from .dense_gcn_conv import DenseGCNConv from .dense_graph_conv import DenseGraphConv from .dense_gin_conv import DenseGINConv from .diff_pool import dense_diff_pool from .mincut_pool import dense_mincut_pool from .dmon_pool import DMoNPooling __all__ = [ 'Linear', 'HeteroLinear', 'HeteroDictLinear', 'DenseGCNConv', 'DenseGINConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseGATConv', 'dense_diff_pool', 'dense_mincut_pool', 'DMoNPooling', ] lin_classes = __all__[:3] conv_classes = __all__[3:8] pool_classes = __all__[8:] ================================================ FILE: torch_geometric/nn/dense/dense_gat_conv.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros class DenseGATConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.GATConv`.""" def __init__( self, in_channels: int, out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, bias: bool = True, ): # TODO Add support for edge features. super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') # The learnable parameters to compute attention coefficients: self.att_src = Parameter(torch.empty(1, 1, heads, out_channels)) self.att_dst = Parameter(torch.empty(1, 1, heads, out_channels)) if bias and concat: self.bias = Parameter(torch.empty(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): self.lin.reset_parameters() glorot(self.att_src) glorot(self.att_dst) zeros(self.bias) def forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None, add_loop: bool = True): r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) add_loop (bool, optional): If set to :obj:`False`, the layer will not automatically add self-loops to the adjacency matrices. (default: :obj:`True`) """ x = x.unsqueeze(0) if x.dim() == 2 else x # [B, N, F] adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # [B, N, N] H, C = self.heads, self.out_channels B, N, _ = x.size() if add_loop: adj = adj.clone() idx = torch.arange(N, dtype=torch.long, device=adj.device) adj[:, idx, idx] = 1.0 x = self.lin(x).view(B, N, H, C) # [B, N, H, C] alpha_src = torch.sum(x * self.att_src, dim=-1) # [B, N, H] alpha_dst = torch.sum(x * self.att_dst, dim=-1) # [B, N, H] alpha = alpha_src.unsqueeze(1) + alpha_dst.unsqueeze(2) # [B, N, N, H] # Weighted and masked softmax: alpha = F.leaky_relu(alpha, self.negative_slope) alpha = alpha.masked_fill(adj.unsqueeze(-1) == 0, float('-inf')) alpha = alpha.softmax(dim=2) alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = torch.matmul(alpha.movedim(3, 1), x.movedim(2, 1)) out = out.movedim(1, 2) # [B,N,H,C] if self.concat: out = out.reshape(B, N, H * C) else: out = out.mean(dim=2) if self.bias is not None: out = out + self.bias if mask is not None: out = out * mask.view(-1, N, 1).to(x.dtype) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') ================================================ FILE: torch_geometric/nn/dense/dense_gcn_conv.py ================================================ import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import OptTensor class DenseGCNConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.GCNConv`.""" def __init__( self, in_channels: int, out_channels: int, improved: bool = False, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() zeros(self.bias) def forward(self, x: Tensor, adj: Tensor, mask: OptTensor = None, add_loop: bool = True) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) add_loop (bool, optional): If set to :obj:`False`, the layer will not automatically add self-loops to the adjacency matrices. (default: :obj:`True`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj B, N, _ = adj.size() if add_loop: adj = adj.clone() idx = torch.arange(N, dtype=torch.long, device=adj.device) adj[:, idx, idx] = 1 if not self.improved else 2 out = self.lin(x) deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5) adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2) out = torch.matmul(adj, out) if self.bias is not None: out = out + self.bias if mask is not None: out = out * mask.view(B, N, 1).to(x.dtype) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/dense/dense_gin_conv.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Module from torch_geometric.nn.inits import reset class DenseGINConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.GINConv`.""" def __init__( self, nn: Module, eps: float = 0.0, train_eps: bool = False, ): super().__init__() self.nn = nn self.initial_eps = eps if train_eps: self.eps = torch.nn.Parameter(torch.empty(1)) else: self.register_buffer('eps', torch.empty(1)) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.nn) self.eps.data.fill_(self.initial_eps) def forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None, add_loop: bool = True) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) add_loop (bool, optional): If set to :obj:`False`, the layer will not automatically add self-loops to the adjacency matrices. (default: :obj:`True`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj B, N, _ = adj.size() out = torch.matmul(adj, x) if add_loop: out = (1 + self.eps) * x + out out = self.nn(out) if mask is not None: out = out * mask.view(B, N, 1).to(x.dtype) return out def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn})' ================================================ FILE: torch_geometric/nn/dense/dense_graph_conv.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Linear class DenseGraphConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.GraphConv`.""" def __init__( self, in_channels: int, out_channels: int, aggr: str = 'add', bias: bool = True, ): assert aggr in ['add', 'mean', 'max'] super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.aggr = aggr self.lin_rel = Linear(in_channels, out_channels, bias=bias) self.lin_root = Linear(in_channels, out_channels, bias=False) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin_rel.reset_parameters() self.lin_root.reset_parameters() def forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj B, N, C = x.size() if self.aggr == 'add': out = torch.matmul(adj, x) elif self.aggr == 'mean': out = torch.matmul(adj, x) out = out / adj.sum(dim=-1, keepdim=True).clamp_(min=1) elif self.aggr == 'max': out = x.unsqueeze(-2).repeat(1, 1, N, 1) adj = adj.unsqueeze(-1).expand(B, N, N, C) out[adj == 0] = float('-inf') out = out.max(dim=-3)[0] out[out == float('-inf')] = 0. else: raise NotImplementedError out = self.lin_rel(out) out = out + self.lin_root(x) if mask is not None: out = out * mask.view(-1, N, 1).to(x.dtype) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/dense/dense_sage_conv.py ================================================ import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Linear from torch_geometric.typing import OptTensor class DenseSAGEConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.SAGEConv`. .. note:: :class:`~torch_geometric.nn.dense.DenseSAGEConv` expects to work on binary adjacency matrices. If you want to make use of weighted dense adjacency matrices, please use :class:`torch_geometric.nn.dense.DenseGraphConv` instead. """ def __init__( self, in_channels: int, out_channels: int, normalize: bool = False, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.normalize = normalize self.lin_rel = Linear(in_channels, out_channels, bias=False) self.lin_root = Linear(in_channels, out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin_rel.reset_parameters() self.lin_root.reset_parameters() def forward(self, x: Tensor, adj: Tensor, mask: OptTensor = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj B, N, _ = adj.size() out = torch.matmul(adj, x) out = out / adj.sum(dim=-1, keepdim=True).clamp(min=1) out = self.lin_rel(out) + self.lin_root(x) if self.normalize: out = F.normalize(out, p=2.0, dim=-1) if mask is not None: out = out * mask.view(B, N, 1).to(x.dtype) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})') ================================================ FILE: torch_geometric/nn/dense/diff_pool.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor def dense_diff_pool( x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None, normalize: bool = True, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""The differentiable pooling operator from the `"Hierarchical Graph Representation Learning with Differentiable Pooling" `_ paper. .. math:: \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X} \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`. Returns the pooled node feature matrix, the coarsened adjacency matrix and two auxiliary objectives: (1) The link prediction loss .. math:: \mathcal{L}_{LP} = {\| \mathbf{A} - \mathrm{softmax}(\mathbf{S}) {\mathrm{softmax}(\mathbf{S})}^{\top} \|}_F, and (2) the entropy regularization .. math:: \mathcal{L}_E = \frac{1}{N} \sum_{n=1}^N H(\mathbf{S}_n). Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. s (torch.Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` with number of clusters :math:`C`. The softmax does not have to be applied before-hand, since it is executed within this method. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) normalize (bool, optional): If set to :obj:`False`, the link prediction loss is not divided by :obj:`adj.numel()`. (default: :obj:`True`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj s = s.unsqueeze(0) if s.dim() == 2 else s batch_size, num_nodes, _ = x.size() s = torch.softmax(s, dim=-1) if mask is not None: mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) x, s = x * mask, s * mask out = torch.matmul(s.transpose(1, 2), x) out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) link_loss = adj - torch.matmul(s, s.transpose(1, 2)) link_loss = torch.norm(link_loss, p=2) if normalize is True: link_loss = link_loss / adj.numel() ent_loss = (-s * torch.log(s + 1e-15)).sum(dim=-1).mean() return out, out_adj, link_loss, ent_loss ================================================ FILE: torch_geometric/nn/dense/dmon_pool.py ================================================ from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.dense.mincut_pool import _rank3_trace EPS = 1e-15 class DMoNPooling(torch.nn.Module): r"""The spectral modularity pooling operator from the `"Graph Clustering with Graph Neural Networks" `_ paper. .. math:: \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X} \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`. Returns the learned cluster assignment matrix, the pooled node feature matrix, the coarsened symmetrically normalized adjacency matrix, and three auxiliary objectives: (1) The spectral loss .. math:: \mathcal{L}_s = - \frac{1}{2m} \cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})} where :math:`\mathbf{B}` is the modularity matrix, (2) the orthogonality loss .. math:: \mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F where :math:`C` is the number of clusters, and (3) the cluster loss .. math:: \mathcal{L}_c = \frac{\sqrt{C}}{n} {\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1. .. note:: For an example of using :class:`DMoNPooling`, see `examples/proteins_dmon_pool.py `_. Args: channels (int or List[int]): Size of each input sample. If given as a list, will construct an MLP based on the given feature sizes. k (int): The number of clusters. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__(self, channels: Union[int, List[int]], k: int, dropout: float = 0.0): super().__init__() if isinstance(channels, int): channels = [channels] from torch_geometric.nn.models.mlp import MLP self.mlp = MLP(channels + [k], act=None, norm=None) self.dropout = dropout self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.mlp.reset_parameters() def forward( self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. Note that the cluster assignment matrix :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is being created within this method. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj s = self.mlp(x) s = F.dropout(s, self.dropout, training=self.training) s = torch.softmax(s, dim=-1) (batch_size, num_nodes, _), C = x.size(), s.size(-1) if mask is None: mask = torch.ones(batch_size, num_nodes, dtype=torch.bool, device=x.device) mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) x, s = x * mask, s * mask out = F.selu(torch.matmul(s.transpose(1, 2), x)) out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) # Spectral loss: degrees = torch.einsum('ijk->ij', adj) # B X N degrees = degrees.unsqueeze(-1) * mask # B x N x 1 degrees_t = degrees.transpose(1, 2) # B x 1 x N m = torch.einsum('ijk->i', degrees) / 2 # B m_expand = m.view(-1, 1, 1).expand(-1, C, C) # B x C x C ca = torch.matmul(s.transpose(1, 2), degrees) # B x C x 1 cb = torch.matmul(degrees_t, s) # B x 1 x C normalizer = torch.matmul(ca, cb) / 2 / m_expand decompose = out_adj - normalizer spectral_loss = -_rank3_trace(decompose) / 2 / m spectral_loss = spectral_loss.mean() # Orthogonality regularization: ss = torch.matmul(s.transpose(1, 2), s) i_s = torch.eye(C).type_as(ss) ortho_loss = torch.norm( ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - i_s / torch.norm(i_s), dim=(-1, -2)) ortho_loss = ortho_loss.mean() # Cluster loss: cluster_size = torch.einsum('ijk->ik', s) # B x C cluster_loss = torch.norm(input=cluster_size, dim=1) cluster_loss = cluster_loss / mask.sum(dim=1) * torch.norm(i_s) - 1 cluster_loss = cluster_loss.mean() # Fix and normalize coarsened adjacency matrix: ind = torch.arange(C, device=out_adj.device) out_adj[:, ind, ind] = 0 d = torch.einsum('ijk->ij', out_adj) d = torch.sqrt(d)[:, None] + EPS out_adj = (out_adj / d) / d.transpose(1, 2) return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.mlp.in_channels}, ' f'num_clusters={self.mlp.out_channels})') ================================================ FILE: torch_geometric/nn/dense/linear.py ================================================ import math import os import time from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn.parameter import Parameter import torch_geometric.backend import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.index import index2ptr from torch_geometric.nn import inits from torch_geometric.typing import pyg_lib from torch_geometric.utils import index_sort def is_uninitialized_parameter(x: Any) -> bool: if not hasattr(torch.nn.parameter, 'UninitializedParameter'): return False return isinstance(x, torch.nn.parameter.UninitializedParameter) def reset_weight_(weight: Tensor, in_channels: int, initializer: Optional[str] = None) -> Tensor: if in_channels <= 0: pass elif initializer == 'glorot': inits.glorot(weight) elif initializer == 'uniform': bound = 1.0 / math.sqrt(in_channels) torch.nn.init.uniform_(weight.data, -bound, bound) elif initializer == 'kaiming_uniform': inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5)) elif initializer is None: inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5)) else: raise RuntimeError(f"Weight initializer '{initializer}' not supported") return weight def reset_bias_(bias: Optional[Tensor], in_channels: int, initializer: Optional[str] = None) -> Optional[Tensor]: if bias is None or in_channels <= 0: pass elif initializer == 'zeros': inits.zeros(bias) elif initializer is None: inits.uniform(in_channels, bias) else: raise RuntimeError(f"Bias initializer '{initializer}' not supported") return bias class Linear(torch.nn.Module): r"""Applies a linear transformation to the incoming data. .. math:: \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b} In contrast to :class:`torch.nn.Linear`, it supports lazy initialization and customizable weight and bias initialization. Args: in_channels (int): Size of each input sample. Will be initialized lazily in case it is given as :obj:`-1`. out_channels (int): Size of each output sample. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) weight_initializer (str, optional): The initializer for the weight matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"` or :obj:`None`). If set to :obj:`None`, will match default weight initialization of :class:`torch.nn.Linear`. (default: :obj:`None`) bias_initializer (str, optional): The initializer for the bias vector (:obj:`"zeros"` or :obj:`None`). If set to :obj:`None`, will match default bias initialization of :class:`torch.nn.Linear`. (default: :obj:`None`) Shapes: - **input:** features :math:`(*, F_{in})` - **output:** features :math:`(*, F_{out})` """ def __init__( self, in_channels: int, out_channels: int, bias: bool = True, weight_initializer: Optional[str] = None, bias_initializer: Optional[str] = None, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer if in_channels > 0: self.weight = Parameter(torch.empty(out_channels, in_channels)) else: self.weight = torch.nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset_weight_(self.weight, self.in_channels, self.weight_initializer) reset_bias_(self.bias, self.in_channels, self.bias_initializer) def forward(self, x: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input features. """ return F.linear(x, self.weight, self.bias) @torch.no_grad() def initialize_parameters(self, module, input): if is_uninitialized_parameter(self.weight): self.in_channels = input[0].size(-1) self.weight.materialize((self.out_channels, self.in_channels)) self.reset_parameters() self._hook.remove() delattr(self, '_hook') def _save_to_state_dict(self, destination, prefix, keep_vars): if (is_uninitialized_parameter(self.weight) or torch.onnx.is_in_onnx_export() or keep_vars): destination[prefix + 'weight'] = self.weight else: destination[prefix + 'weight'] = self.weight.detach() if self.bias is not None: if torch.onnx.is_in_onnx_export() or keep_vars: destination[prefix + 'bias'] = self.bias else: destination[prefix + 'bias'] = self.bias.detach() def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): weight = state_dict.get(prefix + 'weight', None) if weight is not None and is_uninitialized_parameter(weight): self.in_channels = -1 self.weight = torch.nn.parameter.UninitializedParameter() if not hasattr(self, '_hook'): self._hook = self.register_forward_pre_hook( self.initialize_parameters) elif weight is not None and is_uninitialized_parameter(self.weight): self.in_channels = weight.size(-1) self.weight.materialize((self.out_channels, self.in_channels)) if hasattr(self, '_hook'): self._hook.remove() delattr(self, '_hook') super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, bias={self.bias is not None})') class HeteroLinear(torch.nn.Module): r"""Applies separate linear transformations to the incoming data according to types. For type :math:`\kappa`, it computes .. math:: \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}. It supports lazy initialization and customizable weight and bias initialization. Args: in_channels (int): Size of each input sample. Will be initialized lazily in case it is given as :obj:`-1`. out_channels (int): Size of each output sample. num_types (int): The number of types. is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`type_vec` is sorted. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.Linear`. Shapes: - **input:** features :math:`(*, F_{in})`, type vector :math:`(*)` - **output:** features :math:`(*, F_{out})` """ _timing_cache: Dict[int, Tuple[float, float]] def __init__( self, in_channels: int, out_channels: int, num_types: int, is_sorted: bool = False, **kwargs, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_types = num_types self.is_sorted = is_sorted self.kwargs = kwargs if self.in_channels == -1: self.weight = torch.nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) else: self.weight = torch.nn.Parameter( torch.empty(num_types, in_channels, out_channels)) if kwargs.get('bias', True): self.bias = Parameter(torch.empty(num_types, out_channels)) else: self.register_parameter('bias', None) # Timing cache for benchmarking naive vs. segment matmul usage: self._timing_cache: Dict[int, Tuple[float, float]] = {} self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset_weight_(self.weight, self.in_channels, self.kwargs.get('weight_initializer', None)) reset_bias_(self.bias, self.in_channels, self.kwargs.get('bias_initializer', None)) def forward_naive(self, x: Tensor, type_ptr: Tensor) -> Tensor: out = x.new_empty(x.size(0), self.out_channels) for i, (start, end) in enumerate(zip(type_ptr[:-1], type_ptr[1:])): out[start:end] = x[start:end] @ self.weight[i] return out def forward_segmm(self, x: Tensor, type_ptr: Tensor) -> Tensor: return pyg_lib.ops.segment_matmul(x, type_ptr, self.weight) @torch.no_grad() def _update_timing_cache( self, x: Tensor, type_ptr: Tensor, key: int, ) -> None: MEASURE_ITER = 1 if 'PYTEST_CURRENT_TEST' in os.environ else 3 if torch.cuda.is_available(): torch.cuda.synchronize() t = time.perf_counter() for _ in range(MEASURE_ITER): _ = self.forward_segmm(x, type_ptr) if torch.cuda.is_available(): torch.cuda.synchronize() time_segmm = time.perf_counter() - t if torch.cuda.is_available(): torch.cuda.synchronize() t = time.perf_counter() for _ in range(MEASURE_ITER): _ = self.forward_naive(x, type_ptr) if torch.cuda.is_available(): torch.cuda.synchronize() time_naive = time.perf_counter() - t self._timing_cache[key] = (time_segmm, time_naive) def forward(self, x: Tensor, type_vec: Tensor) -> Tensor: r"""The forward pass. Args: x (torch.Tensor): The input features. type_vec (torch.Tensor): A vector that maps each entry to a type. """ perm: Optional[Tensor] = None if not self.is_sorted and (type_vec[1:] < type_vec[:-1]).any(): type_vec, perm = index_sort(type_vec, self.num_types) x = x[perm] type_ptr = index2ptr(type_vec, self.num_types) if torch_geometric.backend.use_segment_matmul is None: use_segment_matmul = False if (torch_geometric.typing.WITH_SEGMM and not is_compiling() and not torch.jit.is_scripting()): # Use "magnitude" of number of rows as timing key: key = math.floor(math.log10(x.size(0))) if key not in self._timing_cache: self._update_timing_cache(x, type_ptr, key) time_segmm, time_naive = self._timing_cache[key] use_segment_matmul = time_segmm < time_naive else: use_segment_matmul = torch_geometric.backend.use_segment_matmul if (torch_geometric.typing.WITH_SEGMM and not is_compiling() and use_segment_matmul): out = self.forward_segmm(x, type_ptr) else: out = self.forward_naive(x, type_ptr) if self.bias is not None: out += self.bias[type_vec] if perm is not None: # Restore original order (if necessary). out_unsorted = torch.empty_like(out) out_unsorted[perm] = out out = out_unsorted return out @torch.no_grad() def initialize_parameters(self, module, input): if is_uninitialized_parameter(self.weight): self.in_channels = input[0].size(-1) self.weight.materialize( (self.num_types, self.in_channels, self.out_channels)) self.reset_parameters() self._hook.remove() delattr(self, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_types={self.num_types}, ' f'bias={self.kwargs.get("bias", True)})') class HeteroDictLinear(torch.nn.Module): r"""Applies separate linear transformations to the incoming data dictionary. For key :math:`\kappa`, it computes .. math:: \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}. It supports lazy initialization and customizable weight and bias initialization. Args: in_channels (int or Dict[Any, int]): Size of each input sample. If passed an integer, :obj:`types` will be a mandatory argument. initialized lazily in case it is given as :obj:`-1`. out_channels (int): Size of each output sample. types (List[Any], optional): The keys of the input dictionary. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.Linear`. """ def __init__( self, in_channels: Union[int, Dict[Any, int]], out_channels: int, types: Optional[Any] = None, **kwargs, ): super().__init__() if isinstance(in_channels, dict): self.types = list(in_channels.keys()) if any([i == -1 for i in in_channels.values()]): self._hook = self.register_forward_pre_hook( self.initialize_parameters) if types is not None and set(self.types) != set(types): raise ValueError("The provided 'types' do not match with the " "keys in the 'in_channels' dictionary") else: if types is None: raise ValueError("Please provide a list of 'types' if passing " "'in_channels' as an integer") if in_channels == -1: self._hook = self.register_forward_pre_hook( self.initialize_parameters) self.types = types in_channels = {node_type: in_channels for node_type in types} self.in_channels = in_channels self.out_channels = out_channels self.kwargs = kwargs self.lins = torch.nn.ModuleDict({ key: Linear(channels, self.out_channels, **kwargs) for key, channels in self.in_channels.items() }) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for lin in self.lins.values(): lin.reset_parameters() def forward( self, x_dict: Dict[str, Tensor], ) -> Dict[str, Tensor]: r"""Forward pass. Args: x_dict (Dict[Any, torch.Tensor]): A dictionary holding input features for each individual type. """ out_dict = {} # Only apply fused kernel for more than 10 types, otherwise use # sequential computation (which is generally faster for these cases). use_segment_matmul = torch_geometric.backend.use_segment_matmul if use_segment_matmul is None: use_segment_matmul = len(x_dict) >= 10 if (use_segment_matmul and torch_geometric.typing.WITH_GMM and not is_compiling() and not torch.jit.is_scripting()): xs, weights, biases = [], [], [] for key, lin in self.lins.items(): if key in x_dict: xs.append(x_dict[key]) weights.append(lin.weight.t()) biases.append(lin.bias) biases = None if biases[0] is None else biases outs = pyg_lib.ops.grouped_matmul(xs, weights, biases) for key, out in zip(x_dict.keys(), outs): if key in x_dict: out_dict[key] = out else: for key, lin in self.lins.items(): if key in x_dict: out_dict[key] = lin(x_dict[key]) return out_dict @torch.no_grad() def initialize_parameters(self, module, input): for key, x in input[0].items(): lin = self.lins[key] if is_uninitialized_parameter(lin.weight): self.lins[key].initialize_parameters(None, x) self.lins[key].reset_parameters() self._hook.remove() self.in_channels = {key: x.size(-1) for key, x in input[0].items()} delattr(self, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, bias={self.kwargs.get("bias", True)})') ================================================ FILE: torch_geometric/nn/dense/mincut_pool.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor def dense_mincut_pool( x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None, temp: float = 1.0, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""The MinCut pooling operator from the `"Spectral Clustering in Graph Neural Networks for Graph Pooling" `_ paper. .. math:: \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X} \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`. Returns the pooled node feature matrix, the coarsened and symmetrically normalized adjacency matrix and two auxiliary objectives: (1) The MinCut loss .. math:: \mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A} \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D} \mathbf{S})} where :math:`\mathbf{D}` is the degree matrix, and (2) the orthogonality loss .. math:: \mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. s (torch.Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` with number of clusters :math:`C`. The softmax does not have to be applied before-hand, since it is executed within this method. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) temp (float, optional): Temperature parameter for softmax function. (default: :obj:`1.0`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj s = s.unsqueeze(0) if s.dim() == 2 else s (batch_size, num_nodes, _), k = x.size(), s.size(-1) s = torch.softmax(s / temp if temp != 1.0 else s, dim=-1) if mask is not None: mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) x, s = x * mask, s * mask out = torch.matmul(s.transpose(1, 2), x) out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) # MinCut regularization. mincut_num = _rank3_trace(out_adj) d_flat = torch.einsum('ijk->ij', adj) d = _rank3_diag(d_flat) mincut_den = _rank3_trace( torch.matmul(torch.matmul(s.transpose(1, 2), d), s)) mincut_loss = -(mincut_num / mincut_den) mincut_loss = torch.mean(mincut_loss) # Orthogonality regularization. ss = torch.matmul(s.transpose(1, 2), s) i_s = torch.eye(k).type_as(ss) ortho_loss = torch.norm( ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - i_s / torch.norm(i_s), dim=(-1, -2)) ortho_loss = torch.mean(ortho_loss) EPS = 1e-15 # Fix and normalize coarsened adjacency matrix. ind = torch.arange(k, device=out_adj.device) out_adj[:, ind, ind] = 0 d = torch.einsum('ijk->ij', out_adj) d = torch.sqrt(d)[:, None] + EPS out_adj = (out_adj / d) / d.transpose(1, 2) return out, out_adj, mincut_loss, ortho_loss def _rank3_trace(x: Tensor) -> Tensor: return torch.einsum('ijj->i', x) def _rank3_diag(x: Tensor) -> Tensor: eye = torch.eye(x.size(1)).type_as(x) out = eye * x.unsqueeze(2).expand(x.size(0), x.size(1), x.size(1)) return out ================================================ FILE: torch_geometric/nn/encoding.py ================================================ import math from typing import Optional import torch from torch import Tensor __all__ = classes = [ 'PositionalEncoding', 'TemporalEncoding', ] class PositionalEncoding(torch.nn.Module): r"""The positional encoding scheme from the `"Attention Is All You Need" `_ paper. .. math:: PE(x)_{2 \cdot i} &= \sin(x / 10000^{2 \cdot i / d}) PE(x)_{2 \cdot i + 1} &= \cos(x / 10000^{2 \cdot i / d}) where :math:`x` is the position and :math:`i` is the dimension. Args: out_channels (int): Size :math:`d` of each output sample. base_freq (float, optional): The base frequency of sinusoidal functions. (default: :obj:`1e-4`) granularity (float, optional): The granularity of the positions. If set to smaller value, the encoder will capture more fine-grained changes in positions. (default: :obj:`1.0`) device (torch.device, optional): The device of the module. (default: :obj:`None`) """ def __init__( self, out_channels: int, base_freq: float = 1e-4, granularity: float = 1.0, device: Optional[torch.device] = None, ): super().__init__() if out_channels % 2 != 0: raise ValueError(f"Cannot use sinusoidal positional encoding with " f"odd 'out_channels' (got {out_channels}).") self.out_channels = out_channels self.base_freq = base_freq self.granularity = granularity frequency = torch.logspace(0, 1, out_channels // 2, base_freq, device=device) self.register_buffer('frequency', frequency) self.reset_parameters() def reset_parameters(self): pass def forward(self, x: Tensor) -> Tensor: """""" # noqa: D419 x = x / self.granularity if self.granularity != 1.0 else x out = x.view(-1, 1) * self.frequency.view(1, -1) return torch.cat([torch.sin(out), torch.cos(out)], dim=-1) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.out_channels})' class TemporalEncoding(torch.nn.Module): r"""The time-encoding function from the `"Do We Really Need Complicated Model Architectures for Temporal Networks?" `_ paper. It first maps each entry to a vector with exponentially decreasing values, and then uses the cosine function to project all values to range :math:`[-1, 1]`. .. math:: y_{i} = \cos \left(x \cdot \sqrt{d}^{-(i - 1)/\sqrt{d}} \right) where :math:`d` defines the output feature dimension, and :math:`1 \leq i \leq d`. Args: out_channels (int): Size :math:`d` of each output sample. device (torch.device, optional): The device of the module. (default: :obj:`None`) """ def __init__(self, out_channels: int, device: Optional[torch.device] = None): super().__init__() self.out_channels = out_channels sqrt = math.sqrt(out_channels) weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels, device=device).view(1, -1) self.register_buffer('weight', weight) self.reset_parameters() def reset_parameters(self): pass def forward(self, x: Tensor) -> Tensor: """""" # noqa: D419 return torch.cos(x.view(-1, 1) @ self.weight) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.out_channels})' ================================================ FILE: torch_geometric/nn/functional/__init__.py ================================================ r"""Functional operator package.""" from .bro import bro from .gini import gini __all__ = classes = [ 'bro', 'gini', ] ================================================ FILE: torch_geometric/nn/functional/bro.py ================================================ from typing import Union import torch def bro( x: torch.Tensor, batch: torch.Tensor, p: Union[int, str] = 2, ) -> torch.Tensor: r"""The Batch Representation Orthogonality penalty from the `"Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity" `_ paper. Computes a regularization for each graph representation in a mini-batch according to .. math:: \mathcal{L}_{\textrm{BRO}}^\mathrm{graph} = || \mathbf{HH}^T - \mathbf{I}||_p and returns an average over all graphs in the batch. Args: x (torch.Tensor): The node feature matrix. batch (torch.Tensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. p (int or str, optional): The norm order to use. (default: :obj:`2`) """ _, counts = torch.unique(batch, return_counts=True) diags = torch.stack([ torch.diag(x) for x in torch.nn.utils.rnn.pad_sequence( sequences=torch.ones_like(batch).split_with_sizes(counts.tolist()), padding_value=0., batch_first=True, ) ]) x = x.split_with_sizes(split_sizes=counts.tolist()) x = torch.nn.utils.rnn.pad_sequence( sequences=x, padding_value=0., batch_first=True, ) return torch.sum(torch.norm(x @ x.transpose(1, 2) - diags, p=p, dim=(1, 2))) / counts.shape[0] ================================================ FILE: torch_geometric/nn/functional/gini.py ================================================ import torch def gini(w: torch.Tensor) -> torch.Tensor: r"""The Gini coefficient from the `"Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity" `_ paper. Computes a regularization penalty :math:`\in [0, 1]` for each row of a matrix according to .. math:: \mathcal{L}_\textrm{Gini}^i = \sum_j^n \sum_{j'}^n \frac{|w_{ij} - w_{ij'}|}{2 (n^2 - n)\bar{w_i}} and returns an average over all rows. Args: w (torch.Tensor): A two-dimensional tensor. """ s = 0 for row in w: t = row.repeat(row.size(0), 1) u = (t - t.T).abs().sum() / (2 * (row.size(-1)**2 - row.size(-1)) * row.abs().mean() + torch.finfo().eps) s += u s /= w.shape[0] return s ================================================ FILE: torch_geometric/nn/fx.py ================================================ import copy import warnings from typing import Any, Callable, Dict, List, Optional, Type, Union import torch from torch import Tensor from torch.nn import Module, ModuleDict, ModuleList, Sequential try: from torch.fx import Graph, GraphModule, Node except (ImportError, ModuleNotFoundError, AttributeError): GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node' class Transformer: r"""A :class:`Transformer` executes an FX graph node-by-node, applies transformations to each node, and produces a new :class:`torch.nn.Module`. It exposes a :func:`transform` method that returns the transformed :class:`~torch.nn.Module`. :class:`Transformer` works entirely symbolically. Methods in the :class:`Transformer` class can be overridden to customize the behavior of transformation. .. code-block:: none transform() +-- Iterate over each node in the graph +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- call_message_passing_module() +-- call_global_pooling_module() +-- output() +-- Erase unused nodes in the graph +-- Iterate over each children module +-- init_submodule() In contrast to the :class:`torch.fx.Transformer` class, the :class:`Transformer` exposes additional functionality: #. It subdivides :func:`call_module` into nodes that call a regular :class:`torch.nn.Module` (:func:`call_module`), a :class:`MessagePassing` module (:func:`call_message_passing_module`), or a :class:`GlobalPooling` module (:func:`call_global_pooling_module`). #. It allows to customize or initialize new children modules via :func:`init_submodule` #. It allows to infer whether a node returns node-level or edge-level information via :meth:`is_edge_level`. Args: module (torch.nn.Module): The module to be transformed. input_map (Dict[str, str], optional): A dictionary holding information about the type of input arguments of :obj:`module.forward`. For example, in case :obj:`arg` is a node-level argument, then :obj:`input_map['arg'] = 'node'`, and :obj:`input_map['arg'] = 'edge'` otherwise. In case :obj:`input_map` is not further specified, will try to automatically determine the correct type of input arguments. (default: :obj:`None`) debug (bool, optional): If set to :obj:`True`, will perform transformation in debug mode. (default: :obj:`False`) """ def __init__( self, module: Module, input_map: Optional[Dict[str, str]] = None, debug: bool = False, ): self.module = module self.gm = symbolic_trace(module) self.input_map = input_map self.debug = debug # Methods to override ##################################################### def placeholder(self, node: Node, target: Any, name: str): pass def get_attr(self, node: Node, target: Any, name: str): pass def call_message_passing_module(self, node: Node, target: Any, name: str): pass def call_global_pooling_module(self, node: Node, target: Any, name: str): pass def call_module(self, node: Node, target: Any, name: str): pass def call_method(self, node: Node, target: Any, name: str): pass def call_function(self, node: Node, target: Any, name: str): pass def output(self, node: Node, target: Any, name: str): pass def init_submodule(self, module: Module, target: str) -> Module: return module # Internal functionality ################################################## @property def graph(self) -> Graph: return self.gm.graph def transform(self) -> GraphModule: r"""Transforms :obj:`self.module` and returns a transformed :class:`torch.fx.GraphModule`. """ if self.debug: self.graph.print_tabular() print() code = self.graph.python_code('self') print(code.src if hasattr(code, 'src') else code) # We create a private dictionary `self._state` which holds information # about whether a node returns node-level or edge-level information: # `self._state[node.name] in { 'node', 'edge' }` self._state = copy.copy(self.input_map or {}) # We iterate over each node and determine its output level # (node-level, edge-level) by filling `self._state`: for node in list(self.graph.nodes): if node.op == 'call_function' and 'training' in node.kwargs: warnings.warn( f"Found function '{node.name}' with keyword " f"argument 'training'. During FX tracing, this " f"will likely be baked in as a constant value. " f"Consider replacing this function by a module " f"to properly encapsulate its training flag.", stacklevel=2) if node.op == 'placeholder': if node.name not in self._state: if 'edge' in node.name or 'adj' in node.name: self._state[node.name] = 'edge' else: self._state[node.name] = 'node' elif is_message_passing_op(self.module, node.op, node.target): self._state[node.name] = 'node' elif is_global_pooling_op(self.module, node.op, node.target): self._state[node.name] = 'graph' elif node.op in ['call_module', 'call_method', 'call_function']: if self.has_edge_level_arg(node): self._state[node.name] = 'edge' elif self.has_node_level_arg(node): self._state[node.name] = 'node' else: self._state[node.name] = 'graph' # We iterate over each node and may transform it: for node in list(self.graph.nodes): # Call the corresponding `Transformer` method for each `node.op`, # e.g.: `call_module(...)`, `call_function(...)`, ... op = node.op if is_message_passing_op(self.module, op, node.target): op = 'call_message_passing_module' elif is_global_pooling_op(self.module, op, node.target): op = 'call_global_pooling_module' getattr(self, op)(node, node.target, node.name) # Remove all unused nodes in the computation graph, i.e., all nodes # which have been replaced by node type-wise or edge type-wise variants # but which are still present in the computation graph. # We do this by iterating over the computation graph in reversed order, # and try to remove every node. This does only succeed in case there # are no users of that node left in the computation graph. for node in reversed(list(self.graph.nodes)): try: if node.op not in ['placeholder', 'output']: self.graph.erase_node(node) except RuntimeError: pass for target, submodule in dict(self.module._modules).items(): self.gm._modules[target] = self._init_submodule(submodule, target) del self._state if self.debug: self.gm.graph.print_tabular() print() code = self.graph.python_code('self') print(code.src if hasattr(code, 'src') else code) self.gm.graph.lint() self.gm.recompile() return self.gm def _init_submodule(self, module: Module, target: str) -> Module: if isinstance(module, ModuleList) or isinstance(module, Sequential): return ModuleList([ self._init_submodule(submodule, f'{target}.{i}') for i, submodule in enumerate(module) ]) elif isinstance(module, ModuleDict): return ModuleDict({ key: self._init_submodule(submodule, f'{target}.{key}') for key, submodule in module.items() }) else: return self.init_submodule(module, target) def _is_level(self, node: Node, name: str) -> bool: return self._state[node.name] == name def _has_level_arg(self, node: Node, name: str) -> bool: def _recurse(value: Any) -> bool: if isinstance(value, Node): return getattr(self, f'is_{name}_level')(value) elif isinstance(value, dict): return any([_recurse(v) for v in value.values()]) elif isinstance(value, (list, tuple)): return any([_recurse(v) for v in value]) else: return False return (any([_recurse(value) for value in node.args]) or any([_recurse(value) for value in node.kwargs.values()])) def is_node_level(self, node: Node) -> bool: return self._is_level(node, name='node') def is_edge_level(self, node: Node) -> bool: return self._is_level(node, name='edge') def is_graph_level(self, node: Node) -> bool: return self._is_level(node, name='graph') def has_node_level_arg(self, node: Node) -> bool: return self._has_level_arg(node, name='node') def has_edge_level_arg(self, node: Node) -> bool: return self._has_level_arg(node, name='edge') def has_graph_level_arg(self, node: Node) -> bool: return self._has_level_arg(node, name='graph') def find_by_name(self, name: str) -> Optional[Node]: for node in self.graph.nodes: if node.name == name: return node return None def find_by_target(self, target: Any) -> Optional[Node]: for node in self.graph.nodes: if node.target == target: return node return None def replace_all_uses_with(self, to_replace: Node, replace_with: Node): def maybe_replace_node(n: Node) -> Node: return replace_with if n == to_replace else n node = replace_with.next while node.op != 'root': node.args = torch.fx.map_arg(node.args, maybe_replace_node) node.kwargs = torch.fx.map_arg(node.kwargs, maybe_replace_node) node = node.next def symbolic_trace( module: Module, concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: # This is to support compatibility with pytorch version 1.9 and lower try: import torch.fx._symbolic_trace as st except (ImportError, ModuleNotFoundError): import torch.fx.symbolic_trace as st from torch_geometric.nn import Aggregation class Tracer(torch.fx.Tracer): def is_leaf_module(self, module: Module, *args, **kwargs) -> bool: # TODO We currently only trace top-level modules. return not isinstance(module, torch.nn.Sequential) # Note: This is a hack around the fact that `Aggregation.__call__` # is not patched by the base implementation of `trace`. # see https://github.com/pyg-team/pytorch_geometric/pull/5021 for # details on the rationale # TODO: Revisit https://github.com/pyg-team/pytorch_geometric/pull/5021 @st.compatibility(is_backward_compatible=True) def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: if isinstance(root, torch.nn.Module): self.root = root fn = type(root).forward self.submodule_paths = { mod: name for name, mod in root.named_modules() } else: self.root = torch.nn.Module() fn = root tracer_cls: Optional[Type['Tracer']] = getattr( self, '__class__', None) self.graph = Graph(tracer_cls=tracer_cls) self.tensor_attrs: Dict[Union[Tensor, st.ScriptObject], str] = {} def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): for k, v in m.__dict__.items(): if isinstance(v, (Tensor, st.ScriptObject)): self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) for k, v in m.named_children(): collect_tensor_attrs(v, prefix_atoms + [k]) collect_tensor_attrs(self.root, []) assert isinstance(fn, st.FunctionType) fn_globals = fn.__globals__ # run before it gets patched fn, args = self.create_args_for_root( fn, isinstance(root, torch.nn.Module), concrete_args) parameter_proxy_cache: Dict[str, st.Proxy] = { } # Reduce number of get_attr calls @st.functools.wraps(st._orig_module_getattr) def module_getattr_wrapper(mod, attr): attr_val = st._orig_module_getattr(mod, attr) # Support for PyTorch > 1.12, see: # https://github.com/pytorch/pytorch/pull/84011 if hasattr(self, 'getattr'): return self.getattr(attr, attr_val, parameter_proxy_cache) return self._module_getattr(attr, attr_val, parameter_proxy_cache) @st.functools.wraps(st._orig_module_call) def module_call_wrapper(mod, *args, **kwargs): def forward(*args, **kwargs): return st._orig_module_call(mod, *args, **kwargs) st._autowrap_check( patcher, getattr(getattr(mod, "forward", mod), "__globals__", {}), self._autowrap_function_ids) return self.call_module(mod, forward, args, kwargs) with st._Patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method(torch.nn.Module, "__getattr__", module_getattr_wrapper, deduplicate=False) patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) patcher.patch_method(Aggregation, "__call__", module_call_wrapper, deduplicate=False) st._patch_wrapped_functions(patcher) st._autowrap_check(patcher, fn_globals, self._autowrap_function_ids) for module in self._autowrap_search: st._autowrap_check(patcher, module.__dict__, self._autowrap_function_ids) self.create_node( 'output', 'output', (self.create_arg(fn(*args)), ), {}, type_expr=fn.__annotations__.get('return', None)) self.submodule_paths = None return self.graph return GraphModule(module, Tracer().trace(module, concrete_args)) def get_submodule(module: Module, target: str) -> Module: out = module for attr in target.split('.'): out = getattr(out, attr) return out def is_message_passing_op(module: Module, op: str, target: str) -> bool: from torch_geometric.nn import MessagePassing if op == 'call_module': return isinstance(get_submodule(module, target), MessagePassing) return False def is_global_pooling_op(module: Module, op: str, target: str) -> bool: from torch_geometric.nn import Aggregation if op == 'call_module': return isinstance(get_submodule(module, target), Aggregation) return False ================================================ FILE: torch_geometric/nn/glob.py ================================================ from torch_geometric.deprecation import deprecated from torch_geometric.nn import ( global_add_pool, global_max_pool, global_mean_pool, ) from torch_geometric.nn.aggr import AttentionalAggregation, SortAggregation @deprecated( details="use 'nn.aggr.AttentionalAggregation' instead", func_name='nn.glob.GlobalAttention', ) class GlobalAttention(AttentionalAggregation): def __call__(self, x, batch=None, size=None): return super().__call__(x, batch, dim_size=size) @deprecated( details="use 'nn.aggr.SortAggr' instead", func_name='nn.glob.global_sort_pool', ) def global_sort_pool(x, index, k): module = SortAggregation(k=k) return module(x, index=index) deprecated( details="use 'nn.pool.global_add_pool' instead", func_name='nn.glob.global_add_pool', )(global_add_pool) deprecated( details="use 'nn.pool.global_max_pool' instead", func_name='nn.glob.global_max_pool', )(global_max_pool) deprecated( details="use 'nn.pool.global_mean_pool' instead", func_name='nn.glob.global_mean_pool', )(global_mean_pool) ================================================ FILE: torch_geometric/nn/inits.py ================================================ import math from typing import Any import torch from torch import Tensor def uniform(size: int, value: Any): if isinstance(value, Tensor): bound = 1.0 / math.sqrt(size) value.data.uniform_(-bound, bound) else: for v in value.parameters() if hasattr(value, 'parameters') else []: uniform(size, v) for v in value.buffers() if hasattr(value, 'buffers') else []: uniform(size, v) def kaiming_uniform(value: Any, fan: int, a: float): if isinstance(value, Tensor): bound = math.sqrt(6 / ((1 + a**2) * fan)) value.data.uniform_(-bound, bound) else: for v in value.parameters() if hasattr(value, 'parameters') else []: kaiming_uniform(v, fan, a) for v in value.buffers() if hasattr(value, 'buffers') else []: kaiming_uniform(v, fan, a) def glorot(value: Any): if isinstance(value, Tensor): stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1))) value.data.uniform_(-stdv, stdv) else: for v in value.parameters() if hasattr(value, 'parameters') else []: glorot(v) for v in value.buffers() if hasattr(value, 'buffers') else []: glorot(v) def glorot_orthogonal(tensor, scale): if tensor is not None: torch.nn.init.orthogonal_(tensor.data) scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var()) tensor.data *= scale.sqrt() def constant(value: Any, fill_value: float): if isinstance(value, Tensor): value.data.fill_(fill_value) else: for v in value.parameters() if hasattr(value, 'parameters') else []: constant(v, fill_value) for v in value.buffers() if hasattr(value, 'buffers') else []: constant(v, fill_value) def zeros(value: Any): constant(value, 0.) def ones(tensor: Any): constant(tensor, 1.) def normal(value: Any, mean: float, std: float): if isinstance(value, Tensor): value.data.normal_(mean, std) else: for v in value.parameters() if hasattr(value, 'parameters') else []: normal(v, mean, std) for v in value.buffers() if hasattr(value, 'buffers') else []: normal(v, mean, std) def reset(value: Any): if hasattr(value, 'reset_parameters'): value.reset_parameters() else: for child in value.children() if hasattr(value, 'children') else []: reset(child) ================================================ FILE: torch_geometric/nn/kge/__init__.py ================================================ r"""Knowledge Graph Embedding (KGE) package.""" from .base import KGEModel from .transe import TransE from .complex import ComplEx from .distmult import DistMult from .rotate import RotatE __all__ = classes = [ 'KGEModel', 'TransE', 'ComplEx', 'DistMult', 'RotatE', ] ================================================ FILE: torch_geometric/nn/kge/base.py ================================================ from typing import Tuple import torch from torch import Tensor from torch.nn import Embedding from tqdm import tqdm from torch_geometric.nn.kge.loader import KGTripletLoader class KGEModel(torch.nn.Module): r"""An abstract base class for implementing custom KGE models. Args: num_nodes (int): The number of nodes/entities in the graph. num_relations (int): The number of relations in the graph. hidden_channels (int): The hidden embedding size. sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the embedding matrices will be sparse. (default: :obj:`False`) """ def __init__( self, num_nodes: int, num_relations: int, hidden_channels: int, sparse: bool = False, ): super().__init__() self.num_nodes = num_nodes self.num_relations = num_relations self.hidden_channels = hidden_channels self.node_emb = Embedding(num_nodes, hidden_channels, sparse=sparse) self.rel_emb = Embedding(num_relations, hidden_channels, sparse=sparse) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.node_emb.reset_parameters() self.rel_emb.reset_parameters() def forward( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: r"""Returns the score for the given triplet. Args: head_index (torch.Tensor): The head indices. rel_type (torch.Tensor): The relation type. tail_index (torch.Tensor): The tail indices. """ raise NotImplementedError def loss( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: r"""Returns the loss value for the given triplet. Args: head_index (torch.Tensor): The head indices. rel_type (torch.Tensor): The relation type. tail_index (torch.Tensor): The tail indices. """ raise NotImplementedError def loader( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, **kwargs, ) -> Tensor: r"""Returns a mini-batch loader that samples a subset of triplets. Args: head_index (torch.Tensor): The head indices. rel_type (torch.Tensor): The relation type. tail_index (torch.Tensor): The tail indices. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ return KGTripletLoader(head_index, rel_type, tail_index, **kwargs) @torch.no_grad() def test( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, batch_size: int, k: int = 10, log: bool = True, ) -> Tuple[float, float, float]: r"""Evaluates the model quality by computing Mean Rank, MRR and Hits@:math:`k` across all possible tail entities. Args: head_index (torch.Tensor): The head indices. rel_type (torch.Tensor): The relation type. tail_index (torch.Tensor): The tail indices. batch_size (int): The batch size to use for evaluating. k (int, optional): The :math:`k` in Hits @ :math:`k`. (default: :obj:`10`) log (bool, optional): If set to :obj:`False`, will not print a progress bar to the console. (default: :obj:`True`) """ arange = range(head_index.numel()) arange = tqdm(arange) if log else arange mean_ranks, reciprocal_ranks, hits_at_k = [], [], [] for i in arange: h, r, t = head_index[i], rel_type[i], tail_index[i] scores = [] tail_indices = torch.arange(self.num_nodes, device=t.device) for ts in tail_indices.split(batch_size): scores.append(self(h.expand_as(ts), r.expand_as(ts), ts)) rank = int((torch.cat(scores).argsort( descending=True) == t).nonzero().view(-1)) mean_ranks.append(rank) reciprocal_ranks.append(1 / (rank + 1)) hits_at_k.append(rank < k) mean_rank = float(torch.tensor(mean_ranks, dtype=torch.float).mean()) mrr = float(torch.tensor(reciprocal_ranks, dtype=torch.float).mean()) hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k) return mean_rank, mrr, hits_at_k @torch.no_grad() def random_sample( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Randomly samples negative triplets by either replacing the head or the tail (but not both). Args: head_index (torch.Tensor): The head indices. rel_type (torch.Tensor): The relation type. tail_index (torch.Tensor): The tail indices. """ # Random sample either `head_index` or `tail_index` (but not both): num_negatives = head_index.numel() // 2 rnd_index = torch.randint(self.num_nodes, head_index.size(), device=head_index.device) head_index = head_index.clone() head_index[:num_negatives] = rnd_index[:num_negatives] tail_index = tail_index.clone() tail_index[num_negatives:] = rnd_index[num_negatives:] return head_index, rel_type, tail_index def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.num_nodes}, ' f'num_relations={self.num_relations}, ' f'hidden_channels={self.hidden_channels})') ================================================ FILE: torch_geometric/nn/kge/complex.py ================================================ import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding from torch_geometric.nn.kge import KGEModel class ComplEx(KGEModel): r"""The ComplEx model from the `"Complex Embeddings for Simple Link Prediction" `_ paper. :class:`ComplEx` models relations as complex-valued bilinear mappings between head and tail entities using the Hermetian dot product. The entities and relations are embedded in different dimensional spaces, resulting in the scoring function: .. math:: d(h, r, t) = Re(< \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t>) .. note:: For an example of using the :class:`ComplEx` model, see `examples/kge_fb15k_237.py `_. Args: num_nodes (int): The number of nodes/entities in the graph. num_relations (int): The number of relations in the graph. hidden_channels (int): The hidden embedding size. sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the embedding matrices will be sparse. (default: :obj:`False`) """ def __init__( self, num_nodes: int, num_relations: int, hidden_channels: int, sparse: bool = False, ): super().__init__(num_nodes, num_relations, hidden_channels, sparse) self.node_emb_im = Embedding(num_nodes, hidden_channels, sparse=sparse) self.rel_emb_im = Embedding(num_relations, hidden_channels, sparse=sparse) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.node_emb.weight) torch.nn.init.xavier_uniform_(self.node_emb_im.weight) torch.nn.init.xavier_uniform_(self.rel_emb.weight) torch.nn.init.xavier_uniform_(self.rel_emb_im.weight) def forward( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: head_re = self.node_emb(head_index) head_im = self.node_emb_im(head_index) rel_re = self.rel_emb(rel_type) rel_im = self.rel_emb_im(rel_type) tail_re = self.node_emb(tail_index) tail_im = self.node_emb_im(tail_index) return (triple_dot(head_re, rel_re, tail_re) + triple_dot(head_im, rel_re, tail_im) + triple_dot(head_re, rel_im, tail_im) - triple_dot(head_im, rel_im, tail_re)) def loss( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: pos_score = self(head_index, rel_type, tail_index) neg_score = self(*self.random_sample(head_index, rel_type, tail_index)) scores = torch.cat([pos_score, neg_score], dim=0) pos_target = torch.ones_like(pos_score) neg_target = torch.zeros_like(neg_score) target = torch.cat([pos_target, neg_target], dim=0) return F.binary_cross_entropy_with_logits(scores, target) def triple_dot(x: Tensor, y: Tensor, z: Tensor) -> Tensor: return (x * y * z).sum(dim=-1) ================================================ FILE: torch_geometric/nn/kge/distmult.py ================================================ import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.kge import KGEModel class DistMult(KGEModel): r"""The DistMult model from the `"Embedding Entities and Relations for Learning and Inference in Knowledge Bases" `_ paper. :class:`DistMult` models relations as diagonal matrices, which simplifies the bi-linear interaction between the head and tail entities to the score function: .. math:: d(h, r, t) = < \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t > .. note:: For an example of using the :class:`DistMult` model, see `examples/kge_fb15k_237.py `_. Args: num_nodes (int): The number of nodes/entities in the graph. num_relations (int): The number of relations in the graph. hidden_channels (int): The hidden embedding size. margin (float, optional): The margin of the ranking loss. (default: :obj:`1.0`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the embedding matrices will be sparse. (default: :obj:`False`) """ def __init__( self, num_nodes: int, num_relations: int, hidden_channels: int, margin: float = 1.0, sparse: bool = False, ): super().__init__(num_nodes, num_relations, hidden_channels, sparse) self.margin = margin self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.node_emb.weight) torch.nn.init.xavier_uniform_(self.rel_emb.weight) def forward( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: head = self.node_emb(head_index) rel = self.rel_emb(rel_type) tail = self.node_emb(tail_index) return (head * rel * tail).sum(dim=-1) def loss( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: pos_score = self(head_index, rel_type, tail_index) neg_score = self(*self.random_sample(head_index, rel_type, tail_index)) return F.margin_ranking_loss( pos_score, neg_score, target=torch.ones_like(pos_score), margin=self.margin, ) ================================================ FILE: torch_geometric/nn/kge/loader.py ================================================ from typing import List, Tuple import torch from torch import Tensor class KGTripletLoader(torch.utils.data.DataLoader): def __init__(self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, **kwargs): self.head_index = head_index self.rel_type = rel_type self.tail_index = tail_index super().__init__(range(head_index.numel()), collate_fn=self.sample, **kwargs) def sample(self, index: List[int]) -> Tuple[Tensor, Tensor, Tensor]: index = torch.tensor(index, device=self.head_index.device) head_index = self.head_index[index] rel_type = self.rel_type[index] tail_index = self.tail_index[index] return head_index, rel_type, tail_index ================================================ FILE: torch_geometric/nn/kge/rotate.py ================================================ import math import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding from torch_geometric.nn.kge import KGEModel class RotatE(KGEModel): r"""The RotatE model from the `"RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space" `_ paper. :class:`RotatE` models relations as a rotation in complex space from head to tail such that .. math:: \mathbf{e}_t = \mathbf{e}_h \circ \mathbf{e}_r, resulting in the scoring function .. math:: d(h, r, t) = - {\| \mathbf{e}_h \circ \mathbf{e}_r - \mathbf{e}_t \|}_p .. note:: For an example of using the :class:`RotatE` model, see `examples/kge_fb15k_237.py `_. Args: num_nodes (int): The number of nodes/entities in the graph. num_relations (int): The number of relations in the graph. hidden_channels (int): The hidden embedding size. margin (float, optional): The margin of the ranking loss. sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the embedding matrices will be sparse. (default: :obj:`False`) """ def __init__( self, num_nodes: int, num_relations: int, hidden_channels: int, margin: float = 1.0, sparse: bool = False, ): super().__init__(num_nodes, num_relations, hidden_channels, sparse) self.margin = margin self.node_emb_im = Embedding(num_nodes, hidden_channels, sparse=sparse) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.node_emb.weight) torch.nn.init.xavier_uniform_(self.node_emb_im.weight) torch.nn.init.uniform_(self.rel_emb.weight, 0, 2 * math.pi) def forward( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: head_re = self.node_emb(head_index) head_im = self.node_emb_im(head_index) tail_re = self.node_emb(tail_index) tail_im = self.node_emb_im(tail_index) rel_theta = self.rel_emb(rel_type) rel_re, rel_im = torch.cos(rel_theta), torch.sin(rel_theta) re_score = (rel_re * head_re - rel_im * head_im) - tail_re im_score = (rel_re * head_im + rel_im * head_re) - tail_im complex_score = torch.stack([re_score, im_score], dim=2) score = torch.linalg.vector_norm(complex_score, dim=(1, 2)) return self.margin - score def loss( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: pos_score = self(head_index, rel_type, tail_index) neg_score = self(*self.random_sample(head_index, rel_type, tail_index)) scores = torch.cat([pos_score, neg_score], dim=0) pos_target = torch.ones_like(pos_score) neg_target = torch.zeros_like(neg_score) target = torch.cat([pos_target, neg_target], dim=0) return F.binary_cross_entropy_with_logits(scores, target) ================================================ FILE: torch_geometric/nn/kge/transe.py ================================================ import math import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.kge import KGEModel class TransE(KGEModel): r"""The TransE model from the `"Translating Embeddings for Modeling Multi-Relational Data" `_ paper. :class:`TransE` models relations as a translation from head to tail entities such that .. math:: \mathbf{e}_h + \mathbf{e}_r \approx \mathbf{e}_t, resulting in the scoring function: .. math:: d(h, r, t) = - {\| \mathbf{e}_h + \mathbf{e}_r - \mathbf{e}_t \|}_p .. note:: For an example of using the :class:`TransE` model, see `examples/kge_fb15k_237.py `_. Args: num_nodes (int): The number of nodes/entities in the graph. num_relations (int): The number of relations in the graph. hidden_channels (int): The hidden embedding size. margin (int, optional): The margin of the ranking loss. (default: :obj:`1.0`) p_norm (int, optional): The order embedding and distance normalization. (default: :obj:`1.0`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the embedding matrices will be sparse. (default: :obj:`False`) """ def __init__( self, num_nodes: int, num_relations: int, hidden_channels: int, margin: float = 1.0, p_norm: float = 1.0, sparse: bool = False, ): super().__init__(num_nodes, num_relations, hidden_channels, sparse) self.p_norm = p_norm self.margin = margin self.reset_parameters() def reset_parameters(self): bound = 6. / math.sqrt(self.hidden_channels) torch.nn.init.uniform_(self.node_emb.weight, -bound, bound) torch.nn.init.uniform_(self.rel_emb.weight, -bound, bound) F.normalize(self.rel_emb.weight.data, p=self.p_norm, dim=-1, out=self.rel_emb.weight.data) def forward( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: head = self.node_emb(head_index) rel = self.rel_emb(rel_type) tail = self.node_emb(tail_index) head = F.normalize(head, p=self.p_norm, dim=-1) tail = F.normalize(tail, p=self.p_norm, dim=-1) # Calculate *negative* TransE norm: return -((head + rel) - tail).norm(p=self.p_norm, dim=-1) def loss( self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor, ) -> Tensor: pos_score = self(head_index, rel_type, tail_index) neg_score = self(*self.random_sample(head_index, rel_type, tail_index)) return F.margin_ranking_loss( pos_score, neg_score, target=torch.ones_like(pos_score), margin=self.margin, ) ================================================ FILE: torch_geometric/nn/lr_scheduler.py ================================================ # See HuggingFace `transformers/optimization.py`. import functools import math from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR class ConstantWithWarmupLR(LambdaLR): r"""Creates a LR scheduler with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between :obj:`0` and the initial LR set in the optimizer. Args: optimizer (Optimizer): The optimizer to be scheduled. num_warmup_steps (int): The number of steps for the warmup phase. last_epoch (int, optional): The index of the last epoch when resuming training. (default: :obj:`-1`) """ def __init__( self, optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1, ): lr_lambda = functools.partial( self._lr_lambda, num_warmup_steps=num_warmup_steps, ) super().__init__(optimizer, lr_lambda, last_epoch) @staticmethod def _lr_lambda( current_step: int, num_warmup_steps: int, ) -> float: if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 class LinearWithWarmupLR(LambdaLR): r"""Creates a LR scheduler with a learning rate that decreases linearly from the initial LR set in the optimizer to :obj:`0`, after a warmup period during which it increases linearly from :obj:`0` to the initial LR set in the optimizer. Args: optimizer (Optimizer): The optimizer to be scheduled. num_warmup_steps (int): The number of steps for the warmup phase. num_training_steps (int): The total number of training steps. last_epoch (int, optional): The index of the last epoch when resuming training. (default: :obj:`-1`) """ def __init__( self, optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1, ): lr_lambda = functools.partial( self._lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) super().__init__(optimizer, lr_lambda, last_epoch) @staticmethod def _lr_lambda( current_step: int, num_warmup_steps: int, num_training_steps: int, ) -> float: if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)), ) class CosineWithWarmupLR(LambdaLR): r"""Creates a LR scheduler with a learning rate that decreases following the values of the cosine function between the initial LR set in the optimizer to :obj:`0`, after a warmup period during which it increases linearly between :obj:`0` and the initial LR set in the optimizer. Args: optimizer (Optimizer): The optimizer to be scheduled. num_warmup_steps (int): The number of steps for the warmup phase. num_training_steps (int): The total number of training steps. num_cycles (float, optional): The number of waves in the cosine schedule (the default decreases LR from the max value to :obj:`0` following a half-cosine). (default: :obj:`0.5`) last_epoch (int, optional): The index of the last epoch when resuming training. (default: :obj:`-1`) """ def __init__( self, optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, ): lr_lambda = functools.partial( self._lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, ) super().__init__(optimizer, lr_lambda, last_epoch) @staticmethod def _lr_lambda( current_step: int, num_warmup_steps: int, num_training_steps: int, num_cycles: float, ): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps)) return max( 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), ) class CosineWithWarmupRestartsLR(LambdaLR): r"""Creates a LR scheduler with a learning rate that decreases following the values of the cosine function between the initial LR set in the optimizer to :obj:`0`, with several hard restarts, after a warmup period during which it increases linearly between :obj:`0` and the initial LR set in the optimizer. Args: optimizer (Optimizer): The optimizer to be scheduled. num_warmup_steps (int): The number of steps for the warmup phase. num_training_steps (int): The total number of training steps. num_cycles (int, optional): The number of hard restarts to use. (default: :obj:`3`) last_epoch (int, optional): The index of the last epoch when resuming training. (default: :obj:`-1`) """ def __init__( self, optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 3, last_epoch: int = -1, ): lr_lambda = functools.partial( self._lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, ) super().__init__(optimizer, lr_lambda, last_epoch) @staticmethod def _lr_lambda( current_step: int, num_warmup_steps: int, num_training_steps: int, num_cycles: int, ) -> float: if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max( 0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))), ) class PolynomialWithWarmupLR(LambdaLR): r"""Creates a LR scheduler with a learning rate that decreases as a polynomial decay from the initial LR set in the optimizer to end LR defined by `lr_end`, after a warmup period during which it increases linearly from :obj:`0` to the initial LR set in the optimizer. Args: optimizer (Optimizer): The optimizer to be scheduled. num_warmup_steps (int): The number of steps for the warmup phase. num_training_steps (int): The total number of training steps. lr_end (float, optional): The end learning rate. (default: :obj:`1e-7`) power (float, optional): The power factor of the polynomial decay. (default: :obj:`1.0`) last_epoch (int, optional): The index of the last epoch when resuming training. (default: :obj:`-1`) """ def __init__( self, optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, lr_end: float = 1e-7, power: float = 1.0, last_epoch: int = -1, ): lr_init = optimizer.defaults["lr"] if not (lr_init > lr_end): raise ValueError(f"`lr_end` ({lr_end}) must be smaller than the " f"initial lr ({lr_init})") lr_lambda = functools.partial( self._lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, lr_init=lr_init, lr_end=lr_end, power=power, ) super().__init__(optimizer, lr_lambda, last_epoch) @staticmethod def _lr_lambda( current_step: int, num_warmup_steps: int, num_training_steps: int, lr_init: float, lr_end: float, power: float, ) -> float: if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # As `LambdaLR` multiplies by `lr_init`. else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining**power + lr_end return decay / lr_init # As `LambdaLR` multiplies by `lr_init`. ================================================ FILE: torch_geometric/nn/model_hub.py ================================================ import os.path as osp from pathlib import Path from typing import Any, Dict, Optional, Union import torch from torch_geometric.io import fs try: from huggingface_hub import ModelHubMixin, hf_hub_download except ImportError: ModelHubMixin = object hf_hub_download = None CONFIG_NAME = 'config.json' MODEL_HUB_ORGANIZATION = "pytorch_geometric" MODEL_WEIGHTS_NAME = 'model.pth' TAGS = ['graph-machine-learning'] class PyGModelHubMixin(ModelHubMixin): r"""A mixin for saving and loading models to the `Huggingface Model Hub `_. .. code-block:: python from torch_geometric.datasets import Planetoid from torch_geometric.nn import Node2Vec from torch_geometric.nn.model_hub import PyGModelHubMixin # Define your class with the mixin: class N2V(Node2Vec, PyGModelHubMixin): def __init__(self,model_name, dataset_name, model_kwargs): Node2Vec.__init__(self,**model_kwargs) PyGModelHubMixin.__init__(self, model_name, dataset_name, model_kwargs) # Instantiate your model: n2v = N2V(model_name='node2vec', dataset_name='Cora', model_kwargs=dict( edge_index=data.edge_index, embedding_dim=128, walk_length=20, context_size=10, walks_per_node=10, num_negative_samples=1, p=1, q=1, sparse=True)) # Train the model: ... # Push to the HuggingFace hub: repo_id = ... # your repo id n2v.save_pretrained( local_file_path, push_to_hub=True, repo_id=repo_id, ) # Load the model for inference: # The required arguments are the repo id/local folder, and any model # initialisation arguments that are not native python types (e.g # Node2Vec requires the edge_index argument which is not stored in the # model hub). model = N2V.from_pretrained( repo_id, model_name='node2vec', dataset_name='Cora', edge_index=data.edge_index, ) Args: model_name (str): Name of the model. dataset_name (str): Name of the dataset the model was trained against. model_kwargs (Dict[str, Any]): The arguments to initialise the model. """ def __init__(self, model_name: str, dataset_name: str, model_kwargs: Dict): ModelHubMixin.__init__(self) # Huggingface Hub API only accepts saving the config as a dict. # If the model is instantiated with non-native python types # such as torch Tensors (node2vec being an example), we have to remove # these as they are not json serialisable self.model_config = { k: v for k, v in model_kwargs.items() if type(v) in [str, int, float] } self.model_name = model_name self.dataset_name = dataset_name def construct_model_card(self, model_name: str, dataset_name: str) -> Any: from huggingface_hub import ModelCard, ModelCardData card_data = ModelCardData( language='en', license='mit', library_name=MODEL_HUB_ORGANIZATION, tags=TAGS, datasets=dataset_name, model_name=model_name, ) card = ModelCard.from_template(card_data) return card def _save_pretrained(self, save_directory: Union[Path, str]): path = osp.join(save_directory, MODEL_WEIGHTS_NAME) model_to_save = self.module if hasattr(self, 'module') else self torch.save(model_to_save.state_dict(), path) def save_pretrained(self, save_directory: Union[str, Path], push_to_hub: bool = False, repo_id: Optional[str] = None, **kwargs): r"""Save a trained model to a local directory or to the HuggingFace model hub. Args: save_directory (str): The directory where weights are saved. push_to_hub (bool, optional): If :obj:`True`, push the model to the HuggingFace model hub. (default: :obj:`False`) repo_id (str, optional): The repository name in the hub. If not provided will default to the name of :obj:`save_directory` in your namespace. (default: :obj:`None`) **kwargs: Additional keyword arguments passed to :meth:`huggingface_hub.ModelHubMixin.save_pretrained`. """ config = self.model_config # due to way huggingface hub handles the loading/saving of models, # the model config can end up in one of the items in the kwargs # this has to be removed to prevent a duplication of arguments to # ModelHubMixin.save_pretrained kwargs.pop('config', None) super().save_pretrained( save_directory=save_directory, config=config, push_to_hub=push_to_hub, repo_id=repo_id, **kwargs, ) model_card = self.construct_model_card(self.model_name, self.dataset_name) if push_to_hub: model_card.push_to_hub(repo_id) @classmethod def _from_pretrained( cls, model_id, revision, cache_dir, force_download, local_files_only, token, proxies=None, resume_download=False, dataset_name='', model_name='', map_location='cpu', strict=False, **model_kwargs, ): map_location = torch.device(map_location) if osp.isdir(model_id): model_file = osp.join(model_id, MODEL_WEIGHTS_NAME) else: model_file = hf_hub_download( repo_id=model_id, filename=MODEL_WEIGHTS_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, token=token, local_files_only=local_files_only, ) config = model_kwargs.pop('config', None) if config is not None: model_kwargs = {**model_kwargs, **config} model = cls(dataset_name, model_name, model_kwargs) state_dict = fs.torch_load(model_file, map_location=map_location) model.load_state_dict(state_dict, strict=strict) model.eval() return model @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, force_download: bool = False, token: Optional[Union[str, bool]] = None, cache_dir: Optional[str] = None, local_files_only: bool = False, **model_kwargs, ) -> Any: r"""Downloads and instantiates a model from the HuggingFace hub. Args: pretrained_model_name_or_path (str): Can be either: - The :obj:`model_id` of a pretrained model hosted inside the HuggingFace hub. - You can add a :obj:`revision` by appending :obj:`@` at the end of :obj:`model_id` to load a specific model version. - A path to a directory containing the saved model weights. - :obj:`None` if you are both providing the configuration :obj:`config` and state dictionary :obj:`state_dict`. force_download (bool, optional): Whether to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. (default: :obj:`False`) token (str or bool, optional): The token to use as HTTP bearer authorization for remote files. If set to :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). It is **required** if you want to use a private model. (default: :obj:`None`) cache_dir (str, optional): The path to a directory in which a downloaded model configuration should be cached if the standard cache should not be used. (default: :obj:`None`) local_files_only (bool, optional): Whether to only look at local files, *i.e.* do not try to download the model. (default: :obj:`False`) **model_kwargs: Additional keyword arguments passed to the model during initialization. """ return super().from_pretrained( pretrained_model_name_or_path, force_download=force_download, use_auth_token=token, cache_dir=cache_dir, local_files_only=local_files_only, **model_kwargs, ) ================================================ FILE: torch_geometric/nn/models/__init__.py ================================================ r"""Model package.""" from .mlp import MLP from .basic_gnn import GCN, GraphSAGE, GIN, GAT, PNA, EdgeCNN from .jumping_knowledge import JumpingKnowledge, HeteroJumpingKnowledge from .meta import MetaLayer from .node2vec import Node2Vec from .deep_graph_infomax import DeepGraphInfomax from .autoencoder import InnerProductDecoder, GAE, VGAE, ARGA, ARGVA from .signed_gcn import SignedGCN from .re_net import RENet from .graph_unet import GraphUNet from .schnet import SchNet from .dimenet import DimeNet, DimeNetPlusPlus from .gpse import GPSE, GPSENodeEncoder from .captum import to_captum_model from .metapath2vec import MetaPath2Vec from .deepgcn import DeepGCNLayer from .tgn import TGNMemory from .label_prop import LabelPropagation from .correct_and_smooth import CorrectAndSmooth from .attentive_fp import AttentiveFP from .rect import RECT_L from .linkx import LINKX from .lightgcn import LightGCN from .mask_label import MaskLabel from .rev_gnn import GroupAddRev from .gnnff import GNNFF from .pmlp import PMLP from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .lpformer import LPFormer from .sgformer import SGFormer from .polynormer import Polynormer # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) from .attract_repel import ARLinkPredictor __all__ = classes = [ 'MLP', 'GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN', 'JumpingKnowledge', 'HeteroJumpingKnowledge', 'MetaLayer', 'Node2Vec', 'DeepGraphInfomax', 'InnerProductDecoder', 'GAE', 'VGAE', 'ARGA', 'ARGVA', 'SignedGCN', 'RENet', 'GraphUNet', 'SchNet', 'DimeNet', 'DimeNetPlusPlus', 'GPSE', 'GPSENodeEncoder', 'to_captum_model', 'to_captum_input', 'captum_output_to_dicts', 'MetaPath2Vec', 'DeepGCNLayer', 'TGNMemory', 'LabelPropagation', 'CorrectAndSmooth', 'AttentiveFP', 'RECT_L', 'LINKX', 'LightGCN', 'MaskLabel', 'GroupAddRev', 'GNNFF', 'PMLP', 'NeuralFingerprint', 'ViSNet', 'LPFormer', 'SGFormer', 'Polynormer', 'ARLinkPredictor', ] ================================================ FILE: torch_geometric/nn/models/attentive_fp.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import GRUCell, Linear, Parameter from torch_geometric.nn import GATConv, MessagePassing, global_add_pool from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import softmax class GATEConv(MessagePassing): def __init__( self, in_channels: int, out_channels: int, edge_dim: int, dropout: float = 0.0, ): super().__init__(aggr='add', node_dim=0) self.dropout = dropout self.att_l = Parameter(torch.empty(1, out_channels)) self.att_r = Parameter(torch.empty(1, in_channels)) self.lin1 = Linear(in_channels + edge_dim, out_channels, False) self.lin2 = Linear(out_channels, out_channels, False) self.bias = Parameter(torch.empty(out_channels)) self.reset_parameters() def reset_parameters(self): glorot(self.att_l) glorot(self.att_r) glorot(self.lin1.weight) glorot(self.lin2.weight) zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor: # edge_updater_type: (x: Tensor, edge_attr: Tensor) alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr) # propagate_type: (x: Tensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha) out = out + self.bias return out def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1))) alpha_j = (x_j @ self.att_l.t()).squeeze(-1) alpha_i = (x_i @ self.att_r.t()).squeeze(-1) alpha = alpha_j + alpha_i alpha = F.leaky_relu_(alpha) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: return self.lin2(x_j) * alpha.unsqueeze(-1) class AttentiveFP(torch.nn.Module): r"""The Attentive FP model for molecular representation learning from the `"Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism" `_ paper, based on graph attention mechanisms. Args: in_channels (int): Size of each input sample. hidden_channels (int): Hidden node feature dimensionality. out_channels (int): Size of each output sample. edge_dim (int): Edge feature dimensionality. num_layers (int): Number of GNN layers. num_timesteps (int): Number of iterative refinement steps for global readout. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, edge_dim: int, num_layers: int, num_timesteps: int, dropout: float = 0.0, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.edge_dim = edge_dim self.num_layers = num_layers self.num_timesteps = num_timesteps self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) self.gate_conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout) self.gru = GRUCell(hidden_channels, hidden_channels) self.atom_convs = torch.nn.ModuleList() self.atom_grus = torch.nn.ModuleList() for _ in range(num_layers - 1): conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(hidden_channels, hidden_channels)) self.mol_conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.mol_conv.explain = False # Cannot explain global pooling. self.mol_gru = GRUCell(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin1.reset_parameters() self.gate_conv.reset_parameters() self.gru.reset_parameters() for conv, gru in zip(self.atom_convs, self.atom_grus): conv.reset_parameters() gru.reset_parameters() self.mol_conv.reset_parameters() self.mol_gru.reset_parameters() self.lin2.reset_parameters() def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, batch: Tensor) -> Tensor: """""" # noqa: D419 # Atom Embedding: x = F.leaky_relu_(self.lin1(x)) h = F.elu_(self.gate_conv(x, edge_index, edge_attr)) h = F.dropout(h, p=self.dropout, training=self.training) x = self.gru(h, x).relu_() for conv, gru in zip(self.atom_convs, self.atom_grus): h = conv(x, edge_index) h = F.elu(h) h = F.dropout(h, p=self.dropout, training=self.training) x = gru(h, x).relu() # Molecule Embedding: row = torch.arange(batch.size(0), device=batch.device) edge_index = torch.stack([row, batch], dim=0) out = global_add_pool(x, batch).relu_() for _ in range(self.num_timesteps): h = F.elu_(self.mol_conv((x, out), edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) out = self.mol_gru(h, out).relu_() # Predictor: out = F.dropout(out, p=self.dropout, training=self.training) return self.lin2(out) def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'in_channels={self.in_channels}, ' f'hidden_channels={self.hidden_channels}, ' f'out_channels={self.out_channels}, ' f'edge_dim={self.edge_dim}, ' f'num_layers={self.num_layers}, ' f'num_timesteps={self.num_timesteps}' f')') ================================================ FILE: torch_geometric/nn/models/attract_repel.py ================================================ import torch import torch.nn.functional as F class ARLinkPredictor(torch.nn.Module): r"""Link predictor using Attract-Repel embeddings from the paper `"Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs" `_. This model splits node embeddings into: attract and repel. The edge prediction score is computed as the dot product of attract components minus the dot product of repel components. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of hidden embeddings. out_channels (int, optional): Size of output embeddings. If set to :obj:`None`, will default to :obj:`hidden_channels`. (default: :obj:`None`) num_layers (int): Number of message passing layers. (default: :obj:`2`) dropout (float): Dropout probability. (default: :obj:`0.0`) attract_ratio (float): Ratio to use for attract component. Must be between 0 and 1. (default: :obj:`0.5`) """ def __init__(self, in_channels, hidden_channels, out_channels=None, num_layers=2, dropout=0.0, attract_ratio=0.5): super().__init__() if out_channels is None: out_channels = hidden_channels self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.num_layers = num_layers self.dropout = dropout if not 0 <= attract_ratio <= 1: raise ValueError( f"attract_ratio must be between 0 and 1, got {attract_ratio}") self.attract_ratio = attract_ratio self.attract_dim = int(out_channels * attract_ratio) self.repel_dim = out_channels - self.attract_dim # Create model layers self.lins = torch.nn.ModuleList() self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) for _ in range(num_layers - 2): self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) # Final layer splits into attract and repel components self.lin_attract = torch.nn.Linear(hidden_channels, self.attract_dim) self.lin_repel = torch.nn.Linear(hidden_channels, self.repel_dim) self.reset_parameters() def reset_parameters(self): """Reset all learnable parameters.""" for lin in self.lins: lin.reset_parameters() self.lin_attract.reset_parameters() self.lin_repel.reset_parameters() def encode(self, x, *args, **kwargs): """Encode node features into attract-repel embeddings. Args: x (torch.Tensor): Node feature matrix of shape :obj:`[num_nodes, in_channels]`. *args: Variable length argument list **kwargs: Arbitrary keyword arguments """ for lin in self.lins: x = lin(x) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) # Split into attract and repel components attract_x = self.lin_attract(x) repel_x = self.lin_repel(x) return attract_x, repel_x def decode(self, attract_z, repel_z, edge_index): """Decode edge scores from attract-repel embeddings. Args: attract_z (torch.Tensor): Attract embeddings of shape :obj:`[num_nodes, attract_dim]`. repel_z (torch.Tensor): Repel embeddings of shape :obj:`[num_nodes, repel_dim]`. edge_index (torch.Tensor): Edge indices of shape :obj:`[2, num_edges]`. Returns: torch.Tensor: Edge prediction scores. """ # Get node embeddings for edges row, col = edge_index attract_z_row = attract_z[row] attract_z_col = attract_z[col] repel_z_row = repel_z[row] repel_z_col = repel_z[col] # Compute attract-repel scores attract_score = torch.sum(attract_z_row * attract_z_col, dim=1) repel_score = torch.sum(repel_z_row * repel_z_col, dim=1) return attract_score - repel_score def forward(self, x, edge_index): """Forward pass for link prediction. Args: x (torch.Tensor): Node feature matrix. edge_index (torch.Tensor): Edge indices to predict. Returns: torch.Tensor: Predicted edge scores. """ # Encode nodes into attract-repel embeddings attract_z, repel_z = self.encode(x) # Decode target edges return torch.sigmoid(self.decode(attract_z, repel_z, edge_index)) def calculate_r_fraction(self, attract_z, repel_z): """Calculate the R-fraction (proportion of energy in repel space). Args: attract_z (torch.Tensor): Attract embeddings. repel_z (torch.Tensor): Repel embeddings. Returns: float: R-fraction value. """ attract_norm_squared = torch.sum(attract_z**2) repel_norm_squared = torch.sum(repel_z**2) r_fraction = repel_norm_squared / (attract_norm_squared + repel_norm_squared + 1e-10) return r_fraction.item() ================================================ FILE: torch_geometric/nn/models/autoencoder.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch.nn import Module from torch_geometric.nn.inits import reset from torch_geometric.utils import negative_sampling EPS = 1e-15 MAX_LOGSTD = 10 class InnerProductDecoder(torch.nn.Module): r"""The inner product decoder from the `"Variational Graph Auto-Encoders" `_ paper. .. math:: \sigma(\mathbf{Z}\mathbf{Z}^{\top}) where :math:`\mathbf{Z} \in \mathbb{R}^{N \times d}` denotes the latent space produced by the encoder. """ def forward( self, z: Tensor, edge_index: Tensor, sigmoid: bool = True, ) -> Tensor: r"""Decodes the latent variables :obj:`z` into edge probabilities for the given node-pairs :obj:`edge_index`. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. edge_index (torch.Tensor): The edge indices. sigmoid (bool, optional): If set to :obj:`False`, does not apply the logistic sigmoid function to the output. (default: :obj:`True`) """ value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1) return torch.sigmoid(value) if sigmoid else value def forward_all(self, z: Tensor, sigmoid: bool = True) -> Tensor: r"""Decodes the latent variables :obj:`z` into a probabilistic dense adjacency matrix. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. sigmoid (bool, optional): If set to :obj:`False`, does not apply the logistic sigmoid function to the output. (default: :obj:`True`) """ adj = torch.matmul(z, z.t()) return torch.sigmoid(adj) if sigmoid else adj class GAE(torch.nn.Module): r"""The Graph Auto-Encoder model from the `"Variational Graph Auto-Encoders" `_ paper based on user-defined encoder and decoder models. Args: encoder (torch.nn.Module): The encoder module. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__(self, encoder: Module, decoder: Optional[Module] = None): super().__init__() self.encoder = encoder self.decoder = InnerProductDecoder() if decoder is None else decoder GAE.reset_parameters(self) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder) reset(self.decoder) def forward(self, *args, **kwargs) -> Tensor: # pragma: no cover r"""Alias for :meth:`encode`.""" return self.encoder(*args, **kwargs) def encode(self, *args, **kwargs) -> Tensor: r"""Runs the encoder and computes node-wise latent variables.""" return self.encoder(*args, **kwargs) def decode(self, *args, **kwargs) -> Tensor: r"""Runs the decoder and computes edge probabilities.""" return self.decoder(*args, **kwargs) def recon_loss(self, z: Tensor, pos_edge_index: Tensor, neg_edge_index: Optional[Tensor] = None) -> Tensor: r"""Given latent variables :obj:`z`, computes the binary cross entropy loss for positive edges :obj:`pos_edge_index` and negative sampled edges. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. pos_edge_index (torch.Tensor): The positive edges to train against. neg_edge_index (torch.Tensor, optional): The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: :obj:`None`) """ pos_loss = -torch.log( self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean() if neg_edge_index is None: neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + EPS).mean() return pos_loss + neg_loss def test(self, z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) -> Tuple[Tensor, Tensor]: r"""Given latent variables :obj:`z`, positive edges :obj:`pos_edge_index` and negative edges :obj:`neg_edge_index`, computes area under the ROC curve (AUC) and average precision (AP) scores. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. pos_edge_index (torch.Tensor): The positive edges to evaluate against. neg_edge_index (torch.Tensor): The negative edges to evaluate against. """ from sklearn.metrics import average_precision_score, roc_auc_score pos_y = z.new_ones(pos_edge_index.size(1)) neg_y = z.new_zeros(neg_edge_index.size(1)) y = torch.cat([pos_y, neg_y], dim=0) pos_pred = self.decoder(z, pos_edge_index, sigmoid=True) neg_pred = self.decoder(z, neg_edge_index, sigmoid=True) pred = torch.cat([pos_pred, neg_pred], dim=0) y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy() return roc_auc_score(y, pred), average_precision_score(y, pred) class VGAE(GAE): r"""The Variational Graph Auto-Encoder model from the `"Variational Graph Auto-Encoders" `_ paper. Args: encoder (torch.nn.Module): The encoder module to compute :math:`\mu` and :math:`\log\sigma^2`. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__(self, encoder: Module, decoder: Optional[Module] = None): super().__init__(encoder, decoder) def reparametrize(self, mu: Tensor, logstd: Tensor) -> Tensor: if self.training: return mu + torch.randn_like(logstd) * torch.exp(logstd) else: return mu def encode(self, *args, **kwargs) -> Tensor: """""" # noqa: D419 self.__mu__, self.__logstd__ = self.encoder(*args, **kwargs) self.__logstd__ = self.__logstd__.clamp(max=MAX_LOGSTD) z = self.reparametrize(self.__mu__, self.__logstd__) return z def kl_loss(self, mu: Optional[Tensor] = None, logstd: Optional[Tensor] = None) -> Tensor: r"""Computes the KL loss, either for the passed arguments :obj:`mu` and :obj:`logstd`, or based on latent variables from last encoding. Args: mu (torch.Tensor, optional): The latent space for :math:`\mu`. If set to :obj:`None`, uses the last computation of :math:`\mu`. (default: :obj:`None`) logstd (torch.Tensor, optional): The latent space for :math:`\log\sigma`. If set to :obj:`None`, uses the last computation of :math:`\log\sigma^2`. (default: :obj:`None`) """ mu = self.__mu__ if mu is None else mu logstd = self.__logstd__ if logstd is None else logstd.clamp( max=MAX_LOGSTD) return -0.5 * torch.mean( torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1)) class ARGA(GAE): r"""The Adversarially Regularized Graph Auto-Encoder model from the `"Adversarially Regularized Graph Autoencoder for Graph Embedding" `_ paper. Args: encoder (torch.nn.Module): The encoder module. discriminator (torch.nn.Module): The discriminator module. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__( self, encoder: Module, discriminator: Module, decoder: Optional[Module] = None, ): super().__init__(encoder, decoder) self.discriminator = discriminator reset(self.discriminator) def reset_parameters(self): super().reset_parameters() reset(self.discriminator) def reg_loss(self, z: Tensor) -> Tensor: r"""Computes the regularization loss of the encoder. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. """ real = torch.sigmoid(self.discriminator(z)) real_loss = -torch.log(real + EPS).mean() return real_loss def discriminator_loss(self, z: Tensor) -> Tensor: r"""Computes the loss of the discriminator. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. """ real = torch.sigmoid(self.discriminator(torch.randn_like(z))) fake = torch.sigmoid(self.discriminator(z.detach())) real_loss = -torch.log(real + EPS).mean() fake_loss = -torch.log(1 - fake + EPS).mean() return real_loss + fake_loss class ARGVA(ARGA): r"""The Adversarially Regularized Variational Graph Auto-Encoder model from the `"Adversarially Regularized Graph Autoencoder for Graph Embedding" `_ paper. Args: encoder (torch.nn.Module): The encoder module to compute :math:`\mu` and :math:`\log\sigma^2`. discriminator (torch.nn.Module): The discriminator module. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__( self, encoder: Module, discriminator: Module, decoder: Optional[Module] = None, ): super().__init__(encoder, discriminator, decoder) self.VGAE = VGAE(encoder, decoder) @property def __mu__(self) -> Tensor: return self.VGAE.__mu__ @property def __logstd__(self) -> Tensor: return self.VGAE.__logstd__ def reparametrize(self, mu: Tensor, logstd: Tensor) -> Tensor: return self.VGAE.reparametrize(mu, logstd) def encode(self, *args, **kwargs) -> Tensor: """""" # noqa: D419 return self.VGAE.encode(*args, **kwargs) def kl_loss( self, mu: Optional[Tensor] = None, logstd: Optional[Tensor] = None, ) -> Tensor: return self.VGAE.kl_loss(mu, logstd) ================================================ FILE: torch_geometric/nn/models/basic_gnn.py ================================================ import copy import inspect from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Linear, ModuleList from tqdm import tqdm from torch_geometric.data import Data from torch_geometric.loader import CachedLoader, NeighborLoader from torch_geometric.nn.conv import ( EdgeConv, GATConv, GATv2Conv, GCNConv, GINConv, MessagePassing, PNAConv, SAGEConv, ) from torch_geometric.nn.models import MLP from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge from torch_geometric.nn.resolver import ( activation_resolver, normalization_resolver, ) from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils._trim_to_layer import TrimToLayer class BasicGNN(torch.nn.Module): r"""An abstract class for implementing basic GNN models. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of the underlying :class:`torch_geometric.nn.conv.MessagePassing` layers. """ supports_edge_weight: Final[bool] supports_edge_attr: Final[bool] supports_norm_batch: Final[bool] def __init__( self, in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Union[str, Callable, None] = "relu", act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Union[str, Callable, None] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.num_layers = num_layers self.dropout = torch.nn.Dropout(p=dropout) self.act = activation_resolver(act, **(act_kwargs or {})) self.jk_mode = jk self.act_first = act_first self.norm = norm if isinstance(norm, str) else None self.norm_kwargs = norm_kwargs if out_channels is not None: self.out_channels = out_channels else: self.out_channels = hidden_channels self.convs = ModuleList() if num_layers > 1: self.convs.append( self.init_conv(in_channels, hidden_channels, **kwargs)) if isinstance(in_channels, (tuple, list)): in_channels = (hidden_channels, hidden_channels) else: in_channels = hidden_channels for _ in range(num_layers - 2): self.convs.append( self.init_conv(in_channels, hidden_channels, **kwargs)) if isinstance(in_channels, (tuple, list)): in_channels = (hidden_channels, hidden_channels) else: in_channels = hidden_channels if out_channels is not None and jk is None: self._is_conv_to_out = True self.convs.append( self.init_conv(in_channels, out_channels, **kwargs)) else: self.convs.append( self.init_conv(in_channels, hidden_channels, **kwargs)) self.norms = ModuleList() norm_layer = normalization_resolver( norm, hidden_channels, **(norm_kwargs or {}), ) if norm_layer is None: norm_layer = torch.nn.Identity() self.supports_norm_batch = False if hasattr(norm_layer, 'forward'): norm_params = inspect.signature(norm_layer.forward).parameters self.supports_norm_batch = 'batch' in norm_params for _ in range(num_layers - 1): self.norms.append(copy.deepcopy(norm_layer)) if jk is not None: self.norms.append(copy.deepcopy(norm_layer)) else: self.norms.append(torch.nn.Identity()) if jk is not None and jk != 'last': self.jk = JumpingKnowledge(jk, hidden_channels, num_layers) if jk is not None: if jk == 'cat': in_channels = num_layers * hidden_channels else: in_channels = hidden_channels self.lin = Linear(in_channels, self.out_channels) # We define `trim_to_layer` functionality as a module such that we can # still use `to_hetero` on-top. self._trim = TrimToLayer() def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: raise NotImplementedError def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.convs: conv.reset_parameters() for norm in self.norms: if hasattr(norm, 'reset_parameters'): norm.reset_parameters() if hasattr(self, 'jk'): self.jk.reset_parameters() if hasattr(self, 'lin'): self.lin.reset_parameters() def forward( self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, edge_attr: OptTensor = None, batch: OptTensor = None, batch_size: Optional[int] = None, num_sampled_nodes_per_hop: Optional[List[int]] = None, num_sampled_edges_per_hop: Optional[List[int]] = None, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_weight (torch.Tensor, optional): The edge weights (if supported by the underlying GNN layer). (default: :obj:`None`) edge_attr (torch.Tensor, optional): The edge features (if supported by the underlying GNN layer). (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) num_sampled_nodes_per_hop (List[int], optional): The number of sampled nodes per hop. Useful in :class:`~torch_geometric.loader.NeighborLoader` scenarios to only operate on minimal-sized representations. (default: :obj:`None`) num_sampled_edges_per_hop (List[int], optional): The number of sampled edges per hop. Useful in :class:`~torch_geometric.loader.NeighborLoader` scenarios to only operate on minimal-sized representations. (default: :obj:`None`) """ if (num_sampled_nodes_per_hop is not None and isinstance(edge_weight, Tensor) and isinstance(edge_attr, Tensor)): raise NotImplementedError("'trim_to_layer' functionality does not " "yet support trimming of both " "'edge_weight' and 'edge_attr'") xs: List[Tensor] = [] assert len(self.convs) == len(self.norms) for i, (conv, norm) in enumerate(zip(self.convs, self.norms)): if (not torch.jit.is_scripting() and num_sampled_nodes_per_hop is not None): x, edge_index, value = self._trim( i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop, x, edge_index, edge_weight if edge_weight is not None else edge_attr, ) if edge_weight is not None: edge_weight = value else: edge_attr = value # Tracing the module is not allowed with *args and **kwargs :( # As such, we rely on a static solution to pass optional edge # weights and edge attributes to the module. if self.supports_edge_weight and self.supports_edge_attr: x = conv(x, edge_index, edge_weight=edge_weight, edge_attr=edge_attr) elif self.supports_edge_weight: x = conv(x, edge_index, edge_weight=edge_weight) elif self.supports_edge_attr: x = conv(x, edge_index, edge_attr=edge_attr) else: x = conv(x, edge_index) if i < self.num_layers - 1 or self.jk_mode is not None: if self.act is not None and self.act_first: x = self.act(x) if self.supports_norm_batch: x = norm(x, batch, batch_size) else: x = norm(x) if self.act is not None and not self.act_first: x = self.act(x) x = self.dropout(x) if hasattr(self, 'jk'): xs.append(x) x = self.jk(xs) if hasattr(self, 'jk') else x x = self.lin(x) if hasattr(self, 'lin') else x return x @torch.no_grad() def inference_per_layer( self, layer: int, x: Tensor, edge_index: Adj, batch_size: int, ) -> Tensor: x = self.convs[layer](x, edge_index)[:batch_size] if layer == self.num_layers - 1 and self.jk_mode is None: return x if self.act is not None and self.act_first: x = self.act(x) if self.norms is not None: x = self.norms[layer](x) if self.act is not None and not self.act_first: x = self.act(x) if layer == self.num_layers - 1 and hasattr(self, 'lin'): x = self.lin(x) return x @torch.no_grad() def inference( self, loader: NeighborLoader, device: Optional[Union[str, torch.device]] = None, embedding_device: Union[str, torch.device] = 'cpu', progress_bar: bool = False, cache: bool = False, ) -> Tensor: r"""Performs layer-wise inference on large-graphs using a :class:`~torch_geometric.loader.NeighborLoader`, where :class:`~torch_geometric.loader.NeighborLoader` should sample the full neighborhood for only one layer. This is an efficient way to compute the output embeddings for all nodes in the graph. Only applicable in case :obj:`jk=None` or `jk='last'`. Args: loader (torch_geometric.loader.NeighborLoader): A neighbor loader object that generates full 1-hop subgraphs, *i.e.*, :obj:`loader.num_neighbors = [-1]`. device (torch.device, optional): The device to run the GNN on. (default: :obj:`None`) embedding_device (torch.device, optional): The device to store intermediate embeddings on. If intermediate embeddings fit on GPU, this option helps to avoid unnecessary device transfers. (default: :obj:`"cpu"`) progress_bar (bool, optional): If set to :obj:`True`, will print a progress bar during computation. (default: :obj:`False`) cache (bool, optional): If set to :obj:`True`, caches intermediate sampler outputs for usage in later epochs. This will avoid repeated sampling to accelerate inference. (default: :obj:`False`) """ assert self.jk_mode is None or self.jk_mode == 'last' assert isinstance(loader, NeighborLoader) assert len(loader.dataset) == loader.data.num_nodes assert len(loader.node_sampler.num_neighbors) == 1 assert not self.training # assert not loader.shuffle # TODO (matthias) does not work :( if progress_bar: pbar = tqdm(total=len(self.convs) * len(loader)) pbar.set_description('Inference') x_all = loader.data.x.to(embedding_device) if cache: # Only cache necessary attributes: def transform(data: Data) -> Data: kwargs = dict(n_id=data.n_id, batch_size=data.batch_size) if hasattr(data, 'adj_t'): kwargs['adj_t'] = data.adj_t else: kwargs['edge_index'] = data.edge_index return Data.from_dict(kwargs) loader = CachedLoader(loader, device=device, transform=transform) for i in range(self.num_layers): xs: List[Tensor] = [] for batch in loader: x = x_all[batch.n_id].to(device) batch_size = batch.batch_size if hasattr(batch, 'adj_t'): edge_index = batch.adj_t.to(device) else: edge_index = batch.edge_index.to(device) x = self.inference_per_layer(i, x, edge_index, batch_size) xs.append(x.to(embedding_device)) if progress_bar: pbar.update(1) x_all = torch.cat(xs, dim=0) if progress_bar: pbar.close() return x_all def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_layers={self.num_layers})') class GCN(BasicGNN): r"""The Graph Neural Network from the `"Semi-supervised Classification with Graph Convolutional Networks" `_ paper, using the :class:`~torch_geometric.nn.conv.GCNConv` operator for message passing. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality, while default will not. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GCNConv`. """ supports_edge_weight: Final[bool] = True supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return GCNConv(in_channels, out_channels, **kwargs) class GraphSAGE(BasicGNN): r"""The Graph Neural Network from the `"Inductive Representation Learning on Large Graphs" `_ paper, using the :class:`~torch_geometric.nn.SAGEConv` operator for message passing. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.SAGEConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: return SAGEConv(in_channels, out_channels, **kwargs) class GIN(BasicGNN): r"""The Graph Neural Network from the `"How Powerful are Graph Neural Networks?" `_ paper, using the :class:`~torch_geometric.nn.GINConv` operator for message passing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GINConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return GINConv(mlp, **kwargs) class GAT(BasicGNN): r"""The Graph Neural Network from `"Graph Attention Networks" `_ or `"How Attentive are Graph Attention Networks?" `_ papers, using the :class:`~torch_geometric.nn.GATConv` or :class:`~torch_geometric.nn.GATv2Conv` operator for message passing, respectively. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) v2 (bool, optional): If set to :obj:`True`, will make use of :class:`~torch_geometric.nn.conv.GATv2Conv` rather than :class:`~torch_geometric.nn.conv.GATConv`. (default: :obj:`False`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GATConv` or :class:`torch_geometric.nn.conv.GATv2Conv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = True supports_norm_batch: Final[bool] def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: v2 = kwargs.pop('v2', False) heads = kwargs.pop('heads', 1) concat = kwargs.pop('concat', True) # Do not use concatenation in case the layer `GATConv` layer maps to # the desired output channels (out_channels != None and jk != None): if getattr(self, '_is_conv_to_out', False): concat = False if concat and out_channels % heads != 0: raise ValueError(f"Ensure that the number of output channels of " f"'GATConv' (got '{out_channels}') is divisible " f"by the number of heads (got '{heads}')") if concat: out_channels = out_channels // heads Conv = GATConv if not v2 else GATv2Conv return Conv(in_channels, out_channels, heads=heads, concat=concat, dropout=self.dropout.p, **kwargs) class PNA(BasicGNN): r"""The Graph Neural Network from the `"Principal Neighbourhood Aggregation for Graph Nets" `_ paper, using the :class:`~torch_geometric.nn.conv.PNAConv` operator for message passing. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.PNAConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = True supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return PNAConv(in_channels, out_channels, **kwargs) class EdgeCNN(BasicGNN): r"""The Graph Neural Network from the `"Dynamic Graph CNN for Learning on Point Clouds" `_ paper, using the :class:`~torch_geometric.nn.conv.EdgeConv` operator for message passing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.EdgeConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [2 * in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return EdgeConv(mlp, **kwargs) __all__ = [ 'GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN', ] ================================================ FILE: torch_geometric/nn/models/captum.py ================================================ from typing import Optional, Union import torch from torch_geometric.explain.algorithm.captum import ( CaptumHeteroModel, CaptumModel, MaskLevelType, ) from torch_geometric.typing import Metadata def to_captum_model( model: torch.nn.Module, mask_type: Union[str, MaskLevelType] = MaskLevelType.edge, output_idx: Optional[int] = None, metadata: Optional[Metadata] = None, ) -> Union[CaptumModel, CaptumHeteroModel]: r"""Converts a model to a model that can be used for `Captum `_ attribution methods. Sample code for homogeneous graphs: .. code-block:: python from captum.attr import IntegratedGradients from torch_geometric.data import Data from torch_geometric.nn import GCN from torch_geometric.nn import to_captum_model, to_captum_input data = Data(x=(...), edge_index(...)) model = GCN(...) ... # Train the model. # Explain predictions for node `10`: mask_type="edge" output_idx = 10 captum_model = to_captum_model(model, mask_type, output_idx) inputs, additional_forward_args = to_captum_input(data.x, data.edge_index,mask_type) ig = IntegratedGradients(captum_model) ig_attr = ig.attribute(inputs = inputs, target=int(y[output_idx]), additional_forward_args=additional_forward_args, internal_batch_size=1) Sample code for heterogeneous graphs: .. code-block:: python from captum.attr import IntegratedGradients from torch_geometric.data import HeteroData from torch_geometric.nn import HeteroConv from torch_geometric.nn import (captum_output_to_dicts, to_captum_model, to_captum_input) data = HeteroData(...) model = HeteroConv(...) ... # Train the model. # Explain predictions for node `10`: mask_type="edge" metadata = data.metadata output_idx = 10 captum_model = to_captum_model(model, mask_type, output_idx, metadata) inputs, additional_forward_args = to_captum_input(data.x_dict, data.edge_index_dict, mask_type) ig = IntegratedGradients(captum_model) ig_attr = ig.attribute(inputs=inputs, target=int(y[output_idx]), additional_forward_args=additional_forward_args, internal_batch_size=1) edge_attr_dict = captum_output_to_dicts(ig_attr, mask_type, metadata) .. note:: For an example of using a :captum:`Captum` attribution method within :pyg:`PyG`, see `examples/explain/captum_explainer.py `_. Args: model (torch.nn.Module): The model to be explained. mask_type (str, optional): Denotes the type of mask to be created with a :captum:`Captum` explainer. Valid inputs are :obj:`"edge"`, :obj:`"node"`, and :obj:`"node_and_edge"`. (default: :obj:`"edge"`) output_idx (int, optional): Index of the output element (node or link index) to be explained. With :obj:`output_idx` set, the forward function will return the output of the model for the element at the index specified. (default: :obj:`None`) metadata (Metadata, optional): The metadata of the heterogeneous graph. Only required if explaining a :class:`~torch_geometric.data.HeteroData` object. (default: :obj:`None`) """ if metadata is None: return CaptumModel(model, mask_type, output_idx) else: return CaptumHeteroModel(model, mask_type, output_idx, metadata) ================================================ FILE: torch_geometric/nn/models/correct_and_smooth.py ================================================ import torch from torch import Tensor from torch_geometric.nn.models import LabelPropagation from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import one_hot class CorrectAndSmooth(torch.nn.Module): r"""The correct and smooth (C&S) post-processing model from the `"Combining Label Propagation And Simple Models Out-performs Graph Neural Networks" `_ paper, where soft predictions :math:`\mathbf{Z}` (obtained from a simple base predictor) are first corrected based on ground-truth training label information :math:`\mathbf{Y}` and residual propagation. .. math:: \mathbf{e}^{(0)}_i &= \begin{cases} \mathbf{y}_i - \mathbf{z}_i, & \text{if }i \text{ is training node,}\\ \mathbf{0}, & \text{else} \end{cases} .. math:: \mathbf{E}^{(\ell)} &= \alpha_1 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{E}^{(\ell - 1)} + (1 - \alpha_1) \mathbf{E}^{(\ell - 1)} \mathbf{\hat{Z}} &= \mathbf{Z} + \gamma \cdot \mathbf{E}^{(L_1)}, where :math:`\gamma` denotes the scaling factor (either fixed or automatically determined), and then smoothed over the graph via label propagation .. math:: \mathbf{\hat{z}}^{(0)}_i &= \begin{cases} \mathbf{y}_i, & \text{if }i\text{ is training node,}\\ \mathbf{\hat{z}}_i, & \text{else} \end{cases} .. math:: \mathbf{\hat{Z}}^{(\ell)} = \alpha_2 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{\hat{Z}}^{(\ell - 1)} + (1 - \alpha_2) \mathbf{\hat{Z}}^{(\ell - 1)} to obtain the final prediction :math:`\mathbf{\hat{Z}}^{(L_2)}`. .. note:: For an example of using the C&S model, see `examples/correct_and_smooth.py `_. Args: num_correction_layers (int): The number of propagations :math:`L_1`. correction_alpha (float): The :math:`\alpha_1` coefficient. num_smoothing_layers (int): The number of propagations :math:`L_2`. smoothing_alpha (float): The :math:`\alpha_2` coefficient. autoscale (bool, optional): If set to :obj:`True`, will automatically determine the scaling factor :math:`\gamma`. (default: :obj:`True`) scale (float, optional): The scaling factor :math:`\gamma`, in case :obj:`autoscale = False`. (default: :obj:`1.0`) """ def __init__(self, num_correction_layers: int, correction_alpha: float, num_smoothing_layers: int, smoothing_alpha: float, autoscale: bool = True, scale: float = 1.0): super().__init__() self.autoscale = autoscale self.scale = scale self.prop1 = LabelPropagation(num_correction_layers, correction_alpha) self.prop2 = LabelPropagation(num_smoothing_layers, smoothing_alpha) def forward(self, y_soft: Tensor, *args) -> Tensor: # pragma: no cover r"""Applies both :meth:`correct` and :meth:`smooth`.""" y_soft = self.correct(y_soft, *args) return self.smooth(y_soft, *args) def correct(self, y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: r"""Forward pass. Args: y_soft (torch.Tensor): The soft predictions :math:`\mathbf{Z}` obtained from a simple base predictor. y_true (torch.Tensor): The ground-truth label information :math:`\mathbf{Y}` of training nodes. mask (torch.Tensor): A mask or index tensor denoting which nodes were used for training. edge_index (torch.Tensor or SparseTensor): The edge connectivity. edge_weight (torch.Tensor, optional): The edge weights. (default: :obj:`None`) """ numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) assert y_true.size(0) == numel if y_true.dtype == torch.long and y_true.size(0) == y_true.numel(): y_true = one_hot(y_true.view(-1), num_classes=y_soft.size(-1), dtype=y_soft.dtype) error = torch.zeros_like(y_soft) error[mask] = y_true - y_soft[mask] if self.autoscale: smoothed_error = self.prop1(error, edge_index, edge_weight=edge_weight, post_step=lambda x: x.clamp_(-1., 1.)) sigma = error[mask].abs().sum() / numel scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True) scale[scale.isinf() | (scale > 1000)] = 1.0 return y_soft + scale * smoothed_error else: def fix_input(x): x[mask] = error[mask] return x smoothed_error = self.prop1(error, edge_index, edge_weight=edge_weight, post_step=fix_input) return y_soft + self.scale * smoothed_error def smooth(self, y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: r"""Forward pass. Args: y_soft (torch.Tensor): The corrected predictions :math:`\mathbf{Z}` obtained from :meth:`correct`. y_true (torch.Tensor): The ground-truth label information :math:`\mathbf{Y}` of training nodes. mask (torch.Tensor): A mask or index tensor denoting which nodes were used for training. edge_index (torch.Tensor or SparseTensor): The edge connectivity. edge_weight (torch.Tensor, optional): The edge weights. (default: :obj:`None`) """ numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) assert y_true.size(0) == numel if y_true.dtype == torch.long and y_true.size(0) == y_true.numel(): y_true = one_hot(y_true.view(-1), num_classes=y_soft.size(-1), dtype=y_soft.dtype) y_soft = y_soft.clone() y_soft[mask] = y_true return self.prop2(y_soft, edge_index, edge_weight=edge_weight) def __repr__(self): L1, alpha1 = self.prop1.num_layers, self.prop1.alpha L2, alpha2 = self.prop2.num_layers, self.prop2.alpha return (f'{self.__class__.__name__}(\n' f' correct: num_layers={L1}, alpha={alpha1}\n' f' smooth: num_layers={L2}, alpha={alpha2}\n' f' autoscale={self.autoscale}, scale={self.scale}\n' ')') ================================================ FILE: torch_geometric/nn/models/deep_graph_infomax.py ================================================ import copy from typing import Callable, Tuple import torch from torch import Tensor from torch.nn import Module, Parameter from torch_geometric.nn.inits import reset, uniform EPS = 1e-15 class DeepGraphInfomax(torch.nn.Module): r"""The Deep Graph Infomax model from the `"Deep Graph Infomax" `_ paper based on user-defined encoder and summary model :math:`\mathcal{E}` and :math:`\mathcal{R}` respectively, and a corruption function :math:`\mathcal{C}`. Args: hidden_channels (int): The latent space dimensionality. encoder (torch.nn.Module): The encoder module :math:`\mathcal{E}`. summary (callable): The readout function :math:`\mathcal{R}`. corruption (callable): The corruption function :math:`\mathcal{C}`. """ def __init__( self, hidden_channels: int, encoder: Module, summary: Callable, corruption: Callable, ): super().__init__() self.hidden_channels = hidden_channels self.encoder = encoder self.summary = summary self.corruption = corruption self.weight = Parameter(torch.empty(hidden_channels, hidden_channels)) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder) reset(self.summary) uniform(self.hidden_channels, self.weight) def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: """Returns the latent space for the input arguments, their corruptions and their summary representation. """ pos_z = self.encoder(*args, **kwargs) cor = self.corruption(*args, **kwargs) cor = cor if isinstance(cor, tuple) else (cor, ) cor_args = cor[:len(args)] cor_kwargs = copy.copy(kwargs) for key, value in zip(kwargs.keys(), cor[len(args):]): cor_kwargs[key] = value neg_z = self.encoder(*cor_args, **cor_kwargs) summary = self.summary(pos_z, *args, **kwargs) return pos_z, neg_z, summary def discriminate(self, z: Tensor, summary: Tensor, sigmoid: bool = True) -> Tensor: r"""Given the patch-summary pair :obj:`z` and :obj:`summary`, computes the probability scores assigned to this patch-summary pair. Args: z (torch.Tensor): The latent space. summary (torch.Tensor): The summary vector. sigmoid (bool, optional): If set to :obj:`False`, does not apply the logistic sigmoid function to the output. (default: :obj:`True`) """ summary = summary.t() if summary.dim() > 1 else summary value = torch.matmul(z, torch.matmul(self.weight, summary)) return torch.sigmoid(value) if sigmoid else value def loss(self, pos_z: Tensor, neg_z: Tensor, summary: Tensor) -> Tensor: r"""Computes the mutual information maximization objective.""" pos_loss = -torch.log( self.discriminate(pos_z, summary, sigmoid=True) + EPS).mean() neg_loss = -torch.log(1 - self.discriminate(neg_z, summary, sigmoid=True) + EPS).mean() return pos_loss + neg_loss def test( self, train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', *args, **kwargs, ) -> float: r"""Evaluates latent space quality via a logistic regression downstream task. """ from sklearn.linear_model import LogisticRegression clf = LogisticRegression(*args, solver=solver, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.hidden_channels})' ================================================ FILE: torch_geometric/nn/models/deepgcn.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Module from torch.utils.checkpoint import checkpoint class DeepGCNLayer(torch.nn.Module): r"""The skip connection operations from the `"DeepGCNs: Can GCNs Go as Deep as CNNs?" `_ and `"All You Need to Train Deeper GCNs" `_ papers. The implemented skip connections includes the pre-activation residual connection (:obj:`"res+"`), the residual connection (:obj:`"res"`), the dense connection (:obj:`"dense"`) and no connections (:obj:`"plain"`). * **Res+** (:obj:`"res+"`): .. math:: \text{Normalization}\to\text{Activation}\to\text{Dropout}\to \text{GraphConv}\to\text{Res} * **Res** (:obj:`"res"`) / **Dense** (:obj:`"dense"`) / **Plain** (:obj:`"plain"`): .. math:: \text{GraphConv}\to\text{Normalization}\to\text{Activation}\to \text{Res/Dense/Plain}\to\text{Dropout} .. note:: For an example of using :obj:`GENConv`, see `examples/ogbn_proteins_deepgcn.py `_. Args: conv (torch.nn.Module, optional): the GCN operator. (default: :obj:`None`) norm (torch.nn.Module): the normalization layer. (default: :obj:`None`) act (torch.nn.Module): the activation layer. (default: :obj:`None`) block (str, optional): The skip connection operation to use (:obj:`"res+"`, :obj:`"res"`, :obj:`"dense"` or :obj:`"plain"`). (default: :obj:`"res+"`) dropout (float, optional): Whether to apply or dropout. (default: :obj:`0.`) ckpt_grad (bool, optional): If set to :obj:`True`, will checkpoint this part of the model. Checkpointing works by trading compute for memory, since intermediate activations do not need to be kept in memory. Set this to :obj:`True` in case you encounter out-of-memory errors while going deep. (default: :obj:`False`) """ def __init__( self, conv: Optional[Module] = None, norm: Optional[Module] = None, act: Optional[Module] = None, block: str = 'res+', dropout: float = 0., ckpt_grad: bool = False, ): super().__init__() self.conv = conv self.norm = norm self.act = act self.block = block.lower() assert self.block in ['res+', 'res', 'dense', 'plain'] self.dropout = dropout self.ckpt_grad = ckpt_grad def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.conv.reset_parameters() self.norm.reset_parameters() def forward(self, *args, **kwargs) -> Tensor: """""" # noqa: D419 args = list(args) x = args.pop(0) if self.block == 'res+': h = x if self.norm is not None: h = self.norm(h) if self.act is not None: h = self.act(h) h = F.dropout(h, p=self.dropout, training=self.training) if self.conv is not None and self.ckpt_grad and h.requires_grad: h = checkpoint(self.conv, h, *args, use_reentrant=True, **kwargs) else: h = self.conv(h, *args, **kwargs) return x + h else: if self.conv is not None and self.ckpt_grad and x.requires_grad: h = checkpoint(self.conv, x, *args, use_reentrant=True, **kwargs) else: h = self.conv(x, *args, **kwargs) if self.norm is not None: h = self.norm(h) if self.act is not None: h = self.act(h) if self.block == 'res': h = x + h elif self.block == 'dense': h = torch.cat([x, h], dim=-1) elif self.block == 'plain': pass return F.dropout(h, p=self.dropout, training=self.training) def __repr__(self) -> str: return f'{self.__class__.__name__}(block={self.block})' ================================================ FILE: torch_geometric/nn/models/dimenet.py ================================================ import os import os.path as osp from functools import partial from math import pi as PI from math import sqrt from typing import Callable, Dict, Optional, Tuple, Union import numpy as np import torch from torch import Tensor from torch.nn import Embedding, Linear from torch_geometric.data import Dataset, download_url from torch_geometric.nn import radius_graph from torch_geometric.nn.inits import glorot_orthogonal from torch_geometric.nn.resolver import activation_resolver from torch_geometric.typing import OptTensor, SparseTensor from torch_geometric.utils import scatter qm9_target_dict: Dict[int, str] = { 0: 'mu', 1: 'alpha', 2: 'homo', 3: 'lumo', 5: 'r2', 6: 'zpve', 7: 'U0', 8: 'U', 9: 'H', 10: 'G', 11: 'Cv', } class Envelope(torch.nn.Module): def __init__(self, exponent: int): super().__init__() self.p = exponent + 1 self.a = -(self.p + 1) * (self.p + 2) / 2 self.b = self.p * (self.p + 2) self.c = -self.p * (self.p + 1) / 2 def forward(self, x: Tensor) -> Tensor: p, a, b, c = self.p, self.a, self.b, self.c x_pow_p0 = x.pow(p - 1) x_pow_p1 = x_pow_p0 * x x_pow_p2 = x_pow_p1 * x return (1.0 / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2) * (x < 1.0).to(x.dtype) class BesselBasisLayer(torch.nn.Module): def __init__(self, num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5): super().__init__() self.cutoff = cutoff self.envelope = Envelope(envelope_exponent) self.freq = torch.nn.Parameter(torch.empty(num_radial)) self.reset_parameters() def reset_parameters(self): with torch.no_grad(): torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI) self.freq.requires_grad_() def forward(self, dist: Tensor) -> Tensor: dist = dist.unsqueeze(-1) / self.cutoff return self.envelope(dist) * (self.freq * dist).sin() class SphericalBasisLayer(torch.nn.Module): def __init__( self, num_spherical: int, num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5, ): super().__init__() import sympy as sym from torch_geometric.nn.models.dimenet_utils import ( bessel_basis, real_sph_harm, ) assert num_radial <= 64 self.num_spherical = num_spherical self.num_radial = num_radial self.cutoff = cutoff self.envelope = Envelope(envelope_exponent) bessel_forms = bessel_basis(num_spherical, num_radial) sph_harm_forms = real_sph_harm(num_spherical) self.sph_funcs = [] self.bessel_funcs = [] x, theta = sym.symbols('x theta') modules = {'sin': torch.sin, 'cos': torch.cos} for i in range(num_spherical): if i == 0: sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) self.sph_funcs.append(partial(self._sph_to_tensor, sph1)) else: sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) self.sph_funcs.append(sph) for j in range(num_radial): bessel = sym.lambdify([x], bessel_forms[i][j], modules) self.bessel_funcs.append(bessel) @staticmethod def _sph_to_tensor(sph, x: Tensor) -> Tensor: return torch.zeros_like(x) + sph def forward(self, dist: Tensor, angle: Tensor, idx_kj: Tensor) -> Tensor: dist = dist / self.cutoff rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) rbf = self.envelope(dist).unsqueeze(-1) * rbf cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) n, k = self.num_spherical, self.num_radial out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) return out class EmbeddingBlock(torch.nn.Module): def __init__(self, num_radial: int, hidden_channels: int, act: Callable): super().__init__() self.act = act self.emb = Embedding(95, hidden_channels) self.lin_rbf = Linear(num_radial, hidden_channels) self.lin = Linear(3 * hidden_channels, hidden_channels) self.reset_parameters() def reset_parameters(self): self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) self.lin_rbf.reset_parameters() self.lin.reset_parameters() def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor: x = self.emb(x) rbf = self.act(self.lin_rbf(rbf)) return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) class ResidualLayer(torch.nn.Module): def __init__(self, hidden_channels: int, act: Callable): super().__init__() self.act = act self.lin1 = Linear(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, hidden_channels) self.reset_parameters() def reset_parameters(self): glorot_orthogonal(self.lin1.weight, scale=2.0) self.lin1.bias.data.fill_(0) glorot_orthogonal(self.lin2.weight, scale=2.0) self.lin2.bias.data.fill_(0) def forward(self, x: Tensor) -> Tensor: return x + self.act(self.lin2(self.act(self.lin1(x)))) class InteractionBlock(torch.nn.Module): def __init__( self, hidden_channels: int, num_bilinear: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable, ): super().__init__() self.act = act self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear, bias=False) # Dense transformations of input messages. self.lin_kj = Linear(hidden_channels, hidden_channels) self.lin_ji = Linear(hidden_channels, hidden_channels) self.W = torch.nn.Parameter( torch.empty(hidden_channels, num_bilinear, hidden_channels)) self.layers_before_skip = torch.nn.ModuleList([ ResidualLayer(hidden_channels, act) for _ in range(num_before_skip) ]) self.lin = Linear(hidden_channels, hidden_channels) self.layers_after_skip = torch.nn.ModuleList([ ResidualLayer(hidden_channels, act) for _ in range(num_after_skip) ]) self.reset_parameters() def reset_parameters(self): glorot_orthogonal(self.lin_rbf.weight, scale=2.0) glorot_orthogonal(self.lin_sbf.weight, scale=2.0) glorot_orthogonal(self.lin_kj.weight, scale=2.0) self.lin_kj.bias.data.fill_(0) glorot_orthogonal(self.lin_ji.weight, scale=2.0) self.lin_ji.bias.data.fill_(0) self.W.data.normal_(mean=0, std=2 / self.W.size(0)) for res_layer in self.layers_before_skip: res_layer.reset_parameters() glorot_orthogonal(self.lin.weight, scale=2.0) self.lin.bias.data.fill_(0) for res_layer in self.layers_after_skip: res_layer.reset_parameters() def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor) -> Tensor: rbf = self.lin_rbf(rbf) sbf = self.lin_sbf(sbf) x_ji = self.act(self.lin_ji(x)) x_kj = self.act(self.lin_kj(x)) x_kj = x_kj * rbf x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W) x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce='sum') h = x_ji + x_kj for layer in self.layers_before_skip: h = layer(h) h = self.act(self.lin(h)) + x for layer in self.layers_after_skip: h = layer(h) return h class InteractionPPBlock(torch.nn.Module): def __init__( self, hidden_channels: int, int_emb_size: int, basis_emb_size: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable, ): super().__init__() self.act = act # Transformation of Bessel and spherical basis representations: self.lin_rbf1 = Linear(num_radial, basis_emb_size, bias=False) self.lin_rbf2 = Linear(basis_emb_size, hidden_channels, bias=False) self.lin_sbf1 = Linear(num_spherical * num_radial, basis_emb_size, bias=False) self.lin_sbf2 = Linear(basis_emb_size, int_emb_size, bias=False) # Hidden transformation of input message: self.lin_kj = Linear(hidden_channels, hidden_channels) self.lin_ji = Linear(hidden_channels, hidden_channels) # Embedding projections for interaction triplets: self.lin_down = Linear(hidden_channels, int_emb_size, bias=False) self.lin_up = Linear(int_emb_size, hidden_channels, bias=False) # Residual layers before and after skip connection: self.layers_before_skip = torch.nn.ModuleList([ ResidualLayer(hidden_channels, act) for _ in range(num_before_skip) ]) self.lin = Linear(hidden_channels, hidden_channels) self.layers_after_skip = torch.nn.ModuleList([ ResidualLayer(hidden_channels, act) for _ in range(num_after_skip) ]) self.reset_parameters() def reset_parameters(self): glorot_orthogonal(self.lin_rbf1.weight, scale=2.0) glorot_orthogonal(self.lin_rbf2.weight, scale=2.0) glorot_orthogonal(self.lin_sbf1.weight, scale=2.0) glorot_orthogonal(self.lin_sbf2.weight, scale=2.0) glorot_orthogonal(self.lin_kj.weight, scale=2.0) self.lin_kj.bias.data.fill_(0) glorot_orthogonal(self.lin_ji.weight, scale=2.0) self.lin_ji.bias.data.fill_(0) glorot_orthogonal(self.lin_down.weight, scale=2.0) glorot_orthogonal(self.lin_up.weight, scale=2.0) for res_layer in self.layers_before_skip: res_layer.reset_parameters() glorot_orthogonal(self.lin.weight, scale=2.0) self.lin.bias.data.fill_(0) for res_layer in self.layers_after_skip: res_layer.reset_parameters() def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor) -> Tensor: # Initial transformation: x_ji = self.act(self.lin_ji(x)) x_kj = self.act(self.lin_kj(x)) # Transformation via Bessel basis: rbf = self.lin_rbf1(rbf) rbf = self.lin_rbf2(rbf) x_kj = x_kj * rbf # Down project embedding and generating triple-interactions: x_kj = self.act(self.lin_down(x_kj)) # Transform via 2D spherical basis: sbf = self.lin_sbf1(sbf) sbf = self.lin_sbf2(sbf) x_kj = x_kj[idx_kj] * sbf # Aggregate interactions and up-project embeddings: x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce='sum') x_kj = self.act(self.lin_up(x_kj)) h = x_ji + x_kj for layer in self.layers_before_skip: h = layer(h) h = self.act(self.lin(h)) + x for layer in self.layers_after_skip: h = layer(h) return h class OutputBlock(torch.nn.Module): def __init__( self, num_radial: int, hidden_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros', ): assert output_initializer in {'zeros', 'glorot_orthogonal'} super().__init__() self.act = act self.output_initializer = output_initializer self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) self.lins = torch.nn.ModuleList() for _ in range(num_layers): self.lins.append(Linear(hidden_channels, hidden_channels)) self.lin = Linear(hidden_channels, out_channels, bias=False) self.reset_parameters() def reset_parameters(self): glorot_orthogonal(self.lin_rbf.weight, scale=2.0) for lin in self.lins: glorot_orthogonal(lin.weight, scale=2.0) lin.bias.data.fill_(0) if self.output_initializer == 'zeros': self.lin.weight.data.fill_(0) elif self.output_initializer == 'glorot_orthogonal': glorot_orthogonal(self.lin.weight, scale=2.0) def forward(self, x: Tensor, rbf: Tensor, i: Tensor, num_nodes: Optional[int] = None) -> Tensor: x = self.lin_rbf(rbf) * x x = scatter(x, i, dim=0, dim_size=num_nodes, reduce='sum') for lin in self.lins: x = self.act(lin(x)) return self.lin(x) class OutputPPBlock(torch.nn.Module): def __init__( self, num_radial: int, hidden_channels: int, out_emb_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros', ): assert output_initializer in {'zeros', 'glorot_orthogonal'} super().__init__() self.act = act self.output_initializer = output_initializer self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) # The up-projection layer: self.lin_up = Linear(hidden_channels, out_emb_channels, bias=False) self.lins = torch.nn.ModuleList() for _ in range(num_layers): self.lins.append(Linear(out_emb_channels, out_emb_channels)) self.lin = Linear(out_emb_channels, out_channels, bias=False) self.reset_parameters() def reset_parameters(self): glorot_orthogonal(self.lin_rbf.weight, scale=2.0) glorot_orthogonal(self.lin_up.weight, scale=2.0) for lin in self.lins: glorot_orthogonal(lin.weight, scale=2.0) lin.bias.data.fill_(0) if self.output_initializer == 'zeros': self.lin.weight.data.fill_(0) elif self.output_initializer == 'glorot_orthogonal': glorot_orthogonal(self.lin.weight, scale=2.0) def forward(self, x: Tensor, rbf: Tensor, i: Tensor, num_nodes: Optional[int] = None) -> Tensor: x = self.lin_rbf(rbf) * x x = scatter(x, i, dim=0, dim_size=num_nodes, reduce='sum') x = self.lin_up(x) for lin in self.lins: x = self.act(lin(x)) return self.lin(x) def triplets( edge_index: Tensor, num_nodes: int, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: row, col = edge_index # j->i value = torch.arange(row.size(0), device=row.device) adj_t = SparseTensor(row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes)) adj_t_row = adj_t[row] num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) # Node indices (k->j->i) for triplets. idx_i = col.repeat_interleave(num_triplets) idx_j = row.repeat_interleave(num_triplets) idx_k = adj_t_row.storage.col() mask = idx_i != idx_k # Remove i == k triplets. idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] # Edge indices (k-j, j->i) for triplets. idx_kj = adj_t_row.storage.value()[mask] idx_ji = adj_t_row.storage.row()[mask] return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji class DimeNet(torch.nn.Module): r"""The directional message passing neural network (DimeNet) from the `"Directional Message Passing for Molecular Graphs" `_ paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion. .. note:: For an example of using a pretrained DimeNet variant, see `examples/qm9_pretrained_dimenet.py `_. Args: hidden_channels (int): Hidden embedding size. out_channels (int): Size of each output sample. num_blocks (int): Number of building blocks. num_bilinear (int): Size of the bilinear layer tensor. num_spherical (int): Number of spherical harmonics. num_radial (int): Number of radial basis functions. cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) envelope_exponent (int, optional): Shape of the smooth cutoff. (default: :obj:`5`) num_before_skip (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`1`) num_after_skip (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) act (str or Callable, optional): The activation function. (default: :obj:`"swish"`) output_initializer (str, optional): The initialization method for the output layer (:obj:`"zeros"`, :obj:`"glorot_orthogonal"`). (default: :obj:`"zeros"`) """ url = ('https://github.com/klicperajo/dimenet/raw/master/pretrained/' 'dimenet') def __init__( self, hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros', ): super().__init__() if num_spherical < 2: raise ValueError("'num_spherical' should be greater than 1") act = activation_resolver(act) self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.num_blocks = num_blocks self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent) self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff, envelope_exponent) self.emb = EmbeddingBlock(num_radial, hidden_channels, act) self.output_blocks = torch.nn.ModuleList([ OutputBlock( num_radial, hidden_channels, out_channels, num_output_layers, act, output_initializer, ) for _ in range(num_blocks + 1) ]) self.interaction_blocks = torch.nn.ModuleList([ InteractionBlock( hidden_channels, num_bilinear, num_spherical, num_radial, num_before_skip, num_after_skip, act, ) for _ in range(num_blocks) ]) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.rbf.reset_parameters() self.emb.reset_parameters() for out in self.output_blocks: out.reset_parameters() for interaction in self.interaction_blocks: interaction.reset_parameters() @classmethod def from_qm9_pretrained( cls, root: str, dataset: Dataset, target: int, ) -> Tuple['DimeNet', Dataset, Dataset, Dataset]: # pragma: no cover r"""Returns a pre-trained :class:`DimeNet` model on the :class:`~torch_geometric.datasets.QM9` dataset, trained on the specified target :obj:`target`. """ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf assert target >= 0 and target <= 12 and not target == 4 root = osp.expanduser(osp.normpath(root)) path = osp.join(root, 'pretrained_dimenet', qm9_target_dict[target]) os.makedirs(path, exist_ok=True) url = f'{cls.url}/{qm9_target_dict[target]}' if not osp.exists(osp.join(path, 'checkpoint')): download_url(f'{url}/checkpoint', path) download_url(f'{url}/ckpt.data-00000-of-00002', path) download_url(f'{url}/ckpt.data-00001-of-00002', path) download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') reader = tf.train.load_checkpoint(path) model = cls( hidden_channels=128, out_channels=1, num_blocks=6, num_bilinear=8, num_spherical=7, num_radial=6, cutoff=5.0, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, ) def copy_(src, name, transpose=False): init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE') init = torch.from_numpy(init) if name[-6:] == 'kernel': init = init.t() src.data.copy_(init) copy_(model.rbf.freq, 'rbf_layer/frequencies') copy_(model.emb.emb.weight, 'emb_block/embeddings') copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel') copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias') copy_(model.emb.lin.weight, 'emb_block/dense/kernel') copy_(model.emb.lin.bias, 'emb_block/dense/bias') for i, block in enumerate(model.output_blocks): copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel') for j, lin in enumerate(block.lins): copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel') copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias') copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel') for i, block in enumerate(model.interaction_blocks): copy_(block.lin_rbf.weight, f'int_blocks/{i}/dense_rbf/kernel') copy_(block.lin_sbf.weight, f'int_blocks/{i}/dense_sbf/kernel') copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel') copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias') copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel') copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias') copy_(block.W, f'int_blocks/{i}/bilinear') for j, layer in enumerate(block.layers_before_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias') copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel') copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias') for j, layer in enumerate(block.layers_after_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias') # Use the same random seed as the official DimeNet` implementation. random_state = np.random.RandomState(seed=42) perm = torch.from_numpy(random_state.permutation(np.arange(130831))) perm = perm.long() train_idx = perm[:110000] val_idx = perm[110000:120000] test_idx = perm[120000:] return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx]) def forward( self, z: Tensor, pos: Tensor, batch: OptTensor = None, ) -> Tensor: r"""Forward pass. Args: z (torch.Tensor): Atomic number of each atom with shape :obj:`[num_atoms]`. pos (torch.Tensor): Coordinates of each atom with shape :obj:`[num_atoms, 3]`. batch (torch.Tensor, optional): Batch indices assigning each atom to a separate molecule with shape :obj:`[num_atoms]`. (default: :obj:`None`) """ edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( edge_index, num_nodes=z.size(0)) # Calculate distances. dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() # Calculate angles. if isinstance(self, DimeNetPlusPlus): pos_jk, pos_ij = pos[idx_j] - pos[idx_k], pos[idx_i] - pos[idx_j] a = (pos_ij * pos_jk).sum(dim=-1) b = torch.cross(pos_ij, pos_jk, dim=1).norm(dim=-1) elif isinstance(self, DimeNet): pos_ji, pos_ki = pos[idx_j] - pos[idx_i], pos[idx_k] - pos[idx_i] a = (pos_ji * pos_ki).sum(dim=-1) b = torch.cross(pos_ji, pos_ki, dim=1).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) # Embedding block. x = self.emb(z, rbf, i, j) P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) # Interaction blocks. for interaction_block, output_block in zip(self.interaction_blocks, self.output_blocks[1:]): x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) P = P + output_block(x, rbf, i, num_nodes=pos.size(0)) if batch is None: return P.sum(dim=0) else: return scatter(P, batch, dim=0, reduce='sum') class DimeNetPlusPlus(DimeNet): r"""The DimeNet++ from the `"Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules" `_ paper. :class:`DimeNetPlusPlus` is an upgrade to the :class:`DimeNet` model with 8x faster and 10% more accurate than :class:`DimeNet`. Args: hidden_channels (int): Hidden embedding size. out_channels (int): Size of each output sample. num_blocks (int): Number of building blocks. int_emb_size (int): Size of embedding in the interaction block. basis_emb_size (int): Size of basis embedding in the interaction block. out_emb_channels (int): Size of embedding in the output block. num_spherical (int): Number of spherical harmonics. num_radial (int): Number of radial basis functions. cutoff: (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) envelope_exponent (int, optional): Shape of the smooth cutoff. (default: :obj:`5`) num_before_skip: (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`1`) num_after_skip: (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers: (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) act: (str or Callable, optional): The activation function. (default: :obj:`"swish"`) output_initializer (str, optional): The initialization method for the output layer (:obj:`"zeros"`, :obj:`"glorot_orthogonal"`). (default: :obj:`"zeros"`) """ url = ('https://raw.githubusercontent.com/gasteigerjo/dimenet/' 'master/pretrained/dimenet_pp') def __init__( self, hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros', ): act = activation_resolver(act) super().__init__( hidden_channels=hidden_channels, out_channels=out_channels, num_blocks=num_blocks, num_bilinear=1, num_spherical=num_spherical, num_radial=num_radial, cutoff=cutoff, max_num_neighbors=max_num_neighbors, envelope_exponent=envelope_exponent, num_before_skip=num_before_skip, num_after_skip=num_after_skip, num_output_layers=num_output_layers, act=act, output_initializer=output_initializer, ) # We are re-using the RBF, SBF and embedding layers of `DimeNet` and # redefine output_block and interaction_block in DimeNet++. # Hence, it is to be noted that in the above initialization, the # variable `num_bilinear` does not have any purpose as it is used # solely in the `OutputBlock` of DimeNet: self.output_blocks = torch.nn.ModuleList([ OutputPPBlock( num_radial, hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_initializer, ) for _ in range(num_blocks + 1) ]) self.interaction_blocks = torch.nn.ModuleList([ InteractionPPBlock( hidden_channels, int_emb_size, basis_emb_size, num_spherical, num_radial, num_before_skip, num_after_skip, act, ) for _ in range(num_blocks) ]) self.reset_parameters() @classmethod def from_qm9_pretrained( cls, root: str, dataset: Dataset, target: int, ) -> Tuple['DimeNetPlusPlus', Dataset, Dataset, Dataset]: # pragma: no cover r"""Returns a pre-trained :class:`DimeNetPlusPlus` model on the :class:`~torch_geometric.datasets.QM9` dataset, trained on the specified target :obj:`target`. """ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf assert target >= 0 and target <= 12 and not target == 4 root = osp.expanduser(osp.normpath(root)) path = osp.join(root, 'pretrained_dimenet_pp', qm9_target_dict[target]) os.makedirs(path, exist_ok=True) url = f'{cls.url}/{qm9_target_dict[target]}' if not osp.exists(osp.join(path, 'checkpoint')): download_url(f'{url}/checkpoint', path) download_url(f'{url}/ckpt.data-00000-of-00002', path) download_url(f'{url}/ckpt.data-00001-of-00002', path) download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') reader = tf.train.load_checkpoint(path) # Configuration from DimeNet++: # https://github.com/gasteigerjo/dimenet/blob/master/config_pp.yaml model = cls( hidden_channels=128, out_channels=1, num_blocks=4, int_emb_size=64, basis_emb_size=8, out_emb_channels=256, num_spherical=7, num_radial=6, cutoff=5.0, max_num_neighbors=32, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, ) def copy_(src, name, transpose=False): init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE') init = torch.from_numpy(init) if name[-6:] == 'kernel': init = init.t() src.data.copy_(init) copy_(model.rbf.freq, 'rbf_layer/frequencies') copy_(model.emb.emb.weight, 'emb_block/embeddings') copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel') copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias') copy_(model.emb.lin.weight, 'emb_block/dense/kernel') copy_(model.emb.lin.bias, 'emb_block/dense/bias') for i, block in enumerate(model.output_blocks): copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel') copy_(block.lin_up.weight, f'output_blocks/{i}/up_projection/kernel') for j, lin in enumerate(block.lins): copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel') copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias') copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel') for i, block in enumerate(model.interaction_blocks): copy_(block.lin_rbf1.weight, f'int_blocks/{i}/dense_rbf1/kernel') copy_(block.lin_rbf2.weight, f'int_blocks/{i}/dense_rbf2/kernel') copy_(block.lin_sbf1.weight, f'int_blocks/{i}/dense_sbf1/kernel') copy_(block.lin_sbf2.weight, f'int_blocks/{i}/dense_sbf2/kernel') copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel') copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias') copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel') copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias') copy_(block.lin_down.weight, f'int_blocks/{i}/down_projection/kernel') copy_(block.lin_up.weight, f'int_blocks/{i}/up_projection/kernel') for j, layer in enumerate(block.layers_before_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias') copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel') copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias') for j, layer in enumerate(block.layers_after_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias') random_state = np.random.RandomState(seed=42) perm = torch.from_numpy(random_state.permutation(np.arange(130831))) perm = perm.long() train_idx = perm[:110000] val_idx = perm[110000:120000] test_idx = perm[120000:] return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx]) ================================================ FILE: torch_geometric/nn/models/dimenet_utils.py ================================================ # Shameless steal from: https://github.com/klicperajo/dimenet import math import numpy as np import sympy as sym from scipy import special as sp from scipy.optimize import brentq def Jn(r, n): return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) def Jn_zeros(n, k): zerosj = np.zeros((n, k), dtype='float32') zerosj[0] = np.arange(1, k + 1) * np.pi points = np.arange(1, k + n) * np.pi racines = np.zeros(k + n - 1, dtype='float32') for i in range(1, n): for j in range(k + n - 1 - i): foo = brentq(Jn, points[j], points[j + 1], (i, )) racines[j] = foo points = racines zerosj[i][:k] = racines[:k] return zerosj def spherical_bessel_formulas(n): x = sym.symbols('x') f = [sym.sin(x) / x] a = sym.sin(x) / x for i in range(1, n): b = sym.diff(a, x) / x f += [sym.simplify(b * (-x)**i)] a = sym.simplify(b) return f def bessel_basis(n, k): zeros = Jn_zeros(n, k) normalizer = [] for order in range(n): normalizer_tmp = [] for i in range(k): normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 normalizer += [normalizer_tmp] f = spherical_bessel_formulas(n) x = sym.symbols('x') bess_basis = [] for order in range(n): bess_basis_tmp = [] for i in range(k): bess_basis_tmp += [ sym.simplify(normalizer[order][i] * f[order].subs(x, zeros[order, i] * x)) ] bess_basis += [bess_basis_tmp] return bess_basis def sph_harm_prefactor(k, m): return ((2 * k + 1) * math.factorial(k - abs(m)) / (4 * np.pi * math.factorial(k + abs(m))))**0.5 def associated_legendre_polynomials(k, zero_m_only=True): r"""Helper function to calculate Y_l^m.""" z = sym.symbols('z') P_l_m = [[0] * (j + 1) for j in range(k)] P_l_m[0][0] = 1 if k > 0: P_l_m[1][0] = z for j in range(2, k): # Use the property of Eq (7) in # https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html: P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0]) / j) if not zero_m_only: for i in range(1, k): P_l_m[i][i] = sym.simplify( (1 - 2 * i) * P_l_m[i - 1][i - 1] * (1 - z**2)**0.5) if i + 1 < k: # Use the property of Eq (11) in # https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html: P_l_m[i + 1][i] = sym.simplify( (2 * i + 1) * z * P_l_m[i][i]) for j in range(i + 2, k): # Use the property of Eq (7) in # https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html: P_l_m[j][i] = sym.simplify( ((2 * j - 1) * z * P_l_m[j - 1][i] - (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) return P_l_m def real_sph_harm(k, zero_m_only=True, spherical_coordinates=True): if not zero_m_only: S_m = [0] C_m = [1] for i in range(1, k): x = sym.symbols('x') y = sym.symbols('y') S_m += [x * S_m[i - 1] + y * C_m[i - 1]] C_m += [x * C_m[i - 1] - y * S_m[i - 1]] P_l_m = associated_legendre_polynomials(k, zero_m_only) if spherical_coordinates: theta = sym.symbols('theta') z = sym.symbols('z') for i in range(len(P_l_m)): for j in range(len(P_l_m[i])): if not isinstance(P_l_m[i][j], int): P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) if not zero_m_only: phi = sym.symbols('phi') for i in range(len(S_m)): S_m[i] = S_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs( y, sym.sin(theta) * sym.sin(phi)) for i in range(len(C_m)): C_m[i] = C_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs( y, sym.sin(theta) * sym.sin(phi)) Y_func_l_m = [['0'] * (2 * j + 1) for j in range(k)] for i in range(k): Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) if not zero_m_only: for i in range(1, k): for j in range(1, i + 1): Y_func_l_m[i][j] = sym.simplify( 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) for i in range(1, k): for j in range(1, i + 1): Y_func_l_m[i][-j] = sym.simplify( 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) return Y_func_l_m ================================================ FILE: torch_geometric/nn/models/gnnff.py ================================================ import torch from torch import Tensor from torch.nn import BatchNorm1d, Embedding, Linear, ModuleList, Sequential from torch_geometric.nn import radius_graph from torch_geometric.nn.inits import reset from torch_geometric.nn.models.dimenet import triplets from torch_geometric.nn.models.schnet import ShiftedSoftplus from torch_geometric.typing import OptTensor from torch_geometric.utils import scatter class GaussianFilter(torch.nn.Module): def __init__(self, start=0.0, stop=5.0, num_gaussians=50): super().__init__() offset = torch.linspace(start, stop, num_gaussians) self.coeff = -0.5 / (float(offset[1]) - float(offset[0]))**2 self.register_buffer('offset', offset) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" def forward(self, dist: Tensor) -> Tensor: dist = dist.view(-1, 1) - self.offset.view(1, -1) return torch.exp(self.coeff * dist.pow(2)) class NodeBlock(torch.nn.Module): def __init__(self, hidden_node_channels: int, hidden_edge_channels: int): super().__init__() self.lin_c1 = Linear(hidden_node_channels + hidden_edge_channels, 2 * hidden_node_channels) # BN was added based on previous studies. # ref: https://github.com/txie-93/cgcnn/blob/master/cgcnn/model.py self.bn_c1 = BatchNorm1d(2 * hidden_node_channels) self.bn = BatchNorm1d(hidden_node_channels) def reset_parameters(self): self.lin_c1.reset_parameters() self.bn_c1.reset_parameters() self.bn.reset_parameters() def forward(self, node_emb: Tensor, edge_emb: Tensor, i: Tensor) -> Tensor: c1 = torch.cat([node_emb[i], edge_emb], dim=1) c1 = self.bn_c1(self.lin_c1(c1)) c1_filter, c1_core = c1.chunk(2, dim=1) c1_filter = c1_filter.sigmoid() c1_core = c1_core.tanh() c1_emb = scatter(c1_filter * c1_core, i, dim=0, dim_size=node_emb.size(0), reduce='sum') c1_emb = self.bn(c1_emb) return (node_emb + c1_emb).tanh() class EdgeBlock(torch.nn.Module): def __init__(self, hidden_node_channels: int, hidden_edge_channels: int): super().__init__() self.lin_c2 = Linear(hidden_node_channels, 2 * hidden_edge_channels) self.lin_c3 = Linear( 3 * hidden_node_channels + 2 * hidden_edge_channels, 2 * hidden_edge_channels, ) # BN was added based on previous studies. # ref: https://github.com/txie-93/cgcnn/blob/master/cgcnn/model.py self.bn_c2 = BatchNorm1d(2 * hidden_edge_channels) self.bn_c3 = BatchNorm1d(2 * hidden_edge_channels) self.bn_c2_2 = BatchNorm1d(hidden_edge_channels) self.bn_c3_2 = BatchNorm1d(hidden_edge_channels) def reset_parameters(self): self.lin_c2.reset_parameters() self.lin_c3.reset_parameters() self.bn_c2.reset_parameters() self.bn_c3.reset_parameters() self.bn_c2_2.reset_parameters() self.bn_c3_2.reset_parameters() def forward( self, node_emb: Tensor, edge_emb: Tensor, i: Tensor, j: Tensor, idx_i: Tensor, idx_j: Tensor, idx_k: Tensor, idx_ji: Tensor, idx_kj: Tensor, ) -> Tensor: c2 = node_emb[i] * node_emb[j] c2 = self.bn_c2(self.lin_c2(c2)) c2_filter, c2_core = c2.chunk(2, dim=1) c2_filter = c2_filter.sigmoid() c2_core = c2_core.tanh() c2_emb = self.bn_c2_2(c2_filter * c2_core) c3 = torch.cat([ node_emb[idx_i], node_emb[idx_j], node_emb[idx_k], edge_emb[idx_ji], edge_emb[idx_kj], ], dim=1) c3 = self.bn_c3(self.lin_c3(c3)) c3_filter, c3_core = c3.chunk(2, dim=1) c3_filter = c3_filter.sigmoid() c3_core = c3_core.tanh() c3_emb = scatter(c3_filter * c3_core, idx_ji, dim=0, dim_size=edge_emb.size(0), reduce='sum') c3_emb = self.bn_c3_2(c3_emb) return (edge_emb + c2_emb + c3_emb).tanh() class GNNFF(torch.nn.Module): r"""The Graph Neural Network Force Field (GNNFF) from the `"Accurate and scalable graph neural network force field and molecular dynamics with direct force architecture" `_ paper. :class:`GNNFF` directly predicts atomic forces from automatically extracted features of the local atomic environment that are translationally-invariant, but rotationally-covariant to the coordinate of the atoms. Args: hidden_node_channels (int): Hidden node embedding size. hidden_edge_channels (int): Hidden edge embedding size. num_layers (int): Number of message passing blocks. cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) """ def __init__( self, hidden_node_channels: int, hidden_edge_channels: int, num_layers: int, cutoff: float = 5.0, max_num_neighbors: int = 32, ): super().__init__() self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.node_emb = Sequential( Embedding(95, hidden_node_channels), ShiftedSoftplus(), Linear(hidden_node_channels, hidden_node_channels), ShiftedSoftplus(), Linear(hidden_node_channels, hidden_node_channels), ) self.edge_emb = GaussianFilter(0.0, 5.0, hidden_edge_channels) self.node_blocks = ModuleList([ NodeBlock(hidden_node_channels, hidden_edge_channels) for _ in range(num_layers) ]) self.edge_blocks = ModuleList([ EdgeBlock(hidden_node_channels, hidden_edge_channels) for _ in range(num_layers) ]) self.force_predictor = Sequential( Linear(hidden_edge_channels, hidden_edge_channels), ShiftedSoftplus(), Linear(hidden_edge_channels, hidden_edge_channels), ShiftedSoftplus(), Linear(hidden_edge_channels, 1), ) def reset_parameters(self): reset(self.node_emb) self.edge_emb.reset_parameters() for node_block in self.node_blocks: node_block.reset_parameters() for edge_block in self.edge_blocks: edge_block.reset_parameters() reset(self.force_predictor) def forward(self, z: Tensor, pos: Tensor, batch: OptTensor = None) -> Tensor: """""" # noqa: D419 edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( edge_index, num_nodes=z.size(0)) # Calculate distances and unit vector: dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() unit_vec = (pos[i] - pos[j]) / dist.view(-1, 1) # Embedding blocks: node_emb = self.node_emb(z) edge_emb = self.edge_emb(dist) # Message passing blocks: for node_block, edge_block in zip(self.node_blocks, self.edge_blocks): node_emb = node_block(node_emb, edge_emb, i) edge_emb = edge_block(node_emb, edge_emb, i, j, idx_i, idx_j, idx_k, idx_ji, idx_kj) # Force prediction block: force = self.force_predictor(edge_emb) * unit_vec return scatter(force, i, dim=0, reduce='sum') ================================================ FILE: torch_geometric/nn/models/gpse.py ================================================ import logging import os import os.path as osp import time from collections import OrderedDict from typing import List, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Module from tqdm import trange import torch_geometric.transforms as T from torch_geometric.data import Data, Dataset, download_url from torch_geometric.loader import DataLoader, NeighborLoader from torch_geometric.nn import ( ResGatedGraphConv, global_add_pool, global_max_pool, global_mean_pool, ) from torch_geometric.nn.resolver import activation_resolver from torch_geometric.utils import to_dense_batch class Linear(torch.nn.Module): def __init__( self, in_channels: int, out_channels: int, bias: bool, ) -> None: super().__init__() self.model = torch.nn.Linear(in_channels, out_channels, bias=bias) def forward(self, batch): if isinstance(batch, torch.Tensor): batch = self.model(batch) else: batch.x = self.model(batch.x) return batch class ResGatedGCNConv(torch.nn.Module): def __init__( self, in_channels: int, out_channels: int, bias: bool, **kwargs, ) -> None: super().__init__() self.model = ResGatedGraphConv( in_channels, out_channels, bias=bias, **kwargs, ) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch class GeneralLayer(torch.nn.Module): def __init__( self, name: str, in_channels: int, out_channels: int, has_batch_norm: bool, has_l2_norm: bool, dropout: float, act: Optional[str], **kwargs, ): super().__init__() self.has_l2_norm = has_l2_norm layer_dict = { 'linear': Linear, 'resgatedgcnconv': ResGatedGCNConv, } self.layer = layer_dict[name]( in_channels, out_channels, bias=not has_batch_norm, **kwargs, ) post_layers = [] if has_batch_norm: post_layers.append( torch.nn.BatchNorm1d(out_channels, eps=1e-5, momentum=0.1)) if dropout > 0: post_layers.append(torch.nn.Dropout(p=dropout, inplace=False)) if act is not None: post_layers.append(activation_resolver(act)) self.post_layer = nn.Sequential(*post_layers) def forward(self, batch): batch = self.layer(batch) if isinstance(batch, torch.Tensor): batch = self.post_layer(batch) if self.has_l2_norm: batch = F.normalize(batch, p=2, dim=1) else: batch.x = self.post_layer(batch.x) if self.has_l2_norm: batch.x = F.normalize(batch.x, p=2, dim=1) return batch class GeneralMultiLayer(torch.nn.Module): def __init__( self, name: str, in_channels: int, out_channels: int, hidden_channels: Optional[int], num_layers: int, has_batch_norm: bool, has_l2_norm: bool, dropout: float, act: str, final_act: bool, **kwargs, ) -> None: super().__init__() hidden_channels = hidden_channels or out_channels for i in range(num_layers): d_in = in_channels if i == 0 else hidden_channels d_out = out_channels if i == num_layers - 1 else hidden_channels layer = GeneralLayer( name=name, in_channels=d_in, out_channels=d_out, has_batch_norm=has_batch_norm, has_l2_norm=has_l2_norm, dropout=dropout, act=None if i == num_layers - 1 and not final_act else act, **kwargs, ) self.add_module(f'Layer_{i}', layer) def forward(self, batch): for layer in self.children(): batch = layer(batch) return batch class BatchNorm1dNode(torch.nn.Module): def __init__(self, channels: int) -> None: super().__init__() self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1) def forward(self, batch): batch.x = self.bn(batch.x) return batch class BatchNorm1dEdge(torch.nn.Module): def __init__(self, channels: int) -> None: super().__init__() self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1) def forward(self, batch): batch.edge_attr = self.bn(batch.edge_attr) return batch class MLP(torch.nn.Module): def __init__( self, in_channels: int, out_channels: int, hidden_channels: Optional[int], num_layers: int, has_batch_norm: bool = True, has_l2_norm: bool = True, dropout: float = 0.2, act: str = 'relu', **kwargs, ): super().__init__() hidden_channels = hidden_channels or in_channels layers = [] if num_layers > 1: layer = GeneralMultiLayer( 'linear', in_channels, hidden_channels, hidden_channels, num_layers - 1, has_batch_norm, has_l2_norm, dropout, act, final_act=True, **kwargs, ) layers.append(layer) layers.append(Linear(hidden_channels, out_channels, bias=True)) self.model = nn.Sequential(*layers) def forward(self, batch): if isinstance(batch, torch.Tensor): batch = self.model(batch) else: batch.x = self.model(batch.x) return batch class GNNStackStage(torch.nn.Module): def __init__( self, in_channels: int, out_channels: int, num_layers: int, layer_type: str, stage_type: str = 'skipsum', final_l2_norm: bool = True, has_batch_norm: bool = True, has_l2_norm: bool = True, dropout: float = 0.2, act: Optional[str] = 'relu', ): super().__init__() self.num_layers = num_layers self.stage_type = stage_type self.final_l2_norm = final_l2_norm for i in range(num_layers): if stage_type == 'skipconcat': if i == 0: d_in = in_channels else: d_in = in_channels + i * out_channels else: d_in = in_channels if i == 0 else out_channels layer = GeneralLayer(layer_type, d_in, out_channels, has_batch_norm, has_l2_norm, dropout, act) self.add_module(f'layer{i}', layer) def forward(self, batch): for i, layer in enumerate(self.children()): x = batch.x batch = layer(batch) if self.stage_type == 'skipsum': batch.x = x + batch.x elif self.stage_type == 'skipconcat' and i < self.num_layers - 1: batch.x = torch.cat([x, batch.x], dim=1) if self.final_l2_norm: batch.x = F.normalize(batch.x, p=2, dim=-1) return batch class GNNInductiveHybridMultiHead(torch.nn.Module): r"""GNN prediction head for inductive node and graph prediction tasks using individual MLP for each task. Args: dim_in (int): Input dimension. dim_out (int): Output dimension. Not used, as the dimension is determined by :obj:`num_node_targets` and :obj:`num_graph_targets` instead. num_node_targets (int): Number of individual PSEs used as node-level targets in pretraining :class:`GPSE`. num_graph_targets (int): Number of graph-level targets used in pretraining :class:`GPSE`. layers_post_mp (int): Number of MLP layers after GNN message-passing. virtual_node (bool, optional): Whether a virtual node is added to graphs in :class:`GPSE` computation. (default: :obj:`True`) multi_head_dim_inner (int, optional): Width of MLPs for PSE target prediction heads. (default: :obj:`32`) graph_pooling (str, optional): Type of graph pooling applied before post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`. (default: :obj:`add`) has_bn (bool, optional): Whether to apply batch normalization to layer outputs. (default: :obj:`True`) has_l2norm (bool, optional): Whether to apply L2 normalization to the layer outputs. (default: :obj:`True`) dropout (float, optional): Dropout ratio at layer output. (default: :obj:`0.2`) act (str, optional): Activation to apply to layer outputs if :obj:`has_act` is :obj:`True`. (default: :obj:`relu`) """ def __init__( self, dim_in: int, dim_out: int, num_node_targets: int, num_graph_targets: int, layers_post_mp: int, virtual_node: bool = True, multi_head_dim_inner: int = 32, graph_pooling: str = 'add', has_bn: bool = True, has_l2norm: bool = True, dropout: float = 0.2, act: str = 'relu', ): super().__init__() pool_dict = { 'add': global_add_pool, 'max': global_max_pool, 'mean': global_mean_pool } self.node_target_dim = num_node_targets self.graph_target_dim = num_graph_targets self.virtual_node = virtual_node num_layers = layers_post_mp self.node_post_mps = nn.ModuleList([ MLP(dim_in, 1, multi_head_dim_inner, num_layers, has_bn, has_l2norm, dropout, act) for _ in range(self.node_target_dim) ]) self.graph_pooling = pool_dict[graph_pooling] self.graph_post_mp = MLP(dim_in, self.graph_target_dim, dim_in, num_layers, has_bn, has_l2norm, dropout, act) def _pad_and_stack(self, x1: torch.Tensor, x2: torch.Tensor, pad1: int, pad2: int): padded_x1 = nn.functional.pad(x1, (0, pad2)) padded_x2 = nn.functional.pad(x2, (pad1, 0)) return torch.vstack([padded_x1, padded_x2]) def _apply_index(self, batch, virtual_node: bool, pad_node: int, pad_graph: int): graph_pred, graph_true = batch.graph_feature, batch.y_graph node_pred, node_true = batch.node_feature, batch.y if virtual_node: # Remove virtual node idx = torch.concat([ torch.where(batch.batch == i)[0][:-1] for i in range(batch.batch.max().item() + 1) ]) node_pred, node_true = node_pred[idx], node_true[idx] # Stack node predictions on top of graph predictions and pad with zeros pred = self._pad_and_stack(node_pred, graph_pred, pad_node, pad_graph) true = self._pad_and_stack(node_true, graph_true, pad_node, pad_graph) return pred, true def forward(self, batch): batch.node_feature = torch.hstack( [m(batch.x) for m in self.node_post_mps]) graph_emb = self.graph_pooling(batch.x, batch.batch) batch.graph_feature = self.graph_post_mp(graph_emb) return self._apply_index(batch, self.virtual_node, self.node_target_dim, self.graph_target_dim) class IdentityHead(torch.nn.Module): def forward(self, batch): return batch.x, batch.y class GPSE(torch.nn.Module): r"""The Graph Positional and Structural Encoder (GPSE) model from the `"Graph Positional and Structural Encoder" `_ paper. The GPSE model consists of a (1) deep GNN that consists of stacked message passing layers, and a (2) prediction head to predict pre-computed positional and structural encodings (PSE). When used on downstream datasets, these prediction heads are removed and the final fully-connected layer outputs are used as learned PSE embeddings. GPSE also provides a static method :meth:`from_pretrained` to load pre-trained GPSE models trained on a variety of molecular datasets. .. code-block:: python from torch_geometric.nn import GPSE, GPSENodeEncoder from torch_geometric.transforms import AddGPSE from torch_geometric.nn.models.gpse import precompute_GPSE gpse_model = GPSE.from_pretrained('molpcba') # Option 1: Precompute GPSE encodings in-place for a given dataset dataset = ZINC(path, subset=True, split='train') precompute_gpse(gpse_model, dataset) # Option 2: Use the GPSE model with AddGPSE as a pre_transform to save # the encodings dataset = ZINC(path, subset=True, split='train', pre_transform=AddGPSE(gpse_model, vn=True, rand_type='NormalSE')) Both approaches append the generated encodings to the :obj:`pestat_GPSE` attribute of :class:`~torch_geometric.data.Data` objects. To use the GPSE encodings for a downstream task, one may need to add these encodings to the :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects. To do so, one can use the :class:`GPSENodeEncoder` provided to map these encodings to a desired dimension before appending them to :obj:`x`. Let's say we have a graph dataset with 64 original node features, and we have generated GPSE encodings of dimension 32, i.e. :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an inner dimension of 128. To do so, we can map the 32-dimensional GPSE encodings to a higher dimension of 64, and then append them to the :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to obtain a 128-dimensional node feature representation. :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and concatenation to :obj:`x`, the outputs of which can be used as input to a GNN: .. code-block:: python encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64, expand_x=False) gnn = GNN(...) for batch in loader: x = encoder(batch.x, batch.pestat_GPSE) out = gnn(x, batch.edge_index) Args: dim_in (int, optional): Input dimension. (default: :obj:`20`) dim_out (int, optional): Output dimension. (default: :obj:`51`) dim_inner (int, optional): Width of the encoder layers. (default: :obj:`512`) layer_type (str, optional): Type of graph convolutional layer for message-passing. (default: :obj:`resgatedgcnconv`) layers_pre_mp (int, optional): Number of MLP layers before message-passing. (default: :obj:`1`) layers_mp (int, optional): Number of layers for message-passing. (default: :obj:`20`) layers_post_mp (int, optional): Number of MLP layers after message-passing. (default: :obj:`2`) num_node_targets (int, optional): Number of individual PSEs used as node-level targets in pretraining :class:`GPSE`. (default: :obj:`51`) num_graph_targets (int, optional): Number of graph-level targets used in pretraining :class:`GPSE`. (default: :obj:`11`) stage_type (str, optional): The type of staging to apply. Possible values are: :obj:`skipsum`, :obj:`skipconcat`. Any other value will default to no skip connections. (default: :obj:`skipsum`) has_bn (bool, optional): Whether to apply batch normalization in the layer. (default: :obj:`True`) final_l2norm (bool, optional): Whether to apply L2 normalization to the outputs. (default: :obj:`True`) has_l2norm (bool, optional): Whether to apply L2 normalization after the layer. (default: :obj:`True`) dropout (float, optional): Dropout ratio at layer output. (default: :obj:`0.2`) has_act (bool, optional): Whether has activation after the layer. (default: :obj:`True`) final_act (bool, optional): Whether to apply activation after the layer stack. (default: :obj:`True`) act (str, optional): Activation to apply to layer output if :obj:`has_act` is :obj:`True`. (default: :obj:`relu`) virtual_node (bool, optional): Whether a virtual node is added to graphs in :class:`GPSE` computation. (default: :obj:`True`) multi_head_dim_inner (int, optional): Width of MLPs for PSE target prediction heads. (default: :obj:`32`) graph_pooling (str, optional): Type of graph pooling applied before post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`. (default: :obj:`add`) use_repr (bool, optional): Whether to use the hidden representation of the final layer as :class:`GPSE` encodings. (default: :obj:`True`) repr_type (str, optional): Type of representation to use. Options are :obj:`no_post_mp`, :obj:`one_layer_before`. (default: :obj:`no_post_mp`) bernoulli_threshold (float, optional): Threshold for Bernoulli sampling of virtual nodes. (default: :obj:`0.5`) """ url_dict = { 'molpcba': 'https://zenodo.org/record/8145095/files/' 'gpse_model_molpcba_1.0.pt', 'zinc': 'https://zenodo.org/record/8145095/files/gpse_model_zinc_1.0.pt', 'pcqm4mv2': 'https://zenodo.org/record/8145095/files/' 'gpse_model_pcqm4mv2_1.0.pt', 'geom': 'https://zenodo.org/record/8145095/files/gpse_model_geom_1.0.pt', 'chembl': 'https://zenodo.org/record/8145095/files/gpse_model_chembl_1.0.pt' } def __init__( self, dim_in: int = 20, dim_out: int = 51, dim_inner: int = 512, layer_type: str = 'resgatedgcnconv', layers_pre_mp: int = 1, layers_mp: int = 20, layers_post_mp: int = 2, num_node_targets: int = 51, num_graph_targets: int = 11, stage_type: str = 'skipsum', has_bn: bool = True, head_bn: bool = False, final_l2norm: bool = True, has_l2norm: bool = True, dropout: float = 0.2, has_act: bool = True, final_act: bool = True, act: str = 'relu', virtual_node: bool = True, multi_head_dim_inner: int = 32, graph_pooling: str = 'add', use_repr: bool = True, repr_type: str = 'no_post_mp', bernoulli_threshold: float = 0.5, ): super().__init__() self.use_repr = use_repr self.repr_type = repr_type self.bernoulli_threshold = bernoulli_threshold if layers_pre_mp > 0: self.pre_mp = GeneralMultiLayer( name='linear', in_channels=dim_in, out_channels=dim_inner, hidden_channels=dim_inner, num_layers=layers_pre_mp, has_batch_norm=has_bn, has_l2_norm=has_l2norm, dropout=dropout, act=act, final_act=final_act, ) dim_in = dim_inner if layers_mp > 0: self.mp = GNNStackStage( in_channels=dim_in, out_channels=dim_inner, num_layers=layers_mp, layer_type=layer_type, stage_type=stage_type, final_l2_norm=final_l2norm, has_batch_norm=has_bn, has_l2_norm=has_l2norm, dropout=dropout, act=act if has_act else None, ) self.post_mp = GNNInductiveHybridMultiHead( dim_inner, dim_out, num_node_targets, num_graph_targets, layers_post_mp, virtual_node, multi_head_dim_inner, graph_pooling, head_bn, has_l2norm, dropout, act, ) self.reset_parameters() def reset_parameters(self): pass @classmethod def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'): r"""Returns a pretrained :class:`GPSE` model on a dataset. Args: name (str): The name of the dataset (:obj:`"molpcba"`, :obj:`"zinc"`, :obj:`"pcqm4mv2"`, :obj:`"geom"`, :obj:`"chembl"`). root (str, optional): The root directory to save the pre-trained model. (default: :obj:`"GPSE_pretrained"`) """ root = osp.expanduser(osp.normpath(root)) os.makedirs(root, exist_ok=True) path = download_url(cls.url_dict[name], root) model = GPSE() # All pretrained models use the default arguments model_state = torch.load(path, map_location='cpu')['model_state'] model_state_new = OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state.items()]) model.load_state_dict(model_state_new) # Set the final linear layer to identity if we use hidden reprs if model.use_repr: if model.repr_type == 'one_layer_before': model.post_mp.layer_post_mp.model[-1] = torch.nn.Identity() elif model.repr_type == 'no_post_mp': model.post_mp = IdentityHead() else: raise ValueError(f"Unknown type '{model.repr_type}'") model.eval() return model def forward(self, batch): batch = batch.clone() for module in self.children(): batch = module(batch) return batch class GPSENodeEncoder(torch.nn.Module): r"""A helper linear/MLP encoder that takes the :class:`GPSE` encodings (based on the `"Graph Positional and Structural Encoder" `_ paper) precomputed as :obj:`batch.pestat_GPSE` in the input graphs, maps them to a desired dimension defined by :obj:`dim_pe_out` and appends them to node features. Let's say we have a graph dataset with 64 original node features, and we have generated GPSE encodings of dimension 32, i.e. :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an inner dimension of 128. To do so, we can map the 32-dimensional GPSE encodings to a higher dimension of 64, and then append them to the :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to obtain a 128-dimensional node feature representation. :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and concatenation to :obj:`x`, the outputs of which can be used as input to a GNN: .. code-block:: python encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64, expand_x=False) gnn = GNN(...) for batch in loader: x = encoder(batch.x, batch.pestat_GPSE) batch = gnn(x, batch.edge_index) Args: dim_emb (int): Size of final node embedding. dim_pe_in (int): Original dimension of :obj:`batch.pestat_GPSE`. dim_pe_out (int): Desired dimension of :class:`GPSE` after the encoder. dim_in (int, optional): Original dimension of input node features, required only if :obj:`expand_x` is set to :obj:`True`. (default: :obj:`None`) expand_x (bool, optional): Expand node features :obj:`x` from :obj:`dim_in` to (:obj:`dim_emb` - :obj:`dim_pe_out`) norm_type (str, optional): Type of normalization to apply. (default: :obj:`batchnorm`) model_type (str, optional): Type of encoder, either :obj:`mlp` or :obj:`linear`. (default: :obj:`mlp`) n_layers (int, optional): Number of MLP layers if :obj:`model_type` is :obj:`mlp`. (default: :obj:`2`) dropout_be (float, optional): Dropout ratio of inputs to encoder, i.e. before encoding. (default: :obj:`0.5`) dropout_ae (float, optional): Dropout ratio of outputs, i.e. after encoding. (default: :obj:`0.2`) """ def __init__(self, dim_emb: int, dim_pe_in: int, dim_pe_out: int, dim_in: int = None, expand_x=False, norm_type='batchnorm', model_type='mlp', n_layers=2, dropout_be=0.5, dropout_ae=0.2): super().__init__() assert dim_emb > dim_pe_out, ('Desired GPSE dimension (dim_pe_out) ' 'must be smaller than the final node ' 'embedding dimension (dim_emb).') if expand_x: self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe_out) self.expand_x = expand_x self.raw_norm = None if norm_type == 'batchnorm': self.raw_norm = nn.BatchNorm1d(dim_pe_in) self.dropout_be = nn.Dropout(p=dropout_be) self.dropout_ae = nn.Dropout(p=dropout_ae) activation = nn.ReLU # register.act_dict[cfg.gnn.act] if model_type == 'mlp': layers = [] if n_layers == 1: layers.append(torch.nn.Linear(dim_pe_in, dim_pe_out)) layers.append(activation()) else: layers.append(torch.nn.Linear(dim_pe_in, 2 * dim_pe_out)) layers.append(activation()) for _ in range(n_layers - 2): layers.append( torch.nn.Linear(2 * dim_pe_out, 2 * dim_pe_out)) layers.append(activation()) layers.append(torch.nn.Linear(2 * dim_pe_out, dim_pe_out)) layers.append(activation()) self.pe_encoder = nn.Sequential(*layers) elif model_type == 'linear': self.pe_encoder = nn.Linear(dim_pe_in, dim_pe_out) else: raise ValueError(f"{self.__class__.__name__}: Does not support " f"'{model_type}' encoder model.") def forward(self, x, pos_enc): pos_enc = self.dropout_be(pos_enc) pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe pos_enc = self.dropout_ae(pos_enc) # Expand node features if needed h = self.linear_x(x) if self.expand_x else x # Concatenate final PEs to input embedding return torch.cat((h, pos_enc), 1) @torch.no_grad() def gpse_process( model: Module, data: Data, rand_type: str, use_vn: bool = True, bernoulli_thresh: float = 0.5, neighbor_loader: bool = False, num_neighbors: Optional[List[int]] = None, fillval: int = 5, layers_mp: int = None, **kwargs, ) -> torch.Tensor: r"""Processes the data using the :class:`GPSE` model to generate and append GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a single :class:`~torch_geometric.data.Dataset` object. Unlike transform-based GPSE processing (i.e. :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument does not append virtual nodes if set to :obj:`True`, and instead assumes the input graphs to :obj:`gpse_process` already have virtual nodes. Under normal circumstances, one does not need to call this function; running :obj:`precompute_GPSE` on your whole dataset is advised instead. Args: model (Module): The :class:`GPSE` model. data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data` object. rand_type (str, optional): Type of random features to use. Options are :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`. (default: :obj:`NormalSE`) use_vn (bool, optional): Whether the input graphs have virtual nodes. (default: :obj:`True`) bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of virtual nodes. (default: :obj:`0.5`) neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`. (default: :obj:`False`) num_neighbors (List[int], optional): Number of neighbors to consider for each message-passing layer. (default: :obj:`[30, 20, 10]`) fillval (int, optional): Value to fill for missing :obj:`num_neighbors`. (default: :obj:`5`) layers_mp (int, optional): Number of message-passing layers. (default: :obj:`None`) **kwargs (optional): Additional arguments for :obj:`NeighborLoader`. Returns: torch.Tensor: A tensor corresponding to the original :class:`~torch_geometric.data.Data` object, with :class:`GPSE` encodings appended as :obj:`out.pestat_GPSE` attribute. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Generate random features for the encoder n = data.num_nodes dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1] # Prepare input distributions for GPSE if rand_type == 'NormalSE': rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in)) elif rand_type == 'UniformSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) elif rand_type == 'BernoulliSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) rand = (rand < bernoulli_thresh) else: raise ValueError(f'Unknown {rand_type=!r}') data.x = torch.from_numpy(rand.astype('float32')) if use_vn: data.x[-1] = 0 model, data = model.to(device), data.to(device) # Generate encodings using the pretrained encoder if neighbor_loader: if layers_mp is None: raise ValueError('Please provide the number of message-passing ' 'layers as "layers_mp".') num_neighbors = num_neighbors or [30, 20, 10] diff = layers_mp - len(num_neighbors) if fillval > 0 and diff > 0: num_neighbors += [fillval] * diff loader = NeighborLoader(data, num_neighbors=num_neighbors, shuffle=False, pin_memory=True, **kwargs) out_list = [] pbar = trange(data.num_nodes, position=2) for batch in loader: out, _ = model(batch.to(device)) out = out[:batch.batch_size].to("cpu", non_blocking=True) out_list.append(out) pbar.update(batch.batch_size) out = torch.vstack(out_list) else: out, _ = model(data) out = out.to("cpu") return out @torch.no_grad() def gpse_process_batch( model: GPSE, batch, rand_type: str, use_vn: bool = True, bernoulli_thresh: float = 0.5, neighbor_loader: bool = False, num_neighbors: Optional[List[int]] = None, fillval: int = 5, layers_mp: int = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Process a batch of data using the :class:`GPSE` model to generate and append :class:`GPSE` encodings. Identical to `gpse_process`, but operates on a batch of :class:`~torch_geometric.data.Data` objects. Unlike transform-based GPSE processing (i.e. :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument does not append virtual nodes if set to :obj:`True`, and instead assumes the input graphs to :obj:`gpse_process` already have virtual nodes. This is because the virtual nodes are already added to graphs before the call to :obj:`gpse_process_batch` in :obj:`precompute_GPSE` for better efficiency. Under normal circumstances, one does not need to call this function; running :obj:`precompute_GPSE` on your whole dataset is advised instead. Args: model (GPSE): The :class:`GPSE` model. batch: A batch of PyG Data objects. rand_type (str, optional): Type of random features to use. Options are :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`. (default: :obj:`NormalSE`) use_vn (bool, optional): Whether the input graphs have virtual nodes. (default: :obj:`True`) bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of virtual nodes. (default: :obj:`0.5`) neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`. (default: :obj:`False`) num_neighbors (List[int], optional): Number of neighbors to consider for each message-passing layer. (default: :obj:`[30, 20, 10]`) fillval (int, optional): Value to fill for missing :obj:`num_neighbors`. (default: :obj:`5`) layers_mp (int, optional): Number of message-passing layers. (default: :obj:`None`) **kwargs: Additional keyword arguments for :obj:`NeighborLoader`. Returns: Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding to the stacked :class:`GPSE` encodings and the pointers indicating individual graphs. """ n = batch.num_nodes dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1] # Prepare input distributions for GPSE if rand_type == 'NormalSE': rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in)) elif rand_type == 'UniformSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) elif rand_type == 'BernoulliSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) rand = (rand < bernoulli_thresh) else: raise ValueError(f'Unknown {rand_type=!r}') batch.x = torch.from_numpy(rand.astype('float32')) if use_vn: # HACK: We need to reset virtual node features to zeros to match the # pretraining setting (virtual node applied after random node features # are set, and the default node features for the virtual node are all # zeros). Can potentially test if initializing virtual node features to # random features is better than setting them to zeros. for i in batch.ptr[1:]: batch.x[i - 1] = 0 # Generate encodings using the pretrained encoder device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) if neighbor_loader: if layers_mp is None: raise ValueError('Please provide the number of message-passing ' 'layers as "layers_mp".') num_neighbors = num_neighbors or [30, 20, 10] diff = layers_mp - len(num_neighbors) if fillval > 0 and diff > 0: num_neighbors += [fillval] * diff loader = NeighborLoader(batch, num_neighbors=num_neighbors, shuffle=False, pin_memory=True, **kwargs) out_list = [] pbar = trange(batch.num_nodes, position=2) for batch in loader: out, _ = model(batch.to(device)) out = out[:batch.batch_size].to('cpu', non_blocking=True) out_list.append(out) pbar.update(batch.batch_size) out = torch.vstack(out_list) else: out, _ = model(batch.to(device)) out = out.to('cpu') return out, batch.ptr @torch.no_grad() def precompute_GPSE(model: GPSE, dataset: Dataset, use_vn: bool = True, rand_type: str = 'NormalSE', **kwargs): r"""Precomputes :class:`GPSE` encodings in-place for a given dataset using a :class:`GPSE` model. Args: model (GPSE): The :class:`GPSE` model. dataset (Dataset): A PyG Dataset. use_vn (bool, optional): Whether to append virtual nodes to graphs in :class:`GPSE` computation. Should match the setting used when pre-training the :class:`GPSE` model. (default :obj:`True`) rand_type (str, optional): The type of randomization to use. (default :obj:`NormalSE`) **kwargs (optional): Additional arguments for :class:`~torch_geometric.data.DataLoader`. """ # Temporarily replace the transformation orig_dataset_transform = dataset.transform dataset.transform = None if use_vn: dataset.transform = T.VirtualNode() # Remove split indices, to be recovered at the end of the precomputation tmp_store = {} for name in [ 'train_mask', 'val_mask', 'test_mask', 'train_graph_index', 'val_graph_index', 'test_graph_index', 'train_edge_index', 'val_edge_index', 'test_edge_index' ]: if (name in dataset.data) and (dataset.slices is None or name in dataset.slices): tmp_store_data = dataset.data.pop(name) tmp_store_slices = dataset.slices.pop(name) \ if dataset.slices else None tmp_store[name] = (tmp_store_data, tmp_store_slices) loader = DataLoader(dataset, shuffle=False, pin_memory=True, **kwargs) # Batched GPSE precomputation loop data_list = [] curr_idx = 0 pbar = trange(len(dataset), desc='Pre-computing GPSE') tic = time.perf_counter() for batch in loader: batch_out, batch_ptr = gpse_process_batch(model, batch, rand_type, **kwargs) batch_out = batch_out.to('cpu', non_blocking=True) # Need to wait for batch_ptr to finish transferring so that start and # end indices are ready to use batch_ptr = batch_ptr.to('cpu', non_blocking=False) for start, end in zip(batch_ptr[:-1], batch_ptr[1:]): data = dataset.get(curr_idx) if use_vn: end = end - 1 data.pestat_GPSE = batch_out[start:end] data_list.append(data) curr_idx += 1 pbar.update(len(batch_ptr) - 1) pbar.close() # Collate dataset and reset indices and data list dataset.transform = orig_dataset_transform dataset._indices = None dataset._data_list = data_list dataset.data, dataset.slices = dataset.collate(data_list) # Recover split indices for name, (tmp_store_data, tmp_store_slices) in tmp_store.items(): dataset.data[name] = tmp_store_data if tmp_store_slices is not None: dataset.slices[name] = tmp_store_slices dataset._data_list = None timestr = time.strftime('%H:%M:%S', time.gmtime(time.perf_counter() - tic)) logging.info(f'Finished GPSE pre-computation, took {timestr}') # Release resource and recover original configs del model torch.cuda.empty_cache() def cosim_col_sep(pred: torch.Tensor, true: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor: r"""Calculates the average cosine similarity between predicted and true features on a batch of graphs. Args: pred (torch.Tensor): Predicted outputs. true (torch.Tensor): Value of ground truths. batch_idx (torch.Tensor): Batch indices to separate the graphs. Returns: torch.Tensor: Average cosine similarity per graph in batch. Raises: ValueError: If batch_index is not specified. """ if batch_idx is None: raise ValueError("mae_cosim_col_sep requires batch index as " "input to distinguish different graphs.") batch_idx = batch_idx + 1 if batch_idx.min() == -1 else batch_idx pred_dense = to_dense_batch(pred, batch_idx)[0] true_dense = to_dense_batch(true, batch_idx)[0] mask = (true_dense == 0).all(1) # exclude trivial features from loss loss = 1 - F.cosine_similarity(pred_dense, true_dense, dim=1)[~mask].mean() return loss def gpse_loss(pred: torch.Tensor, true: torch.Tensor, batch_idx: torch.Tensor = None) \ -> Tuple[torch.Tensor, torch.Tensor]: r"""Calculates :class:`GPSE` loss as the sum of MAE loss and cosine similarity loss over a batch of graphs. Args: pred (torch.Tensor): Predicted outputs. true (torch.Tensor): Value of ground truths. batch_idx (torch.Tensor): Batch indices to separate the graphs. Returns: Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding to the :class:`GPSE` loss and the predicted node-and-graph level outputs. """ if batch_idx is None: raise ValueError("mae_cosim_col_sep requires batch index as " "input to distinguish different graphs.") mae_loss = F.l1_loss(pred, true) cosim_loss = cosim_col_sep(pred, true, batch_idx) loss = mae_loss + cosim_loss return loss, pred def process_batch_idx(batch_idx, true, use_vn=True): r"""Processes batch indices to adjust for the removal of virtual nodes, and pads batch index for hybrid tasks. Args: batch_idx: Batch indices to separate the graphs. true: Value of ground truths. use_vn: If input graphs have virtual nodes that need to be removed. Returns: torch.Tensor: Batch indices that separate the graphs. """ if batch_idx is None: return if use_vn: # remove virtual node batch_idx = torch.concat([ batch_idx[batch_idx == i][:-1] for i in range(batch_idx.max().item() + 1) ]) # Pad batch index for hybrid tasks (set batch index for graph heads to -1) if (pad := true.shape[0] - batch_idx.shape[0]) > 0: pad_idx = -torch.ones(pad, dtype=torch.long, device=batch_idx.device) batch_idx = torch.hstack([batch_idx, pad_idx]) return batch_idx ================================================ FILE: torch_geometric/nn/models/graph_mixer.py ================================================ import numpy as np import torch import torch.nn.functional as F from torch import Tensor from torch.nn import LayerNorm, Linear from torch_geometric.nn import TemporalEncoding from torch_geometric.utils import scatter, to_dense_batch class NodeEncoder(torch.nn.Module): r"""The node encoder module from the `"Do We Really Need Complicated Model Architectures for Temporal Networks?" `_ paper. :class:`NodeEncoder` captures the 1-hop temporal neighborhood information via mean pooling. .. math:: \mathbf{x}_v^{\prime}(t_0) = \mathbf{x}_v + \textrm{mean} \left\{ \mathbf{x}_w : w \in \mathcal{N}(v, t_0 - T, t_0) \right\} Args: time_window (int): The temporal window size :math:`T` to define the 1-hop temporal neighborhood. """ def __init__(self, time_window: int): super().__init__() self.time_window = time_window def reset_parameters(self): pass def forward( self, x: Tensor, edge_index: Tensor, edge_time: Tensor, seed_time: Tensor, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor): The edge indices. edge_time (torch.Tensor): The timestamp attached to every edge. seed_time (torch.Tensor): The seed time :math:`t_0` for every destination node. """ mask = ((edge_time <= seed_time[edge_index[1]]) & (edge_time > seed_time[edge_index[1]] - self.time_window)) src, dst = edge_index[:, mask] mean = scatter(x[src], dst, dim=0, dim_size=x.size(0), reduce='mean') return x + mean def __repr__(self) -> str: return f'{self.__class__.__name__}(time_window={self.time_window})' class _MLPMixer(torch.nn.Module): r"""The MLP-Mixer module. Args: num_tokens (int): Number of tokens/patches in each sample. in_channels (int): Input channels. out_channels (int): Output channels. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__( self, num_tokens: int, in_channels: int, out_channels: int, dropout: float = 0.0, ): super().__init__() self.dropout = dropout self.token_norm = LayerNorm(in_channels) self.token_lin1 = Linear(num_tokens, num_tokens // 2) self.token_lin2 = Linear(num_tokens // 2, num_tokens) self.channel_norm = LayerNorm(in_channels) self.channel_lin1 = Linear(in_channels, 4 * in_channels) self.channel_lin2 = Linear(4 * in_channels, in_channels) self.head_norm = LayerNorm(in_channels) self.head_lin = Linear(in_channels, out_channels) def reset_parameters(self): self.token_norm.reset_parameters() self.token_lin1.reset_parameters() self.token_lin2.reset_parameters() self.channel_norm.reset_parameters() self.channel_lin1.reset_parameters() self.channel_lin2.reset_parameters() self.head_norm.reset_parameters() self.head_lin.reset_parameters() def forward(self, x: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Tensor of size :obj:`[*, num_tokens, in_channels]`. Returns: Tensor of size :obj:`[*, out_channels]`. """ # Token mixing: h = self.token_norm(x).mT h = self.token_lin1(h) h = F.gelu(h) h = F.dropout(h, p=self.dropout, training=self.training) h = self.token_lin2(h) h = F.dropout(h, p=self.dropout, training=self.training) h_token = h.mT + x # Channel mixing: h = self.channel_norm(h_token) h = self.channel_lin1(h) h = F.gelu(h) h = F.dropout(h, p=self.dropout, training=self.training) h = self.channel_lin2(h) h = F.dropout(h, p=self.dropout, training=self.training) h_channel = h + h_token # Head: out = self.head_norm(h_channel) out = out.mean(dim=1) out = self.head_lin(out) return out def get_latest_k_edge_attr( k: int, edge_index: Tensor, edge_attr: Tensor, edge_time: Tensor, num_nodes: int, is_sorted: bool = False, ) -> Tensor: r"""Returns the latest :obj:`k` incoming edge attributes by :obj:`edge_time` for each node. The shape of the output tensor is :obj:`[num_nodes, k, edge_attr_dim]`. Nodes with fewer than :obj:`k` incoming edges are zero-padded. """ _, col = edge_index if not is_sorted: perm = np.lexsort([ -edge_time.detach().cpu().numpy(), col.detach().cpu().numpy(), ]) perm = torch.from_numpy(perm).to(edge_index.device) col = col[perm] edge_attr = edge_attr[perm] return to_dense_batch( edge_attr, col, max_num_nodes=k, batch_size=num_nodes, )[0] class LinkEncoder(torch.nn.Module): r"""The link encoder module from the `"Do We Really Need Complicated Model Architectures for Temporal Networks?" `_ paper. It is composed of two components: (1) :class:`TemporalEncoding` maps each edge timestamp to a :obj:`time_channels`-dimensional vector; (2) an MLP that groups and maps the :math:`k`-latest encoded timestamps and edge features to a :obj:`out_channels`-dimensional representation. Args: k (int): The number of most recent temporal links to use. in_channels (int): The edge feature dimensionality. hidden_channels (int): Size of each hidden sample. time_channels (int): Size of encoded timestamp. out_channels (int): Size of each output sample. is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column and the rows are sorted according to :obj:`edge_time` within individual neighborhoods. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) dropout (float, optional): Dropout probability of the MLP layer. (default: :obj:`0.0`) """ def __init__( self, k: int, in_channels: int, hidden_channels: int, out_channels: int, time_channels: int, is_sorted: bool = False, dropout: float = 0.0, ): super().__init__() self.k = k self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.time_channels = time_channels self.is_sorted = is_sorted self.dropout = dropout self.temporal_encoder = TemporalEncoding(time_channels) self.temporal_head = Linear(time_channels + in_channels, hidden_channels) self.mlp_mixer = _MLPMixer( # MLP that summarizes temporal embeddings: num_tokens=k, in_channels=hidden_channels, out_channels=out_channels, dropout=dropout, ) def reset_parameters(self): self.temporal_encoder.reset_parameters() self.temporal_head.reset_parameters() self.mlp_mixer.reset_parameters() def forward( self, edge_index: Tensor, edge_attr: Tensor, edge_time: Tensor, seed_time: Tensor, ) -> Tensor: r"""Forward pass. Args: edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor): The edge features of shape :obj:`[num_edges, in_channels]`. edge_time (torch.Tensor): The time tensor of shape :obj:`[num_edges]`. This can be in the order of millions. seed_time (torch.Tensor): The seed time :math:`t_0` for every destination node. Returns: A node embedding tensor of shape :obj:`[num_nodes, out_channels]`. """ mask = edge_time <= seed_time[edge_index[1]] edge_index = edge_index[:, mask] edge_attr = edge_attr[mask] edge_time = edge_time[mask] time_enc = self.temporal_encoder(seed_time[edge_index[1]] - edge_time) edge_attr = torch.cat([time_enc, edge_attr], dim=-1) edge_attr = self.temporal_head(edge_attr) edge_attr = get_latest_k_edge_attr( k=self.k, edge_index=edge_index, edge_attr=edge_attr, edge_time=edge_time, num_nodes=seed_time.size(0), is_sorted=self.is_sorted, ) return self.mlp_mixer(edge_attr) def __repr__(self) -> str: return (f'{self.__class__.__name__}(k={self.k}, ' f'in_channels={self.in_channels}, ' f'hidden_channels={self.hidden_channels}, ' f'out_channels={self.out_channels}, ' f'time_channels={self.time_channels}, ' f'dropout={self.dropout})') ================================================ FILE: torch_geometric/nn/models/graph_unet.py ================================================ from typing import Callable, List, Union import torch from torch import Tensor from torch_geometric.nn import GCNConv, TopKPooling from torch_geometric.nn.resolver import activation_resolver from torch_geometric.typing import OptTensor, PairTensor from torch_geometric.utils import ( add_self_loops, remove_self_loops, to_torch_csr_tensor, ) from torch_geometric.utils.repeat import repeat class GraphUNet(torch.nn.Module): r"""The Graph U-Net model from the `"Graph U-Nets" `_ paper which implements a U-Net like architecture with graph pooling and unpooling operations. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output sample. depth (int): The depth of the U-Net architecture. pool_ratios (float or [float], optional): Graph pooling ratio for each depth. (default: :obj:`0.5`) sum_res (bool, optional): If set to :obj:`False`, will use concatenation for integration of skip connections instead summation. (default: :obj:`True`) act (torch.nn.functional, optional): The nonlinearity to use. (default: :obj:`torch.nn.functional.relu`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, depth: int, pool_ratios: Union[float, List[float]] = 0.5, sum_res: bool = True, act: Union[str, Callable] = 'relu', ): super().__init__() assert depth >= 1 self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.depth = depth self.pool_ratios = repeat(pool_ratios, depth) self.act = activation_resolver(act) self.sum_res = sum_res channels = hidden_channels self.down_convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.down_convs.append(GCNConv(in_channels, channels, improved=True)) for i in range(depth): self.pools.append(TopKPooling(channels, self.pool_ratios[i])) self.down_convs.append(GCNConv(channels, channels, improved=True)) in_channels = channels if sum_res else 2 * channels self.up_convs = torch.nn.ModuleList() for _ in range(depth - 1): self.up_convs.append(GCNConv(in_channels, channels, improved=True)) self.up_convs.append(GCNConv(in_channels, out_channels, improved=True)) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.down_convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() for conv in self.up_convs: conv.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, batch: OptTensor = None, edge_weight: Tensor = None, ) -> Tensor: """""" # noqa: D419 if batch is None: batch = edge_index.new_zeros(x.size(0)) if edge_weight is None: edge_weight = x.new_ones(edge_index.size(1)) assert edge_weight.dim() == 1 assert edge_weight.size(0) == edge_index.size(1) x = self.down_convs[0](x, edge_index, edge_weight) x = self.act(x) xs = [x] edge_indices = [edge_index] edge_weights = [edge_weight] perms = [] for i in range(1, self.depth + 1): edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0)) x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1]( x, edge_index, edge_weight, batch) x = self.down_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth: xs += [x] edge_indices += [edge_index] edge_weights += [edge_weight] perms += [perm] for i in range(self.depth): j = self.depth - 1 - i res = xs[j] edge_index = edge_indices[j] edge_weight = edge_weights[j] perm = perms[j] up = torch.zeros_like(res) up[perm] = x x = res + up if self.sum_res else torch.cat((res, up), dim=-1) x = self.up_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth - 1 else x return x def augment_adj(self, edge_index: Tensor, edge_weight: Tensor, num_nodes: int) -> PairTensor: edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes) adj = to_torch_csr_tensor(edge_index, edge_weight, size=(num_nodes, num_nodes)) adj = (adj @ adj).to_sparse_coo() edge_index, edge_weight = adj.indices(), adj.values() edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) return edge_index, edge_weight def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.hidden_channels}, {self.out_channels}, ' f'depth={self.depth}, pool_ratios={self.pool_ratios})') ================================================ FILE: torch_geometric/nn/models/jumping_knowledge.py ================================================ from typing import Dict, List, Optional import torch from torch import Tensor from torch.nn import LSTM, Linear class JumpingKnowledge(torch.nn.Module): r"""The Jumping Knowledge layer aggregation module from the `"Representation Learning on Graphs with Jumping Knowledge Networks" `_ paper. Jumping knowledge is performed based on either **concatenation** (:obj:`"cat"`) .. math:: \mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)}, **max pooling** (:obj:`"max"`) .. math:: \max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right), or **weighted summation** .. math:: \sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)} with attention scores :math:`\alpha_v^{(t)}` obtained from a bi-directional LSTM (:obj:`"lstm"`). Args: mode (str): The aggregation scheme to use (:obj:`"cat"`, :obj:`"max"` or :obj:`"lstm"`). channels (int, optional): The number of channels per representation. Needs to be only set for LSTM-style aggregation. (default: :obj:`None`) num_layers (int, optional): The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: :obj:`None`) """ def __init__( self, mode: str, channels: Optional[int] = None, num_layers: Optional[int] = None, ) -> None: super().__init__() self.mode = mode.lower() assert self.mode in ['cat', 'max', 'lstm'] if mode == 'lstm': assert channels is not None, 'channels cannot be None for lstm' assert num_layers is not None, 'num_layers cannot be None for lstm' self.lstm = LSTM(channels, (num_layers * channels) // 2, bidirectional=True, batch_first=True) self.att = Linear(2 * ((num_layers * channels) // 2), 1) self.channels = channels self.num_layers = num_layers else: self.lstm = None self.att = None self.channels = None self.num_layers = None self.reset_parameters() def reset_parameters(self) -> None: r"""Resets all learnable parameters of the module.""" if self.lstm is not None: self.lstm.reset_parameters() if self.att is not None: self.att.reset_parameters() def forward(self, xs: List[Tensor]) -> Tensor: r"""Forward pass. Args: xs (List[torch.Tensor]): List containing the layer-wise representations. """ if self.mode == 'cat': return torch.cat(xs, dim=-1) elif self.mode == 'max': return torch.stack(xs, dim=-1).max(dim=-1)[0] else: # self.mode == 'lstm' assert self.lstm is not None and self.att is not None x = torch.stack(xs, dim=1) # [num_nodes, num_layers, num_channels] alpha, _ = self.lstm(x) alpha = self.att(alpha).squeeze(-1) # [num_nodes, num_layers] alpha = torch.softmax(alpha, dim=-1) return (x * alpha.unsqueeze(-1)).sum(dim=1) def __repr__(self) -> str: if self.mode == 'lstm': return (f'{self.__class__.__name__}({self.mode}, ' f'channels={self.channels}, layers={self.num_layers})') return f'{self.__class__.__name__}({self.mode})' class HeteroJumpingKnowledge(torch.nn.Module): r"""A heterogeneous version of the :class:`JumpingKnowledge` module. Args: types (List[str]): The keys of the input dictionary. mode (str): The aggregation scheme to use (:obj:`"cat"`, :obj:`"max"` or :obj:`"lstm"`). channels (int, optional): The number of channels per representation. Needs to be only set for LSTM-style aggregation. (default: :obj:`None`) num_layers (int, optional): The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: :obj:`None`) """ def __init__( self, types: List[str], mode: str, channels: Optional[int] = None, num_layers: Optional[int] = None, ) -> None: super().__init__() self.mode = mode.lower() self.jk_dict = torch.nn.ModuleDict({ key: JumpingKnowledge(mode, channels, num_layers) for key in types }) def reset_parameters(self) -> None: r"""Resets all learnable parameters of the module.""" for jk in self.jk_dict.values(): jk.reset_parameters() def forward(self, xs_dict: Dict[str, List[Tensor]]) -> Dict[str, Tensor]: r"""Forward pass. Args: xs_dict (Dict[str, List[torch.Tensor]]): A dictionary holding a list of layer-wise representation for each type. """ return {key: jk(xs_dict[key]) for key, jk in self.jk_dict.items()} def __repr__(self): if self.mode == 'lstm': jk = next(iter(self.jk_dict.values())) return (f'{self.__class__.__name__}(' f'num_types={len(self.jk_dict)}, ' f'mode={self.mode}, channels={jk.channels}, ' f'layers={jk.num_layers})') return (f'{self.__class__.__name__}(num_types={len(self.jk_dict)}, ' f'mode={self.mode})') ================================================ FILE: torch_geometric/nn/models/label_prop.py ================================================ from typing import Callable, Optional import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import one_hot, spmm class LabelPropagation(MessagePassing): r"""The label propagation operator, firstly introduced in the `"Learning from Labeled and Unlabeled Data with Label Propagation" `_ paper. .. math:: \mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y}, where unlabeled data is inferred by labeled data via propagation. This concrete implementation here is derived from the `"Combining Label Propagation And Simple Models Out-performs Graph Neural Networks" `_ paper. .. note:: For an example of using the :class:`LabelPropagation`, see `examples/label_prop.py `_. Args: num_layers (int): The number of propagations. alpha (float): The :math:`\alpha` coefficient. """ def __init__(self, num_layers: int, alpha: float): super().__init__(aggr='add') self.num_layers = num_layers self.alpha = alpha @torch.no_grad() def forward( self, y: Tensor, edge_index: Adj, mask: OptTensor = None, edge_weight: OptTensor = None, post_step: Optional[Callable[[Tensor], Tensor]] = None, ) -> Tensor: r"""Forward pass. Args: y (torch.Tensor): The ground-truth label information :math:`\mathbf{Y}`. edge_index (torch.Tensor or SparseTensor): The edge connectivity. mask (torch.Tensor, optional): A mask or index tensor denoting which nodes are used for label propagation. (default: :obj:`None`) edge_weight (torch.Tensor, optional): The edge weights. (default: :obj:`None`) post_step (callable, optional): A post step function specified to apply after label propagation. If no post step function is specified, the output will be clamped between 0 and 1. (default: :obj:`None`) """ if y.dtype == torch.long and y.size(0) == y.numel(): y = one_hot(y.view(-1)) out = y if mask is not None: out = torch.zeros_like(y) out[mask] = y[mask] if isinstance(edge_index, SparseTensor) and not edge_index.has_value(): edge_index = gcn_norm(edge_index, add_self_loops=False) elif isinstance(edge_index, Tensor) and edge_weight is None: edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0), add_self_loops=False) res = (1 - self.alpha) * out for _ in range(self.num_layers): # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=out, edge_weight=edge_weight) out.mul_(self.alpha).add_(res) if post_step is not None: out = post_step(out) else: out.clamp_(0., 1.) return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_layers={self.num_layers}, ' f'alpha={self.alpha})') ================================================ FILE: torch_geometric/nn/models/lightgcn.py ================================================ from typing import Optional, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding, ModuleList from torch.nn.modules.loss import _Loss from torch_geometric.nn.conv import LGConv from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import is_sparse, to_edge_index class LightGCN(torch.nn.Module): r"""The LightGCN model from the `"LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation" `_ paper. :class:`~torch_geometric.nn.models.LightGCN` learns embeddings by linearly propagating them on the underlying graph, and uses the weighted sum of the embeddings learned at all layers as the final embedding .. math:: \textbf{x}_i = \sum_{l=0}^{L} \alpha_l \textbf{x}^{(l)}_i, where each layer's embedding is computed as .. math:: \mathbf{x}^{(l+1)}_i = \sum_{j \in \mathcal{N}(i)} \frac{1}{\sqrt{\deg(i)\deg(j)}}\mathbf{x}^{(l)}_j. Two prediction heads and training objectives are provided: **link prediction** (via :meth:`~torch_geometric.nn.models.LightGCN.link_pred_loss` and :meth:`~torch_geometric.nn.models.LightGCN.predict_link`) and **recommendation** (via :meth:`~torch_geometric.nn.models.LightGCN.recommendation_loss` and :meth:`~torch_geometric.nn.models.LightGCN.recommend`). .. note:: Embeddings are propagated according to the graph connectivity specified by :obj:`edge_index` while rankings or link probabilities are computed according to the edges specified by :obj:`edge_label_index`. .. note:: For an example of using :class:`LightGCN`, see `examples/lightgcn.py `_. Args: num_nodes (int): The number of nodes in the graph. embedding_dim (int): The dimensionality of node embeddings. num_layers (int): The number of :class:`~torch_geometric.nn.conv.LGConv` layers. alpha (float or torch.Tensor, optional): The scalar or vector specifying the re-weighting coefficients for aggregating the final embedding. If set to :obj:`None`, the uniform initialization of :obj:`1 / (num_layers + 1)` is used. (default: :obj:`None`) **kwargs (optional): Additional arguments of the underlying :class:`~torch_geometric.nn.conv.LGConv` layers. """ def __init__( self, num_nodes: int, embedding_dim: int, num_layers: int, alpha: Optional[Union[float, Tensor]] = None, **kwargs, ): super().__init__() self.num_nodes = num_nodes self.embedding_dim = embedding_dim self.num_layers = num_layers if alpha is None: alpha = 1. / (num_layers + 1) if isinstance(alpha, Tensor): assert alpha.size(0) == num_layers + 1 else: alpha = torch.tensor([alpha] * (num_layers + 1)) self.register_buffer('alpha', alpha) self.embedding = Embedding(num_nodes, embedding_dim) self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)]) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" torch.nn.init.xavier_uniform_(self.embedding.weight) for conv in self.convs: conv.reset_parameters() def get_embedding( self, edge_index: Adj, edge_weight: OptTensor = None, ) -> Tensor: r"""Returns the embedding of nodes in the graph.""" x = self.embedding.weight out = x * self.alpha[0] for i in range(self.num_layers): x = self.convs[i](x, edge_index, edge_weight) out = out + x * self.alpha[i + 1] return out def forward( self, edge_index: Adj, edge_label_index: OptTensor = None, edge_weight: OptTensor = None, ) -> Tensor: r"""Computes rankings for pairs of nodes. Args: edge_index (torch.Tensor or SparseTensor): Edge tensor specifying the connectivity of the graph. edge_label_index (torch.Tensor, optional): Edge tensor specifying the node pairs for which to compute rankings or probabilities. If :obj:`edge_label_index` is set to :obj:`None`, all edges in :obj:`edge_index` will be used instead. (default: :obj:`None`) edge_weight (torch.Tensor, optional): The weight of each edge in :obj:`edge_index`. (default: :obj:`None`) """ if edge_label_index is None: if is_sparse(edge_index): edge_label_index, _ = to_edge_index(edge_index) else: edge_label_index = edge_index out = self.get_embedding(edge_index, edge_weight) out_src = out[edge_label_index[0]] out_dst = out[edge_label_index[1]] return (out_src * out_dst).sum(dim=-1) def predict_link( self, edge_index: Adj, edge_label_index: OptTensor = None, edge_weight: OptTensor = None, prob: bool = False, ) -> Tensor: r"""Predict links between nodes specified in :obj:`edge_label_index`. Args: edge_index (torch.Tensor or SparseTensor): Edge tensor specifying the connectivity of the graph. edge_label_index (torch.Tensor, optional): Edge tensor specifying the node pairs for which to compute probabilities. If :obj:`edge_label_index` is set to :obj:`None`, all edges in :obj:`edge_index` will be used instead. (default: :obj:`None`) edge_weight (torch.Tensor, optional): The weight of each edge in :obj:`edge_index`. (default: :obj:`None`) prob (bool, optional): Whether probabilities should be returned. (default: :obj:`False`) """ pred = self(edge_index, edge_label_index, edge_weight).sigmoid() return pred if prob else pred.round() def recommend( self, edge_index: Adj, edge_weight: OptTensor = None, src_index: OptTensor = None, dst_index: OptTensor = None, k: int = 1, sorted: bool = True, ) -> Tensor: r"""Get top-:math:`k` recommendations for nodes in :obj:`src_index`. Args: edge_index (torch.Tensor or SparseTensor): Edge tensor specifying the connectivity of the graph. edge_weight (torch.Tensor, optional): The weight of each edge in :obj:`edge_index`. (default: :obj:`None`) src_index (torch.Tensor, optional): Node indices for which recommendations should be generated. If set to :obj:`None`, all nodes will be used. (default: :obj:`None`) dst_index (torch.Tensor, optional): Node indices which represent the possible recommendation choices. If set to :obj:`None`, all nodes will be used. (default: :obj:`None`) k (int, optional): Number of recommendations. (default: :obj:`1`) sorted (bool, optional): Whether to sort the recommendations by score. (default: :obj:`True`) """ out_src = out_dst = self.get_embedding(edge_index, edge_weight) if src_index is not None: out_src = out_src[src_index] if dst_index is not None: out_dst = out_dst[dst_index] pred = out_src @ out_dst.t() top_index = pred.topk(k, dim=-1, sorted=sorted).indices if dst_index is not None: # Map local top-indices to original indices. top_index = dst_index[top_index.view(-1)].view(*top_index.size()) return top_index def link_pred_loss(self, pred: Tensor, edge_label: Tensor, **kwargs) -> Tensor: r"""Computes the model loss for a link prediction objective via the :class:`torch.nn.BCEWithLogitsLoss`. Args: pred (torch.Tensor): The predictions. edge_label (torch.Tensor): The ground-truth edge labels. **kwargs (optional): Additional arguments of the underlying :class:`torch.nn.BCEWithLogitsLoss` loss function. """ loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs) return loss_fn(pred, edge_label.to(pred.dtype)) def recommendation_loss( self, pos_edge_rank: Tensor, neg_edge_rank: Tensor, node_id: Optional[Tensor] = None, lambda_reg: float = 1e-4, **kwargs, ) -> Tensor: r"""Computes the model loss for a ranking objective via the Bayesian Personalized Ranking (BPR) loss. .. note:: The i-th entry in the :obj:`pos_edge_rank` vector and i-th entry in the :obj:`neg_edge_rank` entry must correspond to ranks of positive and negative edges of the same entity (*e.g.*, user). Args: pos_edge_rank (torch.Tensor): Positive edge rankings. neg_edge_rank (torch.Tensor): Negative edge rankings. node_id (torch.Tensor): The indices of the nodes involved for deriving a prediction for both positive and negative edges. If set to :obj:`None`, all nodes will be used. lambda_reg (int, optional): The :math:`L_2` regularization strength of the Bayesian Personalized Ranking (BPR) loss. (default: :obj:`1e-4`) **kwargs (optional): Additional arguments of the underlying :class:`torch_geometric.nn.models.lightgcn.BPRLoss` loss function. """ loss_fn = BPRLoss(lambda_reg, **kwargs) emb = self.embedding.weight emb = emb if node_id is None else emb[node_id] return loss_fn(pos_edge_rank, neg_edge_rank, emb) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.num_nodes}, ' f'{self.embedding_dim}, num_layers={self.num_layers})') class BPRLoss(_Loss): r"""The Bayesian Personalized Ranking (BPR) loss. The BPR loss is a pairwise loss that encourages the prediction of an observed entry to be higher than its unobserved counterparts (see `here `__). .. math:: L_{\text{BPR}} = - \sum_{u=1}^{M} \sum_{i \in \mathcal{N}_u} \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj}) + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2 where :math:`\lambda` controls the :math:`L_2` regularization strength. We compute the mean BPR loss for simplicity. Args: lambda_reg (float, optional): The :math:`L_2` regularization strength (default: 0). **kwargs (optional): Additional arguments of the underlying :class:`torch.nn.modules.loss._Loss` class. """ __constants__ = ['lambda_reg'] lambda_reg: float def __init__(self, lambda_reg: float = 0, **kwargs): super().__init__(None, None, "sum", **kwargs) self.lambda_reg = lambda_reg def forward(self, positives: Tensor, negatives: Tensor, parameters: Tensor = None) -> Tensor: r"""Compute the mean Bayesian Personalized Ranking (BPR) loss. .. note:: The i-th entry in the :obj:`positives` vector and i-th entry in the :obj:`negatives` entry should correspond to the same entity (*.e.g*, user), as the BPR is a personalized ranking loss. Args: positives (Tensor): The vector of positive-pair rankings. negatives (Tensor): The vector of negative-pair rankings. parameters (Tensor, optional): The tensor of parameters which should be used for :math:`L_2` regularization (default: :obj:`None`). """ log_prob = F.logsigmoid(positives - negatives).mean() regularization = 0 if self.lambda_reg != 0: regularization = self.lambda_reg * parameters.norm(p=2).pow(2) regularization = regularization / positives.size(0) return -log_prob + regularization ================================================ FILE: torch_geometric/nn/models/linkx.py ================================================ import math import torch from torch import Tensor from torch.nn import BatchNorm1d, Parameter from torch_geometric.nn import inits from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.models import MLP from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import spmm class SparseLinear(MessagePassing): def __init__(self, in_channels: int, out_channels: int, bias: bool = True): super().__init__(aggr='add') self.in_channels = in_channels self.out_channels = out_channels self.weight = Parameter(torch.empty(in_channels, out_channels)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): inits.kaiming_uniform(self.weight, fan=self.in_channels, a=math.sqrt(5)) inits.uniform(self.in_channels, self.bias) def forward( self, edge_index: Adj, edge_weight: OptTensor = None, ) -> Tensor: # propagate_type: (weight: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, weight=self.weight, edge_weight=edge_weight) if self.bias is not None: out = out + self.bias return out def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor: if edge_weight is None: return weight_j else: return edge_weight.view(-1, 1) * weight_j def message_and_aggregate(self, adj_t: Adj, weight: Tensor) -> Tensor: return spmm(adj_t, weight, reduce=self.aggr) class LINKX(torch.nn.Module): r"""The LINKX model from the `"Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods" `_ paper. .. math:: \mathbf{H}_{\mathbf{A}} &= \textrm{MLP}_{\mathbf{A}}(\mathbf{A}) \mathbf{H}_{\mathbf{X}} &= \textrm{MLP}_{\mathbf{X}}(\mathbf{X}) \mathbf{Y} &= \textrm{MLP}_{f} \left( \sigma \left( \mathbf{W} [\mathbf{H}_{\mathbf{A}}, \mathbf{H}_{\mathbf{X}}] + \mathbf{H}_{\mathbf{A}} + \mathbf{H}_{\mathbf{X}} \right) \right) .. note:: For an example of using LINKX, see `examples/linkx.py `_. Args: num_nodes (int): The number of nodes in the graph. in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output sample. num_layers (int): Number of layers of :math:`\textrm{MLP}_{f}`. num_edge_layers (int, optional): Number of layers of :math:`\textrm{MLP}_{\mathbf{A}}`. (default: :obj:`1`) num_node_layers (int, optional): Number of layers of :math:`\textrm{MLP}_{\mathbf{X}}`. (default: :obj:`1`) dropout (float, optional): Dropout probability of each hidden embedding. (default: :obj:`0.0`) """ def __init__( self, num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, num_edge_layers: int = 1, num_node_layers: int = 1, dropout: float = 0.0, ): super().__init__() self.num_nodes = num_nodes self.in_channels = in_channels self.out_channels = out_channels self.num_edge_layers = num_edge_layers self.edge_lin = SparseLinear(num_nodes, hidden_channels) if self.num_edge_layers > 1: self.edge_norm = BatchNorm1d(hidden_channels) channels = [hidden_channels] * num_edge_layers self.edge_mlp = MLP(channels, dropout=0., act_first=True) else: self.edge_norm = None self.edge_mlp = None channels = [in_channels] + [hidden_channels] * num_node_layers self.node_mlp = MLP(channels, dropout=0., act_first=True) self.cat_lin1 = torch.nn.Linear(hidden_channels, hidden_channels) self.cat_lin2 = torch.nn.Linear(hidden_channels, hidden_channels) channels = [hidden_channels] * num_layers + [out_channels] self.final_mlp = MLP(channels, dropout=dropout, act_first=True) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.edge_lin.reset_parameters() if self.edge_norm is not None: self.edge_norm.reset_parameters() if self.edge_mlp is not None: self.edge_mlp.reset_parameters() self.node_mlp.reset_parameters() self.cat_lin1.reset_parameters() self.cat_lin2.reset_parameters() self.final_mlp.reset_parameters() def forward( self, x: OptTensor, edge_index: Adj, edge_weight: OptTensor = None, ) -> Tensor: """""" # noqa: D419 out = self.edge_lin(edge_index, edge_weight) if self.edge_norm is not None and self.edge_mlp is not None: out = out.relu_() out = self.edge_norm(out) out = self.edge_mlp(out) out = out + self.cat_lin1(out) if x is not None: x = self.node_mlp(x) out = out + x out = out + self.cat_lin2(x) return self.final_mlp(out.relu_()) def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, ' f'in_channels={self.in_channels}, ' f'out_channels={self.out_channels})') ================================================ FILE: torch_geometric/nn/models/lpformer.py ================================================ import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from ...nn.conv import MessagePassing from ...nn.dense.linear import Linear from ...nn.inits import glorot, zeros from ...typing import Adj, OptTensor, Tuple from ...utils import get_ppr, is_sparse, scatter, softmax from .basic_gnn import GCN class LPFormer(nn.Module): r"""The LPFormer model from the `"LPFormer: An Adaptive Graph Transformer for Link Prediction" `_ paper. .. note:: For an example of using LPFormer, see `examples/lpformer.py `_. Args: in_channels (int): Size of input dimension hidden_channels (int): Size of hidden dimension num_gnn_layers (int, optional): Number of GNN layers (default: :obj:`2`) gnn_dropout(float, optional): Dropout used for GNN (default: :obj:`0.1`) num_transformer_layers (int, optional): Number of Transformer layers (default: :obj:`1`) num_heads (int, optional): Number of heads to use in MHA (default: :obj:`1`) transformer_dropout (float, optional): Dropout used for Transformer (default: :obj:`0.1`) ppr_thresholds (list): PPR thresholds for different types of nodes. Types include (in order) common neighbors, 1-Hop nodes (that aren't CNs), and all other nodes. (default: :obj:`[0, 1e-4, 1e-2]`) gcn_cache (bool, optional): Whether to cache edge indices during message passing. (default: :obj:`False`) """ def __init__( self, in_channels: int, hidden_channels: int, num_gnn_layers: int = 2, gnn_dropout: float = 0.1, num_transformer_layers: int = 1, num_heads: int = 1, transformer_dropout: float = 0.1, ppr_thresholds: list = None, gcn_cache=False, ): super().__init__() # Default thresholds if ppr_thresholds is None: ppr_thresholds = [0, 1e-4, 1e-2] if len(ppr_thresholds) == 3: self.thresh_cn = ppr_thresholds[0] self.thresh_1hop = ppr_thresholds[1] self.thresh_non1hop = ppr_thresholds[2] else: raise ValueError( "Argument 'ppr_thresholds' must only be length 3!") self.in_dim = in_channels self.hid_dim = hidden_channels self.gnn_drop = gnn_dropout self.trans_drop = transformer_dropout self.gnn = GCN(in_channels, hidden_channels, num_gnn_layers, dropout=gnn_dropout, norm="layer_norm", cached=gcn_cache) self.gnn_norm = nn.LayerNorm(hidden_channels) # Create Transformer Layers self.att_layers = nn.ModuleList() for il in range(num_transformer_layers): if il == 0: node_dim = None self.out_dim = self.hid_dim * 2 if num_transformer_layers > 1 \ else self.hid_dim elif il == self.num_layers - 1: node_dim = self.hid_dim else: self.out_dim = node_dim = self.hid_dim self.att_layers.append( LPAttLayer(self.hid_dim, self.out_dim, node_dim, num_heads, self.trans_drop)) self.elementwise_lin = MLP(self.hid_dim, self.hid_dim, self.hid_dim) # Relative Positional Encodings self.ppr_encoder_cn = MLP(2, self.hid_dim, self.hid_dim) self.ppr_encoder_onehop = MLP(2, self.hid_dim, self.hid_dim) self.ppr_encoder_non1hop = MLP(2, self.hid_dim, self.hid_dim) # thresh=1 implies ignoring some set of nodes # Also allows us to be more efficient later if self.thresh_non1hop == 1 and self.thresh_1hop == 1: self.mask = "cn" elif self.thresh_non1hop == 1 and self.thresh_1hop < 1: self.mask = "1-hop" else: self.mask = "all" # 4 is for counts of diff nodes pairwise_dim = self.hid_dim * num_heads + 4 self.pairwise_lin = MLP(pairwise_dim, pairwise_dim, self.hid_dim) self.score_func = MLP(self.hid_dim * 2, self.hid_dim * 2, 1, norm=None) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_dim}, ' f'{self.hid_dim}, num_gnn_layers={self.gnn.num_layers}, ' f'num_transformer_layers={len(self.att_layers)})') def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.gnn.reset_parameters() self.gnn_norm.reset_parameters() self.elementwise_lin.reset_parameters() self.pairwise_lin.reset_parameters() self.ppr_encoder_cn.reset_parameters() self.ppr_encoder_onehop.reset_parameters() self.ppr_encoder_non1hop.reset_parameters() self.score_func.reset_parameters() for i in range(len(self.att_layers)): self.att_layers[i].reset_parameters() def forward( self, batch: Tensor, x: Tensor, edge_index: Adj, ppr_matrix: Tensor, ) -> Tensor: r"""Forward Pass of LPFormer. Returns raw logits for each link Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. x (Tensor): Input node features edge_index (torch.Tensor, SparseTensor): The edge indices. Either in COO or SparseTensor format ppr_matrix (Tensor): PPR matrix """ batch = batch.to(x.device) X_node = self.propagate(x, edge_index) x_i, x_j = X_node[batch[0]], X_node[batch[1]] elementwise_edge_feats = self.elementwise_lin(x_i * x_j) # Ensure in sparse format # Need as native torch.sparse for later computations # (necessary operations are not supported by PyG SparseTensor) if not edge_index.is_sparse: num_nodes = ppr_matrix.size(1) vals = torch.ones(len(edge_index[0]), device=edge_index.device) edge_index = torch.sparse_coo_tensor(edge_index, vals, [num_nodes, num_nodes]) # Checks if SparseTensor, if so the convert if is_sparse(edge_index) and not edge_index.is_sparse: edge_index = edge_index.to_torch_sparse_coo_tensor() # Ensure {0, 1} edge_index = edge_index.coalesce().bool().int() pairwise_feats = self.calc_pairwise(batch, X_node, edge_index, ppr_matrix) combined_feats = torch.cat((elementwise_edge_feats, pairwise_feats), dim=-1) logits = self.score_func(combined_feats) return logits def propagate(self, x: Tensor, adj: Adj) -> Tensor: """Propagate via GNN. Args: x (Tensor): Node features adj (torch.Tensor, SparseTensor): Adjacency matrix """ x = F.dropout(x, p=self.gnn_drop, training=self.training) X_node = self.gnn(x, adj) X_node = self.gnn_norm(X_node) return X_node def calc_pairwise(self, batch: Tensor, X_node: Tensor, adj_mask: Tensor, ppr_matrix: Tensor) -> Tensor: r"""Calculate the pairwise features for the node pairs. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. X_node (Tensor): Node representations adj_mask (Tensor): Mask of adjacency matrix used for computing the different node types. ppr_matrix (Tensor): PPR matrix """ k_i, k_j = X_node[batch[0]], X_node[batch[1]] pairwise_feats = torch.cat((k_i, k_j), dim=-1) cn_info, onehop_info, non1hop_info = self.compute_node_mask( batch, adj_mask, ppr_matrix) all_mask = cn_info[0] if onehop_info is not None: all_mask = torch.cat((all_mask, onehop_info[0]), dim=-1) if non1hop_info is not None: all_mask = torch.cat((all_mask, non1hop_info[0]), dim=-1) pes = self.get_pos_encodings(cn_info[1:], onehop_info[1:], non1hop_info[1:]) for lay in range(len(self.att_layers)): pairwise_feats = self.att_layers[lay](all_mask, pairwise_feats, X_node, pes) num_cns, num_1hop, num_non1hop, num_neigh = self.get_structure_cnts( batch, cn_info, onehop_info, non1hop_info) pairwise_feats = torch.cat( (pairwise_feats, num_cns, num_1hop, num_non1hop, num_neigh), dim=-1) pairwise_feats = self.pairwise_lin(pairwise_feats) return pairwise_feats def get_pos_encodings( self, cn_ppr: Tuple[Tensor, Tensor], onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None, non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor: r"""Calculate the PPR-based relative positional encodings. Due to thresholds, sometimes we don't have 1-hop or >1-hop nodes. In those cases, the value of onehop_ppr and/or non1hop_ppr should be `None`. Args: cn_ppr (tuple, optional): PPR scores of CNs. onehop_ppr (tuple, optional): PPR scores of 1-Hop. (default: :obj:`None`) non1hop_ppr (tuple, optional): PPR scores of >1-Hop. (default: :obj:`None`) """ cn_a = self.ppr_encoder_cn(torch.stack((cn_ppr[0], cn_ppr[1])).t()) cn_b = self.ppr_encoder_cn(torch.stack((cn_ppr[1], cn_ppr[0])).t()) cn_pe = cn_a + cn_b if onehop_ppr is None: return cn_pe onehop_a = self.ppr_encoder_onehop( torch.stack((onehop_ppr[0], onehop_ppr[1])).t()) onehop_b = self.ppr_encoder_onehop( torch.stack((onehop_ppr[1], onehop_ppr[0])).t()) onehop_pe = onehop_a + onehop_b if non1hop_ppr is None: return torch.cat((cn_pe, onehop_pe), dim=0) non1hop_a = self.ppr_encoder_non1hop( torch.stack((non1hop_ppr[0], non1hop_ppr[1])).t()) non1hop_b = self.ppr_encoder_non1hop( torch.stack((non1hop_ppr[1], non1hop_ppr[0])).t()) non1hop_pe = non1hop_a + non1hop_b return torch.cat((cn_pe, onehop_pe, non1hop_pe), dim=0) def compute_node_mask( self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor ) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]: r"""Get mask based on type of node. When mask_type is not "cn", also return the ppr vals for both the source and target. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. adj (SparseTensor): Adjacency matrix ppr_matrix (Tensor): PPR matrix """ src_adj = torch.index_select(adj, 0, batch[0]) tgt_adj = torch.index_select(adj, 0, batch[1]) if self.mask == "cn": # 1 when CN, 0 otherwise pair_adj = src_adj * tgt_adj else: # Equals: {0: ">1-Hop", 1: "1-Hop (Non-CN)", 2: "CN"} pair_adj = src_adj + tgt_adj pair_ix, node_type, src_ppr, tgt_ppr = self.get_ppr_vals( batch, pair_adj, ppr_matrix) cn_filt_cond = (src_ppr >= self.thresh_cn) & (tgt_ppr >= self.thresh_cn) onehop_filt_cond = (src_ppr >= self.thresh_1hop) & ( tgt_ppr >= self.thresh_1hop) if self.mask != "cn": filt_cond = torch.where(node_type == 1, onehop_filt_cond, cn_filt_cond) else: filt_cond = torch.where(node_type == 0, onehop_filt_cond, cn_filt_cond) pair_ix, node_type = pair_ix[:, filt_cond], node_type[filt_cond] src_ppr, tgt_ppr = src_ppr[filt_cond], tgt_ppr[filt_cond] # >1-Hop mask is gotten separately if self.mask == "all": non1hop_ix, non1hop_sppr, non1hop_tppr = self.get_non_1hop_ppr( batch, adj, ppr_matrix) # Dropout if self.training and self.trans_drop > 0: pair_ix, src_ppr, tgt_ppr, node_type = self.drop_pairwise( pair_ix, src_ppr, tgt_ppr, node_type) if self.mask == "all": non1hop_ix, non1hop_sppr, non1hop_tppr, _ = self.drop_pairwise( non1hop_ix, non1hop_sppr, non1hop_tppr) # Separate out CN and 1-Hop if self.mask != "cn": cn_ind = node_type == 2 cn_ix = pair_ix[:, cn_ind] cn_src_ppr = src_ppr[cn_ind] cn_tgt_ppr = tgt_ppr[cn_ind] one_hop_ind = node_type == 1 onehop_ix = pair_ix[:, one_hop_ind] onehop_src_ppr = src_ppr[one_hop_ind] onehop_tgt_ppr = tgt_ppr[one_hop_ind] if self.mask == "cn": return (pair_ix, src_ppr, tgt_ppr), None, None elif self.mask == "1-hop": return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr, onehop_tgt_ppr), None else: return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr, onehop_tgt_ppr), (non1hop_ix, non1hop_sppr, non1hop_tppr) def get_ppr_vals( self, batch: Tensor, pair_diff_adj: Tensor, ppr_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""Get the src and tgt ppr vals. Returns the: link the node belongs to, type of node (e.g., CN), PPR relative to src, PPR relative to tgt. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. pair_diff_adj (SparseTensor): Combination of rows in adjacency for src and tgt nodes (e.g., X1 + X2) ppr_matrix (Tensor): PPR matrix """ # Additional terms for also choosing scores when ppr=0 # Multiplication removes any values for nodes not in batch # Addition then adds offset to ensure we select when ppr=0 # All selected scores are +1 higher than their true val src_ppr_adj = torch.index_select( ppr_matrix, 0, batch[0]) * pair_diff_adj + pair_diff_adj tgt_ppr_adj = torch.index_select( ppr_matrix, 0, batch[1]) * pair_diff_adj + pair_diff_adj # Can now convert ppr scores to dense ppr_ix = src_ppr_adj.coalesce().indices() src_ppr = src_ppr_adj.coalesce().values() tgt_ppr = tgt_ppr_adj.coalesce().values() # TODO: Needed due to a bug in recent torch versions # see here for more - https://github.com/pytorch/pytorch/issues/114529 # note that if one is 0 so is the other zero_vals = (src_ppr != 0) src_ppr = src_ppr[zero_vals] tgt_ppr = tgt_ppr[tgt_ppr != 0] ppr_ix = ppr_ix[:, zero_vals] pair_diff_adj = pair_diff_adj.coalesce().values() node_type = pair_diff_adj[src_ppr != 0] # Remove additional +1 from each ppr val src_ppr = (src_ppr - node_type) / node_type tgt_ppr = (tgt_ppr - node_type) / node_type return ppr_ix, node_type, src_ppr, tgt_ppr def drop_pairwise( self, pair_ix: Tensor, src_ppr: Optional[Tensor] = None, tgt_ppr: Optional[Tensor] = None, node_indicator: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""Perform dropout on pairwise information by randomly dropping a percentage of nodes. Done before performing attention for efficiency Args: pair_ix (Tensor): Link node belongs to src_ppr (Tensor, optional): PPR relative to src (default: :obj:`None`) tgt_ppr (Tensor, optional): PPR relative to tgt (default: :obj:`None`) node_indicator (Tensor, optional): Type of node (e.g., CN) (default: :obj:`None`) """ num_indices = math.ceil(pair_ix.size(1) * (1 - self.trans_drop)) indices = torch.randperm(pair_ix.size(1))[:num_indices] pair_ix = pair_ix[:, indices] if src_ppr is not None: src_ppr = src_ppr[indices] if tgt_ppr is not None: tgt_ppr = tgt_ppr[indices] if node_indicator is not None: node_indicator = node_indicator[indices] return pair_ix, src_ppr, tgt_ppr, node_indicator def get_structure_cnts( self, batch: Tensor, cn_info: Tuple[Tensor, Tensor], onehop_info: Tuple[Tensor, Tensor], non1hop_info: Optional[Tuple[Tensor, Tensor]], ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold. Also include total # of neighbors Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. cn_info (tuple): Information of CN nodes Contains (ID of node, src ppr, tgt ppr) onehop_info (tuple): Information of 1-Hop nodes. Contains (ID of node, src ppr, tgt ppr) non1hop_info (tuple): Information of >1-Hop nodes. Contains (ID of node, src ppr, tgt ppr) """ num_cns = self.get_num_ppr_thresh(batch, cn_info[0], cn_info[1], cn_info[2], self.thresh_cn) num_1hop = self.get_num_ppr_thresh(batch, onehop_info[0], onehop_info[1], onehop_info[2], self.thresh_1hop) # TOTAL num of 1-hop neighbors union num_ppr_ones = self.get_num_ppr_thresh(batch, onehop_info[0], onehop_info[1], onehop_info[2], thresh=0) num_neighbors = num_cns + num_ppr_ones # Process for >1-hop is different which is why we use get_count below if non1hop_info is None: return num_cns, num_1hop, 0, num_neighbors else: num_non1hop = self.get_count(non1hop_info[0], batch) return num_cns, num_1hop, num_non1hop, num_neighbors def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor, src_ppr: Tensor, tgt_ppr: Tensor, thresh: float) -> Tensor: """Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. node_mask (Tensor): IDs of nodes src_ppr (Tensor): PPR relative to src node tgt_ppr (Tensor): PPR relative to tgt node thresh (float): PPR threshold for nodes (`eta`) """ weight = torch.ones(node_mask.size(1), device=node_mask.device) ppr_above_thresh = (src_ppr >= thresh) & (tgt_ppr >= thresh) num_ppr = scatter(ppr_above_thresh.float() * weight, node_mask[0].long(), dim=0, dim_size=batch.size(1), reduce="sum") num_ppr = num_ppr.unsqueeze(-1) return num_ppr def get_count( self, node_mask: Tensor, batch: Tensor, ) -> Tensor: """# of nodes for each sample in batch. They node have already filtered by PPR beforehand Args: node_mask (Tensor): IDs of nodes batch (Tensor): The batch vector. Denotes which node pairs to predict. """ weight = torch.ones(node_mask.size(1), device=node_mask.device) num_nodes = scatter(weight, node_mask[0].long(), dim=0, dim_size=batch.size(1), reduce="sum") num_nodes = num_nodes.unsqueeze(-1) return num_nodes def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor) -> Tensor: r"""Get PPR scores for non-1hop nodes. Args: batch (Tensor): Links in batch adj (Tensor): Adjacency matrix ppr_matrix (Tensor): Sparse PPR matrix """ # NOTE: Use original adj (one pass in forward() removes links in batch) # Done since removing them converts src/tgt nodes to >1-hop nodes. # Therefore removing CN and 1-hop will also remove the batch links. # During training we add back in the links in the batch # (we're removed from adjacency before being passed to model) # Done since otherwise they will be mistakenly seen as >1-Hop nodes # Instead they're 1-Hop, and get ignored accordingly # Ignored during eval since we know the links aren't in the adj adj2 = adj if self.training: n = adj.size(0) batch_flip = torch.cat( (batch, torch.flip(batch, (0, )).to(batch.device)), dim=-1) batch_ones = torch.ones_like(batch_flip[0], device=batch.device) adj_edges = torch.sparse_coo_tensor(batch_flip, batch_ones, [n, n], device=batch.device) adj_edges = adj_edges adj2 = (adj + adj_edges).coalesce().bool().int() src_adj = torch.index_select(adj2, 0, batch[0]) tgt_adj = torch.index_select(adj2, 0, batch[1]) src_ppr = torch.index_select(ppr_matrix, 0, batch[0]) tgt_ppr = torch.index_select(ppr_matrix, 0, batch[1]) # Remove CN scores src_ppr = src_ppr - src_ppr * (src_adj * tgt_adj) tgt_ppr = tgt_ppr - tgt_ppr * (src_adj * tgt_adj) # Also need to remove CN entries in Adj # Otherwise they leak into next computation src_adj = src_adj - src_adj * (src_adj * tgt_adj) tgt_adj = tgt_adj - tgt_adj * (src_adj * tgt_adj) # Remove 1-Hop scores src_ppr = src_ppr - src_ppr * (src_adj + tgt_adj) tgt_ppr = tgt_ppr - tgt_ppr * (src_adj + tgt_adj) # Make sure we include both when we convert to dense so indices align # Do so by adding 1 to each based on the other src_ppr_add = src_ppr + torch.sign(tgt_ppr) tgt_ppr_add = tgt_ppr + torch.sign(src_ppr) src_ix = src_ppr_add.coalesce().indices() src_vals = src_ppr_add.coalesce().values() tgt_vals = tgt_ppr_add.coalesce().values() # Now we can remove value which is just 1 # Technically creates -1 scores for ppr scores that were 0 # Doesn't matter as they'll be filtered out by condition later src_vals = src_vals - 1 tgt_vals = tgt_vals - 1 ppr_condition = (src_vals >= self.thresh_non1hop) & ( tgt_vals >= self.thresh_non1hop) src_ix, src_vals, tgt_vals = src_ix[:, ppr_condition], src_vals[ ppr_condition], tgt_vals[ppr_condition] return src_ix, src_vals, tgt_vals def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int, alpha: float = 0.15, eps: float = 5e-5) -> Tensor: r"""Calculate the PPR of the graph in sparse format. Args: edge_index: The edge indices num_nodes: Number of nodes alpha (float, optional): The alpha value of the PageRank algorithm. (default: :obj:`0.15`) eps (float, optional): Threshold for stopping the PPR calculation (default: :obj:`5e-5`) """ ei, ei_w = get_ppr(edge_index.cpu(), alpha=alpha, eps=eps, num_nodes=num_nodes) ppr_matrix = torch.sparse_coo_tensor(ei, ei_w, [num_nodes, num_nodes]) return ppr_matrix class LPAttLayer(MessagePassing): r"""Attention Layer for pairwise interaction module. Args: in_channels (int): Size of input dimension out_channels (int): Size of output dimension node_dim (int): Dimension of nodes being aggregated num_heads (int): Number of heads to use in MHA dropout (float): Dropout on attention values concat (bool, optional): Whether to concat attention heads. Otherwise sum (default: :obj:`True`) """ _alpha: OptTensor def __init__( self, in_channels: int, out_channels: int, node_dim: int, num_heads: int, dropout: float, concat: bool = True, **kwargs, ): super().__init__(node_dim=0, flow="target_to_source", **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = num_heads self.concat = concat self.dropout = dropout self.negative_slope = 0.2 # LeakyRelu out_dim = 2 if node_dim is None: node_dim = in_channels * out_dim else: node_dim = node_dim * out_dim self.lin_l = Linear(in_channels, self.heads * out_channels, weight_initializer='glorot') self.lin_r = Linear(node_dim, self.heads * out_channels, weight_initializer='glorot') att_out = out_channels self.att = Parameter(Tensor(1, self.heads, att_out)) if concat: self.bias = Parameter(Tensor(self.heads * out_channels)) else: self.bias = Parameter(Tensor(out_channels)) self._alpha = None self.dropout = dropout self.post_att_norm = nn.LayerNorm(out_channels) self.reset_parameters() def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') def reset_parameters(self): self.lin_l.reset_parameters() self.lin_r.reset_parameters() self.post_att_norm.reset_parameters() glorot(self.att) zeros(self.bias) def forward( self, edge_index: Tensor, edge_feats: Tensor, node_feats: Tensor, ppr_rpes: Tensor, ) -> Tensor: """Runs the forward pass of the module. Args: edge_index (Tensor): The edge indices. edge_feats (Tensor): Concatenated representations of src and target nodes for each link node_feats (Tensor): Representations for individual nodes ppr_rpes (Tensor): Relative PEs for each node """ out = self.propagate(edge_index, x=(edge_feats, node_feats), ppr_rpes=ppr_rpes, size=None) alpha = self._alpha assert alpha is not None self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias out = self.post_att_norm(out) out = F.dropout(out, p=self.dropout, training=self.training) return out def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor, index: Tensor, ptr: Tensor, size_i: Optional[int]) -> Tensor: H, C = self.heads, self.out_channels x_j = torch.cat((x_j, ppr_rpes), dim=-1) x_j = self.lin_r(x_j).view(-1, H, C) # e=(a, b) attending to v e1, e2 = x_i.chunk(2, dim=-1) e1 = self.lin_l(e1).view(-1, H, C) e2 = self.lin_l(e2).view(-1, H, C) x = x_j * (e1 + e2) x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha return x_j * alpha.unsqueeze(-1) class MLP(nn.Module): r"""L Layer MLP.""" def __init__(self, in_channels: int, hid_channels: int, out_channels: int, num_layers: int = 2, drop: int = 0, norm: str = "layer"): super().__init__() self.dropout = drop if norm == "batch": self.norm = nn.BatchNorm1d(hid_channels) elif norm == "layer": self.norm = nn.LayerNorm(hid_channels) else: self.norm = None self.linears = torch.nn.ModuleList() if num_layers == 1: self.linears.append(nn.Linear(in_channels, out_channels)) else: self.linears.append(nn.Linear(in_channels, hid_channels)) for _ in range(num_layers - 2): self.linears.append(nn.Linear(hid_channels, hid_channels)) self.linears.append(nn.Linear(hid_channels, out_channels)) def reset_parameters(self): for lin in self.linears: lin.reset_parameters() if self.norm is not None: self.norm.reset_parameters() def forward(self, x: Tensor) -> Tensor: for lin in self.linears[:-1]: x = lin(x) x = self.norm(x) if self.norm is not None else x x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.linears[-1](x) return x.squeeze(-1) ================================================ FILE: torch_geometric/nn/models/mask_label.py ================================================ import torch from torch import Tensor class MaskLabel(torch.nn.Module): r"""The label embedding and masking layer from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" `_ paper. Here, node labels :obj:`y` are merged to the initial node features :obj:`x` for a subset of their nodes according to :obj:`mask`. .. note:: For an example of using :class:`MaskLabel`, see `examples/unimp_arxiv.py `_. Args: num_classes (int): The number of classes. out_channels (int): Size of each output sample. method (str, optional): If set to :obj:`"add"`, label embeddings are added to the input. If set to :obj:`"concat"`, label embeddings are concatenated. In case :obj:`method="add"`, then :obj:`out_channels` needs to be identical to the input dimensionality of node features. (default: :obj:`"add"`) """ def __init__(self, num_classes: int, out_channels: int, method: str = "add"): super().__init__() self.method = method if method not in ["add", "concat"]: raise ValueError( f"'method' must be either 'add' or 'concat' (got '{method}')") self.emb = torch.nn.Embedding(num_classes, out_channels) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.emb.reset_parameters() def forward(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor: """""" # noqa: D419 if self.method == "concat": out = x.new_zeros(y.size(0), self.emb.weight.size(-1)) out[mask] = self.emb(y[mask]) return torch.cat([x, out], dim=-1) else: x = torch.clone(x) x[mask] += self.emb(y[mask]) return x @staticmethod def ratio_mask(mask: Tensor, ratio: float): r"""Modifies :obj:`mask` by setting :obj:`ratio` of :obj:`True` entries to :obj:`False`. Does not operate in-place. Args: mask (torch.Tensor): The mask to re-mask. ratio (float): The ratio of entries to keep. """ n = int(mask.sum()) out = mask.clone() out[mask] = torch.rand(n, device=mask.device) < ratio return out def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/nn/models/meta.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor class MetaLayer(torch.nn.Module): r"""A meta layer for building any kind of graph network, inspired by the `"Relational Inductive Biases, Deep Learning, and Graph Networks" `_ paper. A graph network takes a graph as input and returns an updated graph as output (with same connectivity). The input graph has node features :obj:`x`, edge features :obj:`edge_attr` as well as graph-level features :obj:`u`. The output graph has the same structure, but updated features. Edge features, node features as well as global features are updated by calling the modules :obj:`edge_model`, :obj:`node_model` and :obj:`global_model`, respectively. To allow for batch-wise graph processing, all callable functions take an additional argument :obj:`batch`, which determines the assignment of edges or nodes to their specific graphs. Args: edge_model (torch.nn.Module, optional): A callable which updates a graph's edge features based on its source and target node features, its current edge features and its global features. (default: :obj:`None`) node_model (torch.nn.Module, optional): A callable which updates a graph's node features based on its current node features, its graph connectivity, its edge features and its global features. (default: :obj:`None`) global_model (torch.nn.Module, optional): A callable which updates a graph's global features based on its node features, its graph connectivity, its edge features and its current global features. (default: :obj:`None`) .. code-block:: python from torch.nn import Sequential as Seq, Linear as Lin, ReLU from torch_geometric.utils import scatter from torch_geometric.nn import MetaLayer class EdgeModel(torch.nn.Module): def __init__(self): super().__init__() self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, src, dst, edge_attr, u, batch): # src, dst: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # u: [B, F_u], where B is the number of graphs. # batch: [E] with max entry B - 1. out = torch.cat([src, dst, edge_attr, u[batch]], 1) return self.edge_mlp(out) class NodeModel(torch.nn.Module): def __init__(self): super().__init__() self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) out = self.node_mlp_1(out) out = scatter(out, col, dim=0, dim_size=x.size(0), reduce='mean') out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out) class GlobalModel(torch.nn.Module): def __init__(self): super().__init__() self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. out = torch.cat([ u, scatter(x, batch, dim=0, reduce='mean'), ], dim=1) return self.global_mlp(out) op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) x, edge_attr, u = op(x, edge_index, edge_attr, u, batch) """ def __init__( self, edge_model: Optional[torch.nn.Module] = None, node_model: Optional[torch.nn.Module] = None, global_model: Optional[torch.nn.Module] = None, ): super().__init__() self.edge_model = edge_model self.node_model = node_model self.global_model = global_model self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for item in [self.node_model, self.edge_model, self.global_model]: if hasattr(item, 'reset_parameters'): item.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, u: Optional[Tensor] = None, batch: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: r"""Forward pass. Args: x (torch.Tensor): The node features. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) u (torch.Tensor, optional): The global graph features. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific graph. (default: :obj:`None`) """ row = edge_index[0] col = edge_index[1] if self.edge_model is not None: edge_attr = self.edge_model(x[row], x[col], edge_attr, u, batch if batch is None else batch[row]) if self.node_model is not None: x = self.node_model(x, edge_index, edge_attr, u, batch) if self.global_model is not None: u = self.global_model(x, edge_index, edge_attr, u, batch) return x, edge_attr, u def __repr__(self) -> str: return (f'{self.__class__.__name__}(\n' f' edge_model={self.edge_model},\n' f' node_model={self.node_model},\n' f' global_model={self.global_model}\n' f')') ================================================ FILE: torch_geometric/nn/models/metapath2vec.py ================================================ from typing import Dict, List, Optional, Tuple import torch from torch import Tensor from torch.nn import Embedding from torch.utils.data import DataLoader from torch_geometric.index import index2ptr from torch_geometric.typing import EdgeType, NodeType, OptTensor from torch_geometric.utils import sort_edge_index EPS = 1e-15 class MetaPath2Vec(torch.nn.Module): r"""The MetaPath2Vec model from the `"metapath2vec: Scalable Representation Learning for Heterogeneous Networks" `_ paper where random walks based on a given :obj:`metapath` are sampled in a heterogeneous graph, and node embeddings are learned via negative sampling optimization. .. note:: For an example of using MetaPath2Vec, see `examples/hetero/metapath2vec.py `_. Args: edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): Dictionary holding edge indices for each :obj:`(src_node_type, rel_type, dst_node_type)` edge type present in the heterogeneous graph. embedding_dim (int): The size of each embedding vector. metapath (List[Tuple[str, str, str]]): The metapath described as a list of :obj:`(src_node_type, rel_type, dst_node_type)` tuples. walk_length (int): The walk length. context_size (int): The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes. walks_per_node (int, optional): The number of walks to sample for each node. (default: :obj:`1`) num_negative_samples (int, optional): The number of negative samples to use for each positive sample. (default: :obj:`1`) num_nodes_dict (Dict[str, int], optional): Dictionary holding the number of nodes for each node type. (default: :obj:`None`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the weight matrix will be sparse. (default: :obj:`False`) """ def __init__( self, edge_index_dict: Dict[EdgeType, Tensor], embedding_dim: int, metapath: List[EdgeType], walk_length: int, context_size: int, walks_per_node: int = 1, num_negative_samples: int = 1, num_nodes_dict: Optional[Dict[NodeType, int]] = None, sparse: bool = False, ): super().__init__() if num_nodes_dict is None: num_nodes_dict = {} for keys, edge_index in edge_index_dict.items(): key = keys[0] N = int(edge_index[0].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) key = keys[-1] N = int(edge_index[1].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) self.rowptr_dict, self.col_dict, self.rowcount_dict = {}, {}, {} for keys, edge_index in edge_index_dict.items(): sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]]) row, col = sort_edge_index(edge_index, num_nodes=max(sizes)).cpu() rowptr = index2ptr(row, size=sizes[0]) self.rowptr_dict[keys] = rowptr self.col_dict[keys] = col self.rowcount_dict[keys] = rowptr[1:] - rowptr[:-1] for edge_type1, edge_type2 in zip(metapath[:-1], metapath[1:]): if edge_type1[-1] != edge_type2[0]: raise ValueError( "Found invalid metapath. Ensure that the destination node " "type matches with the source node type across all " "consecutive edge types.") assert walk_length + 1 >= context_size if walk_length > len(metapath) and metapath[0][0] != metapath[-1][-1]: raise AttributeError( "The 'walk_length' is longer than the given 'metapath', but " "the 'metapath' does not denote a cycle") self.embedding_dim = embedding_dim self.metapath = metapath self.walk_length = walk_length self.context_size = context_size self.walks_per_node = walks_per_node self.num_negative_samples = num_negative_samples self.num_nodes_dict = num_nodes_dict types = {x[0] for x in metapath} | {x[-1] for x in metapath} types = sorted(list(types)) count = 0 self.start, self.end = {}, {} for key in types: self.start[key] = count count += num_nodes_dict[key] self.end[key] = count offset = [self.start[metapath[0][0]]] offset += [self.start[keys[-1]] for keys in metapath ] * int((walk_length / len(metapath)) + 1) offset = offset[:walk_length + 1] assert len(offset) == walk_length + 1 self.offset = torch.tensor(offset) # + 1 denotes a dummy node used to link to for isolated nodes. self.embedding = Embedding(count + 1, embedding_dim, sparse=sparse) self.dummy_idx = count self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.embedding.reset_parameters() def forward(self, node_type: str, batch: OptTensor = None) -> Tensor: r"""Returns the embeddings for the nodes in :obj:`batch` of type :obj:`node_type`. """ emb = self.embedding.weight[self.start[node_type]:self.end[node_type]] return emb if batch is None else emb.index_select(0, batch) def loader(self, **kwargs): r"""Returns the data loader that creates both positive and negative random walks on the heterogeneous graph. Args: **kwargs (optional): Arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]), collate_fn=self._sample, **kwargs) def _pos_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node) rws = [batch] for i in range(self.walk_length): edge_type = self.metapath[i % len(self.metapath)] batch = sample( self.rowptr_dict[edge_type], self.col_dict[edge_type], self.rowcount_dict[edge_type], batch, num_neighbors=1, dummy_idx=self.dummy_idx, ).view(-1) rws.append(batch) rw = torch.stack(rws, dim=-1) rw.add_(self.offset.view(1, -1)) rw[rw > self.dummy_idx] = self.dummy_idx walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def _neg_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rws = [batch] for i in range(self.walk_length): keys = self.metapath[i % len(self.metapath)] batch = torch.randint(0, self.num_nodes_dict[keys[-1]], (batch.size(0), ), dtype=torch.long) rws.append(batch) rw = torch.stack(rws, dim=-1) rw.add_(self.offset.view(1, -1)) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def _sample(self, batch: List[int]) -> Tuple[Tensor, Tensor]: if not isinstance(batch, Tensor): batch = torch.tensor(batch, dtype=torch.long) return self._pos_sample(batch), self._neg_sample(batch) def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor: r"""Computes the loss given positive and negative random walks.""" # Positive loss. start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean() # Negative loss. start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean() return pos_loss + neg_loss def test(self, train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = "lbfgs", *args, **kwargs) -> float: r"""Evaluates latent space quality via a logistic regression downstream task. """ from sklearn.linear_model import LogisticRegression clf = LogisticRegression(*args, solver=solver, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'{self.embedding.weight.size(0) - 1}, ' f'{self.embedding.weight.size(1)})') def sample(rowptr: Tensor, col: Tensor, rowcount: Tensor, subset: Tensor, num_neighbors: int, dummy_idx: int) -> Tensor: mask = subset >= dummy_idx subset = subset.clamp(min=0, max=rowptr.numel() - 2) count = rowcount[subset] rand = torch.rand((subset.size(0), num_neighbors), device=subset.device) rand *= count.to(rand.dtype).view(-1, 1) rand = rand.to(torch.long) + rowptr[subset].view(-1, 1) rand = rand.clamp(max=col.numel() - 1) # If last node is isolated. col = col[rand] if col.numel() > 0 else rand col[mask | (count == 0)] = dummy_idx return col ================================================ FILE: torch_geometric/nn/models/mlp.py ================================================ import inspect import warnings from typing import Any, Callable, Dict, Final, List, Optional, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Identity from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.resolver import ( activation_resolver, normalization_resolver, ) from torch_geometric.typing import NoneType class MLP(torch.nn.Module): r"""A Multi-Layer Perception (MLP) model. There exists two ways to instantiate an :class:`MLP`: 1. By specifying explicit channel sizes, *e.g.*, .. code-block:: python mlp = MLP([16, 32, 64, 128]) creates a three-layer MLP with **differently** sized hidden layers. 1. By specifying fixed hidden channel sizes over a number of layers, *e.g.*, .. code-block:: python mlp = MLP(in_channels=16, hidden_channels=32, out_channels=128, num_layers=3) creates a three-layer MLP with **equally** sized hidden layers. Args: channel_list (List[int] or int, optional): List of input, intermediate and output channels such that :obj:`len(channel_list) - 1` denotes the number of layers of the MLP (default: :obj:`None`) in_channels (int, optional): Size of each input sample. Will override :attr:`channel_list`. (default: :obj:`None`) hidden_channels (int, optional): Size of each hidden sample. Will override :attr:`channel_list`. (default: :obj:`None`) out_channels (int, optional): Size of each output sample. Will override :attr:`channel_list`. (default: :obj:`None`) num_layers (int, optional): The number of layers. Will override :attr:`channel_list`. (default: :obj:`None`) dropout (float or List[float], optional): Dropout probability of each hidden embedding. If a list is provided, sets the dropout value per layer. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`"batch_norm"`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) plain_last (bool, optional): If set to :obj:`False`, will apply non-linearity, batch normalization and dropout to the last layer as well. (default: :obj:`True`) bias (bool or List[bool], optional): If set to :obj:`False`, the module will not learn additive biases. If a list is provided, sets the bias per layer. (default: :obj:`True`) **kwargs (optional): Additional deprecated arguments of the MLP layer. """ supports_norm_batch: Final[bool] def __init__( self, channel_list: Optional[Union[List[int], int]] = None, *, in_channels: Optional[int] = None, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: Optional[int] = None, dropout: Union[float, List[float]] = 0., act: Union[str, Callable, None] = "relu", act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Union[str, Callable, None] = "batch_norm", norm_kwargs: Optional[Dict[str, Any]] = None, plain_last: bool = True, bias: Union[bool, List[bool]] = True, **kwargs, ): super().__init__() # Backward compatibility: act_first = act_first or kwargs.get("relu_first", False) batch_norm = kwargs.get("batch_norm", None) if batch_norm is not None and isinstance(batch_norm, bool): warnings.warn( "Argument `batch_norm` is deprecated, " "please use `norm` to specify normalization layer.", stacklevel=2) norm = 'batch_norm' if batch_norm else None batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None) norm_kwargs = batch_norm_kwargs or {} if isinstance(channel_list, int): in_channels = channel_list if in_channels is not None: if num_layers is None: raise ValueError("Argument `num_layers` must be given") if num_layers > 1 and hidden_channels is None: raise ValueError(f"Argument `hidden_channels` must be given " f"for `num_layers={num_layers}`") if out_channels is None: raise ValueError("Argument `out_channels` must be given") channel_list = [hidden_channels] * (num_layers - 1) channel_list = [in_channels] + channel_list + [out_channels] assert isinstance(channel_list, (tuple, list)) assert len(channel_list) >= 2 self.channel_list = channel_list self.act = activation_resolver(act, **(act_kwargs or {})) self.act_first = act_first self.plain_last = plain_last if isinstance(dropout, float): dropout = [dropout] * (len(channel_list) - 1) if plain_last: dropout[-1] = 0. if len(dropout) != len(channel_list) - 1: raise ValueError( f"Number of dropout values provided ({len(dropout)} does not " f"match the number of layers specified " f"({len(channel_list)-1})") self.dropout = dropout if isinstance(bias, bool): bias = [bias] * (len(channel_list) - 1) if len(bias) != len(channel_list) - 1: raise ValueError( f"Number of bias values provided ({len(bias)}) does not match " f"the number of layers specified ({len(channel_list)-1})") self.lins = torch.nn.ModuleList() iterator = zip(channel_list[:-1], channel_list[1:], bias) for in_channels, out_channels, _bias in iterator: self.lins.append(Linear(in_channels, out_channels, bias=_bias)) self.norms = torch.nn.ModuleList() iterator = channel_list[1:-1] if plain_last else channel_list[1:] for hidden_channels in iterator: if norm is not None: norm_layer = normalization_resolver( norm, hidden_channels, **(norm_kwargs or {}), ) else: norm_layer = Identity() self.norms.append(norm_layer) self.supports_norm_batch = False if len(self.norms) > 0 and hasattr(self.norms[0], 'forward'): norm_params = inspect.signature(self.norms[0].forward).parameters self.supports_norm_batch = 'batch' in norm_params self.reset_parameters() @property def in_channels(self) -> int: r"""Size of each input sample.""" return self.channel_list[0] @property def out_channels(self) -> int: r"""Size of each output sample.""" return self.channel_list[-1] @property def num_layers(self) -> int: r"""The number of layers.""" return len(self.channel_list) - 1 def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for lin in self.lins: lin.reset_parameters() for norm in self.norms: if hasattr(norm, 'reset_parameters'): norm.reset_parameters() def forward( self, x: Tensor, batch: Optional[Tensor] = None, batch_size: Optional[int] = None, return_emb: NoneType = None, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) return_emb (bool, optional): If set to :obj:`True`, will additionally return the embeddings before execution of the final output layer. (default: :obj:`False`) """ # `return_emb` is annotated here as `NoneType` to be compatible with # TorchScript, which does not support different return types based on # the value of an input argument. emb: Optional[Tensor] = None # If `plain_last=True`, then `len(norms) = len(lins) -1, thus skipping # the execution of the last layer inside the for-loop. for i, (lin, norm) in enumerate(zip(self.lins, self.norms)): x = lin(x) if self.act is not None and self.act_first: x = self.act(x) if self.supports_norm_batch: x = norm(x, batch, batch_size) else: x = norm(x) if self.act is not None and not self.act_first: x = self.act(x) x = F.dropout(x, p=self.dropout[i], training=self.training) if isinstance(return_emb, bool) and return_emb is True: emb = x if self.plain_last: x = self.lins[-1](x) x = F.dropout(x, p=self.dropout[-1], training=self.training) return (x, emb) if isinstance(return_emb, bool) else x def __repr__(self) -> str: return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})' ================================================ FILE: torch_geometric/nn/models/neural_fingerprint.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.nn import Linear, MFConv, global_add_pool from torch_geometric.typing import Adj class NeuralFingerprint(torch.nn.Module): r"""The Neural Fingerprint model from the `"Convolutional Networks on Graphs for Learning Molecular Fingerprints" `__ paper to generate fingerprints of molecules. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output fingerprint. num_layers (int): Number of layers. **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MFConv`. """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, **kwargs, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.num_layers = num_layers self.convs = torch.nn.ModuleList() for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.hidden_channels self.convs.append(MFConv(in_channels, hidden_channels, **kwargs)) self.lins = torch.nn.ModuleList() for _ in range(self.num_layers): self.lins.append(Linear(hidden_channels, out_channels, bias=False)) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.convs: conv.reset_parameters() for lin in self.lins: lin.reset_parameters() def forward( self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None, batch_size: Optional[int] = None, ) -> Tensor: """""" # noqa: D419 outs = [] for conv, lin in zip(self.convs, self.lins): x = conv(x, edge_index).sigmoid() y = lin(x).softmax(dim=-1) outs.append(global_add_pool(y, batch, batch_size)) return sum(outs) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_layers={self.num_layers})') ================================================ FILE: torch_geometric/nn/models/node2vec.py ================================================ from typing import List, Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Embedding from torch.utils.data import DataLoader from torch_geometric.index import index2ptr from torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_CLUSTER from torch_geometric.utils import sort_edge_index from torch_geometric.utils.num_nodes import maybe_num_nodes class Node2Vec(torch.nn.Module): r"""The Node2Vec model from the `"node2vec: Scalable Feature Learning for Networks" `_ paper where random walks of length :obj:`walk_length` are sampled in a given graph, and node embeddings are learned via negative sampling optimization. .. note:: For an example of using Node2Vec, see `examples/node2vec.py `_. Args: edge_index (torch.Tensor): The edge indices. embedding_dim (int): The size of each embedding vector. walk_length (int): The walk length. context_size (int): The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes. walks_per_node (int, optional): The number of walks to sample for each node. (default: :obj:`1`) p (float, optional): Likelihood of immediately revisiting a node in the walk. (default: :obj:`1`) q (float, optional): Control parameter to interpolate between breadth-first strategy and depth-first strategy (default: :obj:`1`) num_negative_samples (int, optional): The number of negative samples to use for each positive sample. (default: :obj:`1`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the weight matrix will be sparse. (default: :obj:`False`) """ def __init__( self, edge_index: Tensor, embedding_dim: int, walk_length: int, context_size: int, walks_per_node: int = 1, p: float = 1.0, q: float = 1.0, num_negative_samples: int = 1, num_nodes: Optional[int] = None, sparse: bool = False, ): super().__init__() if WITH_PYG_LIB and p == 1.0 and q == 1.0: self.random_walk_fn = torch.ops.pyg.random_walk elif WITH_TORCH_CLUSTER: self.random_walk_fn = torch.ops.torch_cluster.random_walk else: if p == 1.0 and q == 1.0: raise ImportError(f"'{self.__class__.__name__}' " f"requires either the 'pyg-lib' or " f"'torch-cluster' package") else: raise ImportError(f"'{self.__class__.__name__}' " f"requires the 'torch-cluster' package") self.num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = sort_edge_index(edge_index, num_nodes=self.num_nodes).cpu() self.rowptr, self.col = index2ptr(row, self.num_nodes), col self.EPS = 1e-15 assert walk_length >= context_size self.embedding_dim = embedding_dim self.walk_length = walk_length - 1 self.context_size = context_size self.walks_per_node = walks_per_node self.p = p self.q = q self.num_negative_samples = num_negative_samples self.embedding = Embedding(self.num_nodes, embedding_dim, sparse=sparse) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.embedding.reset_parameters() def forward(self, batch: Optional[Tensor] = None) -> Tensor: """Returns the embeddings for the nodes in :obj:`batch`.""" emb = self.embedding.weight return emb if batch is None else emb[batch] def loader(self, **kwargs) -> DataLoader: return DataLoader(range(self.num_nodes), collate_fn=self.sample, **kwargs) @torch.jit.export def pos_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node) rw = self.random_walk_fn(self.rowptr, self.col, batch, self.walk_length, self.p, self.q) if not isinstance(rw, Tensor): rw = rw[0] walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) @torch.jit.export def neg_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rw = torch.randint(self.num_nodes, (batch.size(0), self.walk_length), dtype=batch.dtype, device=batch.device) rw = torch.cat([batch.view(-1, 1), rw], dim=-1) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) @torch.jit.export def sample(self, batch: Union[List[int], Tensor]) -> Tuple[Tensor, Tensor]: if not isinstance(batch, Tensor): batch = torch.tensor(batch) return self.pos_sample(batch), self.neg_sample(batch) @torch.jit.export def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor: r"""Computes the loss given positive and negative random walks.""" # Positive loss. start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) pos_loss = -torch.log(torch.sigmoid(out) + self.EPS).mean() # Negative loss. start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) neg_loss = -torch.log(1 - torch.sigmoid(out) + self.EPS).mean() return pos_loss + neg_loss def test( self, train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', *args, **kwargs, ) -> float: r"""Evaluates latent space quality via a logistic regression downstream task. """ from sklearn.linear_model import LogisticRegression clf = LogisticRegression(*args, solver=solver, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.embedding.weight.size(0)}, ' f'{self.embedding.weight.size(1)})') ================================================ FILE: torch_geometric/nn/models/pmlp.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn import SimpleConv from torch_geometric.nn.dense.linear import Linear class PMLP(torch.nn.Module): r"""The P(ropagational)MLP model from the `"Graph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPs" `_ paper. :class:`PMLP` is identical to a standard MLP during training, but then adopts a GNN architecture during testing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output sample. num_layers (int): The number of layers. dropout (float, optional): Dropout probability of each hidden embedding. (default: :obj:`0.`) norm (bool, optional): If set to :obj:`False`, will not apply batch normalization. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the module will not learn additive biases. (default: :obj:`True`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, dropout: float = 0., norm: bool = True, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.num_layers = num_layers self.dropout = dropout self.bias = bias self.lins = torch.nn.ModuleList() self.lins.append(Linear(in_channels, hidden_channels, self.bias)) for _ in range(self.num_layers - 2): lin = Linear(hidden_channels, hidden_channels, self.bias) self.lins.append(lin) self.lins.append(Linear(hidden_channels, out_channels, self.bias)) self.norm = None if norm: self.norm = torch.nn.BatchNorm1d( hidden_channels, affine=False, track_running_stats=False, ) self.conv = SimpleConv(aggr='mean', combine_root='self_loop') self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for lin in self.lins: torch.nn.init.xavier_uniform_(lin.weight, gain=1.414) if self.bias: torch.nn.init.zeros_(lin.bias) def forward( self, x: torch.Tensor, edge_index: Optional[Tensor] = None, ) -> torch.Tensor: """""" # noqa: D419 if not self.training and edge_index is None: raise ValueError(f"'edge_index' needs to be present during " f"inference in '{self.__class__.__name__}'") for i in range(self.num_layers): x = x @ self.lins[i].weight.t() if not self.training: x = self.conv(x, edge_index) if self.bias: x = x + self.lins[i].bias if i != self.num_layers - 1: if self.norm is not None: x = self.norm(x) x = x.relu() x = F.dropout(x, p=self.dropout, training=self.training) return x def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_layers={self.num_layers})') ================================================ FILE: torch_geometric/nn/models/polynormer.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn import GATConv, GCNConv from torch_geometric.nn.attention import PolynormerAttention from torch_geometric.utils import to_dense_batch class Polynormer(torch.nn.Module): r"""The polynormer module from the `"Polynormer: polynomial-expressive graph transformer in linear time" `_ paper. Args: in_channels (int): Input channels. hidden_channels (int): Hidden channels. out_channels (int): Output channels. local_layers (int): The number of local attention layers. (default: :obj:`7`) global_layers (int): The number of global attention layers. (default: :obj:`2`) in_dropout (float): Input dropout rate. (default: :obj:`0.15`) dropout (float): Dropout rate. (default: :obj:`0.5`) global_dropout (float): Global dropout rate. (default: :obj:`0.5`) heads (int): The number of heads. (default: :obj:`1`) beta (float): Aggregate type. (default: :obj:`0.9`) qk_shared (bool optional): Whether weight of query and key are shared. (default: :obj:`True`) pre_ln (bool): Pre layer normalization. (default: :obj:`False`) post_bn (bool): Post batch normalization. (default: :obj:`True`) local_attn (bool): Whether use local attention. (default: :obj:`False`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, local_layers: int = 7, global_layers: int = 2, in_dropout: float = 0.15, dropout: float = 0.5, global_dropout: float = 0.5, heads: int = 1, beta: float = 0.9, qk_shared: bool = False, pre_ln: bool = False, post_bn: bool = True, local_attn: bool = False, ) -> None: super().__init__() self._global = False self.in_drop = in_dropout self.dropout = dropout self.pre_ln = pre_ln self.post_bn = post_bn self.beta = beta self.h_lins = torch.nn.ModuleList() self.local_convs = torch.nn.ModuleList() self.lins = torch.nn.ModuleList() self.lns = torch.nn.ModuleList() if self.pre_ln: self.pre_lns = torch.nn.ModuleList() if self.post_bn: self.post_bns = torch.nn.ModuleList() # first layer inner_channels = heads * hidden_channels self.h_lins.append(torch.nn.Linear(in_channels, inner_channels)) if local_attn: self.local_convs.append( GATConv(in_channels, hidden_channels, heads=heads, concat=True, add_self_loops=False, bias=False)) else: self.local_convs.append( GCNConv(in_channels, inner_channels, cached=False, normalize=True)) self.lins.append(torch.nn.Linear(in_channels, inner_channels)) self.lns.append(torch.nn.LayerNorm(inner_channels)) if self.pre_ln: self.pre_lns.append(torch.nn.LayerNorm(in_channels)) if self.post_bn: self.post_bns.append(torch.nn.BatchNorm1d(inner_channels)) # following layers for _ in range(local_layers - 1): self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels)) if local_attn: self.local_convs.append( GATConv(inner_channels, hidden_channels, heads=heads, concat=True, add_self_loops=False, bias=False)) else: self.local_convs.append( GCNConv(inner_channels, inner_channels, cached=False, normalize=True)) self.lins.append(torch.nn.Linear(inner_channels, inner_channels)) self.lns.append(torch.nn.LayerNorm(inner_channels)) if self.pre_ln: self.pre_lns.append(torch.nn.LayerNorm(heads * hidden_channels)) if self.post_bn: self.post_bns.append(torch.nn.BatchNorm1d(inner_channels)) self.lin_in = torch.nn.Linear(in_channels, inner_channels) self.ln = torch.nn.LayerNorm(inner_channels) self.global_attn = torch.nn.ModuleList() for _ in range(global_layers): self.global_attn.append( PolynormerAttention( channels=hidden_channels, heads=heads, head_channels=hidden_channels, beta=beta, dropout=global_dropout, qk_shared=qk_shared, )) self.pred_local = torch.nn.Linear(inner_channels, out_channels) self.pred_global = torch.nn.Linear(inner_channels, out_channels) self.reset_parameters() def reset_parameters(self) -> None: for local_conv in self.local_convs: local_conv.reset_parameters() for attn in self.global_attn: attn.reset_parameters() for lin in self.lins: lin.reset_parameters() for h_lin in self.h_lins: h_lin.reset_parameters() for ln in self.lns: ln.reset_parameters() if self.pre_ln: for p_ln in self.pre_lns: p_ln.reset_parameters() if self.post_bn: for p_bn in self.post_bns: p_bn.reset_parameters() self.lin_in.reset_parameters() self.ln.reset_parameters() self.pred_local.reset_parameters() self.pred_global.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, batch: Optional[Tensor], ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. """ x = F.dropout(x, p=self.in_drop, training=self.training) # equivariant local attention x_local = 0 for i, local_conv in enumerate(self.local_convs): if self.pre_ln: x = self.pre_lns[i](x) h = self.h_lins[i](x) h = F.relu(h) x = local_conv(x, edge_index) + self.lins[i](x) if self.post_bn: x = self.post_bns[i](x) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x x_local = x_local + x # equivariant global attention if self._global: batch, indices = batch.sort() rev_perm = torch.empty_like(indices) rev_perm[indices] = torch.arange(len(indices), device=indices.device) x_local = self.ln(x_local[indices]) x_global, mask = to_dense_batch(x_local, batch) for attn in self.global_attn: x_global = attn(x_global, mask) x = x_global[mask][rev_perm] x = self.pred_global(x) else: x = self.pred_local(x_local) return F.log_softmax(x, dim=-1) ================================================ FILE: torch_geometric/nn/models/re_net.py ================================================ import math from typing import Callable, List, Tuple import torch import torch.nn.functional as F from torch import Tensor from torch.nn import GRU, Linear, Parameter from torch_geometric.data.data import Data from torch_geometric.utils import scatter class RENet(torch.nn.Module): r"""The Recurrent Event Network model from the `"Recurrent Event Network for Reasoning over Temporal Knowledge Graphs" `_ paper. .. math:: f_{\mathbf{\Theta}}(\mathbf{e}_s, \mathbf{e}_r, \mathbf{h}^{(t-1)}(s, r)) based on a RNN encoder .. math:: \mathbf{h}^{(t)}(s, r) = \textrm{RNN}(\mathbf{e}_s, \mathbf{e}_r, g(\mathcal{O}^{(t)}_r(s)), \mathbf{h}^{(t-1)}(s, r)) where :math:`\mathbf{e}_s` and :math:`\mathbf{e}_r` denote entity and relation embeddings, and :math:`\mathcal{O}^{(t)}_r(s)` represents the set of objects interacted with subject :math:`s` under relation :math:`r` at timestamp :math:`t`. This model implements :math:`g` as the **Mean Aggregator** and :math:`f_{\mathbf{\Theta}}` as a linear projection. Args: num_nodes (int): The number of nodes in the knowledge graph. num_rels (int): The number of relations in the knowledge graph. hidden_channels (int): Hidden size of node and relation embeddings. seq_len (int): The sequence length of past events. num_layers (int, optional): The number of recurrent layers. (default: :obj:`1`) dropout (float): If non-zero, introduces a dropout layer before the final prediction. (default: :obj:`0.`) bias (bool, optional): If set to :obj:`False`, all layers will not learn an additive bias. (default: :obj:`True`) """ def __init__( self, num_nodes: int, num_rels: int, hidden_channels: int, seq_len: int, num_layers: int = 1, dropout: float = 0., bias: bool = True, ): super().__init__() self.num_nodes = num_nodes self.hidden_channels = hidden_channels self.num_rels = num_rels self.seq_len = seq_len self.dropout = dropout self.ent = Parameter(torch.empty(num_nodes, hidden_channels)) self.rel = Parameter(torch.empty(num_rels, hidden_channels)) self.sub_gru = GRU(3 * hidden_channels, hidden_channels, num_layers, batch_first=True, bias=bias) self.obj_gru = GRU(3 * hidden_channels, hidden_channels, num_layers, batch_first=True, bias=bias) self.sub_lin = Linear(3 * hidden_channels, num_nodes, bias=bias) self.obj_lin = Linear(3 * hidden_channels, num_nodes, bias=bias) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.ent, gain=math.sqrt(2.0)) torch.nn.init.xavier_uniform_(self.rel, gain=math.sqrt(2.0)) self.sub_gru.reset_parameters() self.obj_gru.reset_parameters() self.sub_lin.reset_parameters() self.obj_lin.reset_parameters() @staticmethod def pre_transform(seq_len: int) -> Callable: r"""Precomputes history objects. .. math:: \{ \mathcal{O}^{(t-k-1)}_r(s), \ldots, \mathcal{O}^{(t-1)}_r(s) \} of a :class:`torch_geometric.datasets.icews.EventDataset` with :math:`k` denoting the sequence length :obj:`seq_len`. """ class PreTransform: def __init__(self, seq_len: int): self.seq_len = seq_len self.inc = 5000 self.t_last = 0 self.sub_hist = self.increase_hist_node_size([]) self.obj_hist = self.increase_hist_node_size([]) def increase_hist_node_size(self, hist: List[int]) -> List[int]: hist_inc = torch.zeros((self.inc, self.seq_len + 1, 0)) return hist + hist_inc.tolist() def get_history( self, hist: List[int], node: int, rel: int, ) -> Tuple[Tensor, Tensor]: hists, ts = [], [] for s in range(seq_len): h = hist[node][s] hists += h ts.append(torch.full((len(h), ), s, dtype=torch.long)) node, r = torch.tensor(hists, dtype=torch.long).view( -1, 2).t().contiguous() node = node[r == rel] t = torch.cat(ts, dim=0)[r == rel] return node, t def step(self, hist: List[int]) -> List[int]: for i in range(len(hist)): hist[i] = hist[i][1:] hist[i].append([]) return hist def __call__(self, data: Data) -> Data: sub, rel, obj, t = data.sub, data.rel, data.obj, data.t if max(sub, obj) + 1 > len(self.sub_hist): # pragma: no cover self.sub_hist = self.increase_hist_node_size(self.sub_hist) self.obj_hist = self.increase_hist_node_size(self.obj_hist) # Delete last timestamp in history. if t > self.t_last: self.sub_hist = self.step(self.sub_hist) self.obj_hist = self.step(self.obj_hist) self.t_last = t # Save history in data object. data.h_sub, data.h_sub_t = self.get_history( self.sub_hist, sub, rel) data.h_obj, data.h_obj_t = self.get_history( self.obj_hist, obj, rel) # Add new event to history. self.sub_hist[sub][-1].append([obj, rel]) self.obj_hist[obj][-1].append([sub, rel]) return data def __repr__(self) -> str: # pragma: no cover return f'{self.__class__.__name__}(seq_len={self.seq_len})' return PreTransform(seq_len) def forward(self, data: Data) -> Tuple[Tensor, Tensor]: """Given a :obj:`data` batch, computes the forward pass. Args: data (torch_geometric.data.Data): The input data, holding subject :obj:`sub`, relation :obj:`rel` and object :obj:`obj` information with shape :obj:`[batch_size]`. In addition, :obj:`data` needs to hold history information for subjects, given by a vector of node indices :obj:`h_sub` and their relative timestamps :obj:`h_sub_t` and batch assignments :obj:`h_sub_batch`. The same information must be given for objects (:obj:`h_obj`, :obj:`h_obj_t`, :obj:`h_obj_batch`). """ assert 'h_sub_batch' in data and 'h_obj_batch' in data batch_size, seq_len = data.sub.size(0), self.seq_len h_sub_t = data.h_sub_t + data.h_sub_batch * seq_len h_obj_t = data.h_obj_t + data.h_obj_batch * seq_len h_sub = scatter(self.ent[data.h_sub], h_sub_t, dim=0, dim_size=batch_size * seq_len, reduce='mean').view(batch_size, seq_len, -1) h_obj = scatter(self.ent[data.h_obj], h_obj_t, dim=0, dim_size=batch_size * seq_len, reduce='mean').view(batch_size, seq_len, -1) sub = self.ent[data.sub].unsqueeze(1).repeat(1, seq_len, 1) rel = self.rel[data.rel].unsqueeze(1).repeat(1, seq_len, 1) obj = self.ent[data.obj].unsqueeze(1).repeat(1, seq_len, 1) _, h_sub = self.sub_gru(torch.cat([sub, h_sub, rel], dim=-1)) _, h_obj = self.obj_gru(torch.cat([obj, h_obj, rel], dim=-1)) h_sub, h_obj = h_sub.squeeze(0), h_obj.squeeze(0) h_sub = torch.cat([self.ent[data.sub], h_sub, self.rel[data.rel]], dim=-1) h_obj = torch.cat([self.ent[data.obj], h_obj, self.rel[data.rel]], dim=-1) h_sub = F.dropout(h_sub, p=self.dropout, training=self.training) h_obj = F.dropout(h_obj, p=self.dropout, training=self.training) log_prob_obj = F.log_softmax(self.sub_lin(h_sub), dim=1) log_prob_sub = F.log_softmax(self.obj_lin(h_obj), dim=1) return log_prob_obj, log_prob_sub def test(self, logits: Tensor, y: Tensor) -> Tensor: """Given ground-truth :obj:`y`, computes Mean Reciprocal Rank (MRR) and Hits at 1/3/10. """ _, perm = logits.sort(dim=1, descending=True) mask = (y.view(-1, 1) == perm) nnz = mask.nonzero(as_tuple=False) mrr = (1 / (nnz[:, -1] + 1).to(torch.float)).mean().item() hits1 = mask[:, :1].sum().item() / y.size(0) hits3 = mask[:, :3].sum().item() / y.size(0) hits10 = mask[:, :10].sum().item() / y.size(0) return torch.tensor([mrr, hits1, hits3, hits10]) ================================================ FILE: torch_geometric/nn/models/rect.py ================================================ import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Linear from torch_geometric.nn import GCNConv from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import scatter class RECT_L(torch.nn.Module): r"""The RECT model, *i.e.* its supervised RECT-L part, from the `"Network Embedding with Completely-imbalanced Labels" `_ paper. In particular, a GCN model is trained that reconstructs semantic class knowledge. .. note:: For an example of using RECT, see `examples/rect.py `_. Args: in_channels (int): Size of each input sample. hidden_channels (int): Intermediate size of each sample. normalize (bool, optional): Whether to add self-loops and compute symmetric normalization coefficients on-the-fly. (default: :obj:`True`) dropout (float, optional): The dropout probability. (default: :obj:`0.0`) """ def __init__(self, in_channels: int, hidden_channels: int, normalize: bool = True, dropout: float = 0.0): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.dropout = dropout self.conv = GCNConv(in_channels, hidden_channels, normalize=normalize) self.lin = Linear(hidden_channels, in_channels) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.conv.reset_parameters() self.lin.reset_parameters() torch.nn.init.xavier_uniform_(self.lin.weight.data) def forward( self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, ) -> Tensor: """""" # noqa: D419 x = self.conv(x, edge_index, edge_weight) x = F.dropout(x, p=self.dropout, training=self.training) return self.lin(x) @torch.jit.export def embed( self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, ) -> Tensor: with torch.no_grad(): return self.conv(x, edge_index, edge_weight) @torch.jit.export def get_semantic_labels( self, x: Tensor, y: Tensor, mask: Tensor, ) -> Tensor: r"""Replaces the original labels by their class-centers.""" with torch.no_grad(): y = y[mask] mean = scatter(x[mask], y, dim=0, reduce='mean') return mean[y] def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.hidden_channels})') ================================================ FILE: torch_geometric/nn/models/rev_gnn.py ================================================ import copy from typing import Any, List, Optional, Union import numpy as np import torch from torch import Tensor import torch_geometric.typing from torch_geometric.typing import Adj class InvertibleFunction(torch.autograd.Function): r"""An invertible autograd function. This allows for automatic backpropagation in a reversible fashion so that the memory of intermediate results can be freed during the forward pass and be constructed on-the-fly during the bachward pass. Args: ctx (torch.autograd.function.InvertibleFunctionBackward): A context object that can be used to stash information for backward computation. fn (torch.nn.Module): The forward function. fn_inverse (torch.nn.Module): The inverse function to recompute the freed input. num_bwd_passes (int): Number of backward passes to retain a link with the output. After the last backward pass the output is discarded and memory is freed. num_inputs (int): The number of inputs to the forward function. *args (tuple): Inputs and weights. """ @staticmethod def forward(ctx, fn: torch.nn.Module, fn_inverse: torch.nn.Module, num_bwd_passes: int, num_inputs: int, *args): ctx.fn = fn ctx.fn_inverse = fn_inverse ctx.weights = args[num_inputs:] ctx.num_bwd_passes = num_bwd_passes ctx.num_inputs = num_inputs inputs = args[:num_inputs] ctx.input_requires_grad = [] with torch.no_grad(): # Make a detached copy which shares the storage: x = [] for element in inputs: if isinstance(element, torch.Tensor): x.append(element.detach()) ctx.input_requires_grad.append(element.requires_grad) else: x.append(element) ctx.input_requires_grad.append(None) outputs = ctx.fn(*x) if not isinstance(outputs, tuple): outputs = (outputs, ) # Detaches outputs in-place, allows discarding the intermedate result: detached_outputs = tuple(element.detach_() for element in outputs) # Clear memory of node features: if torch_geometric.typing.WITH_PT20: inputs[0].untyped_storage().resize_(0) else: # pragma: no cover inputs[0].storage().resize_(0) # Store these tensor nodes for backward passes: ctx.inputs = [inputs] * num_bwd_passes ctx.outputs = [detached_outputs] * num_bwd_passes return detached_outputs @staticmethod def backward(ctx, *grad_outputs): if len(ctx.outputs) == 0: raise RuntimeError( f"Trying to perform a backward pass on the " f"'InvertibleFunction' for more than '{ctx.num_bwd_passes}' " f"times. Try raising 'num_bwd_passes'.") inputs = ctx.inputs.pop() outputs = ctx.outputs.pop() # Recompute input by swapping out the first argument: with torch.no_grad(): inputs_inverted = ctx.fn_inverse(*(outputs + inputs[1:])) if len(ctx.outputs) == 0: # Clear memory from outputs: for element in outputs: if torch_geometric.typing.WITH_PT20: element.untyped_storage().resize_(0) else: # pragma: no cover element.storage().resize_(0) if not isinstance(inputs_inverted, tuple): inputs_inverted = (inputs_inverted, ) for elem_orig, elem_inv in zip(inputs, inputs_inverted): if torch_geometric.typing.WITH_PT20: elem_orig.untyped_storage().resize_( int(np.prod(elem_orig.size())) * elem_orig.element_size()) else: # pragma: no cover elem_orig.storage().resize_(int(np.prod(elem_orig.size()))) elem_orig.set_(elem_inv) # Compute gradients with grad enabled: with torch.set_grad_enabled(True): detached_inputs = [] for element in inputs: if isinstance(element, torch.Tensor): detached_inputs.append(element.detach()) else: detached_inputs.append(element) detached_inputs = tuple(detached_inputs) for x, req_grad in zip(detached_inputs, ctx.input_requires_grad): if isinstance(x, torch.Tensor): x.requires_grad = req_grad tmp_output = ctx.fn(*detached_inputs) if not isinstance(tmp_output, tuple): tmp_output = (tmp_output, ) filtered_detached_inputs = tuple( filter( lambda x: x.requires_grad if isinstance(x, torch.Tensor) else False, detached_inputs, )) gradients = torch.autograd.grad( outputs=tmp_output, inputs=filtered_detached_inputs + ctx.weights, grad_outputs=grad_outputs, ) input_gradients = [] i = 0 for rg in ctx.input_requires_grad: if rg: input_gradients.append(gradients[i]) i += 1 else: input_gradients.append(None) gradients = tuple(input_gradients) + gradients[-len(ctx.weights):] return (None, None, None, None) + gradients class InvertibleModule(torch.nn.Module): r"""An abstract class for implementing invertible modules. Args: disable (bool, optional): If set to :obj:`True`, will disable the usage of :class:`InvertibleFunction` and will execute the module without memory savings. (default: :obj:`False`) num_bwd_passes (int, optional): Number of backward passes to retain a link with the output. After the last backward pass the output is discarded and memory is freed. (default: :obj:`1`) """ def __init__(self, disable: bool = False, num_bwd_passes: int = 1): super().__init__() self.disable = disable self.num_bwd_passes = num_bwd_passes def forward(self, *args): """""" # noqa: D419 return self._fn_apply(args, self._forward, self._inverse) def inverse(self, *args): return self._fn_apply(args, self._inverse, self._forward) def _forward(self): raise NotImplementedError def _inverse(self): raise NotImplementedError def _fn_apply(self, args, fn, fn_inverse): if not self.disable: out = InvertibleFunction.apply( fn, fn_inverse, self.num_bwd_passes, len(args), *args, *tuple(p for p in self.parameters() if p.requires_grad), ) else: out = fn(*args) # If the layer only has one input, we unpack the tuple: if isinstance(out, tuple) and len(out) == 1: return out[0] return out class GroupAddRev(InvertibleModule): r"""The Grouped Reversible GNN module from the `"Graph Neural Networks with 1000 Layers" `_ paper. This module enables training of arbitrary deep GNNs with a memory complexity independent of the number of layers. It does so by partitioning input node features :math:`\mathbf{X}` into :math:`C` groups across the feature dimension. Then, a grouped reversible GNN block :math:`f_{\theta(i)}` operates on a group of inputs and produces a group of outputs: .. math:: \mathbf{X}^{\prime}_0 &= \sum_{i=2}^C \mathbf{X}_i \mathbf{X}^{\prime}_i &= f_{\theta(i)} ( \mathbf{X}^{\prime}_{i - 1}, \mathbf{A}) + \mathbf{X}_i for all :math:`i \in \{ 1, \ldots, C \}`. .. note:: For an example of using :class:`GroupAddRev`, see `examples/rev_gnn.py `_. Args: conv (torch.nn.Module or torch.nn.ModuleList]): A seed GNN. The input and output feature dimensions need to match. split_dim (int, optional): The dimension across which to split groups. (default: :obj:`-1`) num_groups (int, optional): The number of groups :math:`C`. (default: :obj:`None`) disable (bool, optional): If set to :obj:`True`, will disable the usage of :class:`InvertibleFunction` and will execute the module without memory savings. (default: :obj:`False`) num_bwd_passes (int, optional): Number of backward passes to retain a link with the output. After the last backward pass the output is discarded and memory is freed. (default: :obj:`1`) """ def __init__( self, conv: Union[torch.nn.Module, torch.nn.ModuleList], split_dim: int = -1, num_groups: Optional[int] = None, disable: bool = False, num_bwd_passes: int = 1, ): super().__init__(disable, num_bwd_passes) self.split_dim = split_dim if isinstance(conv, torch.nn.ModuleList): self.convs = conv else: assert num_groups is not None, "Please specific 'num_groups'" self.convs = torch.nn.ModuleList([conv]) for _ in range(num_groups - 1): conv = copy.deepcopy(self.convs[0]) if hasattr(conv, 'reset_parameters'): conv.reset_parameters() self.convs.append(conv) if len(self.convs) < 2: raise ValueError(f"The number of groups should not be smaller " f"than '2' (got '{self.num_groups}'))") @property def num_groups(self) -> int: return len(self.convs) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.convs: conv.reset_parameters() def _forward(self, x: Tensor, edge_index: Adj, *args): channels = x.size(self.split_dim) xs = self._chunk(x, channels) args = list(zip(*[self._chunk(arg, channels) for arg in args])) args = [[]] * self.num_groups if len(args) == 0 else args ys = [] y_in = sum(xs[1:]) for i in range(self.num_groups): y_in = xs[i] + self.convs[i](y_in, edge_index, *args[i]) ys.append(y_in) return torch.cat(ys, dim=self.split_dim) def _inverse(self, y: Tensor, edge_index: Adj, *args): channels = y.size(self.split_dim) ys = self._chunk(y, channels) args = list(zip(*[self._chunk(arg, channels) for arg in args])) args = [[]] * self.num_groups if len(args) == 0 else args xs = [] for i in range(self.num_groups - 1, -1, -1): if i != 0: y_in = ys[i - 1] else: y_in = sum(xs) x = ys[i] - self.convs[i](y_in, edge_index, *args[i]) xs.append(x) return torch.cat(xs[::-1], dim=self.split_dim) def _chunk(self, x: Any, channels: int) -> List[Any]: if not isinstance(x, Tensor): return [x] * self.num_groups try: if x.size(self.split_dim) != channels: return [x] * self.num_groups except IndexError: return [x] * self.num_groups return torch.chunk(x, self.num_groups, dim=self.split_dim) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.convs[0]}, ' f'num_groups={self.num_groups})') ================================================ FILE: torch_geometric/nn/models/schnet.py ================================================ import os import os.path as osp import warnings from math import pi as PI from typing import Callable, Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding, Linear, ModuleList, Sequential from torch_geometric.data import Dataset, download_url, extract_zip from torch_geometric.io import fs from torch_geometric.nn import MessagePassing, SumAggregation, radius_graph from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver from torch_geometric.typing import OptTensor qm9_target_dict: Dict[int, str] = { 0: 'dipole_moment', 1: 'isotropic_polarizability', 2: 'homo', 3: 'lumo', 4: 'gap', 5: 'electronic_spatial_extent', 6: 'zpve', 7: 'energy_U0', 8: 'energy_U', 9: 'enthalpy_H', 10: 'free_energy', 11: 'heat_capacity', } class SchNet(torch.nn.Module): r"""The continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" `_ paper that uses the interactions blocks of the form. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), here :math:`h_{\mathbf{\Theta}}` denotes an MLP and :math:`\mathbf{e}_{j,i}` denotes the interatomic distances between atoms. .. note:: For an example of using a pretrained SchNet variant, see `examples/qm9_pretrained_schnet.py `_. Args: hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) num_filters (int, optional): The number of filters to use. (default: :obj:`128`) num_interactions (int, optional): The number of interaction blocks. (default: :obj:`6`) num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`) interaction_graph (callable, optional): The function used to compute the pairwise interaction graph and interatomic distances. If set to :obj:`None`, will construct a graph based on :obj:`cutoff` and :obj:`max_num_neighbors` properties. If provided, this method takes in :obj:`pos` and :obj:`batch` tensors and should return :obj:`(edge_index, edge_weight)` tensors. (default :obj:`None`) cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) readout (str, optional): Whether to apply :obj:`"add"` or :obj:`"mean"` global aggregation. (default: :obj:`"add"`) dipole (bool, optional): If set to :obj:`True`, will use the magnitude of the dipole moment to make the final prediction, *e.g.*, for target 0 of :class:`torch_geometric.datasets.QM9`. (default: :obj:`False`) mean (float, optional): The mean of the property to predict. (default: :obj:`None`) std (float, optional): The standard deviation of the property to predict. (default: :obj:`None`) atomref (torch.Tensor, optional): The reference of single-atom properties. Expects a vector of shape :obj:`(max_atomic_number, )`. """ url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip' def __init__( self, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: OptTensor = None, ): super().__init__() self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_interactions = num_interactions self.num_gaussians = num_gaussians self.cutoff = cutoff self.dipole = dipole self.sum_aggr = SumAggregation() self.readout = aggr_resolver('sum' if self.dipole else readout) self.mean = mean self.std = std self.scale = None if self.dipole: import ase atomic_mass = torch.from_numpy(ase.data.atomic_masses) self.register_buffer('atomic_mass', atomic_mass) # Support z == 0 for padding atoms so that their embedding vectors # are zeroed and do not receive any gradients. self.embedding = Embedding(100, hidden_channels, padding_idx=0) if interaction_graph is not None: self.interaction_graph = interaction_graph else: self.interaction_graph = RadiusInteractionGraph( cutoff, max_num_neighbors) self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) self.interactions = ModuleList() for _ in range(num_interactions): block = InteractionBlock(hidden_channels, num_gaussians, num_filters, cutoff) self.interactions.append(block) self.lin1 = Linear(hidden_channels, hidden_channels // 2) self.act = ShiftedSoftplus() self.lin2 = Linear(hidden_channels // 2, 1) self.register_buffer('initial_atomref', atomref) self.atomref = None if atomref is not None: self.atomref = Embedding(100, 1) self.atomref.weight.data.copy_(atomref) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.embedding.reset_parameters() for interaction in self.interactions: interaction.reset_parameters() torch.nn.init.xavier_uniform_(self.lin1.weight) self.lin1.bias.data.fill_(0) torch.nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) if self.atomref is not None: self.atomref.weight.data.copy_(self.initial_atomref) @staticmethod def from_qm9_pretrained( root: str, dataset: Dataset, target: int, ) -> Tuple['SchNet', Dataset, Dataset, Dataset]: # pragma: no cover r"""Returns a pre-trained :class:`SchNet` model on the :class:`~torch_geometric.datasets.QM9` dataset, trained on the specified target :obj:`target`. """ import ase import schnetpack as spk # noqa assert target >= 0 and target <= 12 is_dipole = target == 0 units = [1] * 12 units[0] = ase.units.Debye units[1] = ase.units.Bohr**3 units[5] = ase.units.Bohr**2 root = osp.expanduser(osp.normpath(root)) os.makedirs(root, exist_ok=True) folder = 'trained_schnet_models' if not osp.exists(osp.join(root, folder)): path = download_url(SchNet.url, root) extract_zip(path, root) os.unlink(path) name = f'qm9_{qm9_target_dict[target]}' path = osp.join(root, 'trained_schnet_models', name, 'split.npz') split = np.load(path) train_idx = split['train_idx'] val_idx = split['val_idx'] test_idx = split['test_idx'] # Filter the splits to only contain characterized molecules. idx = dataset.data.idx assoc = idx.new_empty(idx.max().item() + 1) assoc[idx] = torch.arange(idx.size(0)) train_idx = assoc[train_idx[np.isin(train_idx, idx)]] val_idx = assoc[val_idx[np.isin(val_idx, idx)]] test_idx = assoc[test_idx[np.isin(test_idx, idx)]] path = osp.join(root, 'trained_schnet_models', name, 'best_model') with warnings.catch_warnings(): warnings.simplefilter('ignore') state = fs.torch_load(path, map_location='cpu') net = SchNet( hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0, dipole=is_dipole, atomref=dataset.atomref(target), ) net.embedding.weight = state.representation.embedding.weight for int1, int2 in zip(state.representation.interactions, net.interactions): int2.mlp[0].weight = int1.filter_network[0].weight int2.mlp[0].bias = int1.filter_network[0].bias int2.mlp[2].weight = int1.filter_network[1].weight int2.mlp[2].bias = int1.filter_network[1].bias int2.lin.weight = int1.dense.weight int2.lin.bias = int1.dense.bias int2.conv.lin1.weight = int1.cfconv.in2f.weight int2.conv.lin2.weight = int1.cfconv.f2out.weight int2.conv.lin2.bias = int1.cfconv.f2out.bias net.lin1.weight = state.output_modules[0].out_net[1].out_net[0].weight net.lin1.bias = state.output_modules[0].out_net[1].out_net[0].bias net.lin2.weight = state.output_modules[0].out_net[1].out_net[1].weight net.lin2.bias = state.output_modules[0].out_net[1].out_net[1].bias mean = state.output_modules[0].atom_pool.average net.readout = aggr_resolver('mean' if mean is True else 'add') dipole = state.output_modules[0].__class__.__name__ == 'DipoleMoment' net.dipole = dipole net.mean = state.output_modules[0].standardize.mean.item() net.std = state.output_modules[0].standardize.stddev.item() if state.output_modules[0].atomref is not None: net.atomref.weight = state.output_modules[0].atomref.weight else: net.atomref = None net.scale = 1.0 / units[target] return net, (dataset[train_idx], dataset[val_idx], dataset[test_idx]) def forward(self, z: Tensor, pos: Tensor, batch: OptTensor = None) -> Tensor: r"""Forward pass. Args: z (torch.Tensor): Atomic number of each atom with shape :obj:`[num_atoms]`. pos (torch.Tensor): Coordinates of each atom with shape :obj:`[num_atoms, 3]`. batch (torch.Tensor, optional): Batch indices assigning each atom to a separate molecule with shape :obj:`[num_atoms]`. (default: :obj:`None`) """ batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index, edge_weight = self.interaction_graph(pos, batch) edge_attr = self.distance_expansion(edge_weight) for interaction in self.interactions: h = h + interaction(h, edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) if self.dipole: # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) M = self.sum_aggr(mass, batch, dim=0) c = self.sum_aggr(mass * pos, batch, dim=0) / M h = h * (pos - c.index_select(0, batch)) if not self.dipole and self.mean is not None and self.std is not None: h = h * self.std + self.mean if not self.dipole and self.atomref is not None: h = h + self.atomref(z) out = self.readout(h, batch, dim=0) if self.dipole: out = torch.norm(out, dim=-1, keepdim=True) if self.scale is not None: out = self.scale * out return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'hidden_channels={self.hidden_channels}, ' f'num_filters={self.num_filters}, ' f'num_interactions={self.num_interactions}, ' f'num_gaussians={self.num_gaussians}, ' f'cutoff={self.cutoff})') class RadiusInteractionGraph(torch.nn.Module): r"""Creates edges based on atom positions :obj:`pos` to all points within the cutoff distance. Args: cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance with the default interaction graph method. (default: :obj:`32`) """ def __init__(self, cutoff: float = 10.0, max_num_neighbors: int = 32): super().__init__() self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors def forward(self, pos: Tensor, batch: Tensor) -> Tuple[Tensor, Tensor]: r"""Forward pass. Args: pos (Tensor): Coordinates of each atom. batch (LongTensor, optional): Batch indices assigning each atom to a separate molecule. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) row, col = edge_index edge_weight = (pos[row] - pos[col]).norm(dim=-1) return edge_index, edge_weight class InteractionBlock(torch.nn.Module): def __init__(self, hidden_channels: int, num_gaussians: int, num_filters: int, cutoff: float): super().__init__() self.mlp = Sequential( Linear(num_gaussians, num_filters), ShiftedSoftplus(), Linear(num_filters, num_filters), ) self.conv = CFConv(hidden_channels, hidden_channels, num_filters, self.mlp, cutoff) self.act = ShiftedSoftplus() self.lin = Linear(hidden_channels, hidden_channels) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.mlp[0].weight) self.mlp[0].bias.data.fill_(0) torch.nn.init.xavier_uniform_(self.mlp[2].weight) self.mlp[2].bias.data.fill_(0) self.conv.reset_parameters() torch.nn.init.xavier_uniform_(self.lin.weight) self.lin.bias.data.fill_(0) def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: x = self.conv(x, edge_index, edge_weight, edge_attr) x = self.act(x) x = self.lin(x) return x class CFConv(MessagePassing): def __init__( self, in_channels: int, out_channels: int, num_filters: int, nn: Sequential, cutoff: float, ): super().__init__(aggr='add') self.lin1 = Linear(in_channels, num_filters, bias=False) self.lin2 = Linear(num_filters, out_channels) self.nn = nn self.cutoff = cutoff self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.lin1.weight) torch.nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) W = self.nn(edge_attr) * C.view(-1, 1) x = self.lin1(x) x = self.propagate(edge_index, x=x, W=W) x = self.lin2(x) return x def message(self, x_j: Tensor, W: Tensor) -> Tensor: return x_j * W class GaussianSmearing(torch.nn.Module): def __init__( self, start: float = 0.0, stop: float = 5.0, num_gaussians: int = 50, ): super().__init__() offset = torch.linspace(start, stop, num_gaussians) self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 self.register_buffer('offset', offset) def forward(self, dist: Tensor) -> Tensor: dist = dist.view(-1, 1) - self.offset.view(1, -1) return torch.exp(self.coeff * torch.pow(dist, 2)) class ShiftedSoftplus(torch.nn.Module): def __init__(self): super().__init__() self.shift = torch.log(torch.tensor(2.0)).item() def forward(self, x: Tensor) -> Tensor: return F.softplus(x) - self.shift ================================================ FILE: torch_geometric/nn/models/sgformer.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn.attention import SGFormerAttention from torch_geometric.nn.conv import GCNConv from torch_geometric.utils import to_dense_batch class GraphModule(torch.nn.Module): def __init__( self, in_channels, hidden_channels, num_layers=2, dropout=0.5, ): super().__init__() self.convs = torch.nn.ModuleList() self.fcs = torch.nn.ModuleList() self.fcs.append(torch.nn.Linear(in_channels, hidden_channels)) self.bns = torch.nn.ModuleList() self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) for _ in range(num_layers): self.convs.append(GCNConv(hidden_channels, hidden_channels)) self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) self.dropout = dropout self.activation = F.relu def reset_parameters(self): for conv in self.convs: conv.reset_parameters() for bn in self.bns: bn.reset_parameters() for fc in self.fcs: fc.reset_parameters() def forward(self, x, edge_index): x = self.fcs[0](x) x = self.bns[0](x) x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) last_x = x for i, conv in enumerate(self.convs): x = conv(x, edge_index) x = self.bns[i + 1](x) x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) x = x + last_x return x class SGModule(torch.nn.Module): def __init__( self, in_channels, hidden_channels, num_layers=2, num_heads=1, dropout=0.5, ): super().__init__() self.attns = torch.nn.ModuleList() self.fcs = torch.nn.ModuleList() self.fcs.append(torch.nn.Linear(in_channels, hidden_channels)) self.bns = torch.nn.ModuleList() self.bns.append(torch.nn.LayerNorm(hidden_channels)) for _ in range(num_layers): self.attns.append( SGFormerAttention(hidden_channels, num_heads, hidden_channels)) self.bns.append(torch.nn.LayerNorm(hidden_channels)) self.dropout = dropout self.activation = F.relu def reset_parameters(self): for attn in self.attns: attn.reset_parameters() for bn in self.bns: bn.reset_parameters() for fc in self.fcs: fc.reset_parameters() def forward(self, x: Tensor, batch: Tensor): # to dense batch expects sorted batch batch, indices = batch.sort(stable=True) rev_perm = torch.empty_like(indices) rev_perm[indices] = torch.arange(len(indices), device=indices.device) x = x[indices] x, mask = to_dense_batch(x, batch) layer_ = [] # input MLP layer x = self.fcs[0](x) x = self.bns[0](x) x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) # store as residual link layer_.append(x) for i, attn in enumerate(self.attns): x = attn(x, mask) x = (x + layer_[i]) / 2. x = self.bns[i + 1](x) x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) layer_.append(x) x_mask = x[mask] # reverse the sorting unsorted_x_mask = x_mask[rev_perm] return unsorted_x_mask class SGFormer(torch.nn.Module): r"""The sgformer module from the `"SGFormer: Simplifying and Empowering Transformers for Large-Graph Representations" `_ paper. Args: in_channels (int): Input channels. hidden_channels (int): Hidden channels. out_channels (int): Output channels. trans_num_layers (int): The number of layers for all-pair attention. (default: :obj:`2`) trans_num_heads (int): The number of heads for attention. (default: :obj:`1`) trans_dropout (float): Global dropout rate. (default: :obj:`0.5`) gnn_num_layers (int): The number of layers for GNN. (default: :obj:`3`) gnn_dropout (float): GNN dropout rate. (default: :obj:`0.5`) graph_weight (float): The weight balance global and gnn module. (default: :obj:`0.5`) aggregate (str): Aggregate type. (default: :obj:`add`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, trans_num_layers: int = 2, trans_num_heads: int = 1, trans_dropout: float = 0.5, gnn_num_layers: int = 3, gnn_dropout: float = 0.5, graph_weight: float = 0.5, aggregate: str = 'add', ): super().__init__() self.trans_conv = SGModule( in_channels, hidden_channels, trans_num_layers, trans_num_heads, trans_dropout, ) self.graph_conv = GraphModule( in_channels, hidden_channels, gnn_num_layers, gnn_dropout, ) self.graph_weight = graph_weight self.aggregate = aggregate if aggregate == 'add': self.fc = torch.nn.Linear(hidden_channels, out_channels) elif aggregate == 'cat': self.fc = torch.nn.Linear(2 * hidden_channels, out_channels) else: raise ValueError(f'Invalid aggregate type:{aggregate}') self.params1 = list(self.trans_conv.parameters()) self.params2 = list(self.graph_conv.parameters()) self.params2.extend(list(self.fc.parameters())) self.out_channels = out_channels def reset_parameters(self) -> None: self.trans_conv.reset_parameters() self.graph_conv.reset_parameters() self.fc.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, batch: Optional[Tensor], ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. """ x1 = self.trans_conv(x, batch) x2 = self.graph_conv(x, edge_index) if self.aggregate == 'add': x = self.graph_weight * x2 + (1 - self.graph_weight) * x1 else: x = torch.cat((x1, x2), dim=1) x = self.fc(x) return F.log_softmax(x, dim=-1) ================================================ FILE: torch_geometric/nn/models/signed_gcn.py ================================================ from typing import Optional, Tuple import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.nn import SignedConv from torch_geometric.utils import ( coalesce, negative_sampling, structured_negative_sampling, ) class SignedGCN(torch.nn.Module): r"""The signed graph convolutional network model from the `"Signed Graph Convolutional Network" `_ paper. Internally, this module uses the :class:`torch_geometric.nn.conv.SignedConv` operator. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of layers. lamb (float, optional): Balances the contributions of the overall objective. (default: :obj:`5`) bias (bool, optional): If set to :obj:`False`, all layers will not learn an additive bias. (default: :obj:`True`) """ def __init__( self, in_channels: int, hidden_channels: int, num_layers: int, lamb: float = 5, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.num_layers = num_layers self.lamb = lamb self.conv1 = SignedConv(in_channels, hidden_channels // 2, first_aggr=True) self.convs = torch.nn.ModuleList() for _ in range(num_layers - 1): self.convs.append( SignedConv(hidden_channels // 2, hidden_channels // 2, first_aggr=False)) self.lin = torch.nn.Linear(2 * hidden_channels, 3) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin.reset_parameters() def split_edges( self, edge_index: Tensor, test_ratio: float = 0.2, ) -> Tuple[Tensor, Tensor]: r"""Splits the edges :obj:`edge_index` into train and test edges. Args: edge_index (LongTensor): The edge indices. test_ratio (float, optional): The ratio of test edges. (default: :obj:`0.2`) """ mask = torch.ones(edge_index.size(1), dtype=torch.bool) mask[torch.randperm(mask.size(0))[:int(test_ratio * mask.size(0))]] = 0 train_edge_index = edge_index[:, mask] test_edge_index = edge_index[:, ~mask] return train_edge_index, test_edge_index def create_spectral_features( self, pos_edge_index: Tensor, neg_edge_index: Tensor, num_nodes: Optional[int] = None, ) -> Tensor: r"""Creates :obj:`in_channels` spectral node features based on positive and negative edges. Args: pos_edge_index (LongTensor): The positive edge indices. neg_edge_index (LongTensor): The negative edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`pos_edge_index` and :attr:`neg_edge_index`. (default: :obj:`None`) """ import scipy.sparse as sp from sklearn.decomposition import TruncatedSVD edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1) N = edge_index.max().item() + 1 if num_nodes is None else num_nodes edge_index = edge_index.to(torch.device('cpu')) pos_val = torch.full((pos_edge_index.size(1), ), 2, dtype=torch.float) neg_val = torch.full((neg_edge_index.size(1), ), 0, dtype=torch.float) val = torch.cat([pos_val, neg_val], dim=0) row, col = edge_index edge_index = torch.cat([edge_index, torch.stack([col, row])], dim=1) val = torch.cat([val, val], dim=0) edge_index, val = coalesce(edge_index, val, num_nodes=N) val = val - 1 # Borrowed from: # https://github.com/benedekrozemberczki/SGCN/blob/master/src/utils.py edge_index = edge_index.detach().numpy() val = val.detach().numpy() A = sp.coo_matrix((val, edge_index), shape=(N, N)) svd = TruncatedSVD(n_components=self.in_channels, n_iter=128) svd.fit(A) x = svd.components_.T return torch.from_numpy(x).to(torch.float).to(pos_edge_index.device) def forward( self, x: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor, ) -> Tensor: """Computes node embeddings :obj:`z` based on positive edges :obj:`pos_edge_index` and negative edges :obj:`neg_edge_index`. Args: x (torch.Tensor): The input node features. pos_edge_index (torch.Tensor): The positive edge indices. neg_edge_index (torch.Tensor): The negative edge indices. """ z = F.relu(self.conv1(x, pos_edge_index, neg_edge_index)) for conv in self.convs: z = F.relu(conv(z, pos_edge_index, neg_edge_index)) return z def discriminate(self, z: Tensor, edge_index: Tensor) -> Tensor: """Given node embeddings :obj:`z`, classifies the link relation between node pairs :obj:`edge_index` to be either positive, negative or non-existent. Args: z (torch.Tensor): The input node features. edge_index (torch.Tensor): The edge indices. """ value = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1) value = self.lin(value) return torch.log_softmax(value, dim=1) def nll_loss( self, z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor, ) -> Tensor: """Computes the discriminator loss based on node embeddings :obj:`z`, and positive edges :obj:`pos_edge_index` and negative nedges :obj:`neg_edge_index`. Args: z (torch.Tensor): The node embeddings. pos_edge_index (torch.Tensor): The positive edge indices. neg_edge_index (torch.Tensor): The negative edge indices. """ edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1) none_edge_index = negative_sampling(edge_index, z.size(0)) nll_loss = 0 nll_loss += F.nll_loss( self.discriminate(z, pos_edge_index), pos_edge_index.new_full((pos_edge_index.size(1), ), 0)) nll_loss += F.nll_loss( self.discriminate(z, neg_edge_index), neg_edge_index.new_full((neg_edge_index.size(1), ), 1)) nll_loss += F.nll_loss( self.discriminate(z, none_edge_index), none_edge_index.new_full((none_edge_index.size(1), ), 2)) return nll_loss / 3.0 def pos_embedding_loss( self, z: Tensor, pos_edge_index: Tensor, ) -> Tensor: """Computes the triplet loss between positive node pairs and sampled non-node pairs. Args: z (torch.Tensor): The node embeddings. pos_edge_index (torch.Tensor): The positive edge indices. """ i, j, k = structured_negative_sampling(pos_edge_index, z.size(0)) out = (z[i] - z[j]).pow(2).sum(dim=1) - (z[i] - z[k]).pow(2).sum(dim=1) return torch.clamp(out, min=0).mean() def neg_embedding_loss(self, z: Tensor, neg_edge_index: Tensor) -> Tensor: """Computes the triplet loss between negative node pairs and sampled non-node pairs. Args: z (torch.Tensor): The node embeddings. neg_edge_index (torch.Tensor): The negative edge indices. """ i, j, k = structured_negative_sampling(neg_edge_index, z.size(0)) out = (z[i] - z[k]).pow(2).sum(dim=1) - (z[i] - z[j]).pow(2).sum(dim=1) return torch.clamp(out, min=0).mean() def loss( self, z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor, ) -> Tensor: """Computes the overall objective. Args: z (torch.Tensor): The node embeddings. pos_edge_index (torch.Tensor): The positive edge indices. neg_edge_index (torch.Tensor): The negative edge indices. """ nll_loss = self.nll_loss(z, pos_edge_index, neg_edge_index) loss_1 = self.pos_embedding_loss(z, pos_edge_index) loss_2 = self.neg_embedding_loss(z, neg_edge_index) return nll_loss + self.lamb * (loss_1 + loss_2) def test( self, z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor, ) -> Tuple[float, float]: """Evaluates node embeddings :obj:`z` on positive and negative test edges by computing AUC and F1 scores. Args: z (torch.Tensor): The node embeddings. pos_edge_index (torch.Tensor): The positive edge indices. neg_edge_index (torch.Tensor): The negative edge indices. """ from sklearn.metrics import f1_score, roc_auc_score with torch.no_grad(): pos_p = self.discriminate(z, pos_edge_index)[:, :2].max(dim=1)[1] neg_p = self.discriminate(z, neg_edge_index)[:, :2].max(dim=1)[1] pred = (1 - torch.cat([pos_p, neg_p])).cpu() y = torch.cat( [pred.new_ones(pos_p.size(0)), pred.new_zeros(neg_p.size(0))]) pred, y = pred.numpy(), y.numpy() auc = roc_auc_score(y, pred) f1 = f1_score(y, pred, average='binary') if pred.sum() > 0 else 0 return auc, f1 def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.hidden_channels}, num_layers={self.num_layers})') ================================================ FILE: torch_geometric/nn/models/tgn.py ================================================ import copy from typing import Callable, Dict, Tuple import torch from torch import Tensor from torch.nn import GRUCell, Linear from torch_geometric.nn.inits import zeros from torch_geometric.utils import scatter from torch_geometric.utils._scatter import scatter_argmax TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]] class TGNMemory(torch.nn.Module): r"""The Temporal Graph Network (TGN) memory model from the `"Temporal Graph Networks for Deep Learning on Dynamic Graphs" `_ paper. .. note:: For an example of using TGN, see `examples/tgn.py `_. Args: num_nodes (int): The number of nodes to save memories for. raw_msg_dim (int): The raw message dimensionality. memory_dim (int): The hidden memory dimensionality. time_dim (int): The time encoding dimensionality. message_module (torch.nn.Module): The message function which combines source and destination node memory embeddings, the raw message and the time encoding. aggregator_module (torch.nn.Module): The message aggregator function which aggregates messages to the same destination into a single representation. """ def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable): super().__init__() self.num_nodes = num_nodes self.raw_msg_dim = raw_msg_dim self.memory_dim = memory_dim self.time_dim = time_dim self.msg_s_module = message_module self.msg_d_module = copy.deepcopy(message_module) self.aggr_module = aggregator_module self.time_enc = TimeEncoder(time_dim) self.gru = GRUCell(message_module.out_channels, memory_dim) self.register_buffer('memory', torch.empty(num_nodes, memory_dim)) last_update = torch.empty(self.num_nodes, dtype=torch.long) self.register_buffer('last_update', last_update) self.register_buffer('_assoc', torch.empty(num_nodes, dtype=torch.long)) self.msg_s_store = {} self.msg_d_store = {} self.reset_parameters() @property def device(self) -> torch.device: return self.time_enc.lin.weight.device def reset_parameters(self): r"""Resets all learnable parameters of the module.""" if hasattr(self.msg_s_module, 'reset_parameters'): self.msg_s_module.reset_parameters() if hasattr(self.msg_d_module, 'reset_parameters'): self.msg_d_module.reset_parameters() if hasattr(self.aggr_module, 'reset_parameters'): self.aggr_module.reset_parameters() self.time_enc.reset_parameters() self.gru.reset_parameters() self.reset_state() def reset_state(self): """Resets the memory to its initial state.""" zeros(self.memory) zeros(self.last_update) self._reset_message_store() def detach(self): """Detaches the memory from gradient computation.""" self.memory.detach_() def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]: """Returns, for all nodes :obj:`n_id`, their current memory and their last updated timestamp. """ if self.training: memory, last_update = self._get_updated_memory(n_id) else: memory, last_update = self.memory[n_id], self.last_update[n_id] return memory, last_update def update_state(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor): """Updates the memory with newly encountered interactions :obj:`(src, dst, t, raw_msg)`. """ n_id = torch.cat([src, dst]).unique() if self.training: self._update_memory(n_id) self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store) self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store) else: self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store) self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store) self._update_memory(n_id) def _reset_message_store(self): i = self.memory.new_empty((0, ), device=self.device, dtype=torch.long) msg = self.memory.new_empty((0, self.raw_msg_dim), device=self.device) # Message store format: (src, dst, t, msg) self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)} self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)} def _update_memory(self, n_id: Tensor): memory, last_update = self._get_updated_memory(n_id) self.memory[n_id] = memory self.last_update[n_id] = last_update def _get_updated_memory(self, n_id: Tensor) -> Tuple[Tensor, Tensor]: self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device) # Compute messages (src -> dst). msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store, self.msg_s_module) # Compute messages (dst -> src). msg_d, t_d, src_d, dst_d = self._compute_msg(n_id, self.msg_d_store, self.msg_d_module) # Aggregate messages. idx = torch.cat([src_s, src_d], dim=0) msg = torch.cat([msg_s, msg_d], dim=0) t = torch.cat([t_s, t_d], dim=0) aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0)) # Get local copy of updated memory. memory = self.gru(aggr, self.memory[n_id]) # Get local copy of updated `last_update`. dim_size = self.last_update.size(0) last_update = scatter(t, idx, 0, dim_size, reduce='max')[n_id] return memory, last_update def _update_msg_store(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor, msg_store: TGNMessageStoreType): n_id, perm = src.sort() n_id, count = n_id.unique_consecutive(return_counts=True) for i, idx in zip(n_id.tolist(), perm.split(count.tolist())): msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx]) def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType, msg_module: Callable): data = [msg_store[i] for i in n_id.tolist()] src, dst, t, raw_msg = list(zip(*data)) src = torch.cat(src, dim=0).to(self.device) dst = torch.cat(dst, dim=0).to(self.device) t = torch.cat(t, dim=0).to(self.device) # Filter out empty tensors to avoid `invalid configuration argument`. # TODO Investigate why this is needed. raw_msg = [m for i, m in enumerate(raw_msg) if m.numel() > 0 or i == 0] raw_msg = torch.cat(raw_msg, dim=0).to(self.device) t_rel = t - self.last_update[src] t_enc = self.time_enc(t_rel.to(raw_msg.dtype)) msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc) return msg, t, src, dst def train(self, mode: bool = True): """Sets the module in training mode.""" if self.training and not mode: # Flush message store to memory in case we just entered eval mode. self._update_memory( torch.arange(self.num_nodes, device=self.memory.device)) self._reset_message_store() super().train(mode) class IdentityMessage(torch.nn.Module): def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int): super().__init__() self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor, t_enc: Tensor): return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1) class LastAggregator(torch.nn.Module): def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int): argmax = scatter_argmax(t, index, dim=0, dim_size=dim_size) out = msg.new_zeros((dim_size, msg.size(-1))) mask = argmax < msg.size(0) # Filter items with at least one entry. out[mask] = msg[argmax[mask]] return out class MeanAggregator(torch.nn.Module): def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int): return scatter(msg, index, dim=0, dim_size=dim_size, reduce='mean') class TimeEncoder(torch.nn.Module): def __init__(self, out_channels: int): super().__init__() self.out_channels = out_channels self.lin = Linear(1, out_channels) def reset_parameters(self): self.lin.reset_parameters() def forward(self, t: Tensor) -> Tensor: return self.lin(t.view(-1, 1)).cos() class LastNeighborLoader: def __init__(self, num_nodes: int, size: int, device=None): self.size = size self.neighbors = torch.empty((num_nodes, size), dtype=torch.long, device=device) self.e_id = torch.empty((num_nodes, size), dtype=torch.long, device=device) self._assoc = torch.empty(num_nodes, dtype=torch.long, device=device) self.reset_state() def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]: neighbors = self.neighbors[n_id] nodes = n_id.view(-1, 1).repeat(1, self.size) e_id = self.e_id[n_id] # Filter invalid neighbors (identified by `e_id < 0`). mask = e_id >= 0 neighbors, nodes, e_id = neighbors[mask], nodes[mask], e_id[mask] # Relabel node indices. n_id = torch.cat([n_id, neighbors]).unique() self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device) neighbors, nodes = self._assoc[neighbors], self._assoc[nodes] return n_id, torch.stack([neighbors, nodes]), e_id def insert(self, src: Tensor, dst: Tensor): # Inserts newly encountered interactions into an ever-growing # (undirected) temporal graph. # Collect central nodes, their neighbors and the current event ids. neighbors = torch.cat([src, dst], dim=0) nodes = torch.cat([dst, src], dim=0) e_id = torch.arange(self.cur_e_id, self.cur_e_id + src.size(0), device=src.device).repeat(2) self.cur_e_id += src.numel() # Convert newly encountered interaction ids so that they point to # locations of a "dense" format of shape [num_nodes, size]. nodes, perm = nodes.sort() neighbors, e_id = neighbors[perm], e_id[perm] n_id = nodes.unique() self._assoc[n_id] = torch.arange(n_id.numel(), device=n_id.device) dense_id = torch.arange(nodes.size(0), device=nodes.device) % self.size dense_id += self._assoc[nodes].mul_(self.size) dense_e_id = e_id.new_full((n_id.numel() * self.size, ), -1) dense_e_id[dense_id] = e_id dense_e_id = dense_e_id.view(-1, self.size) dense_neighbors = e_id.new_empty(n_id.numel() * self.size) dense_neighbors[dense_id] = neighbors dense_neighbors = dense_neighbors.view(-1, self.size) # Collect new and old interactions... e_id = torch.cat([self.e_id[n_id, :self.size], dense_e_id], dim=-1) neighbors = torch.cat( [self.neighbors[n_id, :self.size], dense_neighbors], dim=-1) # And sort them based on `e_id`. e_id, perm = e_id.topk(self.size, dim=-1) self.e_id[n_id] = e_id self.neighbors[n_id] = torch.gather(neighbors, 1, perm) def reset_state(self): self.cur_e_id = 0 self.e_id.fill_(-1) ================================================ FILE: torch_geometric/nn/models/visnet.py ================================================ import math from typing import Optional, Tuple import torch from torch import Tensor from torch.autograd import grad from torch.nn import Embedding, LayerNorm, Linear, Parameter from torch_geometric.nn import MessagePassing, radius_graph from torch_geometric.utils import scatter class CosineCutoff(torch.nn.Module): r"""Applies a cosine cutoff to the input distances. .. math:: \text{cutoffs} = \begin{cases} 0.5 * (\cos(\frac{\text{distances} * \pi}{\text{cutoff}}) + 1.0), & \text{if } \text{distances} < \text{cutoff} \\ 0, & \text{otherwise} \end{cases} Args: cutoff (float): A scalar that determines the point at which the cutoff is applied. """ def __init__(self, cutoff: float) -> None: super().__init__() self.cutoff = cutoff def forward(self, distances: Tensor) -> Tensor: r"""Applies a cosine cutoff to the input distances. Args: distances (torch.Tensor): A tensor of distances. Returns: cutoffs (torch.Tensor): A tensor where the cosine function has been applied to the distances, but any values that exceed the cutoff are set to 0. """ cutoffs = 0.5 * ((distances * math.pi / self.cutoff).cos() + 1.0) cutoffs = cutoffs * (distances < self.cutoff).float() return cutoffs class ExpNormalSmearing(torch.nn.Module): r"""Applies exponential normal smearing to the input distances. .. math:: \text{smeared\_dist} = \text{CosineCutoff}(\text{dist}) * e^{-\beta * (e^{\alpha * (-\text{dist})} - \text{means})^2} Args: cutoff (float, optional): A scalar that determines the point at which the cutoff is applied. (default: :obj:`5.0`) num_rbf (int, optional): The number of radial basis functions. (default: :obj:`128`) trainable (bool, optional): If set to :obj:`False`, the means and betas of the RBFs will not be trained. (default: :obj:`True`) """ def __init__( self, cutoff: float = 5.0, num_rbf: int = 128, trainable: bool = True, ) -> None: super().__init__() self.cutoff = cutoff self.num_rbf = num_rbf self.trainable = trainable self.cutoff_fn = CosineCutoff(cutoff) self.alpha = 5.0 / cutoff means, betas = self._initial_params() if trainable: self.register_parameter('means', Parameter(means)) self.register_parameter('betas', Parameter(betas)) else: self.register_buffer('means', means) self.register_buffer('betas', betas) def _initial_params(self) -> Tuple[Tensor, Tensor]: r"""Initializes the means and betas for the radial basis functions.""" start_value = torch.exp(torch.tensor(-self.cutoff)) means = torch.linspace(start_value, 1, self.num_rbf) betas = torch.tensor([(2 / self.num_rbf * (1 - start_value))**-2] * self.num_rbf) return means, betas def reset_parameters(self): r"""Resets the means and betas to their initial values.""" means, betas = self._initial_params() self.means.data.copy_(means) self.betas.data.copy_(betas) def forward(self, dist: Tensor) -> Tensor: r"""Applies the exponential normal smearing to the input distance. Args: dist (torch.Tensor): A tensor of distances. """ dist = dist.unsqueeze(-1) smeared_dist = self.cutoff_fn(dist) * (-self.betas * ( (self.alpha * (-dist)).exp() - self.means)**2).exp() return smeared_dist class Sphere(torch.nn.Module): r"""Computes spherical harmonics of the input data. This module computes the spherical harmonics up to a given degree :obj:`lmax` for the input tensor of 3D vectors. The vectors are assumed to be given in Cartesian coordinates. See `here `_ for mathematical details. Args: lmax (int, optional): The maximum degree of the spherical harmonics. (default: :obj:`2`) """ def __init__(self, lmax: int = 2) -> None: super().__init__() self.lmax = lmax def forward(self, edge_vec: Tensor) -> Tensor: r"""Computes the spherical harmonics of the input tensor. Args: edge_vec (torch.Tensor): A tensor of 3D vectors. """ return self._spherical_harmonics( self.lmax, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2], ) @staticmethod def _spherical_harmonics( lmax: int, x: Tensor, y: Tensor, z: Tensor, ) -> Tensor: r"""Computes the spherical harmonics up to degree :obj:`lmax` of the input vectors. Args: lmax (int): The maximum degree of the spherical harmonics. x (torch.Tensor): The x coordinates of the vectors. y (torch.Tensor): The y coordinates of the vectors. z (torch.Tensor): The z coordinates of the vectors. """ sh_1_0, sh_1_1, sh_1_2 = x, y, z if lmax == 1: return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1) sh_2_0 = math.sqrt(3.0) * x * z sh_2_1 = math.sqrt(3.0) * x * y y2 = y.pow(2) x2z2 = x.pow(2) + z.pow(2) sh_2_2 = y2 - 0.5 * x2z2 sh_2_3 = math.sqrt(3.0) * y * z sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2)) if lmax == 2: return torch.stack([ sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, ], dim=-1) raise ValueError(f"'lmax' needs to be 1 or 2 (got {lmax})") class VecLayerNorm(torch.nn.Module): r"""Applies layer normalization to the input data. This module applies a custom layer normalization to a tensor of vectors. The normalization can either be :obj:`"max_min"` normalization, or no normalization. Args: hidden_channels (int): The number of hidden channels in the input. trainable (bool): If set to :obj:`True`, the normalization weights are trainable parameters. norm_type (str, optional): The type of normalization to apply, one of :obj:`"max_min"` or :obj:`None`. (default: :obj:`"max_min"`) """ def __init__( self, hidden_channels: int, trainable: bool, norm_type: Optional[str] = 'max_min', ) -> None: super().__init__() self.hidden_channels = hidden_channels self.norm_type = norm_type self.eps = 1e-12 weight = torch.ones(self.hidden_channels) if trainable: self.register_parameter('weight', Parameter(weight)) else: self.register_buffer('weight', weight) self.reset_parameters() def reset_parameters(self): r"""Resets the normalization weights to their initial values.""" torch.nn.init.ones_(self.weight) def max_min_norm(self, vec: Tensor) -> Tensor: r"""Applies max-min normalization to the input tensor. .. math:: \text{dist} = ||\text{vec}||_2 \text{direct} = \frac{\text{vec}}{\text{dist}} \text{max\_val} = \max(\text{dist}) \text{min\_val} = \min(\text{dist}) \text{delta} = \text{max\_val} - \text{min\_val} \text{dist} = \frac{\text{dist} - \text{min\_val}}{\text{delta}} \text{normed\_vec} = \max(0, \text{dist}) \cdot \text{direct} Args: vec (torch.Tensor): The input tensor. """ dist = torch.norm(vec, dim=1, keepdim=True) if (dist == 0).all(): return torch.zeros_like(vec) dist = dist.clamp(min=self.eps) direct = vec / dist max_val, _ = dist.max(dim=-1) min_val, _ = dist.min(dim=-1) delta = (max_val - min_val).view(-1) delta = torch.where(delta == 0, torch.ones_like(delta), delta) dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1) return dist.relu() * direct def forward(self, vec: Tensor) -> Tensor: r"""Applies the layer normalization to the input tensor. Args: vec (torch.Tensor): The input tensor. """ if vec.size(1) == 3: if self.norm_type == 'max_min': vec = self.max_min_norm(vec) return vec * self.weight.unsqueeze(0).unsqueeze(0) elif vec.size(1) == 8: vec1, vec2 = torch.split(vec, [3, 5], dim=1) if self.norm_type == 'max_min': vec1 = self.max_min_norm(vec1) vec2 = self.max_min_norm(vec2) vec = torch.cat([vec1, vec2], dim=1) return vec * self.weight.unsqueeze(0).unsqueeze(0) raise ValueError(f"'{self.__class__.__name__}' only support 3 or 8 " f"channels (got {vec.size(1)})") class Distance(torch.nn.Module): r"""Computes the pairwise distances between atoms in a molecule. This module computes the pairwise distances between atoms in a molecule, represented by their positions :obj:`pos`. The distances are computed only between points that are within a certain cutoff radius. Args: cutoff (float): The cutoff radius beyond which distances are not computed. max_num_neighbors (int, optional): The maximum number of neighbors considered for each point. (default: :obj:`32`) add_self_loops (bool, optional): If set to :obj:`False`, will not include self-loops. (default: :obj:`True`) """ def __init__( self, cutoff: float, max_num_neighbors: int = 32, add_self_loops: bool = True, ) -> None: super().__init__() self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.add_self_loops = add_self_loops def forward( self, pos: Tensor, batch: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Computes the pairwise distances between atoms in the molecule. Args: pos (torch.Tensor): The positions of the atoms in the molecule. batch (torch.Tensor): A batch vector, which assigns each node to a specific example. Returns: edge_index (torch.Tensor): The indices of the edges in the graph. edge_weight (torch.Tensor): The distances between connected nodes. edge_vec (torch.Tensor): The vector differences between connected nodes. """ edge_index = radius_graph( pos, r=self.cutoff, batch=batch, loop=self.add_self_loops, max_num_neighbors=self.max_num_neighbors, ) edge_vec = pos[edge_index[0]] - pos[edge_index[1]] if self.add_self_loops: mask = edge_index[0] != edge_index[1] edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device) edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) else: edge_weight = torch.norm(edge_vec, dim=-1) return edge_index, edge_weight, edge_vec class NeighborEmbedding(MessagePassing): r"""The :class:`NeighborEmbedding` module from the `"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing" `_ paper. Args: hidden_channels (int): The number of hidden channels in the node embeddings. num_rbf (int): The number of radial basis functions. cutoff (float): The cutoff distance. max_z (int, optional): The maximum atomic numbers. (default: :obj:`100`) """ def __init__( self, hidden_channels: int, num_rbf: int, cutoff: float, max_z: int = 100, ) -> None: super().__init__(aggr='add') self.embedding = Embedding(max_z, hidden_channels) self.distance_proj = Linear(num_rbf, hidden_channels) self.combine = Linear(hidden_channels * 2, hidden_channels) self.cutoff = CosineCutoff(cutoff) self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" self.embedding.reset_parameters() torch.nn.init.xavier_uniform_(self.distance_proj.weight) torch.nn.init.xavier_uniform_(self.combine.weight) self.distance_proj.bias.data.zero_() self.combine.bias.data.zero_() def forward( self, z: Tensor, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, ) -> Tensor: r"""Computes the neighborhood embedding of the nodes in the graph. Args: z (torch.Tensor): The atomic numbers. x (torch.Tensor): The node features. edge_index (torch.Tensor): The indices of the edges. edge_weight (torch.Tensor): The weights of the edges. edge_attr (torch.Tensor): The edge features. Returns: x_neighbors (torch.Tensor): The neighborhood embeddings of the nodes. """ mask = edge_index[0] != edge_index[1] if not mask.all(): edge_index = edge_index[:, mask] edge_weight = edge_weight[mask] edge_attr = edge_attr[mask] C = self.cutoff(edge_weight) W = self.distance_proj(edge_attr) * C.view(-1, 1) x_neighbors = self.embedding(z) x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W) x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) return x_neighbors def message(self, x_j: Tensor, W: Tensor) -> Tensor: return x_j * W class EdgeEmbedding(torch.nn.Module): r"""The :class:`EdgeEmbedding` module from the `"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing" `_ paper. Args: num_rbf (int): The number of radial basis functions. hidden_channels (int): The number of hidden channels in the node embeddings. """ def __init__(self, num_rbf: int, hidden_channels: int) -> None: super().__init__() self.edge_proj = Linear(num_rbf, hidden_channels) self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" torch.nn.init.xavier_uniform_(self.edge_proj.weight) self.edge_proj.bias.data.zero_() def forward( self, edge_index: Tensor, edge_attr: Tensor, x: Tensor, ) -> Tensor: r"""Computes the edge embeddings of the graph. Args: edge_index (torch.Tensor): The indices of the edges. edge_attr (torch.Tensor): The edge features. x (torch.Tensor): The node features. Returns: out_edge_attr (torch.Tensor): The edge embeddings. """ x_j = x[edge_index[0]] x_i = x[edge_index[1]] return (x_i + x_j) * self.edge_proj(edge_attr) class ViS_MP(MessagePassing): r"""The message passing module without vertex geometric features of the equivariant vector-scalar interactive graph neural network (ViSNet) from the `"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing" `_ paper. Args: num_heads (int): The number of attention heads. hidden_channels (int): The number of hidden channels in the node embeddings. cutoff (float): The cutoff distance. vecnorm_type (str, optional): The type of normalization to apply to the vectors. trainable_vecnorm (bool): Whether the normalization weights are trainable. last_layer (bool, optional): Whether this is the last layer in the model. (default: :obj:`False`) """ def __init__( self, num_heads: int, hidden_channels: int, cutoff: float, vecnorm_type: Optional[str], trainable_vecnorm: bool, last_layer: bool = False, ) -> None: super().__init__(aggr='add', node_dim=0) if hidden_channels % num_heads != 0: raise ValueError( f"The number of hidden channels (got {hidden_channels}) must " f"be evenly divisible by the number of attention heads " f"(got {num_heads})") self.num_heads = num_heads self.hidden_channels = hidden_channels self.head_dim = hidden_channels // num_heads self.last_layer = last_layer self.layernorm = LayerNorm(hidden_channels) self.vec_layernorm = VecLayerNorm( hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type, ) self.act = torch.nn.SiLU() self.attn_activation = torch.nn.SiLU() self.cutoff = CosineCutoff(cutoff) self.vec_proj = Linear(hidden_channels, hidden_channels * 3, False) self.q_proj = Linear(hidden_channels, hidden_channels) self.k_proj = Linear(hidden_channels, hidden_channels) self.v_proj = Linear(hidden_channels, hidden_channels) self.dk_proj = Linear(hidden_channels, hidden_channels) self.dv_proj = Linear(hidden_channels, hidden_channels) self.s_proj = Linear(hidden_channels, hidden_channels * 2) if not self.last_layer: self.f_proj = Linear(hidden_channels, hidden_channels) self.w_src_proj = Linear(hidden_channels, hidden_channels, False) self.w_trg_proj = Linear(hidden_channels, hidden_channels, False) self.o_proj = Linear(hidden_channels, hidden_channels * 3) self.reset_parameters() @staticmethod def vector_rejection(vec: Tensor, d_ij: Tensor) -> Tensor: r"""Computes the component of :obj:`vec` orthogonal to :obj:`d_ij`. Args: vec (torch.Tensor): The input vector. d_ij (torch.Tensor): The reference vector. """ vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True) return vec - vec_proj * d_ij.unsqueeze(2) def reset_parameters(self): r"""Resets the parameters of the module.""" self.layernorm.reset_parameters() self.vec_layernorm.reset_parameters() torch.nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.zero_() torch.nn.init.xavier_uniform_(self.k_proj.weight) self.k_proj.bias.data.zero_() torch.nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.zero_() torch.nn.init.xavier_uniform_(self.o_proj.weight) self.o_proj.bias.data.zero_() torch.nn.init.xavier_uniform_(self.s_proj.weight) self.s_proj.bias.data.zero_() if not self.last_layer: torch.nn.init.xavier_uniform_(self.f_proj.weight) self.f_proj.bias.data.zero_() torch.nn.init.xavier_uniform_(self.w_src_proj.weight) torch.nn.init.xavier_uniform_(self.w_trg_proj.weight) torch.nn.init.xavier_uniform_(self.vec_proj.weight) torch.nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.zero_() torch.nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.zero_() def forward( self, x: Tensor, vec: Tensor, edge_index: Tensor, r_ij: Tensor, f_ij: Tensor, d_ij: Tensor, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: r"""Computes the residual scalar and vector features of the nodes and scalar features of the edges. Args: x (torch.Tensor): The scalar features of the nodes. vec (torch.Tensor):The vector features of the nodes. edge_index (torch.Tensor): The indices of the edges. r_ij (torch.Tensor): The distances between connected nodes. f_ij (torch.Tensor): The scalar features of the edges. d_ij (torch.Tensor): The unit vectors of the edges Returns: dx (torch.Tensor): The residual scalar features of the nodes. dvec (torch.Tensor): The residual vector features of the nodes. df_ij (torch.Tensor, optional): The residual scalar features of the edges, or :obj:`None` if this is the last layer. """ x = self.layernorm(x) vec = self.vec_layernorm(vec) q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim) k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim) dk = self.act(self.dk_proj(f_ij)) dk = dk.reshape(-1, self.num_heads, self.head_dim) dv = self.act(self.dv_proj(f_ij)) dv = dv.reshape(-1, self.num_heads, self.head_dim) vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) vec_dot = (vec1 * vec2).sum(dim=1) x, vec_out = self.propagate(edge_index, q=q, k=k, v=v, dk=dk, dv=dv, vec=vec, r_ij=r_ij, d_ij=d_ij) o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) dx = vec_dot * o2 + o3 dvec = vec3 * o1.unsqueeze(1) + vec_out if not self.last_layer: df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, f_ij=f_ij) return dx, dvec, df_ij else: return dx, dvec, None def message(self, q_i: Tensor, k_j: Tensor, v_j: Tensor, vec_j: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, d_ij: Tensor) -> Tuple[Tensor, Tensor]: attn = (q_i * k_j * dk).sum(dim=-1) attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) v_j = v_j * dv v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels) s1, s2 = torch.split(self.act(self.s_proj(v_j)), self.hidden_channels, dim=1) vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2) return v_j, vec_j def edge_update(self, vec_i: Tensor, vec_j: Tensor, d_ij: Tensor, f_ij: Tensor) -> Tensor: w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) w_dot = (w1 * w2).sum(dim=1) df_ij = self.act(self.f_proj(f_ij)) * w_dot return df_ij def aggregate( self, features: Tuple[Tensor, Tensor], index: Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[Tensor, Tensor]: x, vec = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) return x, vec class ViS_MP_Vertex(ViS_MP): r"""The message passing module with vertex geometric features of the equivariant vector-scalar interactive graph neural network (ViSNet) from the `"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing" `_ paper. Args: num_heads (int): The number of attention heads. hidden_channels (int): The number of hidden channels in the node embeddings. cutoff (float): The cutoff distance. vecnorm_type (str, optional): The type of normalization to apply to the vectors. trainable_vecnorm (bool): Whether the normalization weights are trainable. last_layer (bool, optional): Whether this is the last layer in the model. (default: :obj:`False`) """ def __init__( self, num_heads: int, hidden_channels: int, cutoff: float, vecnorm_type: Optional[str], trainable_vecnorm: bool, last_layer: bool = False, ) -> None: super().__init__(num_heads, hidden_channels, cutoff, vecnorm_type, trainable_vecnorm, last_layer) if not self.last_layer: self.f_proj = Linear(hidden_channels, hidden_channels * 2) self.t_src_proj = Linear(hidden_channels, hidden_channels, False) self.t_trg_proj = Linear(hidden_channels, hidden_channels, False) self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" super().reset_parameters() if not self.last_layer: if hasattr(self, 't_src_proj'): torch.nn.init.xavier_uniform_(self.t_src_proj.weight) if hasattr(self, 't_trg_proj'): torch.nn.init.xavier_uniform_(self.t_trg_proj.weight) def edge_update(self, vec_i: Tensor, vec_j: Tensor, d_ij: Tensor, f_ij: Tensor) -> Tensor: w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) w_dot = (w1 * w2).sum(dim=1) t1 = self.vector_rejection(self.t_trg_proj(vec_i), d_ij) t2 = self.vector_rejection(self.t_src_proj(vec_i), -d_ij) t_dot = (t1 * t2).sum(dim=1) f1, f2 = torch.split(self.act(self.f_proj(f_ij)), self.hidden_channels, dim=-1) return f1 * w_dot + f2 * t_dot class ViSNetBlock(torch.nn.Module): r"""The representation module of the equivariant vector-scalar interactive graph neural network (ViSNet) from the `"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing" `_ paper. Args: lmax (int, optional): The maximum degree of the spherical harmonics. (default: :obj:`1`) vecnorm_type (str, optional): The type of normalization to apply to the vectors. (default: :obj:`None`) trainable_vecnorm (bool, optional): Whether the normalization weights are trainable. (default: :obj:`False`) num_heads (int, optional): The number of attention heads. (default: :obj:`8`) num_layers (int, optional): The number of layers in the network. (default: :obj:`6`) hidden_channels (int, optional): The number of hidden channels in the node embeddings. (default: :obj:`128`) num_rbf (int, optional): The number of radial basis functions. (default: :obj:`32`) trainable_rbf (bool, optional): Whether the radial basis function parameters are trainable. (default: :obj:`False`) max_z (int, optional): The maximum atomic numbers. (default: :obj:`100`) cutoff (float, optional): The cutoff distance. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors considered for each atom. (default: :obj:`32`) vertex (bool, optional): Whether to use vertex geometric features. (default: :obj:`False`) """ def __init__( self, lmax: int = 1, vecnorm_type: Optional[str] = None, trainable_vecnorm: bool = False, num_heads: int = 8, num_layers: int = 6, hidden_channels: int = 128, num_rbf: int = 32, trainable_rbf: bool = False, max_z: int = 100, cutoff: float = 5.0, max_num_neighbors: int = 32, vertex: bool = False, ) -> None: super().__init__() self.lmax = lmax self.vecnorm_type = vecnorm_type self.trainable_vecnorm = trainable_vecnorm self.num_heads = num_heads self.num_layers = num_layers self.hidden_channels = hidden_channels self.num_rbf = num_rbf self.trainable_rbf = trainable_rbf self.max_z = max_z self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.embedding = Embedding(max_z, hidden_channels) self.distance = Distance(cutoff, max_num_neighbors=max_num_neighbors) self.sphere = Sphere(lmax=lmax) self.distance_expansion = ExpNormalSmearing(cutoff, num_rbf, trainable_rbf) self.neighbor_embedding = NeighborEmbedding(hidden_channels, num_rbf, cutoff, max_z) self.edge_embedding = EdgeEmbedding(num_rbf, hidden_channels) self.vis_mp_layers = torch.nn.ModuleList() vis_mp_kwargs = dict( num_heads=num_heads, hidden_channels=hidden_channels, cutoff=cutoff, vecnorm_type=vecnorm_type, trainable_vecnorm=trainable_vecnorm, ) vis_mp_class = ViS_MP if not vertex else ViS_MP_Vertex for _ in range(num_layers - 1): layer = vis_mp_class(last_layer=False, **vis_mp_kwargs) self.vis_mp_layers.append(layer) self.vis_mp_layers.append( vis_mp_class(last_layer=True, **vis_mp_kwargs)) self.out_norm = LayerNorm(hidden_channels) self.vec_out_norm = VecLayerNorm( hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type, ) self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" self.embedding.reset_parameters() self.distance_expansion.reset_parameters() self.neighbor_embedding.reset_parameters() self.edge_embedding.reset_parameters() for layer in self.vis_mp_layers: layer.reset_parameters() self.out_norm.reset_parameters() self.vec_out_norm.reset_parameters() def forward( self, z: Tensor, pos: Tensor, batch: Tensor, ) -> Tuple[Tensor, Tensor]: r"""Computes the scalar and vector features of the nodes. Args: z (torch.Tensor): The atomic numbers. pos (torch.Tensor): The coordinates of the atoms. batch (torch.Tensor): A batch vector, which assigns each node to a specific example. Returns: x (torch.Tensor): The scalar features of the nodes. vec (torch.Tensor): The vector features of the nodes. """ x = self.embedding(z) edge_index, edge_weight, edge_vec = self.distance(pos, batch) edge_attr = self.distance_expansion(edge_weight) mask = edge_index[0] != edge_index[1] edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1) edge_vec = self.sphere(edge_vec) x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) vec = torch.zeros(x.size(0), ((self.lmax + 1)**2) - 1, x.size(1), dtype=x.dtype, device=x.device) edge_attr = self.edge_embedding(edge_index, edge_attr, x) for attn in self.vis_mp_layers[:-1]: dx, dvec, dedge_attr = attn(x, vec, edge_index, edge_weight, edge_attr, edge_vec) x = x + dx vec = vec + dvec edge_attr = edge_attr + dedge_attr dx, dvec, _ = self.vis_mp_layers[-1](x, vec, edge_index, edge_weight, edge_attr, edge_vec) x = x + dx vec = vec + dvec x = self.out_norm(x) vec = self.vec_out_norm(vec) return x, vec class GatedEquivariantBlock(torch.nn.Module): r"""Applies a gated equivariant operation to scalar features and vector features from the `"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing" `_ paper. Args: hidden_channels (int): The number of hidden channels in the node embeddings. out_channels (int): The number of output channels. intermediate_channels (int, optional): The number of channels in the intermediate layer, or :obj:`None` to use the same number as :obj:`hidden_channels`. (default: :obj:`None`) scalar_activation (bool, optional): Whether to apply a scalar activation function to the output node features. (default: obj:`False`) """ def __init__( self, hidden_channels: int, out_channels: int, intermediate_channels: Optional[int] = None, scalar_activation: bool = False, ) -> None: super().__init__() self.out_channels = out_channels if intermediate_channels is None: intermediate_channels = hidden_channels self.vec1_proj = Linear(hidden_channels, hidden_channels, bias=False) self.vec2_proj = Linear(hidden_channels, out_channels, bias=False) self.update_net = torch.nn.Sequential( Linear(hidden_channels * 2, intermediate_channels), torch.nn.SiLU(), Linear(intermediate_channels, out_channels * 2), ) self.act = torch.nn.SiLU() if scalar_activation else None self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" torch.nn.init.xavier_uniform_(self.vec1_proj.weight) torch.nn.init.xavier_uniform_(self.vec2_proj.weight) torch.nn.init.xavier_uniform_(self.update_net[0].weight) self.update_net[0].bias.data.zero_() torch.nn.init.xavier_uniform_(self.update_net[2].weight) self.update_net[2].bias.data.zero_() def forward(self, x: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]: r"""Applies a gated equivariant operation to node features and vector features. Args: x (torch.Tensor): The scalar features of the nodes. v (torch.Tensor): The vector features of the nodes. """ vec1 = torch.norm(self.vec1_proj(v), dim=-2) vec2 = self.vec2_proj(v) x = torch.cat([x, vec1], dim=-1) x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) v = v.unsqueeze(1) * vec2 if self.act is not None: x = self.act(x) return x, v class EquivariantScalar(torch.nn.Module): r"""Computes final scalar outputs based on node features and vector features. Args: hidden_channels (int): The number of hidden channels in the node embeddings. """ def __init__(self, hidden_channels: int) -> None: super().__init__() self.output_network = torch.nn.ModuleList([ GatedEquivariantBlock( hidden_channels, hidden_channels // 2, scalar_activation=True, ), GatedEquivariantBlock( hidden_channels // 2, 1, scalar_activation=False, ), ]) self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" for layer in self.output_network: layer.reset_parameters() def pre_reduce(self, x: Tensor, v: Tensor) -> Tensor: r"""Computes the final scalar outputs. Args: x (torch.Tensor): The scalar features of the nodes. v (torch.Tensor): The vector features of the nodes. Returns: out (torch.Tensor): The final scalar outputs of the nodes. """ for layer in self.output_network: x, v = layer(x, v) return x + v.sum() * 0 class Atomref(torch.nn.Module): r"""Adds atom reference values to atomic energies. Args: atomref (torch.Tensor, optional): A tensor of atom reference values, or :obj:`None` if not provided. (default: :obj:`None`) max_z (int, optional): The maximum atomic numbers. (default: :obj:`100`) """ def __init__( self, atomref: Optional[Tensor] = None, max_z: int = 100, ) -> None: super().__init__() if atomref is None: atomref = torch.zeros(max_z, 1) else: atomref = torch.as_tensor(atomref) if atomref.ndim == 1: atomref = atomref.view(-1, 1) self.register_buffer('initial_atomref', atomref) self.atomref = Embedding(len(atomref), 1) self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" self.atomref.weight.data.copy_(self.initial_atomref) def forward(self, x: Tensor, z: Tensor) -> Tensor: r"""Adds atom reference values to atomic energies. Args: x (torch.Tensor): The atomic energies. z (torch.Tensor): The atomic numbers. """ return x + self.atomref(z) class ViSNet(torch.nn.Module): r"""A :pytorch:`PyTorch` module that implements the equivariant vector-scalar interactive graph neural network (ViSNet) from the `"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing" `_ paper. Args: lmax (int, optional): The maximum degree of the spherical harmonics. (default: :obj:`1`) vecnorm_type (str, optional): The type of normalization to apply to the vectors. (default: :obj:`None`) trainable_vecnorm (bool, optional): Whether the normalization weights are trainable. (default: :obj:`False`) num_heads (int, optional): The number of attention heads. (default: :obj:`8`) num_layers (int, optional): The number of layers in the network. (default: :obj:`6`) hidden_channels (int, optional): The number of hidden channels in the node embeddings. (default: :obj:`128`) num_rbf (int, optional): The number of radial basis functions. (default: :obj:`32`) trainable_rbf (bool, optional): Whether the radial basis function parameters are trainable. (default: :obj:`False`) max_z (int, optional): The maximum atomic numbers. (default: :obj:`100`) cutoff (float, optional): The cutoff distance. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors considered for each atom. (default: :obj:`32`) vertex (bool, optional): Whether to use vertex geometric features. (default: :obj:`False`) atomref (torch.Tensor, optional): A tensor of atom reference values, or :obj:`None` if not provided. (default: :obj:`None`) reduce_op (str, optional): The type of reduction operation to apply (:obj:`"sum"`, :obj:`"mean"`). (default: :obj:`"sum"`) mean (float, optional): The mean of the output distribution. (default: :obj:`0.0`) std (float, optional): The standard deviation of the output distribution. (default: :obj:`1.0`) derivative (bool, optional): Whether to compute the derivative of the output with respect to the positions. (default: :obj:`False`) """ def __init__( self, lmax: int = 1, vecnorm_type: Optional[str] = None, trainable_vecnorm: bool = False, num_heads: int = 8, num_layers: int = 6, hidden_channels: int = 128, num_rbf: int = 32, trainable_rbf: bool = False, max_z: int = 100, cutoff: float = 5.0, max_num_neighbors: int = 32, vertex: bool = False, atomref: Optional[Tensor] = None, reduce_op: str = "sum", mean: float = 0.0, std: float = 1.0, derivative: bool = False, ) -> None: super().__init__() self.representation_model = ViSNetBlock( lmax=lmax, vecnorm_type=vecnorm_type, trainable_vecnorm=trainable_vecnorm, num_heads=num_heads, num_layers=num_layers, hidden_channels=hidden_channels, num_rbf=num_rbf, trainable_rbf=trainable_rbf, max_z=max_z, cutoff=cutoff, max_num_neighbors=max_num_neighbors, vertex=vertex, ) self.output_model = EquivariantScalar(hidden_channels=hidden_channels) self.prior_model = Atomref(atomref=atomref, max_z=max_z) self.reduce_op = reduce_op self.derivative = derivative self.register_buffer('mean', torch.tensor(mean)) self.register_buffer('std', torch.tensor(std)) self.reset_parameters() def reset_parameters(self): r"""Resets the parameters of the module.""" self.representation_model.reset_parameters() self.output_model.reset_parameters() if self.prior_model is not None: self.prior_model.reset_parameters() def forward( self, z: Tensor, pos: Tensor, batch: Tensor, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Computes the energies or properties (forces) for a batch of molecules. Args: z (torch.Tensor): The atomic numbers. pos (torch.Tensor): The coordinates of the atoms. batch (torch.Tensor): A batch vector, which assigns each node to a specific example. Returns: y (torch.Tensor): The energies or properties for each molecule. dy (torch.Tensor, optional): The negative derivative of energies. """ if self.derivative: pos.requires_grad_(True) x, v = self.representation_model(z, pos, batch) x = self.output_model.pre_reduce(x, v) x = x * self.std if self.prior_model is not None: x = self.prior_model(x, z) y = scatter(x, batch, dim=0, reduce=self.reduce_op) y = y + self.mean if self.derivative: grad_outputs = [torch.ones_like(y)] dy = grad( [y], [pos], grad_outputs=grad_outputs, create_graph=True, retain_graph=True, )[0] if dy is None: raise RuntimeError( "Autograd returned None for the force prediction.") return y, -dy return y, None ================================================ FILE: torch_geometric/nn/module_dict.py ================================================ from typing import Final, Iterable, Mapping, Optional, Tuple, Union import torch from torch.nn import Module Key = Union[str, Tuple[str, ...]] # `torch.nn.ModuleDict` doesn't allow `.` to be used in key names. # This `ModuleDict` will support it by converting the `.` to `#` in the # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ModuleDict(torch.nn.ModuleDict): CLASS_ATTRS: Final[Tuple[str, ...]] = tuple(dir(torch.nn.ModuleDict)) def __init__( self, modules: Optional[Mapping[Union[str, Tuple[str, ...]], Module]] = None, ): if modules is not None: # Replace the keys in modules: modules = { self.to_internal_key(key): module for key, module in modules.items() } super().__init__(modules) @classmethod def to_internal_key(cls, key: Key) -> str: if isinstance(key, tuple): # ModuleDict can't handle tuples as keys assert len(key) > 1 key = f"<{'___'.join(key)}>" assert isinstance(key, str) # ModuleDict cannot handle keys that exists as class attributes: if key in cls.CLASS_ATTRS: key = f'<{key}>' # ModuleDict cannot handle dots in keys: return key.replace('.', '#') @classmethod def to_external_key(cls, key: str) -> Key: key = key.replace('#', '.') if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS: key = key[1:-1] if key[0] == '<' and key[-1] == '>' and '___' in key: key = tuple(key[1:-1].split('___')) return key def __getitem__(self, key: Key) -> Module: return super().__getitem__(self.to_internal_key(key)) def __setitem__(self, key: Key, module: Module): return super().__setitem__(self.to_internal_key(key), module) def __delitem__(self, key: Key): return super().__delitem__(self.to_internal_key(key)) def __contains__(self, key: Key) -> bool: return super().__contains__(self.to_internal_key(key)) def keys(self) -> Iterable[Key]: return [self.to_external_key(key) for key in super().keys()] def items(self) -> Iterable[Tuple[Key, Module]]: return [(self.to_external_key(k), v) for k, v in super().items()] ================================================ FILE: torch_geometric/nn/norm/__init__.py ================================================ r"""Normalization package.""" from .batch_norm import BatchNorm, HeteroBatchNorm from .instance_norm import InstanceNorm from .layer_norm import LayerNorm, HeteroLayerNorm from .graph_norm import GraphNorm from .graph_size_norm import GraphSizeNorm from .pair_norm import PairNorm from .mean_subtraction_norm import MeanSubtractionNorm from .msg_norm import MessageNorm from .diff_group_norm import DiffGroupNorm __all__ = [ 'BatchNorm', 'HeteroBatchNorm', 'InstanceNorm', 'LayerNorm', 'HeteroLayerNorm', 'GraphNorm', 'GraphSizeNorm', 'PairNorm', 'MeanSubtractionNorm', 'MessageNorm', 'DiffGroupNorm', ] classes = __all__ ================================================ FILE: torch_geometric/nn/norm/batch_norm.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.aggr.fused import FusedAggregation class BatchNorm(torch.nn.Module): r"""Applies batch normalization over a batch of features as described in the `"Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) allow_single_element (bool, optional): If set to :obj:`True`, batches with only a single element will work as during in evaluation. That is the running mean and variance will be used. Requires :obj:`track_running_stats=True`. (default: :obj:`False`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__( self, in_channels: int, eps: float = 1e-5, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True, allow_single_element: bool = False, device: Optional[torch.device] = None, ): super().__init__() if allow_single_element and not track_running_stats: raise ValueError("'allow_single_element' requires " "'track_running_stats' to be set to `True`") self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine, track_running_stats, device=device) self.in_channels = in_channels self.allow_single_element = allow_single_element def reset_running_stats(self): r"""Resets all running statistics of the module.""" self.module.reset_running_stats() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.module.reset_parameters() def forward(self, x: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. """ if self.allow_single_element and x.size(0) <= 1: return torch.nn.functional.batch_norm( x, self.module.running_mean, self.module.running_var, self.module.weight, self.module.bias, False, # bn_training 0.0, # momentum self.module.eps, ) return self.module(x) def __repr__(self): return f'{self.__class__.__name__}({self.module.extra_repr()})' class HeteroBatchNorm(torch.nn.Module): r"""Applies batch normalization over a batch of heterogeneous features as described in the `"Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" `_ paper. Compared to :class:`BatchNorm`, :class:`HeteroBatchNorm` applies normalization individually for each node or edge type. Args: in_channels (int): Size of each input sample. num_types (int): The number of types. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__( self, in_channels: int, num_types: int, eps: float = 1e-5, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True, device: Optional[torch.device] = None, ): super().__init__() self.in_channels = in_channels self.num_types = num_types self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter( torch.empty(num_types, in_channels, device=device)) self.bias = Parameter( torch.empty(num_types, in_channels, device=device)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer( 'running_mean', torch.empty(num_types, in_channels, device=device)) self.register_buffer( 'running_var', torch.empty(num_types, in_channels, device=device)) self.register_buffer('num_batches_tracked', torch.tensor(0)) else: self.register_buffer('running_mean', None) self.register_buffer('running_var', None) self.register_buffer('num_batches_tracked', None) self.mean_var = FusedAggregation(['mean', 'var']) self.reset_parameters() def reset_running_stats(self): r"""Resets all running statistics of the module.""" if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.reset_running_stats() if self.affine: torch.nn.init.ones_(self.weight) torch.nn.init.zeros_(self.bias) def forward(self, x: Tensor, type_vec: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input features. type_vec (torch.Tensor): A vector that maps each entry to a type. """ if not self.training and self.track_running_stats: mean, var = self.running_mean, self.running_var else: with torch.no_grad(): mean, var = self.mean_var(x, type_vec, dim_size=self.num_types) if self.training and self.track_running_stats: if self.momentum is None: self.num_batches_tracked.add_(1) exp_avg_factor = 1.0 / float(self.num_batches_tracked) else: exp_avg_factor = self.momentum with torch.no_grad(): # Update running mean and variance: type_index = torch.unique(type_vec) self.running_mean[type_index] = ( (1.0 - exp_avg_factor) * self.running_mean[type_index] + exp_avg_factor * mean[type_index]) self.running_var[type_index] = ( (1.0 - exp_avg_factor) * self.running_var[type_index] + exp_avg_factor * var[type_index]) out = (x - mean[type_vec]) / var.clamp(self.eps).sqrt()[type_vec] if self.affine: out = out * self.weight[type_vec] + self.bias[type_vec] return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'num_types={self.num_types})') ================================================ FILE: torch_geometric/nn/norm/diff_group_norm.py ================================================ from typing import Optional import torch from torch import Tensor from torch.nn import BatchNorm1d, Linear class DiffGroupNorm(torch.nn.Module): r"""The differentiable group normalization layer from the `"Towards Deeper Graph Neural Networks with Differentiable Group Normalization" `_ paper, which normalizes node features group-wise via a learnable soft cluster assignment. .. math:: \mathbf{S} = \text{softmax} (\mathbf{X} \mathbf{W}) where :math:`\mathbf{W} \in \mathbb{R}^{F \times G}` denotes a trainable weight matrix mapping each node into one of :math:`G` clusters. Normalization is then performed group-wise via: .. math:: \mathbf{X}^{\prime} = \mathbf{X} + \lambda \sum_{i = 1}^G \text{BatchNorm}(\mathbf{S}[:, i] \odot \mathbf{X}) Args: in_channels (int): Size of each input sample :math:`F`. groups (int): The number of groups :math:`G`. lamda (float, optional): The balancing factor :math:`\lambda` between input embeddings and normalized embeddings. (default: :obj:`0.01`) eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__( self, in_channels: int, groups: int, lamda: float = 0.01, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, device: Optional[torch.device] = None, ): super().__init__() self.in_channels = in_channels self.groups = groups self.lamda = lamda self.lin = Linear(in_channels, groups, bias=False, device=device) self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine, track_running_stats, device=device) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() self.norm.reset_parameters() def forward(self, x: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. """ F, G = self.in_channels, self.groups s = self.lin(x).softmax(dim=-1) # [N, G] out = s.unsqueeze(-1) * x.unsqueeze(-2) # [N, G, F] out = self.norm(out.view(-1, G * F)).view(-1, G, F).sum(-2) # [N, F] return x + self.lamda * out @staticmethod def group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-5) -> float: r"""Measures the ratio of inter-group distance over intra-group distance. .. math:: R_{\text{Group}} = \frac{\frac{1}{(C-1)^2} \sum_{i!=j} \frac{1}{|\mathbf{X}_i||\mathbf{X}_j|} \sum_{\mathbf{x}_{iv} \in \mathbf{X}_i } \sum_{\mathbf{x}_{jv^{\prime}} \in \mathbf{X}_j} {\| \mathbf{x}_{iv} - \mathbf{x}_{jv^{\prime}} \|}_2 }{ \frac{1}{C} \sum_{i} \frac{1}{{|\mathbf{X}_i|}^2} \sum_{\mathbf{x}_{iv}, \mathbf{x}_{iv^{\prime}} \in \mathbf{X}_i } {\| \mathbf{x}_{iv} - \mathbf{x}_{iv^{\prime}} \|}_2 } where :math:`\mathbf{X}_i` denotes the set of all nodes that belong to class :math:`i`, and :math:`C` denotes the total number of classes in :obj:`y`. """ num_classes = int(y.max()) + 1 numerator = 0. for i in range(num_classes): mask = y == i dist = torch.cdist(x[mask].unsqueeze(0), x[~mask].unsqueeze(0)) numerator += (1 / dist.numel()) * float(dist.sum()) numerator *= 1 / (num_classes - 1)**2 denominator = 0. for i in range(num_classes): mask = y == i dist = torch.cdist(x[mask].unsqueeze(0), x[mask].unsqueeze(0)) denominator += (1 / dist.numel()) * float(dist.sum()) denominator *= 1 / num_classes return numerator / (denominator + eps) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'groups={self.groups})') ================================================ FILE: torch_geometric/nn/norm/graph_norm.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.nn.inits import ones, zeros from torch_geometric.typing import OptTensor from torch_geometric.utils import scatter class GraphNorm(torch.nn.Module): r"""Applies graph normalization over individual graphs as described in the `"GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]} {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + \epsilon}} \odot \gamma + \beta where :math:`\alpha` denotes parameters that learn how much information to keep in the mean. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__(self, in_channels: int, eps: float = 1e-5, device: Optional[torch.device] = None): super().__init__() self.in_channels = in_channels self.eps = eps self.weight = torch.nn.Parameter( torch.empty(in_channels, device=device)) self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device)) self.mean_scale = torch.nn.Parameter( torch.empty(in_channels, device=device)) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" ones(self.weight) zeros(self.bias) ones(self.mean_scale) def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) batch_size = 1 if batch_size is None: batch_size = int(batch.max()) + 1 mean = scatter(x, batch, 0, batch_size, reduce='mean') out = x - mean.index_select(0, batch) * self.mean_scale var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean') std = (var + self.eps).sqrt().index_select(0, batch) return self.weight * out / std + self.bias def __repr__(self): return f'{self.__class__.__name__}({self.in_channels})' ================================================ FILE: torch_geometric/nn/norm/graph_size_norm.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.typing import OptTensor from torch_geometric.utils import degree class GraphSizeNorm(torch.nn.Module): r"""Applies Graph Size Normalization over each individual graph in a batch of node features as described in the `"Benchmarking Graph Neural Networks" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x}_i}{\sqrt{|\mathcal{V}|}} """ def __init__(self): super().__init__() def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) batch_size = 1 inv_sqrt_deg = degree(batch, batch_size, dtype=x.dtype).pow(-0.5) return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/nn/norm/instance_norm.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn.modules.instancenorm import _InstanceNorm from torch_geometric.typing import OptTensor from torch_geometric.utils import degree, scatter class InstanceNorm(_InstanceNorm): r"""Applies instance normalization over each individual example in a batch of node features as described in the `"Instance Normalization: The Missing Ingredient for Fast Stylization" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`False`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default: :obj:`False`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__( self, in_channels: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = False, track_running_stats: bool = False, device: Optional[torch.device] = None, ): super().__init__(in_channels, eps, momentum, affine, track_running_stats, device=device) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" super().reset_parameters() def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: out = F.instance_norm( x.t().unsqueeze(0), self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, self.momentum, self.eps) return out.squeeze(0).t() if batch_size is None: batch_size = int(batch.max()) + 1 mean = var = unbiased_var = x # Dummies. if self.training or not self.track_running_stats: norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.view(-1, 1) unbiased_norm = (norm - 1).clamp_(min=1) mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='sum') / norm x = x - mean.index_select(0, batch) var = scatter(x * x, batch, dim=0, dim_size=batch_size, reduce='sum') unbiased_var = var / unbiased_norm var = var / norm momentum = self.momentum if self.running_mean is not None: self.running_mean = ( 1 - momentum) * self.running_mean + momentum * mean.mean(0) if self.running_var is not None: self.running_var = ( 1 - momentum ) * self.running_var + momentum * unbiased_var.mean(0) else: if self.running_mean is not None: mean = self.running_mean.view(1, -1).expand(batch_size, -1) if self.running_var is not None: var = self.running_var.view(1, -1).expand(batch_size, -1) x = x - mean.index_select(0, batch) out = x / (var + self.eps).sqrt().index_select(0, batch) if self.weight is not None and self.bias is not None: out = out * self.weight.view(1, -1) + self.bias.view(1, -1) return out def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num_features})' ================================================ FILE: torch_geometric/nn/norm/layer_norm.py ================================================ from typing import List, Optional, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.inits import ones, zeros from torch_geometric.typing import OptTensor from torch_geometric.utils import degree, scatter class LayerNorm(torch.nn.Module): r"""Applies layer normalization over each individual example in a batch of features as described in the `"Layer Normalization" `_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) mode (str, optional): The normalization mode to use for layer normalization (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"` is used, each graph will be considered as an element to be normalized. If `"node"` is used, each node will be considered as an element to be normalized. (default: :obj:`"graph"`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__( self, in_channels: int, eps: float = 1e-5, affine: bool = True, mode: str = 'graph', device: Optional[torch.device] = None, ): super().__init__() self.in_channels = in_channels self.eps = eps self.affine = affine self.mode = mode if affine: self.weight = Parameter(torch.empty(in_channels, device=device)) self.bias = Parameter(torch.empty(in_channels, device=device)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" ones(self.weight) zeros(self.bias) def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if self.mode == 'graph': if batch is None: x = x - x.mean() out = x / (x.std(unbiased=False) + self.eps) else: if batch_size is None: batch_size = int(batch.max()) + 1 norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.mul_(x.size(-1)).view(-1, 1) mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='sum').sum(dim=-1, keepdim=True) / norm x = x - mean.index_select(0, batch) var = scatter(x * x, batch, dim=0, dim_size=batch_size, reduce='sum').sum(dim=-1, keepdim=True) var = var / norm out = x / (var + self.eps).sqrt().index_select(0, batch) if self.weight is not None and self.bias is not None: out = out * self.weight + self.bias return out if self.mode == 'node': return F.layer_norm(x, (self.in_channels, ), self.weight, self.bias, self.eps) raise ValueError(f"Unknownn normalization mode: {self.mode}") def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'affine={self.affine}, mode={self.mode})') class HeteroLayerNorm(torch.nn.Module): r"""Applies layer normalization over each individual example in a batch of heterogeneous features as described in the `"Layer Normalization" `_ paper. Compared to :class:`LayerNorm`, :class:`HeteroLayerNorm` applies normalization individually for each node or edge type. Args: in_channels (int): Size of each input sample. num_types (int): The number of types. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) mode (str, optional): The normalization mode to use for layer normalization (:obj:`"node"`). If `"node"` is used, each node will be considered as an element to be normalized. (default: :obj:`"node"`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__( self, in_channels: int, num_types: int, eps: float = 1e-5, affine: bool = True, mode: str = 'node', device: Optional[torch.device] = None, ): super().__init__() assert mode == 'node' self.in_channels = in_channels self.num_types = num_types self.eps = eps self.affine = affine if affine: self.weight = Parameter( torch.empty(num_types, in_channels, device=device)) self.bias = Parameter( torch.empty(num_types, in_channels, device=device)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" if self.affine: torch.nn.init.ones_(self.weight) torch.nn.init.zeros_(self.bias) def forward( self, x: Tensor, type_vec: OptTensor = None, type_ptr: Optional[Union[Tensor, List[int]]] = None, ) -> Tensor: r"""Forward pass. .. note:: Either :obj:`type_vec` or :obj:`type_ptr` needs to be specified. In general, relying on :obj:`type_ptr` is more efficient in case the input tensor is sorted by types. Args: x (torch.Tensor): The input features. type_vec (torch.Tensor, optional): A vector that maps each entry to a type. (default: :obj:`None`) type_ptr (torch.Tensor or List[int]): A vector denoting the boundaries of types. (default: :obj:`None`) """ if type_vec is None and type_ptr is None: raise ValueError("Either 'type_vec' or 'type_ptr' must be given") out = F.layer_norm(x, (self.in_channels, ), None, None, self.eps) if self.affine: # TODO Revisit this logic completely as it performs worse than just # operating on a dictionary of tensors # (especially the `type_vec` code path) if type_ptr is not None: h = torch.empty_like(out) for i, (s, e) in enumerate(zip(type_ptr[:-1], type_ptr[1:])): h[s:e] = out[s:e] * self.weight[i] + self.bias[i] out = h else: out = out * self.weight[type_vec] + self.bias[type_vec] return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'num_types={self.num_types})') ================================================ FILE: torch_geometric/nn/norm/mean_subtraction_norm.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.utils import scatter class MeanSubtractionNorm(torch.nn.Module): r"""Applies layer normalization by subtracting the mean from the inputs as described in the `"Revisiting 'Over-smoothing' in Deep GCNs" `_ paper. .. math:: \mathbf{x}_i = \mathbf{x}_i - \frac{1}{|\mathcal{V}|} \sum_{j \in \mathcal{V}} \mathbf{x}_j """ def forward(self, x: Tensor, batch: Optional[Tensor] = None, dim_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) dim_size (int, optional): The number of examples :math:`B` in case :obj:`batch` is given. (default: :obj:`None`) """ if batch is None: return x - x.mean(dim=0, keepdim=True) mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean') return x - mean[batch] def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/nn/norm/msg_norm.py ================================================ from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter class MessageNorm(torch.nn.Module): r"""Applies message normalization over the aggregated messages as described in the `"DeeperGCNs: All You Need to Train Deeper GCNs" `_ paper. .. math:: \mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_{i} + s \cdot {\| \mathbf{x}_i \|}_2 \cdot \frac{\mathbf{m}_{i}}{{\|\mathbf{m}_i\|}_2} \right) Args: learn_scale (bool, optional): If set to :obj:`True`, will learn the scaling factor :math:`s` of message normalization. (default: :obj:`False`) device (torch.device, optional): The device to use for the module. (default: :obj:`None`) """ def __init__(self, learn_scale: bool = False, device: Optional[torch.device] = None): super().__init__() self.scale = Parameter(torch.empty(1, device=device), requires_grad=learn_scale) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.scale.data.fill_(1.0) def forward(self, x: Tensor, msg: Tensor, p: float = 2.0) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. msg (torch.Tensor): The message tensor :math:`\mathbf{M}`. p (float, optional): The norm :math:`p` to use for normalization. (default: :obj:`2.0`) """ msg = F.normalize(msg, p=p, dim=-1) x_norm = x.norm(p=p, dim=-1, keepdim=True) return msg * x_norm * self.scale def __repr__(self) -> str: return (f'{self.__class__.__name__}' f'(learn_scale={self.scale.requires_grad})') ================================================ FILE: torch_geometric/nn/norm/pair_norm.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.typing import OptTensor from torch_geometric.utils import scatter class PairNorm(torch.nn.Module): r"""Applies pair normalization over node features as described in the `"PairNorm: Tackling Oversmoothing in GNNs" `_ paper. .. math:: \mathbf{x}_i^c &= \mathbf{x}_i - \frac{1}{n} \sum_{i=1}^n \mathbf{x}_i \\ \mathbf{x}_i^{\prime} &= s \cdot \frac{\mathbf{x}_i^c}{\sqrt{\frac{1}{n} \sum_{i=1}^n {\| \mathbf{x}_i^c \|}^2_2}} Args: scale (float, optional): Scaling factor :math:`s` of normalization. (default, :obj:`1.`) scale_individually (bool, optional): If set to :obj:`True`, will compute the scaling step as :math:`\mathbf{x}^{\prime}_i = s \cdot \frac{\mathbf{x}_i^c}{{\| \mathbf{x}_i^c \|}_2}`. (default: :obj:`False`) eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) """ def __init__(self, scale: float = 1., scale_individually: bool = False, eps: float = 1e-5): super().__init__() self.scale = scale self.scale_individually = scale_individually self.eps = eps def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ scale = self.scale if batch is None: x = x - x.mean(dim=0, keepdim=True) if not self.scale_individually: return scale * x / (self.eps + x.pow(2).sum(-1).mean()).sqrt() else: return scale * x / (self.eps + x.norm(2, -1, keepdim=True)) else: mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='mean') x = x - mean.index_select(0, batch) if not self.scale_individually: return scale * x / torch.sqrt(self.eps + scatter( x.pow(2).sum(-1, keepdim=True), batch, dim=0, dim_size=batch_size, reduce='mean').index_select(0, batch)) else: return scale * x / (self.eps + x.norm(2, -1, keepdim=True)) def __repr__(self): return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/nn/parameter_dict.py ================================================ from typing import Final, Iterable, Mapping, Optional, Tuple, Union import torch from torch.nn import Parameter Key = Union[str, Tuple[str, ...]] # `torch.nn.ParameterDict` doesn't allow `.` to be used in key names. # This `ParameterDict` will support it by converting the `.` to `#` in the # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ParameterDict(torch.nn.ParameterDict): CLASS_ATTRS: Final[Tuple[str, ...]] = set(dir(torch.nn.ParameterDict)) def __init__( self, parameters: Optional[Mapping[Key, Parameter]] = None, ): # Replace the keys in modules. if parameters: parameters = { self.to_internal_key(key): parameter for key, parameter in parameters.items() } super().__init__(parameters) @classmethod def to_internal_key(cls, key: Key) -> str: if isinstance(key, tuple): # ParameterDict can't handle tuples as keys assert len(key) > 1 key = f"<{'___'.join(key)}>" assert isinstance(key, str) # ParameterDict cannot handle keys that exists as class attributes: if key in cls.CLASS_ATTRS: key = f'<{key}>' # ParameterDict cannot handle dots in keys: return key.replace('.', '#') @classmethod def to_external_key(cls, key: str) -> Key: key = key.replace('#', '.') if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS: key = key[1:-1] if key[0] == '<' and key[-1] == '>' and '___' in key: key = tuple(key[1:-1].split('___')) return key def __getitem__(self, key: Key) -> Parameter: return super().__getitem__(self.to_internal_key(key)) def __setitem__(self, key: Key, parameter: Parameter): return super().__setitem__(self.to_internal_key(key), parameter) def __delitem__(self, key: Key): return super().__delitem__(self.to_internal_key(key)) def __contains__(self, key: Key) -> bool: return super().__contains__(self.to_internal_key(key)) def keys(self) -> Iterable[Key]: return [self.to_external_key(key) for key in super().keys()] def items(self) -> Iterable[Tuple[Key, Parameter]]: return [(self.to_external_key(k), v) for k, v in super().items()] ================================================ FILE: torch_geometric/nn/pool/__init__.py ================================================ r"""Pooling package.""" import warnings from typing import Optional from torch import Tensor import torch_geometric.typing from torch_geometric.typing import OptTensor, torch_cluster from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x from .glob import global_add_pool, global_max_pool, global_mean_pool from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex, ApproxMIPSKNNIndex) from .graclus import graclus from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x from .topk_pool import TopKPooling from .sag_pool import SAGPooling from .edge_pool import EdgePooling from .cluster_pool import ClusterPooling from .asap import ASAPooling from .pan_pool import PANPooling from .mem_pool import MemPooling from .voxel_grid import voxel_grid from .approx_knn import approx_knn, approx_knn_graph def fps( x: Tensor, batch: OptTensor = None, ratio: float = 0.5, random_start: bool = True, batch_size: Optional[int] = None, ) -> Tensor: r"""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" `_ paper, which iteratively samples the most distant point with regard to the rest points. .. code-block:: python import torch from torch_geometric.nn import fps x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) batch = torch.tensor([0, 0, 0, 0]) index = fps(x, batch, ratio=0.5) Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. batch (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) ratio (float, optional): Sampling ratio. (default: :obj:`0.5`) random_start (bool, optional): If set to :obj:`False`, use the first node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE: return torch_cluster.fps(x, batch, ratio, random_start) return torch_cluster.fps(x, batch, ratio, random_start, batch_size) def knn( x: Tensor, y: Tensor, k: int, batch_x: OptTensor = None, batch_y: OptTensor = None, cosine: bool = False, num_workers: int = 1, batch_size: Optional[int] = None, ) -> Tensor: r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in :obj:`x`. .. code-block:: python import torch from torch_geometric.nn import knn x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) batch_x = torch.tensor([0, 0, 0, 0]) y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]]) batch_y = torch.tensor([0, 0]) assign_index = knn(x, y, 2, batch_x, batch_y) Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. y (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`. k (int): The number of neighbors. batch_x (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) batch_y (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node to a specific example. (default: :obj:`None`) cosine (bool, optional): If :obj:`True`, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: :obj:`False`) num_workers (int, optional): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE: return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine, num_workers) return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine, num_workers, batch_size) def knn_graph( x: Tensor, k: int, batch: OptTensor = None, loop: bool = False, flow: str = 'source_to_target', cosine: bool = False, num_workers: int = 1, batch_size: Optional[int] = None, ) -> Tensor: r"""Computes graph edges to the nearest :obj:`k` points. .. code-block:: python import torch from torch_geometric.nn import knn_graph x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) batch = torch.tensor([0, 0, 0, 0]) edge_index = knn_graph(x, k=2, batch=batch, loop=False) Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. k (int): The number of neighbors. batch (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) loop (bool, optional): If :obj:`True`, the graph will contain self-loops. (default: :obj:`False`) flow (str, optional): The flow direction when using in combination with message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) cosine (bool, optional): If :obj:`True`, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: :obj:`False`) num_workers (int, optional): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ if batch is not None and x.device != batch.device: warnings.warn( "Input tensor 'x' and 'batch' are on different devices " "in 'knn_graph'. Performing blocking device transfer", stacklevel=2) batch = batch.to(x.device) if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE: return torch_cluster.knn_graph(x, k, batch, loop, flow, cosine, num_workers) return torch_cluster.knn_graph(x, k, batch, loop, flow, cosine, num_workers, batch_size) def radius( x: Tensor, y: Tensor, r: float, batch_x: OptTensor = None, batch_y: OptTensor = None, max_num_neighbors: int = 32, num_workers: int = 1, batch_size: Optional[int] = None, ) -> Tensor: r"""Finds for each element in :obj:`y` all points in :obj:`x` within distance :obj:`r`. .. code-block:: python import torch from torch_geometric.nn import radius x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) batch_x = torch.tensor([0, 0, 0, 0]) y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]]) batch_y = torch.tensor([0, 0]) assign_index = radius(x, y, 1.5, batch_x, batch_y) Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. y (torch.Tensor): Node feature matrix :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`. r (float): The radius. batch_x (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) batch_y (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node to a specific example. (default: :obj:`None`) max_num_neighbors (int, optional): The maximum number of neighbors to return for each element in :obj:`y`. (default: :obj:`32`) num_workers (int, optional): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`torch.Tensor` .. warning:: The CPU implementation of :meth:`radius` with :obj:`max_num_neighbors` is biased towards certain quadrants. Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving inputs to GPU before proceeding. """ if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE: return torch_cluster.radius(x, y, r, batch_x, batch_y, max_num_neighbors, num_workers) return torch_cluster.radius(x, y, r, batch_x, batch_y, max_num_neighbors, num_workers, batch_size) def radius_graph( x: Tensor, r: float, batch: OptTensor = None, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', num_workers: int = 1, batch_size: Optional[int] = None, ) -> Tensor: r"""Computes graph edges to all points within a given distance. .. code-block:: python import torch from torch_geometric.nn import radius_graph x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) batch = torch.tensor([0, 0, 0, 0]) edge_index = radius_graph(x, r=1.5, batch=batch, loop=False) Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. r (float): The radius. batch (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) loop (bool, optional): If :obj:`True`, the graph will contain self-loops. (default: :obj:`False`) max_num_neighbors (int, optional): The maximum number of neighbors to return for each element in :obj:`y`. (default: :obj:`32`) flow (str, optional): The flow direction when using in combination with message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) num_workers (int, optional): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`torch.Tensor` .. warning:: The CPU implementation of :meth:`radius_graph` with :obj:`max_num_neighbors` is biased towards certain quadrants. Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving inputs to GPU before proceeding. """ if batch is not None and x.device != batch.device: warnings.warn( "Input tensor 'x' and 'batch' are on different devices " "in 'radius_graph'. Performing blocking device transfer", stacklevel=2) batch = batch.to(x.device) if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE: return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors, flow, num_workers) return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors, flow, num_workers, batch_size) def nearest( x: Tensor, y: Tensor, batch_x: OptTensor = None, batch_y: OptTensor = None, ) -> Tensor: r"""Finds for each element in :obj:`y` the :obj:`k` nearest point in :obj:`x`. .. code-block:: python import torch from torch_geometric.nn import nearest x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) batch_x = torch.tensor([0, 0, 0, 0]) y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]]) batch_y = torch.tensor([0, 0]) cluster = nearest(x, y, batch_x, batch_y) Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. y (torch.Tensor): Node feature matrix :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`. batch_x (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) batch_y (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node to a specific example. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ return torch_cluster.nearest(x, y, batch_x, batch_y) __all__ = [ 'global_add_pool', 'global_mean_pool', 'global_max_pool', 'KNNIndex', 'L2KNNIndex', 'MIPSKNNIndex', 'ApproxL2KNNIndex', 'ApproxMIPSKNNIndex', 'TopKPooling', 'SAGPooling', 'EdgePooling', 'ClusterPooling', 'ASAPooling', 'PANPooling', 'MemPooling', 'max_pool', 'avg_pool', 'max_pool_x', 'max_pool_neighbor_x', 'avg_pool_x', 'avg_pool_neighbor_x', 'graclus', 'voxel_grid', 'fps', 'knn', 'knn_graph', 'approx_knn', 'approx_knn_graph', 'radius', 'radius_graph', 'nearest', ] classes = __all__ ================================================ FILE: torch_geometric/nn/pool/approx_knn.py ================================================ import torch from torch import Tensor def approx_knn( x: Tensor, y: Tensor, k: int, batch_x: Tensor = None, batch_y: Tensor = None, ) -> Tensor: # pragma: no cover r"""Finds for each element in :obj:`y` the :obj:`k` approximated nearest points in :obj:`x`. .. note:: Approximated :math:`k`-nearest neighbor search is performed via the `pynndescent `_ library. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. y (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`. k (int): The number of neighbors. batch_x (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) batch_y (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node to a specific example. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ from pynndescent import NNDescent if batch_x is None: batch_x = x.new_zeros(x.size(0), dtype=torch.long) if batch_y is None: batch_y = y.new_zeros(y.size(0), dtype=torch.long) x = x.view(-1, 1) if x.dim() == 1 else x y = y.view(-1, 1) if y.dim() == 1 else y assert x.dim() == 2 and batch_x.dim() == 1 assert y.dim() == 2 and batch_y.dim() == 1 assert x.size(1) == y.size(1) assert x.size(0) == batch_x.size(0) assert y.size(0) == batch_y.size(0) min_xy = min(x.min(), y.min()) x, y = x - min_xy, y - min_xy max_xy = max(x.max(), y.max()) x, y, = x / max_xy, y / max_xy # Concat batch/features to ensure no cross-links between examples exist: x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1) y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1) index = NNDescent(x.detach().cpu().numpy()) col, dist = index.query(y.detach().cpu().numpy(), k=k) dist = torch.from_numpy(dist).view(-1).to(x.device, x.dtype) col = torch.from_numpy(col).view(-1).to(x.device, torch.long) row = torch.arange(y.size(0), device=x.device, dtype=torch.long) row = row.repeat_interleave(k) mask = ~torch.isinf(dist) row, col = row[mask], col[mask] return torch.stack([row, col], dim=0) def approx_knn_graph( x: Tensor, k: int, batch: Tensor = None, loop: bool = False, flow: str = 'source_to_target', ) -> Tensor: # pragma: no cover r"""Computes graph edges to the nearest approximated :obj:`k` points. .. note:: Approximated :math:`k`-nearest neighbor search is performed via the `pynndescent `_ library. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. k (int): The number of neighbors. batch (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) loop (bool, optional): If :obj:`True`, the graph will contain self-loops. (default: :obj:`False`) flow (str, optional): The flow direction when using in combination with message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) :rtype: :class:`torch.Tensor` """ assert flow in ['source_to_target', 'target_to_source'] row, col = approx_knn(x, x, k if loop else k + 1, batch, batch) row, col = (col, row) if flow == 'source_to_target' else (row, col) if not loop: mask = row != col row, col = row[mask], col[mask] return torch.stack([row, col], dim=0) ================================================ FILE: torch_geometric/nn/pool/asap.py ================================================ from typing import Callable, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor from torch.nn import Linear from torch_geometric.nn import LEConv from torch_geometric.nn.pool.select import SelectTopK from torch_geometric.utils import ( add_remaining_self_loops, remove_self_loops, scatter, softmax, to_edge_index, to_torch_coo_tensor, to_torch_csr_tensor, ) class ASAPooling(torch.nn.Module): r"""The Adaptive Structure Aware Pooling operator from the `"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations" `_ paper. Args: in_channels (int): Size of each input sample. ratio (float or int): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`, or the value of :math:`k` itself, depending on whether the type of :obj:`ratio` is :obj:`float` or :obj:`int`. (default: :obj:`0.5`) GNN (torch.nn.Module, optional): A graph neural network layer for using intra-cluster properties. Especially helpful for graphs with higher degree of neighborhood (one of :class:`torch_geometric.nn.conv.GraphConv`, :class:`torch_geometric.nn.conv.GCNConv` or any GNN which supports the :obj:`edge_weight` parameter). (default: :obj:`None`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) add_self_loops (bool, optional): If set to :obj:`True`, will add self loops to the new graph connectivity. (default: :obj:`False`) **kwargs (optional): Additional parameters for initializing the graph neural network layer. """ def __init__(self, in_channels: int, ratio: Union[float, int] = 0.5, GNN: Optional[Callable] = None, dropout: float = 0.0, negative_slope: float = 0.2, add_self_loops: bool = False, **kwargs): super().__init__() self.in_channels = in_channels self.ratio = ratio self.negative_slope = negative_slope self.dropout = dropout self.GNN = GNN self.add_self_loops = add_self_loops self.lin = Linear(in_channels, in_channels) self.att = Linear(2 * in_channels, 1) self.gnn_score = LEConv(self.in_channels, 1) if self.GNN is not None: self.gnn_intra_cluster = GNN(self.in_channels, self.in_channels, **kwargs) else: self.gnn_intra_cluster = None self.select = SelectTopK(1, ratio) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() self.att.reset_parameters() self.gnn_score.reset_parameters() if self.gnn_intra_cluster is not None: self.gnn_intra_cluster.reset_parameters() self.select.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature matrix. edge_index (torch.Tensor): The edge indices. edge_weight (torch.Tensor, optional): The edge weights. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) Return types: * **x** (*torch.Tensor*): The pooled node embeddings. * **edge_index** (*torch.Tensor*): The coarsened edge indices. * **edge_weight** (*torch.Tensor, optional*): The coarsened edge weights. * **batch** (*torch.Tensor*): The coarsened batch vector. * **index** (*torch.Tensor*): The top-:math:`k` node indices of nodes which are kept after pooling. """ N = x.size(0) edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value=1., num_nodes=N) if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x x_pool = x if self.gnn_intra_cluster is not None: x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[0]] x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max') x_q = self.lin(x_q)[edge_index[1]] score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[1], num_nodes=N) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout, training=self.training) v_j = x[edge_index[0]] * score.view(-1, 1) x = scatter(v_j, edge_index[1], dim=0, reduce='sum') # Cluster selection. fitness = self.gnn_score(x, edge_index).sigmoid().view(-1) perm = self.select(fitness, batch).node_index x = x[perm] * fitness[perm].view(-1, 1) batch = batch[perm] # Graph coarsening. A = to_torch_csr_tensor(edge_index, edge_weight, size=(N, N)) S = to_torch_coo_tensor(edge_index, score, size=(N, N)) S = S.index_select(1, perm).to_sparse_csr() A = S.t().to_sparse_csr() @ (A @ S) if edge_weight is None: edge_index, _ = to_edge_index(A) else: edge_index, edge_weight = to_edge_index(A) if self.add_self_loops: edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, num_nodes=A.size(0)) else: edge_index, edge_weight = remove_self_loops( edge_index, edge_weight) return x, edge_index, edge_weight, batch, perm def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'ratio={self.ratio})') ================================================ FILE: torch_geometric/nn/pool/avg_pool.py ================================================ from typing import Callable, Optional, Tuple from torch import Tensor from torch_geometric.data import Batch, Data from torch_geometric.nn.pool.consecutive import consecutive_cluster from torch_geometric.nn.pool.pool import pool_batch, pool_edge, pool_pos from torch_geometric.utils import add_self_loops, scatter def _avg_pool_x( cluster: Tensor, x: Tensor, size: Optional[int] = None, ) -> Tensor: return scatter(x, cluster, dim=0, dim_size=size, reduce='mean') def avg_pool_x( cluster: Tensor, x: Tensor, batch: Tensor, batch_size: Optional[int] = None, size: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Average pools node features according to the clustering defined in :attr:`cluster`. See :meth:`torch_geometric.nn.pool.max_pool_x` for more details. Args: cluster (torch.Tensor): The cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. x (Tensor): The node feature matrix. batch (torch.Tensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) size (int, optional): The maximum number of clusters in a single example. (default: :obj:`None`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) if :attr:`size` is :obj:`None`, else :class:`torch.Tensor` """ if size is not None: if batch_size is None: batch_size = int(batch.max().item()) + 1 return _avg_pool_x(cluster, x, batch_size * size), None cluster, perm = consecutive_cluster(cluster) x = _avg_pool_x(cluster, x) batch = pool_batch(perm, batch) return x, batch def avg_pool( cluster: Tensor, data: Data, transform: Optional[Callable] = None, ) -> Data: r"""Pools and coarsens a graph given by the :class:`torch_geometric.data.Data` object according to the clustering defined in :attr:`cluster`. Final node features are defined by the *average* features of all nodes within the same cluster. See :meth:`torch_geometric.nn.pool.max_pool` for more details. Args: cluster (torch.Tensor): The cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. data (Data): Graph data object. transform (callable, optional): A function/transform that takes in the coarsened and pooled :obj:`torch_geometric.data.Data` object and returns a transformed version. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Data` """ cluster, perm = consecutive_cluster(cluster) x = None if data.x is None else _avg_pool_x(cluster, data.x) index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) batch = None if data.batch is None else pool_batch(perm, data.batch) pos = None if data.pos is None else pool_pos(cluster, data.pos) data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) if transform is not None: data = transform(data) return data def avg_pool_neighbor_x( data: Data, flow: Optional[str] = 'source_to_target', ) -> Data: r"""Average pools neighboring node features, where each feature in :obj:`data.x` is replaced by the average feature values from the central node and its neighbors. """ x, edge_index = data.x, data.edge_index edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes) row, col = edge_index row, col = (row, col) if flow == 'source_to_target' else (col, row) data.x = scatter(x[row], col, dim=0, dim_size=data.num_nodes, reduce='mean') return data ================================================ FILE: torch_geometric/nn/pool/cluster_pool.py ================================================ from typing import NamedTuple, Optional, Tuple import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.utils import ( dense_to_sparse, one_hot, to_dense_adj, to_scipy_sparse_matrix, ) class UnpoolInfo(NamedTuple): edge_index: Tensor cluster: Tensor batch: Tensor class ClusterPooling(torch.nn.Module): r"""The cluster pooling operator from the `"Edge-Based Graph Component Pooling" `_ paper. :class:`ClusterPooling` computes a score for each edge. Based on the selected edges, graph clusters are calculated and compressed to one node using the injective :obj:`"sum"` aggregation function. Edges are remapped based on the nodes created by each cluster and the original edges. Args: in_channels (int): Size of each input sample. edge_score_method (str, optional): The function to apply to compute the edge score from raw edge scores (:obj:`"tanh"`, :obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`) dropout (float, optional): The probability with which to drop edge scores during training. (default: :obj:`0.0`) threshold (float, optional): The threshold of edge scores. If set to :obj:`None`, will be automatically inferred depending on :obj:`edge_score_method`. (default: :obj:`None`) """ def __init__( self, in_channels: int, edge_score_method: str = 'tanh', dropout: float = 0.0, threshold: Optional[float] = None, ): super().__init__() assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax'] if threshold is None: threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0 self.in_channels = in_channels self.edge_score_method = edge_score_method self.dropout = dropout self.threshold = threshold self.lin = torch.nn.Linear(2 * in_channels, 1) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: r"""Forward pass. Args: x (torch.Tensor): The node features. edge_index (torch.Tensor): The edge indices. batch (torch.Tensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Return types: * **x** *(torch.Tensor)* - The pooled node features. * **edge_index** *(torch.Tensor)* - The coarsened edge indices. * **batch** *(torch.Tensor)* - The coarsened batch vector. * **unpool_info** *(UnpoolInfo)* - Information that can be consumed for unpooling. """ mask = edge_index[0] != edge_index[1] edge_index = edge_index[:, mask] edge_attr = torch.cat( [x[edge_index[0]], x[edge_index[1]]], dim=-1, ) edge_score = self.lin(edge_attr).view(-1) edge_score = F.dropout(edge_score, p=self.dropout, training=self.training) if self.edge_score_method == 'tanh': edge_score = edge_score.tanh() elif self.edge_score_method == 'sigmoid': edge_score = edge_score.sigmoid() else: assert self.edge_score_method == 'log_softmax' edge_score = F.log_softmax(edge_score, dim=0) return self._merge_edges(x, edge_index, batch, edge_score) def _merge_edges( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_score: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: from scipy.sparse.csgraph import connected_components edge_contract = edge_index[:, edge_score > self.threshold] adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0)) _, cluster_np = connected_components(adj, directed=True, connection="weak") cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device) C = one_hot(cluster) A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0) S = to_dense_adj(edge_index, edge_attr=edge_score, max_num_nodes=x.size(0)).squeeze(0) A_contract = to_dense_adj(edge_contract, max_num_nodes=x.size(0)).squeeze(0) nodes_single = ((A_contract.sum(dim=-1) + A_contract.sum(dim=-2)) == 0).nonzero() S[nodes_single, nodes_single] = 1.0 x_out = (S @ C).t() @ x edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0)) batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch) unpool_info = UnpoolInfo(edge_index, cluster, batch) return x_out, edge_index_out, batch_out, unpool_info def __repr__(self) -> str: return f'{self.__class__.__name__}({self.in_channels})' ================================================ FILE: torch_geometric/nn/pool/connect/__init__.py ================================================ r"""Graph connection package. This package provides classes for determining coarsened graph connections in graph pooling scenarios. """ from .base import Connect, ConnectOutput from .filter_edges import FilterEdges __all__ = [ 'Connect', 'ConnectOutput', 'FilterEdges', ] ================================================ FILE: torch_geometric/nn/pool/connect/base.py ================================================ from dataclasses import dataclass from typing import Optional import torch from torch import Tensor from torch_geometric.nn.pool.select import SelectOutput @dataclass(init=False) class ConnectOutput: r"""The output of the :class:`Connect` method, which holds the coarsened graph structure, and optional pooled edge features and batch vectors. Args: edge_index (torch.Tensor): The edge indices of the cooarsened graph. edge_attr (torch.Tensor, optional): The pooled edge features of the coarsened graph. (default: :obj:`None`) batch (torch.Tensor, optional): The pooled batch vector of the coarsened graph. (default: :obj:`None`) """ edge_index: Tensor edge_attr: Optional[Tensor] = None batch: Optional[Tensor] = None def __init__( self, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, ): if edge_index.dim() != 2: raise ValueError(f"Expected 'edge_index' to be two-dimensional " f"(got {edge_index.dim()} dimensions)") if edge_index.size(0) != 2: raise ValueError(f"Expected 'edge_index' to have size '2' in the " f"first dimension (got '{edge_index.size(0)}')") if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): raise ValueError(f"Expected 'edge_index' and 'edge_attr' to " f"hold the same number of edges (got " f"{edge_index.size(1)} and {edge_attr.size(0)} " f"edges)") self.edge_index = edge_index self.edge_attr = edge_attr self.batch = batch ConnectOutput = torch.jit.script(ConnectOutput) class Connect(torch.nn.Module): r"""An abstract base class for implementing custom edge connection operators as described in the `"Understanding Pooling in Graph Neural Networks" `_ paper. Specifically, :class:`Connect` determines for each pair of supernodes the presence or abscene of an edge based on the existing edges between the nodes in the two supernodes. The operator also computes pooled edge features and batch vectors (if present). """ def reset_parameters(self): r"""Resets all learnable parameters of the module.""" def forward( self, select_output: SelectOutput, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, ) -> ConnectOutput: r"""Forward pass. Args: select_output (SelectOutput): The output of :class:`Select`. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific graph. (default: :obj:`None`) """ raise NotImplementedError @staticmethod def get_pooled_batch( select_output: SelectOutput, batch: Optional[Tensor], ) -> Optional[Tensor]: r"""Returns the batch vector of the coarsened graph. Args: select_output (SelectOutput): The output of :class:`Select`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) """ if batch is None: return batch out = torch.arange(select_output.num_clusters, device=batch.device) return out.scatter_(0, select_output.cluster_index, batch[select_output.node_index]) def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/nn/pool/connect/filter_edges.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.nn.pool.connect import Connect, ConnectOutput from torch_geometric.nn.pool.select import SelectOutput from torch_geometric.utils.num_nodes import maybe_num_nodes def filter_adj( edge_index: Tensor, edge_attr: Optional[Tensor], node_index: Tensor, cluster_index: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: num_nodes = maybe_num_nodes(edge_index, num_nodes) if cluster_index is None: cluster_index = torch.arange(node_index.size(0), device=node_index.device) mask = node_index.new_full((num_nodes, ), -1) mask[node_index] = cluster_index row, col = edge_index[0], edge_index[1] row, col = mask[row], mask[col] mask = (row >= 0) & (col >= 0) row, col = row[mask], col[mask] if edge_attr is not None: edge_attr = edge_attr[mask] return torch.stack([row, col], dim=0), edge_attr class FilterEdges(Connect): r"""Filters out edges if their incident nodes are not in any cluster. .. math:: \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}, where :math:`\mathbf{i}` denotes the set of retained nodes. It is assumed that each cluster contains only one node. """ def forward( self, select_output: SelectOutput, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, ) -> ConnectOutput: if (not torch.jit.is_scripting() and select_output.num_clusters != select_output.cluster_index.size(0)): raise ValueError(f"'{self.__class__.__name__}' requires each " f"cluster to contain only one node") edge_index, edge_attr = filter_adj( edge_index, edge_attr, select_output.node_index, select_output.cluster_index, num_nodes=select_output.num_nodes, ) batch = self.get_pooled_batch(select_output, batch) return ConnectOutput(edge_index, edge_attr, batch) ================================================ FILE: torch_geometric/nn/pool/consecutive.py ================================================ import torch def consecutive_cluster(src): unique, inv = torch.unique(src, sorted=True, return_inverse=True) perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device) perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm) return inv, perm ================================================ FILE: torch_geometric/nn/pool/decimation.py ================================================ from typing import Tuple, Union import torch from torch import LongTensor, Tensor from torch_geometric.utils import cumsum def decimation_indices( ptr: LongTensor, decimation_factor: Union[int, float], ) -> Tuple[Tensor, LongTensor]: """Gets indices which downsample each point cloud by a decimation factor. Decimation happens separately for each cloud to prevent emptying smaller point clouds. Empty clouds are prevented: clouds will have a least one node after decimation. Args: ptr (LongTensor): The indices of samples in the batch. decimation_factor (int or float): The value to divide number of nodes with. Should be higher than (or equal to) :obj:`1` for downsampling. :rtype: (:class:`LongTensor`, :class:`LongTensor`): The indices and updated :obj:`ptr` after downsampling. """ if decimation_factor < 1: raise ValueError( f"The argument `decimation_factor` should be higher than (or " f"equal to) 1 for downsampling. (got {decimation_factor})") batch_size = ptr.size(0) - 1 count = ptr[1:] - ptr[:-1] decim_count = torch.div(count, decimation_factor, rounding_mode='floor') decim_count.clamp_(min=1) # Prevent empty examples. decim_indices = [ ptr[i] + torch.randperm(count[i], device=ptr.device)[:decim_count[i]] for i in range(batch_size) ] decim_indices = torch.cat(decim_indices, dim=0) # Get updated ptr (e.g., for future decimations): decim_ptr = cumsum(decim_count) return decim_indices, decim_ptr ================================================ FILE: torch_geometric/nn/pool/edge_pool.py ================================================ from typing import Callable, List, NamedTuple, Optional, Tuple import torch import torch.nn.functional as F from torch import Tensor from torch_geometric.utils import coalesce, scatter, softmax class UnpoolInfo(NamedTuple): edge_index: Tensor cluster: Tensor batch: Tensor new_edge_score: Tensor class EdgePooling(torch.nn.Module): r"""The edge pooling operator from the `"Towards Graph Pooling by Edge Contraction" `__ and `"Edge Contraction Pooling for Graph Neural Networks" `__ papers. In short, a score is computed for each edge. Edges are contracted iteratively according to that score unless one of their nodes has already been part of a contracted edge. To duplicate the configuration from the `"Towards Graph Pooling by Edge Contraction" `__ paper, use either :func:`EdgePooling.compute_edge_score_softmax` or :func:`EdgePooling.compute_edge_score_tanh`, and set :obj:`add_to_edge_score` to :obj:`0.0`. To duplicate the configuration from the `"Edge Contraction Pooling for Graph Neural Networks" `__ paper, set :obj:`dropout` to :obj:`0.2`. Args: in_channels (int): Size of each input sample. edge_score_method (callable, optional): The function to apply to compute the edge score from raw edge scores. By default, this is the softmax over all incoming edges for each node. This function takes in a :obj:`raw_edge_score` tensor of shape :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of nodes :obj:`num_nodes`, and produces a new tensor of the same size as :obj:`raw_edge_score` describing normalized edge scores. Included functions are :func:`EdgePooling.compute_edge_score_softmax`, :func:`EdgePooling.compute_edge_score_tanh`, and :func:`EdgePooling.compute_edge_score_sigmoid`. (default: :func:`EdgePooling.compute_edge_score_softmax`) dropout (float, optional): The probability with which to drop edge scores during training. (default: :obj:`0.0`) add_to_edge_score (float, optional): A value to be added to each computed edge score. Adding this greatly helps with unpooling stability. (default: :obj:`0.5`) """ def __init__( self, in_channels: int, edge_score_method: Optional[Callable] = None, dropout: float = 0.0, add_to_edge_score: float = 0.5, ): super().__init__() self.in_channels = in_channels if edge_score_method is None: edge_score_method = self.compute_edge_score_softmax self.compute_edge_score = edge_score_method self.add_to_edge_score = add_to_edge_score self.dropout = dropout self.lin = torch.nn.Linear(2 * in_channels, 1) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() @staticmethod def compute_edge_score_softmax( raw_edge_score: Tensor, edge_index: Tensor, num_nodes: int, ) -> Tensor: r"""Normalizes edge scores via softmax application.""" return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes) @staticmethod def compute_edge_score_tanh( raw_edge_score: Tensor, edge_index: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tensor: r"""Normalizes edge scores via hyperbolic tangent application.""" return torch.tanh(raw_edge_score) @staticmethod def compute_edge_score_sigmoid( raw_edge_score: Tensor, edge_index: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tensor: r"""Normalizes edge scores via sigmoid application.""" return torch.sigmoid(raw_edge_score) def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: r"""Forward pass. Args: x (torch.Tensor): The node features. edge_index (torch.Tensor): The edge indices. batch (torch.Tensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Return types: * **x** *(torch.Tensor)* - The pooled node features. * **edge_index** *(torch.Tensor)* - The coarsened edge indices. * **batch** *(torch.Tensor)* - The coarsened batch vector. * **unpool_info** *(UnpoolInfo)* - Information that is consumed by :func:`EdgePooling.unpool` for unpooling. """ e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1) e = self.lin(e).view(-1) e = F.dropout(e, p=self.dropout, training=self.training) e = self.compute_edge_score(e, edge_index, x.size(0)) e = e + self.add_to_edge_score x, edge_index, batch, unpool_info = self._merge_edges( x, edge_index, batch, e) return x, edge_index, batch, unpool_info def _merge_edges( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_score: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: cluster = torch.empty_like(batch) perm: List[int] = torch.argsort(edge_score, descending=True).tolist() # Iterate through all edges, selecting it if it is not incident to # another already chosen edge. mask = torch.ones(x.size(0), dtype=torch.bool) i = 0 new_edge_indices: List[int] = [] edge_index_cpu = edge_index.cpu() for edge_idx in perm: source = int(edge_index_cpu[0, edge_idx]) if not bool(mask[source]): continue target = int(edge_index_cpu[1, edge_idx]) if not bool(mask[target]): continue new_edge_indices.append(edge_idx) cluster[source] = i mask[source] = False if source != target: cluster[target] = i mask[target] = False i += 1 # The remaining nodes are simply kept: j = int(mask.sum()) cluster[mask] = torch.arange(i, i + j, device=x.device) i += j # We compute the new features as an addition of the old ones. new_x = scatter(x, cluster, dim=0, dim_size=i, reduce='sum') new_edge_score = edge_score[new_edge_indices] if int(mask.sum()) > 0: remaining_score = x.new_ones( (new_x.size(0) - len(new_edge_indices), )) new_edge_score = torch.cat([new_edge_score, remaining_score]) new_x = new_x * new_edge_score.view(-1, 1) new_edge_index = coalesce(cluster[edge_index], num_nodes=new_x.size(0)) new_batch = x.new_empty(new_x.size(0), dtype=torch.long) new_batch = new_batch.scatter_(0, cluster, batch) unpool_info = UnpoolInfo(edge_index=edge_index, cluster=cluster, batch=batch, new_edge_score=new_edge_score) return new_x, new_edge_index, new_batch, unpool_info def unpool( self, x: Tensor, unpool_info: UnpoolInfo, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Unpools a previous edge pooling step. For unpooling, :obj:`x` should be of same shape as those produced by this layer's :func:`forward` function. Then, it will produce an unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`. Args: x (torch.Tensor): The node features. unpool_info (UnpoolInfo): Information that has been produced by :func:`EdgePooling.forward`. Return types: * **x** *(torch.Tensor)* - The unpooled node features. * **edge_index** *(torch.Tensor)* - The new edge indices. * **batch** *(torch.Tensor)* - The new batch vector. """ new_x = x / unpool_info.new_edge_score.view(-1, 1) new_x = new_x[unpool_info.cluster] return new_x, unpool_info.edge_index, unpool_info.batch def __repr__(self) -> str: return f'{self.__class__.__name__}({self.in_channels})' ================================================ FILE: torch_geometric/nn/pool/glob.py ================================================ from typing import Optional from torch import Tensor from torch_geometric.utils import scatter def global_add_pool(x: Tensor, batch: Optional[Tensor], size: Optional[int] = None) -> Tensor: r"""Returns batch-wise graph-level-outputs by adding node features across the node dimension. For a single graph :math:`\mathcal{G}_i`, its output is computed by .. math:: \mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n. Functional method of the :class:`~torch_geometric.nn.aggr.SumAggregation` module. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2 if batch is None: return x.sum(dim=dim, keepdim=x.dim() <= 2) return scatter(x, batch, dim=dim, dim_size=size, reduce='sum') def global_mean_pool(x: Tensor, batch: Optional[Tensor], size: Optional[int] = None) -> Tensor: r"""Returns batch-wise graph-level-outputs by averaging node features across the node dimension. For a single graph :math:`\mathcal{G}_i`, its output is computed by .. math:: \mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n. Functional method of the :class:`~torch_geometric.nn.aggr.MeanAggregation` module. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2 if batch is None: return x.mean(dim=dim, keepdim=x.dim() <= 2) return scatter(x, batch, dim=dim, dim_size=size, reduce='mean') def global_max_pool(x: Tensor, batch: Optional[Tensor], size: Optional[int] = None) -> Tensor: r"""Returns batch-wise graph-level-outputs by taking the channel-wise maximum across the node dimension. For a single graph :math:`\mathcal{G}_i`, its output is computed by .. math:: \mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n. Functional method of the :class:`~torch_geometric.nn.aggr.MaxAggregation` module. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2 if batch is None: return x.max(dim=dim, keepdim=x.dim() <= 2)[0] return scatter(x, batch, dim=dim, dim_size=size, reduce='max') ================================================ FILE: torch_geometric/nn/pool/graclus.py ================================================ from typing import Optional from torch import Tensor import torch_geometric.typing if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import graclus_cluster else: graclus_cluster = None def graclus(edge_index: Tensor, weight: Optional[Tensor] = None, num_nodes: Optional[int] = None): r"""A greedy clustering algorithm from the `"Weighted Graph Cuts without Eigenvectors: A Multilevel Approach" `_ paper of picking an unmarked vertex and matching it with one of its unmarked neighbors (that maximizes its edge weight). The GPU algorithm is adapted from the `"A GPU Algorithm for Greedy Graph Matching" `_ paper. Args: edge_index (torch.Tensor): The edge indices. weight (torch.Tensor, optional): One-dimensional edge weights. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ if graclus_cluster is None: raise ImportError('`graclus` requires `torch-cluster`.') return graclus_cluster(edge_index[0], edge_index[1], weight, num_nodes) ================================================ FILE: torch_geometric/nn/pool/knn.py ================================================ import warnings from typing import NamedTuple, Optional import torch from torch import Tensor from torch_geometric.utils import cumsum, degree, to_dense_batch class KNNOutput(NamedTuple): score: Tensor index: Tensor class KNNIndex: r"""A base class to perform fast :math:`k`-nearest neighbor search (:math:`k`-NN) via the :obj:`faiss` library. Please ensure that :obj:`faiss` is installed by running .. code-block:: bash pip install faiss-cpu # or pip install faiss-gpu depending on whether to plan to use GPU-processing for :math:`k`-NN search. Args: index_factory (str, optional): The name of the index factory to use, *e.g.*, :obj:`"IndexFlatL2"` or :obj:`"IndexFlatIP"`. See `here `_ for more information. emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) reserve (int, optional): The number of elements to reserve memory for before re-allocating (GPU-only). (default: :obj:`None`) """ def __init__( self, index_factory: Optional[str] = None, emb: Optional[Tensor] = None, reserve: Optional[int] = None, ): warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*') import faiss self.index_factory = index_factory self.index: Optional[faiss.Index] = None self.reserve = reserve if emb is not None: self.add(emb) @property def numel(self) -> int: r"""The number of data points to search in.""" if self.index is None: return 0 return self.index.ntotal def _create_index(self, channels: int): import faiss return faiss.index_factory(channels, self.index_factory) def add(self, emb: Tensor): r"""Adds new data points to the :class:`KNNIndex` to search in. Args: emb (torch.Tensor): The data points to add. """ import faiss import faiss.contrib.torch_utils if emb.dim() != 2: raise ValueError(f"'emb' needs to be two-dimensional " f"(got {emb.dim()} dimensions)") if self.index is None: self.index = self._create_index(emb.size(1)) if emb.device != torch.device('cpu'): self.index = faiss.index_cpu_to_gpu( faiss.StandardGpuResources(), emb.device.index, self.index, ) if self.reserve is not None: if hasattr(self.index, 'reserveMemory'): self.index.reserveMemory(self.reserve) else: warnings.warn( f"'{self.index.__class__.__name__}' " f"does not support pre-allocation of " f"memory", stacklevel=2) self.index.train(emb) self.index.add(emb.detach()) def search( self, emb: Tensor, k: int, exclude_links: Optional[Tensor] = None, ) -> KNNOutput: r"""Search for the :math:`k` nearest neighbors of the given data points. Returns the distance/similarity score of the nearest neighbors and their indices. Args: emb (torch.Tensor): The data points to add. k (int): The number of nearest neighbors to return. exclude_links (torch.Tensor): The links to exclude from searching. Needs to be a COO tensor of shape :obj:`[2, num_links]`, where :obj:`exclude_links[0]` refers to indices in :obj:`emb`, and :obj:`exclude_links[1]` refers to the data points in the :class:`KNNIndex`. (default: :obj:`None`) """ if self.index is None: raise RuntimeError(f"'{self.__class__.__name__}' is not yet " "initialized. Please call `add(...)` first.") if emb.dim() != 2: raise ValueError(f"'emb' needs to be two-dimensional " f"(got {emb.dim()} dimensions)") query_k = k if exclude_links is not None: deg = degree(exclude_links[0], num_nodes=emb.size(0)).max() query_k = k + int(deg.max() if deg.numel() > 0 else 0) query_k = min(query_k, self.numel) if k > 2048: # `faiss` supports up-to `k=2048`: warnings.warn( f"Capping 'k' to faiss' upper limit of 2048 " f"(got {k}). This may cause some relevant items to " f"not be retrieved.", stacklevel=2) elif query_k > 2048: warnings.warn( f"Capping 'k' to faiss' upper limit of 2048 " f"(got {k} which got extended to {query_k} due to " f"the exclusion of existing links). This may cause " f"some relevant items to not be retrieved.", stacklevel=2) query_k = 2048 score, index = self.index.search(emb.detach(), query_k) if exclude_links is not None: # Drop indices to exclude by converting to flat vector: flat_exclude = self.numel * exclude_links[0] + exclude_links[1] offset = torch.arange( start=0, end=self.numel * index.size(0), step=self.numel, device=index.device, ).view(-1, 1) flat_index = (index + offset).view(-1) notin = torch.isin(flat_index, flat_exclude).logical_not_() score = score.view(-1)[notin] index = index.view(-1)[notin] # Only maintain top-k scores: count = notin.view(-1, query_k).sum(dim=1) cum_count = cumsum(count) batch = torch.arange(count.numel(), device=count.device) batch = batch.repeat_interleave(count, output_size=cum_count[-1]) batch_arange = torch.arange(count.sum(), device=count.device) batch_arange = batch_arange - cum_count[batch] mask = batch_arange < k score = score[mask] index = index[mask] if count.min() < k: # Fill with dummy scores: batch = batch[mask] score, _ = to_dense_batch( score, batch, fill_value=float('-inf'), max_num_nodes=k, batch_size=emb.size(0), ) index, _ = to_dense_batch( index, batch, fill_value=-1, max_num_nodes=k, batch_size=emb.size(0), ) score = score.view(-1, k) index = index.view(-1, k) return KNNOutput(score, index) def get_emb(self) -> Tensor: r"""Returns the data points stored in the :class:`KNNIndex`.""" if self.index is None: raise RuntimeError(f"'{self.__class__.__name__}' is not yet " "initialized. Please call `add(...)` first.") return self.index.reconstruct_n(0, self.numel) class L2KNNIndex(KNNIndex): r"""Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on the :math:`L_2` metric via the :obj:`faiss` library. Args: emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, emb: Optional[Tensor] = None): super().__init__(index_factory=None, emb=emb) def _create_index(self, channels: int): import faiss return faiss.IndexFlatL2(channels) class MIPSKNNIndex(KNNIndex): r"""Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on the maximum inner product via the :obj:`faiss` library. Args: emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, emb: Optional[Tensor] = None): super().__init__(index_factory=None, emb=emb) def _create_index(self, channels: int): import faiss return faiss.IndexFlatIP(channels) class ApproxL2KNNIndex(KNNIndex): r"""Performs fast approximate :math:`k`-nearest neighbor search (:math:`k`-NN) based on the the :math:`L_2` metric via the :obj:`faiss` library. Hyperparameters needs to be tuned for speed-accuracy trade-off. Args: num_cells (int): The number of cells. num_cells_to_visit (int): The number of cells that are visited to perform to search. bits_per_vector (int): The number of bits per sub-vector. emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) reserve (int, optional): The number of elements to reserve memory for before re-allocating (GPU only). (default: :obj:`None`) """ def __init__( self, num_cells: int, num_cells_to_visit: int, bits_per_vector: int, emb: Optional[Tensor] = None, reserve: Optional[int] = None, ): self.num_cells = num_cells self.num_cells_to_visit = num_cells_to_visit self.bits_per_vector = bits_per_vector super().__init__(index_factory=None, emb=emb, reserve=reserve) def _create_index(self, channels: int): import faiss index = faiss.IndexIVFPQ( faiss.IndexFlatL2(channels), channels, self.num_cells, self.bits_per_vector, 8, faiss.METRIC_L2, ) index.nprobe = self.num_cells_to_visit return index class ApproxMIPSKNNIndex(KNNIndex): r"""Performs fast approximate :math:`k`-nearest neighbor search (:math:`k`-NN) based on the maximum inner product via the :obj:`faiss` library. Hyperparameters needs to be tuned for speed-accuracy trade-off. Args: num_cells (int): The number of cells. num_cells_to_visit (int): The number of cells that are visited to perform to search. bits_per_vector (int): The number of bits per sub-vector. emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) reserve (int, optional): The number of elements to reserve memory for before re-allocating (GPU only). (default: :obj:`None`) """ def __init__( self, num_cells: int, num_cells_to_visit: int, bits_per_vector: int, emb: Optional[Tensor] = None, reserve: Optional[int] = None, ): self.num_cells = num_cells self.num_cells_to_visit = num_cells_to_visit self.bits_per_vector = bits_per_vector super().__init__(index_factory=None, emb=emb, reserve=reserve) def _create_index(self, channels: int): import faiss index = faiss.IndexIVFPQ( faiss.IndexFlatIP(channels), channels, self.num_cells, self.bits_per_vector, 8, faiss.METRIC_INNER_PRODUCT, ) index.nprobe = self.num_cells_to_visit return index ================================================ FILE: torch_geometric/nn/pool/max_pool.py ================================================ from typing import Callable, Optional, Tuple from torch import Tensor from torch_geometric.data import Batch, Data from torch_geometric.nn.pool.consecutive import consecutive_cluster from torch_geometric.nn.pool.pool import pool_batch, pool_edge, pool_pos from torch_geometric.utils import add_self_loops, scatter def _max_pool_x( cluster: Tensor, x: Tensor, size: Optional[int] = None, ) -> Tensor: return scatter(x, cluster, dim=0, dim_size=size, reduce='max') def max_pool_x( cluster: Tensor, x: Tensor, batch: Tensor, batch_size: Optional[int] = None, size: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Max-Pools node features according to the clustering defined in :attr:`cluster`. Args: cluster (torch.Tensor): The cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. x (Tensor): The node feature matrix. batch (torch.Tensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) size (int, optional): The maximum number of clusters in a single example. This property is useful to obtain a batch-wise dense representation, *e.g.* for applying FC layers, but should only be used if the size of the maximum number of clusters per example is known in advance. (default: :obj:`None`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) if :attr:`size` is :obj:`None`, else :class:`torch.Tensor` """ if size is not None: if batch_size is None: batch_size = int(batch.max().item()) + 1 return _max_pool_x(cluster, x, batch_size * size), None cluster, perm = consecutive_cluster(cluster) x = _max_pool_x(cluster, x) batch = pool_batch(perm, batch) return x, batch def max_pool( cluster: Tensor, data: Data, transform: Optional[Callable] = None, ) -> Data: r"""Pools and coarsens a graph given by the :class:`torch_geometric.data.Data` object according to the clustering defined in :attr:`cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *maximum* features of all nodes within the same cluster, node positions are averaged and edge indices are defined to be the union of the edge indices of all nodes within the same cluster. Args: cluster (torch.Tensor): The cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. data (Data): Graph data object. transform (callable, optional): A function/transform that takes in the coarsened and pooled :obj:`torch_geometric.data.Data` object and returns a transformed version. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Data` """ cluster, perm = consecutive_cluster(cluster) x = None if data.x is None else _max_pool_x(cluster, data.x) index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) batch = None if data.batch is None else pool_batch(perm, data.batch) pos = None if data.pos is None else pool_pos(cluster, data.pos) data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) if transform is not None: data = transform(data) return data def max_pool_neighbor_x( data: Data, flow: Optional[str] = 'source_to_target', ) -> Data: r"""Max pools neighboring node features, where each feature in :obj:`data.x` is replaced by the feature value with the maximum value from the central node and its neighbors. """ x, edge_index = data.x, data.edge_index edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes) row, col = edge_index row, col = (row, col) if flow == 'source_to_target' else (col, row) data.x = scatter(x[row], col, dim=0, dim_size=data.num_nodes, reduce='max') return data ================================================ FILE: torch_geometric/nn/pool/mem_pool.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch.nn import Conv2d, KLDivLoss, Linear, Parameter from torch_geometric.utils import to_dense_batch EPS = 1e-15 class MemPooling(torch.nn.Module): r"""Memory based pooling layer from `"Memory-Based Graph Networks" `_ paper, which learns a coarsened graph representation based on soft cluster assignments. .. math:: S_{i,j}^{(h)} &= \frac{ (1+{\| \mathbf{x}_i-\mathbf{k}^{(h)}_j \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}{ \sum_{k=1}^K (1 + {\| \mathbf{x}_i-\mathbf{k}^{(h)}_k \|}^2 / \tau)^{ -\frac{1+\tau}{2}}} \mathbf{S} &= \textrm{softmax}(\textrm{Conv2d} (\Vert_{h=1}^H \mathbf{S}^{(h)})) \in \mathbb{R}^{N \times K} \mathbf{X}^{\prime} &= \mathbf{S}^{\top} \mathbf{X} \mathbf{W} \in \mathbb{R}^{K \times F^{\prime}} where :math:`H` denotes the number of heads, and :math:`K` denotes the number of clusters. Args: in_channels (int): Size of each input sample :math:`F`. out_channels (int): Size of each output sample :math:`F^{\prime}`. heads (int): The number of heads :math:`H`. num_clusters (int): number of clusters :math:`K` per head. tau (int, optional): The temperature :math:`\tau`. (default: :obj:`1.`) """ def __init__(self, in_channels: int, out_channels: int, heads: int, num_clusters: int, tau: float = 1.): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.num_clusters = num_clusters self.tau = tau self.k = Parameter(torch.empty(heads, num_clusters, in_channels)) self.conv = Conv2d(heads, 1, kernel_size=1, padding=0, bias=False) self.lin = Linear(in_channels, out_channels, bias=False) self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" torch.nn.init.uniform_(self.k.data, -1., 1.) self.conv.reset_parameters() self.lin.reset_parameters() @staticmethod def kl_loss(S: Tensor) -> Tensor: r"""The additional KL divergence-based loss. .. math:: P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K S_{i,k}^2 / \sum_{n=1}^N S_{n,k}} \mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert \mathbf{S}) """ S_2 = S**2 P = S_2 / S.sum(dim=1, keepdim=True) denom = P.sum(dim=2, keepdim=True) denom[S.sum(dim=2, keepdim=True) == 0.0] = 1.0 P /= denom loss = KLDivLoss(reduction='batchmean', log_target=False) return loss(S.clamp(EPS).log(), P.clamp(EPS)) def forward( self, x: Tensor, batch: Optional[Tensor] = None, mask: Optional[Tensor] = None, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature tensor of shape :math:`\mathbf{X} \in \mathbb{R}^{N \times F}` or :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Should not be provided in case node features already have shape :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`. (default: :obj:`None`) mask (torch.Tensor, optional): A mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}`, which indicates valid nodes for each graph when using node features of shape :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`. (default: :obj:`None`) max_num_nodes (int, optional): The size of the :math:`B` node dimension. Automatically calculated if not given. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if x.dim() <= 2: x, mask = to_dense_batch(x, batch, max_num_nodes=max_num_nodes, batch_size=batch_size) elif mask is None: mask = x.new_ones((x.size(0), x.size(1)), dtype=torch.bool) (B, N, _), H, K = x.size(), self.heads, self.num_clusters dist = torch.cdist(self.k.view(H * K, -1), x.view(B * N, -1), p=2)**2 dist = (1. + dist / self.tau).pow(-(self.tau + 1.0) / 2.0) dist = dist.view(H, K, B, N).permute(2, 0, 3, 1) # [B, H, N, K] S = dist / dist.sum(dim=-1, keepdim=True) S = self.conv(S).squeeze(dim=1).softmax(dim=-1) # [B, N, K] S = S * mask.view(B, N, 1) x = self.lin(S.transpose(1, 2) @ x) return x, S def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads}, ' f'num_clusters={self.num_clusters})') ================================================ FILE: torch_geometric/nn/pool/pan_pool.py ================================================ from typing import Callable, Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.pool.connect import FilterEdges from torch_geometric.nn.pool.select import SelectTopK from torch_geometric.typing import OptTensor, SparseTensor from torch_geometric.utils import scatter class PANPooling(torch.nn.Module): r"""The path integral based pooling operator from the `"Path Integral Based Convolution and Pooling for Graph Neural Networks" `_ paper. PAN pooling performs top-:math:`k` pooling where global node importance is measured based on node features and the MET matrix: .. math:: {\rm score} = \beta_1 \mathbf{X} \cdot \mathbf{p} + \beta_2 {\rm deg}(\mathbf{M}) Args: in_channels (int): Size of each input sample. ratio (float): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`. This value is ignored if min_score is not None. (default: :obj:`0.5`) min_score (float, optional): Minimal node score :math:`\tilde{\alpha}` which is used to compute indices of pooled nodes :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`. When this value is not :obj:`None`, the :obj:`ratio` argument is ignored. (default: :obj:`None`) multiplier (float, optional): Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when :obj:`min_score` is used. (default: :obj:`1.0`) nonlinearity (str or callable, optional): The non-linearity to use. (default: :obj:`"tanh"`) """ def __init__( self, in_channels: int, ratio: float = 0.5, min_score: Optional[float] = None, multiplier: float = 1.0, nonlinearity: Union[str, Callable] = 'tanh', ): super().__init__() self.in_channels = in_channels self.ratio = ratio self.min_score = min_score self.multiplier = multiplier self.p = Parameter(torch.empty(in_channels)) self.beta = Parameter(torch.empty(2)) self.select = SelectTopK(1, ratio, min_score, nonlinearity) self.connect = FilterEdges() self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.p.data.fill_(1) self.beta.data.fill_(0.5) self.select.reset_parameters() def forward( self, x: Tensor, M: SparseTensor, batch: OptTensor = None, ) -> Tuple[Tensor, Tensor, Tensor, OptTensor, Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature matrix. M (SparseTensor): The MET matrix :math:`\mathbf{M}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) """ if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) row, col, edge_weight = M.coo() assert edge_weight is not None score1 = (x * self.p).sum(dim=-1) score2 = scatter(edge_weight, col, 0, dim_size=x.size(0), reduce='sum') score = self.beta[0] * score1 + self.beta[1] * score2 select_out = self.select(score, batch) perm = select_out.node_index score = select_out.weight assert score is not None x = x[perm] * score.view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x edge_index = torch.stack([col, row], dim=0) connect_out = self.connect(select_out, edge_index, edge_weight, batch) edge_weight = connect_out.edge_attr assert edge_weight is not None return (x, connect_out.edge_index, edge_weight, connect_out.batch, perm, score) def __repr__(self) -> str: if self.min_score is None: ratio = f'ratio={self.ratio}' else: ratio = f'min_score={self.min_score}' return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, ' f'multiplier={self.multiplier})') ================================================ FILE: torch_geometric/nn/pool/pool.py ================================================ from typing import Optional import torch from torch_geometric.utils import coalesce, remove_self_loops, scatter def pool_edge( cluster, edge_index, edge_attr: Optional[torch.Tensor] = None, reduce: Optional[str] = 'sum', ): num_nodes = cluster.size(0) edge_index = cluster[edge_index.view(-1)].view(2, -1) edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) if edge_index.numel() > 0: edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, reduce=reduce) return edge_index, edge_attr def pool_batch(perm, batch): return batch[perm] def pool_pos(cluster, pos): return scatter(pos, cluster, dim=0, reduce='mean') ================================================ FILE: torch_geometric/nn/pool/sag_pool.py ================================================ from typing import Callable, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.nn import GraphConv from torch_geometric.nn.pool.connect import FilterEdges from torch_geometric.nn.pool.select import SelectTopK from torch_geometric.typing import OptTensor class SAGPooling(torch.nn.Module): r"""The self-attention pooling operator from the `"Self-Attention Graph Pooling" `_ and `"Understanding Attention and Generalization in Graph Neural Networks" `_ papers. If :obj:`min_score` :math:`\tilde{\alpha}` is :obj:`None`, computes: .. math:: \mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A}) \mathbf{i} &= \mathrm{top}_k(\mathbf{y}) \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}} If :obj:`min_score` :math:`\tilde{\alpha}` is a value in :obj:`[0, 1]`, computes: .. math:: \mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A})) \mathbf{i} &= \mathbf{y}_i > \tilde{\alpha} \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}. Projections scores are learned based on a graph neural network layer. Args: in_channels (int): Size of each input sample. ratio (float or int): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`, or the value of :math:`k` itself, depending on whether the type of :obj:`ratio` is :obj:`float` or :obj:`int`. This value is ignored if :obj:`min_score` is not :obj:`None`. (default: :obj:`0.5`) GNN (torch.nn.Module, optional): A graph neural network layer for calculating projection scores (one of :class:`torch_geometric.nn.conv.GraphConv`, :class:`torch_geometric.nn.conv.GCNConv`, :class:`torch_geometric.nn.conv.GATConv` or :class:`torch_geometric.nn.conv.SAGEConv`). (default: :class:`torch_geometric.nn.conv.GraphConv`) min_score (float, optional): Minimal node score :math:`\tilde{\alpha}` which is used to compute indices of pooled nodes :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`. When this value is not :obj:`None`, the :obj:`ratio` argument is ignored. (default: :obj:`None`) multiplier (float, optional): Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when :obj:`min_score` is used. (default: :obj:`1`) nonlinearity (str or callable, optional): The non-linearity to use. (default: :obj:`"tanh"`) **kwargs (optional): Additional parameters for initializing the graph neural network layer. """ def __init__( self, in_channels: int, ratio: Union[float, int] = 0.5, GNN: torch.nn.Module = GraphConv, min_score: Optional[float] = None, multiplier: float = 1.0, nonlinearity: Union[str, Callable] = 'tanh', **kwargs, ): super().__init__() self.in_channels = in_channels self.ratio = ratio self.min_score = min_score self.multiplier = multiplier self.gnn = GNN(in_channels, 1, **kwargs) self.select = SelectTopK(1, ratio, min_score, nonlinearity) self.connect = FilterEdges() self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.gnn.reset_parameters() self.select.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, edge_attr: OptTensor = None, batch: OptTensor = None, attn: OptTensor = None, ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature matrix. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) attn (torch.Tensor, optional): Optional node-level matrix to use for computing attention scores instead of using the node feature matrix :obj:`x`. (default: :obj:`None`) """ if batch is None: batch = edge_index.new_zeros(x.size(0)) attn = x if attn is None else attn attn = attn.view(-1, 1) if attn.dim() == 1 else attn attn = self.gnn(attn, edge_index) select_out = self.select(attn, batch) perm = select_out.node_index score = select_out.weight assert score is not None x = x[perm] * score.view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x connect_out = self.connect(select_out, edge_index, edge_attr, batch) return (x, connect_out.edge_index, connect_out.edge_attr, connect_out.batch, perm, score) def __repr__(self) -> str: if self.min_score is None: ratio = f'ratio={self.ratio}' else: ratio = f'min_score={self.min_score}' return (f'{self.__class__.__name__}({self.gnn.__class__.__name__}, ' f'{self.in_channels}, {ratio}, multiplier={self.multiplier})') ================================================ FILE: torch_geometric/nn/pool/select/__init__.py ================================================ r"""Node-selection package. This package provides classes for node selection methods in graph pooling scenarios. """ from .base import Select, SelectOutput from .topk import SelectTopK __all__ = [ 'Select', 'SelectOutput', 'SelectTopK', ] ================================================ FILE: torch_geometric/nn/pool/select/base.py ================================================ from dataclasses import dataclass from typing import Optional import torch from torch import Tensor @dataclass(init=False) class SelectOutput: r"""The output of the :class:`Select` method, which holds an assignment from selected nodes to their respective cluster(s). Args: node_index (torch.Tensor): The indices of the selected nodes. num_nodes (int): The number of nodes. cluster_index (torch.Tensor): The indices of the clusters each node in :obj:`node_index` is assigned to. num_clusters (int): The number of clusters. weight (torch.Tensor, optional): A weight vector, denoting the strength of the assignment of a node to its cluster. (default: :obj:`None`) """ node_index: Tensor num_nodes: int cluster_index: Tensor num_clusters: int weight: Optional[Tensor] = None def __init__( self, node_index: Tensor, num_nodes: int, cluster_index: Tensor, num_clusters: int, weight: Optional[Tensor] = None, ): if node_index.dim() != 1: raise ValueError(f"Expected 'node_index' to be one-dimensional " f"(got {node_index.dim()} dimensions)") if cluster_index.dim() != 1: raise ValueError(f"Expected 'cluster_index' to be one-dimensional " f"(got {cluster_index.dim()} dimensions)") if node_index.numel() != cluster_index.numel(): raise ValueError(f"Expected 'node_index' and 'cluster_index' to " f"hold the same number of values (got " f"{node_index.numel()} and " f"{cluster_index.numel()} values)") if weight is not None and weight.dim() != 1: raise ValueError(f"Expected 'weight' vector to be one-dimensional " f"(got {weight.dim()} dimensions)") if weight is not None and weight.numel() != node_index.numel(): raise ValueError(f"Expected 'weight' to hold {node_index.numel()} " f"values (got {weight.numel()} values)") self.node_index = node_index self.num_nodes = num_nodes self.cluster_index = cluster_index self.num_clusters = num_clusters self.weight = weight SelectOutput = torch.jit.script(SelectOutput) class Select(torch.nn.Module): r"""An abstract base class for implementing custom node selections as described in the `"Understanding Pooling in Graph Neural Networks" `_ paper, which maps the nodes of an input graph to supernodes in the coarsened graph. Specifically, :class:`Select` returns a :class:`SelectOutput` output, which holds a (sparse) mapping :math:`\mathbf{C} \in {[0, 1]}^{N \times C}` that assigns selected nodes to one or more of :math:`C` super nodes. """ def reset_parameters(self): r"""Resets all learnable parameters of the module.""" def forward(self, *args, **kwargs) -> SelectOutput: raise NotImplementedError def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/nn/pool/select/topk.py ================================================ from typing import Callable, Optional, Union import torch from torch import Tensor from torch_geometric.nn.inits import uniform from torch_geometric.nn.pool.select import Select, SelectOutput from torch_geometric.nn.resolver import activation_resolver from torch_geometric.utils import cumsum, scatter, softmax # TODO (matthias) Document this method. def topk( x: Tensor, ratio: Optional[Union[float, int]], batch: Tensor, min_score: Optional[float] = None, tol: float = 1e-7, ) -> Tensor: if min_score is not None: # Make sure that we do not drop all nodes in a graph. scores_max = scatter(x, batch, reduce='max')[batch] - tol scores_min = scores_max.clamp(max=min_score) perm = (x > scores_min).nonzero().view(-1) return perm if ratio is not None: num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum') if ratio >= 1: k = num_nodes.new_full((num_nodes.size(0), ), int(ratio)) else: k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long) x, x_perm = torch.sort(x.view(-1), descending=True) batch = batch[x_perm] batch, batch_perm = torch.sort(batch, descending=False, stable=True) arange = torch.arange(x.size(0), dtype=torch.long, device=x.device) ptr = cumsum(num_nodes) batched_arange = arange - ptr[batch] mask = batched_arange < k[batch] return x_perm[batch_perm[mask]] raise ValueError("At least one of the 'ratio' and 'min_score' parameters " "must be specified") class SelectTopK(Select): r"""Selects the top-:math:`k` nodes with highest projection scores from the `"Graph U-Nets" `_, `"Towards Sparse Hierarchical Graph Classifiers" `_ and `"Understanding Attention and Generalization in Graph Neural Networks" `_ papers. If :obj:`min_score` :math:`\tilde{\alpha}` is :obj:`None`, computes: .. math:: \mathbf{y} &= \sigma \left( \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|} \right) \mathbf{i} &= \mathrm{top}_k(\mathbf{y}) If :obj:`min_score` :math:`\tilde{\alpha}` is a value in :obj:`[0, 1]`, computes: .. math:: \mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p}) \mathbf{i} &= \mathbf{y}_i > \tilde{\alpha} where :math:`\mathbf{p}` is the learnable projection vector. Args: in_channels (int): Size of each input sample. ratio (float or int): The graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`, or the value of :math:`k` itself, depending on whether the type of :obj:`ratio` is :obj:`float` or :obj:`int`. This value is ignored if :obj:`min_score` is not :obj:`None`. (default: :obj:`0.5`) min_score (float, optional): Minimal node score :math:`\tilde{\alpha}` which is used to compute indices of pooled nodes :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`. When this value is not :obj:`None`, the :obj:`ratio` argument is ignored. (default: :obj:`None`) act (str or callable, optional): The non-linearity :math:`\sigma`. (default: :obj:`"tanh"`) """ def __init__( self, in_channels: int, ratio: Union[int, float] = 0.5, min_score: Optional[float] = None, act: Union[str, Callable] = 'tanh', ): super().__init__() if ratio is None and min_score is None: raise ValueError(f"At least one of the 'ratio' and 'min_score' " f"parameters must be specified in " f"'{self.__class__.__name__}'") self.in_channels = in_channels self.ratio = ratio self.min_score = min_score self.act = activation_resolver(act) self.weight = torch.nn.Parameter(torch.empty(1, in_channels)) self.reset_parameters() def reset_parameters(self): uniform(self.in_channels, self.weight) def forward( self, x: Tensor, batch: Optional[Tensor] = None, ) -> SelectOutput: """""" # noqa: D419 if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) x = x.view(-1, 1) if x.dim() == 1 else x score = (x * self.weight).sum(dim=-1) if self.min_score is None: score = self.act(score / self.weight.norm(p=2, dim=-1)) else: score = softmax(score, batch) node_index = topk(score, self.ratio, batch, self.min_score) return SelectOutput( node_index=node_index, num_nodes=x.size(0), cluster_index=torch.arange(node_index.size(0), device=x.device), num_clusters=node_index.size(0), weight=score[node_index], ) def __repr__(self) -> str: if self.min_score is None: arg = f'ratio={self.ratio}' else: arg = f'min_score={self.min_score}' return f'{self.__class__.__name__}({self.in_channels}, {arg})' ================================================ FILE: torch_geometric/nn/pool/topk_pool.py ================================================ from typing import Callable, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.nn.pool.connect import FilterEdges from torch_geometric.nn.pool.select import SelectTopK from torch_geometric.typing import OptTensor class TopKPooling(torch.nn.Module): r""":math:`\mathrm{top}_k` pooling operator from the `"Graph U-Nets" `_, `"Towards Sparse Hierarchical Graph Classifiers" `_ and `"Understanding Attention and Generalization in Graph Neural Networks" `_ papers. If :obj:`min_score` :math:`\tilde{\alpha}` is :obj:`None`, computes: .. math:: \mathbf{y} &= \sigma \left( \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|} \right) \mathbf{i} &= \mathrm{top}_k(\mathbf{y}) \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}} If :obj:`min_score` :math:`\tilde{\alpha}` is a value in :obj:`[0, 1]`, computes: .. math:: \mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p}) \mathbf{i} &= \mathbf{y}_i > \tilde{\alpha} \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}, where nodes are dropped based on a learnable projection score :math:`\mathbf{p}`. Args: in_channels (int): Size of each input sample. ratio (float or int): The graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`, or the value of :math:`k` itself, depending on whether the type of :obj:`ratio` is :obj:`float` or :obj:`int`. This value is ignored if :obj:`min_score` is not :obj:`None`. (default: :obj:`0.5`) min_score (float, optional): Minimal node score :math:`\tilde{\alpha}` which is used to compute indices of pooled nodes :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`. When this value is not :obj:`None`, the :obj:`ratio` argument is ignored. (default: :obj:`None`) multiplier (float, optional): Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when :obj:`min_score` is used. (default: :obj:`1`) nonlinearity (str or callable, optional): The non-linearity :math:`\sigma`. (default: :obj:`"tanh"`) """ def __init__( self, in_channels: int, ratio: Union[int, float] = 0.5, min_score: Optional[float] = None, multiplier: float = 1., nonlinearity: Union[str, Callable] = 'tanh', ): super().__init__() self.in_channels = in_channels self.ratio = ratio self.min_score = min_score self.multiplier = multiplier self.select = SelectTopK(in_channels, ratio, min_score, nonlinearity) self.connect = FilterEdges() self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.select.reset_parameters() def forward( self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, attn: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature matrix. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) attn (torch.Tensor, optional): Optional node-level matrix to use for computing attention scores instead of using the node feature matrix :obj:`x`. (default: :obj:`None`) """ if batch is None: batch = edge_index.new_zeros(x.size(0)) attn = x if attn is None else attn select_out = self.select(attn, batch) perm = select_out.node_index score = select_out.weight assert score is not None x = x[perm] * score.view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x connect_out = self.connect(select_out, edge_index, edge_attr, batch) return (x, connect_out.edge_index, connect_out.edge_attr, connect_out.batch, perm, score) def __repr__(self) -> str: if self.min_score is None: ratio = f'ratio={self.ratio}' else: ratio = f'min_score={self.min_score}' return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, ' f'multiplier={self.multiplier})') ================================================ FILE: torch_geometric/nn/pool/voxel_grid.py ================================================ from typing import List, Optional, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric.utils.repeat import repeat if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import grid_cluster else: grid_cluster = None def voxel_grid( pos: Tensor, size: Union[float, List[float], Tensor], batch: Optional[Tensor] = None, start: Optional[Union[float, List[float], Tensor]] = None, end: Optional[Union[float, List[float], Tensor]] = None, ) -> Tensor: r"""Voxel grid pooling from the, *e.g.*, `Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs `_ paper, which overlays a regular grid of user-defined size over a point cloud and clusters all points within the same voxel. Args: pos (torch.Tensor): Node position matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}`. size (float or [float] or Tensor): Size of a voxel (in each dimension). batch (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) start (float or [float] or Tensor, optional): Start coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the minimum coordinates found in :attr:`pos`. (default: :obj:`None`) end (float or [float] or Tensor, optional): End coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the maximum coordinates found in :attr:`pos`. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ if grid_cluster is None: raise ImportError('`voxel_grid` requires `torch-cluster`.') pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos dim = pos.size(1) if batch is None: batch = pos.new_zeros(pos.size(0), dtype=torch.long) pos = torch.cat([pos, batch.view(-1, 1).to(pos.dtype)], dim=-1) if not isinstance(size, Tensor): size = torch.tensor(size, dtype=pos.dtype, device=pos.device) size = repeat(size, dim) size = torch.cat([size, size.new_ones(1)]) # Add additional batch dim. if start is not None: if not isinstance(start, Tensor): start = torch.tensor(start, dtype=pos.dtype, device=pos.device) start = repeat(start, dim) start = torch.cat([start, start.new_zeros(1)]) if end is not None: if not isinstance(end, Tensor): end = torch.tensor(end, dtype=pos.dtype, device=pos.device) end = repeat(end, dim) end = torch.cat([end, batch.max().unsqueeze(0)]) return grid_cluster(pos, size, start, end) ================================================ FILE: torch_geometric/nn/reshape.py ================================================ import torch from torch import Tensor class Reshape(torch.nn.Module): def __init__(self, *shape): super().__init__() self.shape = shape def forward(self, x: Tensor) -> Tensor: """""" # noqa: D419 x = x.view(*self.shape) return x def __repr__(self) -> str: shape = ', '.join([str(dim) for dim in self.shape]) return f'{self.__class__.__name__}({shape})' ================================================ FILE: torch_geometric/nn/resolver.py ================================================ import inspect from typing import Any, Optional, Union import torch from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import ReduceLROnPlateau from torch_geometric.nn.lr_scheduler import ( ConstantWithWarmupLR, CosineWithWarmupLR, CosineWithWarmupRestartsLR, LinearWithWarmupLR, PolynomialWithWarmupLR, ) from torch_geometric.resolver import normalize_string, resolver try: from torch.optim.lr_scheduler import LRScheduler except ImportError: # PyTorch < 2.0 from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # Activation Resolver ######################################################### def swish(x: Tensor) -> Tensor: return x * x.sigmoid() def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs): base_cls = torch.nn.Module base_cls_repr = 'Act' acts = [ act for act in vars(torch.nn.modules.activation).values() if isinstance(act, type) and issubclass(act, base_cls) ] acts += [ swish, ] act_dict = {} return resolver(acts, act_dict, query, base_cls, base_cls_repr, *args, **kwargs) # Normalization Resolver ###################################################### def normalization_resolver(query: Union[Any, str], *args, **kwargs): import torch_geometric.nn.norm as norm base_cls = torch.nn.Module base_cls_repr = 'Norm' norms = [ norm for norm in vars(norm).values() if isinstance(norm, type) and issubclass(norm, base_cls) ] norm_dict = {} return resolver(norms, norm_dict, query, base_cls, base_cls_repr, *args, **kwargs) # Aggregation Resolver ######################################################## def aggregation_resolver(query: Union[Any, str], *args, **kwargs): import torch_geometric.nn.aggr as aggr if isinstance(query, (list, tuple)): return aggr.MultiAggregation(query, *args, **kwargs) base_cls = aggr.Aggregation aggrs = [ aggr for aggr in vars(aggr).values() if isinstance(aggr, type) and issubclass(aggr, base_cls) ] aggr_dict = { 'add': aggr.SumAggregation, } return resolver(aggrs, aggr_dict, query, base_cls, None, *args, **kwargs) # Optimizer Resolver ########################################################## def optimizer_resolver(query: Union[Any, str], *args, **kwargs): base_cls = Optimizer optimizers = [ optimizer for optimizer in vars(torch.optim).values() if isinstance(optimizer, type) and issubclass(optimizer, base_cls) ] return resolver(optimizers, {}, query, base_cls, None, *args, **kwargs) # Learning Rate Scheduler Resolver ############################################ def lr_scheduler_resolver( query: Union[Any, str], optimizer: Optimizer, warmup_ratio_or_steps: Optional[Union[float, int]] = 0.1, num_training_steps: Optional[int] = None, **kwargs, ) -> Union[LRScheduler, ReduceLROnPlateau]: r"""A resolver to obtain a learning rate scheduler implemented in either PyG or PyTorch from its name or type. Args: query (Any or str): The query name of the learning rate scheduler. optimizer (Optimizer): The optimizer to be scheduled. warmup_ratio_or_steps (float or int, optional): The number of warmup steps. If given as a `float`, it will act as a ratio that gets multiplied with the number of training steps to obtain the number of warmup steps. Only required for warmup-based LR schedulers. (default: :obj:`0.1`) num_training_steps (int, optional): The total number of training steps. (default: :obj:`None`) **kwargs (optional): Additional arguments of the LR scheduler. """ if not isinstance(query, str): return query if isinstance(warmup_ratio_or_steps, float): if warmup_ratio_or_steps < 0 or warmup_ratio_or_steps > 1: raise ValueError(f"`warmup_ratio_or_steps` needs to be between " f"0.0 and 1.0 when given as a floating point " f"number (got {warmup_ratio_or_steps}).") if num_training_steps is not None: warmup_steps = round(warmup_ratio_or_steps * num_training_steps) elif isinstance(warmup_ratio_or_steps, int): if warmup_ratio_or_steps < 0: raise ValueError(f"`warmup_ratio_or_steps` needs to be positive " f"when given as an integer " f"(got {warmup_ratio_or_steps}).") warmup_steps = warmup_ratio_or_steps else: raise ValueError(f"Found invalid type of `warmup_ratio_or_steps` " f"(got {type(warmup_ratio_or_steps)})") base_cls = LRScheduler classes = [ scheduler for scheduler in vars(torch.optim.lr_scheduler).values() if isinstance(scheduler, type) and issubclass(scheduler, base_cls) ] + [ReduceLROnPlateau] customized_lr_schedulers = [ ConstantWithWarmupLR, LinearWithWarmupLR, CosineWithWarmupLR, CosineWithWarmupRestartsLR, PolynomialWithWarmupLR, ] classes += customized_lr_schedulers query_repr = normalize_string(query) base_cls_repr = normalize_string('LR') for cls in classes: cls_repr = normalize_string(cls.__name__) if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]: if inspect.isclass(cls): if cls in customized_lr_schedulers: cls_keys = inspect.signature(cls).parameters.keys() if 'num_warmup_steps' in cls_keys: kwargs['num_warmup_steps'] = warmup_steps if 'num_training_steps' in cls_keys: kwargs['num_training_steps'] = num_training_steps obj = cls(optimizer, **kwargs) return obj return cls choices = {cls.__name__ for cls in classes} raise ValueError(f"Could not resolve '{query}' among choices {choices}") ================================================ FILE: torch_geometric/nn/sequential.jinja ================================================ import typing import torch from torch import Tensor import torch_geometric.typing {% for module in modules %} from {{module}} import * {%- endfor %} def forward( self, {%- for param in signature.param_dict.values() %} {{param.name}}: {{param.type_repr}}, {%- endfor %} ) -> {{signature.return_type_repr}}: {%- for child in children %} {{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}}) {%- endfor %} return {{children[-1].return_names|join(', ')}} ================================================ FILE: torch_geometric/nn/sequential.py ================================================ import copy import inspect import os.path as osp import random import sys from typing import ( Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union, ) import torch from torch import Tensor from torch_geometric.inspector import Parameter, Signature, eval_type, split from torch_geometric.template import module_from_template class Child(NamedTuple): name: str param_names: List[str] return_names: List[str] class Sequential(torch.nn.Module): r"""An extension of the :class:`torch.nn.Sequential` container in order to define a sequential GNN model. Since GNN operators take in multiple input arguments, :class:`torch_geometric.nn.Sequential` additionally expects both global input arguments, and function header definitions of individual operators. If omitted, an intermediate module will operate on the *output* of its preceding module: .. code-block:: python from torch.nn import Linear, ReLU from torch_geometric.nn import Sequential, GCNConv model = Sequential('x, edge_index', [ (GCNConv(in_channels, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, out_channels), ]) Here, :obj:`'x, edge_index'` defines the input arguments of :obj:`model`, and :obj:`'x, edge_index -> x'` defines the function header, *i.e.* input arguments *and* return types of :class:`~torch_geometric.nn.conv.GCNConv`. In particular, this also allows to create more sophisticated models, such as utilizing :class:`~torch_geometric.nn.models.JumpingKnowledge`: .. code-block:: python from torch.nn import Linear, ReLU, Dropout from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge from torch_geometric.nn import global_mean_pool model = Sequential('x, edge_index, batch', [ (Dropout(p=0.5), 'x -> x'), (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'), ReLU(inplace=True), (GCNConv(64, 64), 'x1, edge_index -> x2'), ReLU(inplace=True), (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'), (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'), (global_mean_pool, 'x, batch -> x'), Linear(2 * 64, dataset.num_classes), ]) Args: input_args (str): The input arguments of the model. modules ([(Callable, str) or Callable]): A list of modules (with optional function header definitions). Alternatively, an :obj:`OrderedDict` of modules (and function header definitions) can be passed. """ _children: List[Child] def __init__( self, input_args: str, modules: List[Union[Tuple[Callable, str], Callable]], ) -> None: super().__init__() caller_path = inspect.stack()[1].filename self._caller_module = osp.splitext(osp.basename(caller_path))[0] _globals = copy.copy(globals()) _globals.update(sys.modules['__main__'].__dict__) if self._caller_module in sys.modules: _globals.update(sys.modules[self._caller_module].__dict__) signature = input_args.split('->') if len(signature) == 1: args_repr = signature[0] return_type_repr = 'Tensor' return_type = Tensor elif len(signature) == 2: args_repr = signature[0] return_type_repr = signature[1].strip() return_type = eval_type(return_type_repr, _globals) else: raise ValueError(f"Failed to parse arguments (got '{input_args}')") param_dict: Dict[str, Parameter] = {} for arg in split(args_repr, sep=','): signature = arg.split(':') if len(signature) == 1: name = signature[0].strip() param_dict[name] = Parameter( name=name, type=Tensor, type_repr='Tensor', default=inspect._empty, ) elif len(signature) == 2: name = signature[0].strip() param_dict[name] = Parameter( name=name, type=eval_type(signature[1].strip(), _globals), type_repr=signature[1].strip(), default=inspect._empty, ) else: raise ValueError(f"Failed to parse argument " f"(got '{arg.strip()}')") self.signature = Signature(param_dict, return_type, return_type_repr) if not isinstance(modules, dict): modules = { f'module_{i}': module for i, module in enumerate(modules) } if len(modules) == 0: raise ValueError(f"'{self.__class__.__name__}' expects a " f"non-empty list of modules") self._children: List[Child] = [] for i, (name, module) in enumerate(modules.items()): desc: Optional[str] = None if isinstance(module, (tuple, list)): if len(module) == 1: module = module[0] elif len(module) == 2: module, desc = module else: raise ValueError(f"Expected tuple of length 2 " f"(got {module})") if i == 0 and desc is None: raise ValueError("Signature for first module required") if not callable(module): raise ValueError(f"Expected callable module (got {module})") if desc is not None and not isinstance(desc, str): raise ValueError(f"Expected type hint representation " f"(got {desc})") if desc is not None: signature = desc.split('->') if len(signature) != 2: raise ValueError( f"Failed to parse arguments (got '{desc}')") param_names = [v.strip() for v in signature[0].split(',')] return_names = [v.strip() for v in signature[1].split(',')] child = Child(name, param_names, return_names) else: param_names = self._children[-1].return_names child = Child(name, param_names, param_names) setattr(self, name, module) self._children.append(child) self._set_jittable_template() def reset_parameters(self) -> None: r"""Resets all learnable parameters of the module.""" for child in self._children: module = getattr(self, child.name) if hasattr(module, 'reset_parameters'): module.reset_parameters() def __len__(self) -> int: return len(self._children) def __getitem__(self, idx: int) -> torch.nn.Module: return getattr(self, self._children[idx].name) def __setstate__(self, data: Dict[str, Any]) -> None: super().__setstate__(data) self._set_jittable_template() def __repr__(self) -> str: module_descs = [ f"{', '.join(c.param_names)} -> {', '.join(c.return_names)}" for c in self._children ] module_reprs = [ f' ({i}) - {self[i]}: {module_descs[i]}' for i in range(len(self)) ] return '{}(\n{}\n)'.format( self.__class__.__name__, '\n'.join(module_reprs), ) def forward(self, *args: Any, **kwargs: Any) -> Any: """""" # noqa: D419 value_dict = { name: arg for name, arg in zip(self.signature.param_dict.keys(), args) } for key, arg in kwargs.items(): if key in value_dict: raise TypeError(f"'{self.__class__.__name__}' got multiple " f"values for argument '{key}'") value_dict[key] = arg for child in self._children: args = [value_dict[name] for name in child.param_names] outs = getattr(self, child.name)(*args) if len(child.return_names) == 1: value_dict[child.return_names[0]] = outs else: for name, out in zip(child.return_names, outs): value_dict[name] = out return outs # TorchScript Support ##################################################### def _set_jittable_template(self, raise_on_error: bool = False) -> None: try: # Optimize `forward()` via `*.jinja` templates: if ('forward' in self.__class__.__dict__ and self.__class__.__dict__['forward'] != Sequential.forward): raise ValueError("Cannot compile custom 'forward' method") root_dir = osp.dirname(osp.realpath(__file__)) uid = '%06x' % random.randrange(16**6) jinja_prefix = f'{self.__module__}_{self.__class__.__name__}_{uid}' module = module_from_template( module_name=jinja_prefix, template_path=osp.join(root_dir, 'sequential.jinja'), tmp_dirname='sequential', # Keyword arguments: modules=[self._caller_module], signature=self.signature, children=self._children, ) self.forward = module.forward.__get__(self) # NOTE We override `forward` on the class level here in order to # support `torch.jit.trace` - this is generally dangerous to do, # and limits `torch.jit.trace` to a single `Sequential` module: self.__class__.forward = module.forward except Exception as e: # pragma: no cover if raise_on_error: raise e def __prepare_scriptable__(self) -> 'Sequential': # Prevent type sharing when scripting `Sequential` modules: type_store = torch.jit._recursive.concrete_type_store.type_store type_store.pop(self.__class__, None) return self ================================================ FILE: torch_geometric/nn/summary.py ================================================ from collections import defaultdict from typing import Any, List, Optional, Union import torch from torch.jit import ScriptModule from torch.nn import Module from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import is_uninitialized_parameter from torch_geometric.typing import SparseTensor def summary( model: torch.nn.Module, *args, max_depth: int = 3, leaf_module: Optional[Union[Module, List[Module]]] = 'MessagePassing', **kwargs, ) -> str: r"""Summarizes a given :class:`torch.nn.Module`. The summarized information includes (1) layer names, (2) input and output shapes, and (3) the number of parameters. .. code-block:: python import torch from torch_geometric.nn import GCN, summary model = GCN(128, 64, num_layers=2, out_channels=32) x = torch.randn(100, 128) edge_index = torch.randint(100, size=(2, 20)) print(summary(model, x, edge_index)) .. code-block:: +---------------------+---------------------+--------------+--------+ | Layer | Input Shape | Output Shape | #Param | |---------------------+---------------------+--------------+--------| | GCN | [100, 128], [2, 20] | [100, 32] | 10,336 | | ├─(act)ReLU | [100, 64] | [100, 64] | -- | | ├─(convs)ModuleList | -- | -- | 10,336 | | │ └─(0)GCNConv | [100, 128], [2, 20] | [100, 64] | 8,256 | | │ └─(1)GCNConv | [100, 64], [2, 20] | [100, 32] | 2,080 | +---------------------+---------------------+--------------+--------+ Args: model (torch.nn.Module): The model to summarize. *args: The arguments of the :obj:`model`. max_depth (int, optional): The depth of nested layers to display. Any layers deeper than this depth will not be displayed in the summary. (default: :obj:`3`) leaf_module (torch.nn.Module or [torch.nn.Module], optional): The modules to be treated as leaf modules, whose submodules are excluded from the summary. (default: :class:`~torch_geometric.nn.conv.MessagePassing`) **kwargs: Additional arguments of the :obj:`model`. """ # NOTE This is just for the doc-string to render nicely: if leaf_module == 'MessagePassing': leaf_module = MessagePassing def register_hook(info): def hook(module, inputs, output): info['input_shape'].append(get_shape(inputs)) info['output_shape'].append(get_shape(output)) return hook hooks = {} depth = 0 stack = [(model.__class__.__name__, model, depth)] info_list = [] input_shape = defaultdict(list) output_shape = defaultdict(list) while stack: name, module, depth = stack.pop() module_id = id(module) if name.startswith('(_'): # Do not summarize private modules. continue if module_id in hooks: # Avoid duplicated hooks. hooks[module_id].remove() info = {} info['name'] = name info['input_shape'] = input_shape[module_id] info['output_shape'] = output_shape[module_id] info['depth'] = depth if any([is_uninitialized_parameter(p) for p in module.parameters()]): info['#param'] = '-1' else: num_params = sum(p.numel() for p in module.parameters()) info['#param'] = f'{num_params:,}' if num_params > 0 else '--' info_list.append(info) if not isinstance(module, ScriptModule): hooks[module_id] = module.register_forward_hook( register_hook(info)) if depth >= max_depth: continue if (leaf_module is not None and isinstance(module, leaf_module)): continue module_items = reversed(module._modules.items()) stack += [(f"({name}){mod.__class__.__name__}", mod, depth + 1) for name, mod in module_items if mod is not None] training = model.training model.eval() with torch.no_grad(): model(*args, **kwargs) model.train(training) for h in hooks.values(): # Remove hooks. h.remove() info_list = postprocess(info_list) return make_table(info_list, max_depth=max_depth) def get_shape(inputs: Any) -> str: if not isinstance(inputs, (tuple, list)): inputs = (inputs, ) out = [] for x in inputs: if isinstance(x, SparseTensor): out.append(str(list(x.sizes()))) elif hasattr(x, 'size'): out.append(str(list(x.size()))) return ', '.join(out) def postprocess(info_list: List[dict]) -> List[dict]: for idx, info in enumerate(info_list): depth = info['depth'] if idx > 0: # root module (0) is excluded if depth == 1: prefix = '├─' else: prefix = f"{'│ '*(depth-1)}└─" info['name'] = prefix + info['name'] if info['input_shape']: info['input_shape'] = info['input_shape'].pop(0) info['output_shape'] = info['output_shape'].pop(0) else: info['input_shape'] = '--' info['output_shape'] = '--' return info_list def make_table(info_list: List[dict], max_depth: int) -> str: from tabulate import tabulate content = [['Layer', 'Input Shape', 'Output Shape', '#Param']] for info in info_list: content.append([ info['name'], info['input_shape'], info['output_shape'], info['#param'], ]) return tabulate(content, headers='firstrow', tablefmt='psql') ================================================ FILE: torch_geometric/nn/to_fixed_size_transformer.py ================================================ from typing import Any from torch.nn import Module from torch_geometric.nn.fx import Transformer try: from torch.fx import Graph, GraphModule, Node except (ImportError, ModuleNotFoundError, AttributeError): GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node' def to_fixed_size(module: Module, batch_size: int, debug: bool = False) -> GraphModule: r"""Converts a model and injects a pre-computed and fixed batch size to all global pooling operators. Args: module (torch.nn.Module): The model to transform. batch_size (int): The fixed batch size used in global pooling modules. debug (bool, optional): If set to :obj:`True`, will perform transformation in debug mode. (default: :obj:`False`) """ transformer = ToFixedSizeTransformer(module, batch_size, debug) return transformer.transform() class ToFixedSizeTransformer(Transformer): def __init__(self, module: Module, batch_size: int, debug: bool = False): super().__init__(module, debug=debug) self.batch_size = batch_size def call_global_pooling_module(self, node: Node, target: Any, name: str): kwargs = node.kwargs.copy() kwargs['dim_size'] = self.batch_size node.kwargs = kwargs ================================================ FILE: torch_geometric/nn/to_hetero_module.py ================================================ import copy import warnings from typing import Dict, List, Optional, Union import torch import torch.nn.functional as F from torch import Tensor import torch_geometric from torch_geometric import is_compiling from torch_geometric.typing import EdgeType, NodeType, OptTensor from torch_geometric.utils import cumsum, scatter class ToHeteroLinear(torch.nn.Module): def __init__( self, module: torch.nn.Module, types: Union[List[NodeType], List[EdgeType]], ): from torch_geometric.nn import HeteroLinear, Linear super().__init__() self.types = types if isinstance(module, Linear): in_channels = module.in_channels out_channels = module.out_channels bias = module.bias is not None elif isinstance(module, torch.nn.Linear): in_channels = module.in_features out_channels = module.out_features bias = module.bias is not None else: raise ValueError(f"Expected 'Linear' module (got '{type(module)}'") # TODO We currently assume that `x` is sorted according to `type`. self.hetero_module = HeteroLinear( in_channels, out_channels, num_types=len(types), is_sorted=True, bias=bias, ) def fused_forward(self, x: Tensor, type_vec: Tensor) -> Tensor: return self.hetero_module(x, type_vec) def dict_forward( self, x_dict: Dict[Union[NodeType, EdgeType], Tensor], ) -> Dict[Union[NodeType, EdgeType], Tensor]: if not torch_geometric.typing.WITH_PYG_LIB or is_compiling(): return { key: F.linear(x_dict[key], self.hetero_module.weight[i].t()) + self.hetero_module.bias[i] for i, key in enumerate(self.types) } x = torch.cat([x_dict[key] for key in self.types], dim=0) sizes = [x_dict[key].size(0) for key in self.types] type_vec = torch.arange(len(self.types), device=x.device) size = torch.tensor(sizes, device=x.device) type_vec = type_vec.repeat_interleave(size) outs = self.hetero_module(x, type_vec).split(sizes) return {key: out for key, out in zip(self.types, outs)} def forward( self, x: Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]], type_vec: Optional[Tensor] = None, ) -> Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]]: if isinstance(x, dict): return self.dict_forward(x) elif isinstance(x, Tensor) and type_vec is not None: return self.fused_forward(x, type_vec) raise ValueError(f"Encountered invalid forward types in " f"'{self.__class__.__name__}'") class ToHeteroMessagePassing(torch.nn.Module): def __init__( self, module: torch.nn.Module, node_types: List[NodeType], edge_types: List[NodeType], aggr: str = 'sum', ): from torch_geometric.nn import HeteroConv, MessagePassing super().__init__() self.node_types = node_types self.node_type_to_index = {key: i for i, key in enumerate(node_types)} self.edge_types = edge_types if not isinstance(module, MessagePassing): raise ValueError(f"Expected 'MessagePassing' module " f"(got '{type(module)}'") if (not hasattr(module, 'reset_parameters') and sum([p.numel() for p in module.parameters()]) > 0): warnings.warn( f"'{module}' will be duplicated, but its parameters " f"cannot be reset. To suppress this warning, add a " f"'reset_parameters()' method to '{module}'", stacklevel=2) convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types} self.hetero_module = HeteroConv(convs, aggr) self.hetero_module.reset_parameters() def fused_forward(self, x: Tensor, edge_index: Tensor, node_type: Tensor, edge_type: Tensor) -> Tensor: # TODO This currently does not fuse at all :( # TODO We currently assume that `x` and `edge_index` are both sorted # according to `type`. node_sizes = scatter(torch.ones_like(node_type), node_type, dim=0, dim_size=len(self.node_types), reduce='sum') edge_sizes = scatter(torch.ones_like(edge_type), edge_type, dim=0, dim_size=len(self.edge_types), reduce='sum') ptr = cumsum(node_sizes) xs = x.split(node_sizes.tolist()) x_dict = {node_type: x for node_type, x in zip(self.node_types, xs)} # TODO Consider out-sourcing to its own function. edge_indices = edge_index.clone().split(edge_sizes.tolist(), dim=1) for (src, _, dst), index in zip(self.edge_types, edge_indices): index[0] -= ptr[self.node_type_to_index[src]] index[1] -= ptr[self.node_type_to_index[dst]] edge_index_dict = { edge_type: edge_index for edge_type, edge_index in zip(self.edge_types, edge_indices) } out_dict = self.hetero_module(x_dict, edge_index_dict) return torch.cat([out_dict[key] for key in self.node_types], dim=0) def dict_forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Tensor], **kwargs, ) -> Dict[NodeType, Tensor]: return self.hetero_module(x_dict, edge_index_dict, **kwargs) def forward( self, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], node_type: OptTensor = None, edge_type: OptTensor = None, **kwargs, ) -> Union[Tensor, Dict[NodeType, Tensor]]: if isinstance(x, dict) and isinstance(edge_index, dict): return self.dict_forward(x, edge_index, **kwargs) elif (isinstance(x, Tensor) and isinstance(edge_index, Tensor) and node_type is not None and edge_type is not None): if len(kwargs) > 0: raise ValueError("Additional forward arguments not yet " "supported in fused mode") return self.fused_forward(x, edge_index, node_type, edge_type) raise ValueError(f"Encountered invalid forward types in " f"'{self.__class__.__name__}'") ================================================ FILE: torch_geometric/nn/to_hetero_transformer.py ================================================ import copy import warnings from collections import defaultdict, deque from typing import Any, Dict, Optional, Tuple, Union import torch from torch.nn import Module from torch_geometric.nn.dense.linear import is_uninitialized_parameter from torch_geometric.nn.fx import Transformer, get_submodule from torch_geometric.typing import EdgeType, Metadata, NodeType from torch_geometric.utils.hetero import ( check_add_self_loops, get_unused_node_types, ) try: from torch.fx import Graph, GraphModule, Node except (ImportError, ModuleNotFoundError, AttributeError): GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node' def get_dict(mapping: Optional[Dict[str, Any]]) -> Dict[str, Any]: return mapping if mapping is not None else {} def to_hetero(module: Module, metadata: Metadata, aggr: str = "sum", input_map: Optional[Dict[str, str]] = None, debug: bool = False) -> GraphModule: r"""Converts a homogeneous GNN model into its heterogeneous equivalent in which node representations are learned for each node type in :obj:`metadata[0]`, and messages are exchanged between each edge type in :obj:`metadata[1]`, as denoted in the `"Modeling Relational Data with Graph Convolutional Networks" `_ paper. .. code-block:: python import torch from torch_geometric.nn import SAGEConv, to_hetero class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv((-1, -1), 32) self.conv2 = SAGEConv((32, 32), 32) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return x model = GNN() node_types = ['paper', 'author'] edge_types = [ ('paper', 'cites', 'paper'), ('paper', 'written_by', 'author'), ('author', 'writes', 'paper'), ] metadata = (node_types, edge_types) model = to_hetero(model, metadata) model(x_dict, edge_index_dict) where :obj:`x_dict` and :obj:`edge_index_dict` denote dictionaries that hold node features and edge connectivity information for each node type and edge type, respectively. The below illustration shows the original computation graph of the homogeneous model on the left, and the newly obtained computation graph of the heterogeneous model on the right: .. figure:: ../_figures/to_hetero.svg :align: center :width: 90% Transforming a model via :func:`to_hetero`. Here, each :class:`~torch_geometric.nn.conv.MessagePassing` instance :math:`f_{\theta}^{(\ell)}` is duplicated and stored in a set :math:`\{ f_{\theta}^{(\ell, r)} : r \in \mathcal{R} \}` (one instance for each relation in :math:`\mathcal{R}`), and message passing in layer :math:`\ell` is performed via .. math:: \mathbf{h}^{(\ell)}_v = \bigoplus_{r \in \mathcal{R}} f_{\theta}^{(\ell, r)} ( \mathbf{h}^{(\ell - 1)}_v, \{ \mathbf{h}^{(\ell - 1)}_w : w \in \mathcal{N}^{(r)}(v) \}), where :math:`\mathcal{N}^{(r)}(v)` denotes the neighborhood of :math:`v \in \mathcal{V}` under relation :math:`r \in \mathcal{R}`, and :math:`\bigoplus` denotes the aggregation scheme :attr:`aggr` to use for grouping node embeddings generated by different relations (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"` or :obj:`"mul"`). Args: module (torch.nn.Module): The homogeneous model to transform. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. aggr (str, optional): The aggregation scheme to use for grouping node embeddings generated by different relations (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"sum"`) input_map (Dict[str, str], optional): A dictionary holding information about the type of input arguments of :obj:`module.forward`. For example, in case :obj:`arg` is a node-level argument, then :obj:`input_map['arg'] = 'node'`, and :obj:`input_map['arg'] = 'edge'` otherwise. In case :obj:`input_map` is not further specified, will try to automatically determine the correct type of input arguments. (default: :obj:`None`) debug (bool, optional): If set to :obj:`True`, will perform transformation in debug mode. (default: :obj:`False`) """ transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug) return transformer.transform() class ToHeteroTransformer(Transformer): aggrs = { 'sum': torch.add, # For 'mean' aggregation, we first sum up all feature matrices, and # divide by the number of matrices in a later step. 'mean': torch.add, 'max': torch.max, 'min': torch.min, 'mul': torch.mul, } def __init__( self, module: Module, metadata: Metadata, aggr: str = 'sum', input_map: Optional[Dict[str, str]] = None, debug: bool = False, ): super().__init__(module, input_map, debug) self.metadata = metadata self.aggr = aggr assert len(metadata) == 2 assert len(metadata[0]) > 0 and len(metadata[1]) > 0 assert aggr in self.aggrs.keys() self.validate() def validate(self): unused_node_types = get_unused_node_types(*self.metadata) if len(unused_node_types) > 0: warnings.warn( f"There exist node types ({unused_node_types}) whose " f"representations do not get updated during message passing " f"as they do not occur as destination type in any edge type. " f"This may lead to unexpected behavior.", stacklevel=2) names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]] for name in names: if not name.isidentifier(): warnings.warn( f"The type '{name}' contains invalid characters which " f"may lead to unexpected behavior. To avoid any issues, " f"ensure that your types only contain letters, numbers " f"and underscores.", stacklevel=2) def placeholder(self, node: Node, target: Any, name: str): # Adds a `get` call to the input dictionary for every node-type or # edge-type. if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] self.graph.inserting_after(node) dict_node = self.graph.create_node('call_function', target=get_dict, args=(node, ), name=f'{name}_dict') self.graph.inserting_after(dict_node) for key in self.metadata[int(self.is_edge_level(node))]: out = self.graph.create_node('call_method', target='get', args=(dict_node, key, None), name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def get_attr(self, node: Node, target: Any, name: str): raise NotImplementedError def call_message_passing_module(self, node: Node, target: Any, name: str): # Add calls to edge type-wise `MessagePassing` modules and aggregate # the outputs to node type-wise embeddings afterwards. module = get_submodule(self.module, target) check_add_self_loops(module, self.metadata[1]) # Group edge-wise keys per destination: key_name, keys_per_dst = {}, defaultdict(list) for key in self.metadata[1]: keys_per_dst[key[-1]].append(key) key_name[key] = f'{name}__{key[-1]}{len(keys_per_dst[key[-1]])}' for dst, keys in dict(keys_per_dst).items(): # In case there is only a single edge-wise connection, there is no # need for any destination-wise aggregation, and we can already set # the intermediate variable name to the final output name. if len(keys) == 1: key_name[keys[0]] = f'{name}__{dst}' del keys_per_dst[dst] self.graph.inserting_after(node) for key in self.metadata[1]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_module', target=f'{target}.{key2str(key)}', args=args, kwargs=kwargs, name=key_name[key]) self.graph.inserting_after(out) # Perform destination-wise aggregation. # Here, we aggregate in pairs, popping the first two elements of # `keys_per_dst` and append the result to the list. for dst, keys in keys_per_dst.items(): queue = deque([key_name[key] for key in keys]) i = 1 while len(queue) >= 2: key1, key2 = queue.popleft(), queue.popleft() args = (self.find_by_name(key1), self.find_by_name(key2)) new_name = f'{name}__{dst}' if self.aggr == 'mean' or len(queue) > 0: new_name = f'{new_name}_{i}' out = self.graph.create_node('call_function', target=self.aggrs[self.aggr], args=args, name=new_name) self.graph.inserting_after(out) queue.append(new_name) i += 1 if self.aggr == 'mean': key = queue.popleft() out = self.graph.create_node( 'call_function', target=torch.div, args=(self.find_by_name(key), len(keys_per_dst[dst])), name=f'{name}__{dst}') self.graph.inserting_after(out) def call_global_pooling_module(self, node: Node, target: Any, name: str): # Add calls to node type-wise `GlobalPooling` modules and aggregate # the outputs to graph type-wise embeddings afterwards. self.graph.inserting_after(node) for key in self.metadata[0]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_module', target=f'{target}.{key2str(key)}', args=args, kwargs=kwargs, name=f'{node.name}__{key2str(key)}') self.graph.inserting_after(out) # Perform node-wise aggregation. queue = deque( [f'{node.name}__{key2str(key)}' for key in self.metadata[0]]) i = 1 while len(queue) >= 2: key1, key2 = queue.popleft(), queue.popleft() args = (self.find_by_name(key1), self.find_by_name(key2)) out = self.graph.create_node('call_function', target=self.aggrs[self.aggr], args=args, name=f'{name}_{i}') self.graph.inserting_after(out) queue.append(f'{name}_{i}') i += 1 if self.aggr == 'mean': key = queue.popleft() out = self.graph.create_node( 'call_function', target=torch.div, args=(self.find_by_name(key), len(self.metadata[0])), name=f'{name}_{i}') self.graph.inserting_after(out) self.replace_all_uses_with(node, out) def call_module(self, node: Node, target: Any, name: str): if self.is_graph_level(node): return # Add calls to node type-wise or edge type-wise modules. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_module', target=f'{target}.{key2str(key)}', args=args, kwargs=kwargs, name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def call_method(self, node: Node, target: Any, name: str): if self.is_graph_level(node): return # Add calls to node type-wise or edge type-wise methods. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_method', target=target, args=args, kwargs=kwargs, name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def call_function(self, node: Node, target: Any, name: str): if self.is_graph_level(node): return # Add calls to node type-wise or edge type-wise functions. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_function', target=target, args=args, kwargs=kwargs, name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def output(self, node: Node, target: Any, name: str): # Replace the output by dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node): if self.is_graph_level(value): return value return { key: self.find_by_name(f'{value.name}__{key2str(key)}') for key in self.metadata[int(self.is_edge_level(value))] } elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] if self.is_node_level(output): node.type = Dict[NodeType, node.type] elif self.is_edge_level(output): node.type = Dict[EdgeType, node.type] else: node.type = None node.args = (_recurse(node.args[0]), ) def init_submodule(self, module: Module, target: str) -> Module: # Replicate each module for each node type or edge type. has_node_level_target = bool( self.find_by_target(f'{target}.{key2str(self.metadata[0][0])}')) has_edge_level_target = bool( self.find_by_target(f'{target}.{key2str(self.metadata[1][0])}')) if not has_node_level_target and not has_edge_level_target: return module module_dict = torch.nn.ModuleDict() for key in self.metadata[int(has_edge_level_target)]: module_dict[key2str(key)] = copy.deepcopy(module) if len(self.metadata[int(has_edge_level_target)]) <= 1: continue if hasattr(module, 'reset_parameters'): module_dict[key2str(key)].reset_parameters() elif sum([ is_uninitialized_parameter(p) or p.numel() for p in module.parameters() ]) > 0: warnings.warn( f"'{target}' will be duplicated, but its parameters " f"cannot be reset. To suppress this warning, add a " f"'reset_parameters()' method to '{target}'", stacklevel=2) return module_dict # Helper methods ########################################################## def map_args_kwargs(self, node: Node, key: Union[NodeType, EdgeType]) -> Tuple[Tuple, Dict]: def _recurse(value: Any) -> Any: if isinstance(value, Node): out = self.find_by_name(f'{value.name}__{key2str(key)}') if out is not None: return out elif isinstance(key, tuple) and key[0] == key[-1]: name = f'{value.name}__{key2str(key[0])}' return self.find_by_name(name) elif isinstance(key, tuple) and key[0] != key[-1]: return ( self.find_by_name(f'{value.name}__{key2str(key[0])}'), self.find_by_name(f'{value.name}__{key2str(key[-1])}'), ) else: raise ValueError(f"Cannot generate a graph node '{node}' " f"for type '{key}' since it does not " f"exist. Please make sure that all " f"node types get updated during message " f"passing.") elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value args = tuple(_recurse(v) for v in node.args) kwargs = {k: _recurse(v) for k, v in node.kwargs.items()} return args, kwargs def key2str(key: Union[NodeType, EdgeType]) -> str: key = '__'.join(key) if isinstance(key, tuple) else key return key.replace(' ', '_').replace('-', '_').replace(':', '_') ================================================ FILE: torch_geometric/nn/to_hetero_with_bases_transformer.py ================================================ import copy import warnings from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor from torch.nn import Module, Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense import Linear from torch_geometric.nn.fx import Transformer from torch_geometric.typing import EdgeType, Metadata, NodeType, SparseTensor from torch_geometric.utils.hetero import get_unused_node_types try: from torch.fx import Graph, GraphModule, Node except (ImportError, ModuleNotFoundError, AttributeError): GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node' def to_hetero_with_bases(module: Module, metadata: Metadata, num_bases: int, in_channels: Optional[Dict[str, int]] = None, input_map: Optional[Dict[str, str]] = None, debug: bool = False) -> GraphModule: r"""Converts a homogeneous GNN model into its heterogeneous equivalent via the basis-decomposition technique introduced in the `"Modeling Relational Data with Graph Convolutional Networks" `_ paper. For this, the heterogeneous graph is mapped to a typed homogeneous graph, in which its feature representations are aligned and grouped to a single representation. All GNN layers inside the model will then perform message passing via basis-decomposition regularization. This transformation is especially useful in highly multi-relational data, such that the number of parameters no longer depend on the number of relations of the input graph: .. code-block:: python import torch from torch_geometric.nn import SAGEConv, to_hetero_with_bases class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = SAGEConv((16, 16), 32) self.conv2 = SAGEConv((32, 32), 32) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return x model = GNN() node_types = ['paper', 'author'] edge_types = [ ('paper', 'cites', 'paper'), ('paper', 'written_by', 'author'), ('author', 'writes', 'paper'), ] metadata = (node_types, edge_types) model = to_hetero_with_bases(model, metadata, num_bases=3, in_channels={'x': 16}) model(x_dict, edge_index_dict) where :obj:`x_dict` and :obj:`edge_index_dict` denote dictionaries that hold node features and edge connectivity information for each node type and edge type, respectively. In case :obj:`in_channels` is given for a specific input argument, its heterogeneous feature information is first aligned to the given dimensionality. The below illustration shows the original computation graph of the homogeneous model on the left, and the newly obtained computation graph of the regularized heterogeneous model on the right: .. figure:: ../_figures/to_hetero_with_bases.svg :align: center :width: 90% Transforming a model via :func:`to_hetero_with_bases`. Here, each :class:`~torch_geometric.nn.conv.MessagePassing` instance :math:`f_{\theta}^{(\ell)}` is duplicated :obj:`num_bases` times and stored in a set :math:`\{ f_{\theta}^{(\ell, b)} : b \in \{ 1, \ldots, B \} \}` (one instance for each basis in :obj:`num_bases`), and message passing in layer :math:`\ell` is performed via .. math:: \mathbf{h}^{(\ell)}_v = \sum_{r \in \mathcal{R}} \sum_{b=1}^B f_{\theta}^{(\ell, b)} ( \mathbf{h}^{(\ell - 1)}_v, \{ a^{(\ell)}_{r, b} \cdot \mathbf{h}^{(\ell - 1)}_w : w \in \mathcal{N}^{(r)}(v) \}), where :math:`\mathcal{N}^{(r)}(v)` denotes the neighborhood of :math:`v \in \mathcal{V}` under relation :math:`r \in \mathcal{R}`. Notably, only the trainable basis coefficients :math:`a^{(\ell)}_{r, b}` depend on the relations in :math:`\mathcal{R}`. Args: module (torch.nn.Module): The homogeneous model to transform. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. num_bases (int): The number of bases to use. in_channels (Dict[str, int], optional): A dictionary holding information about the desired input feature dimensionality of input arguments of :obj:`module.forward`. In case :obj:`in_channels` is given for a specific input argument, its heterogeneous feature information is first aligned to the given dimensionality. This allows handling of node and edge features with varying feature dimensionality across different types. (default: :obj:`None`) input_map (Dict[str, str], optional): A dictionary holding information about the type of input arguments of :obj:`module.forward`. For example, in case :obj:`arg` is a node-level argument, then :obj:`input_map['arg'] = 'node'`, and :obj:`input_map['arg'] = 'edge'` otherwise. In case :obj:`input_map` is not further specified, will try to automatically determine the correct type of input arguments. (default: :obj:`None`) debug (bool, optional): If set to :obj:`True`, will perform transformation in debug mode. (default: :obj:`False`) """ transformer = ToHeteroWithBasesTransformer(module, metadata, num_bases, in_channels, input_map, debug) return transformer.transform() class ToHeteroWithBasesTransformer(Transformer): def __init__( self, module: Module, metadata: Metadata, num_bases: int, in_channels: Optional[Dict[str, int]] = None, input_map: Optional[Dict[str, str]] = None, debug: bool = False, ): super().__init__(module, input_map, debug) self.metadata = metadata self.num_bases = num_bases self.in_channels = in_channels or {} assert len(metadata) == 2 assert len(metadata[0]) > 0 and len(metadata[1]) > 0 self.validate() # Compute IDs for each node and edge type: self.node_type2id = {k: i for i, k in enumerate(metadata[0])} self.edge_type2id = {k: i for i, k in enumerate(metadata[1])} def validate(self): unused_node_types = get_unused_node_types(*self.metadata) if len(unused_node_types) > 0: warnings.warn( f"There exist node types ({unused_node_types}) whose " f"representations do not get updated during message passing " f"as they do not occur as destination type in any edge type. " f"This may lead to unexpected behavior.", stacklevel=2) names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]] for name in names: if not name.isidentifier(): warnings.warn( f"The type '{name}' contains invalid characters which " f"may lead to unexpected behavior. To avoid any issues, " f"ensure that your types only contain letters, numbers " f"and underscores.", stacklevel=2) def transform(self) -> GraphModule: self._node_offset_dict_initialized = False self._edge_offset_dict_initialized = False self._edge_type_initialized = False out = super().transform() del self._node_offset_dict_initialized del self._edge_offset_dict_initialized del self._edge_type_initialized return out def placeholder(self, node: Node, target: Any, name: str): if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] out = node # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case # they are not yet initialized. These dictionaries hold the cumulated # sizes used to create a unified graph representation and to split the # output data. if self.is_edge_level(node) and not self._edge_offset_dict_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_edge_offset_dict, args=(node, self.edge_type2id), name='edge_offset_dict') self._edge_offset_dict_initialized = True elif not self._node_offset_dict_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_node_offset_dict, args=(node, self.node_type2id), name='node_offset_dict') self._node_offset_dict_initialized = True # Create a `edge_type` tensor used as input to `HeteroBasisConv`: if self.is_edge_level(node) and not self._edge_type_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_edge_type, args=(node, self.edge_type2id), name='edge_type') self._edge_type_initialized = True # Add `Linear` operation to align features to the same dimensionality: if name in self.in_channels: self.graph.inserting_after(out) out = self.graph.create_node('call_module', target=f'align_lin__{name}', args=(node, ), name=f'{name}__aligned') self._state[out.name] = self._state[name] lin = LinearAlign(self.metadata[int(self.is_edge_level(node))], self.in_channels[name]) setattr(self.module, f'align_lin__{name}', lin) # Perform grouping of type-wise values into a single tensor: if self.is_edge_level(node): self.graph.inserting_after(out) out = self.graph.create_node( 'call_function', target=group_edge_placeholder, args=(out if name in self.in_channels else node, self.edge_type2id, self.find_by_name('node_offset_dict')), name=f'{name}__grouped') self._state[out.name] = 'edge' else: self.graph.inserting_after(out) out = self.graph.create_node( 'call_function', target=group_node_placeholder, args=(out if name in self.in_channels else node, self.node_type2id), name=f'{name}__grouped') self._state[out.name] = 'node' self.replace_all_uses_with(node, out) def call_message_passing_module(self, node: Node, target: Any, name: str): # Call the `HeteroBasisConv` wrapper instead instead of a single # message passing layer. We need to inject the `edge_type` as first # argument in order to do so. node.args = (self.find_by_name('edge_type'), ) + node.args def output(self, node: Node, target: Any, name: str): # Split the output to dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node) and self.is_edge_level(value): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('edge_offset_dict')), name=f'{value.name}__split') elif isinstance(value, Node): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('node_offset_dict')), name=f'{value.name}__split') elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] Type = EdgeType if self.is_edge_level(output) else NodeType node.type = Dict[Type, node.type] else: node.type = None node.args = (_recurse(node.args[0]), ) def init_submodule(self, module: Module, target: str) -> Module: if not isinstance(module, MessagePassing): return module # Replace each `MessagePassing` module by a `HeteroBasisConv` wrapper: return HeteroBasisConv(module, len(self.metadata[1]), self.num_bases) ############################################################################### # We make use of a post-message computation hook to inject the # basis re-weighting for each individual edge type. # This currently requires us to set `conv.fuse = False`, which leads # to a materialization of messages. def hook(module, inputs, output): assert isinstance(module._edge_type, Tensor) if module._edge_type.size(0) != output.size(-2): raise ValueError( f"Number of messages ({output.size(0)}) does not match " f"with the number of original edges " f"({module._edge_type.size(0)}). Does your message " f"passing layer create additional self-loops? Try to " f"remove them via 'add_self_loops=False'") weight = module.edge_type_weight.view(-1)[module._edge_type] weight = weight.view([1] * (output.dim() - 2) + [-1, 1]) return weight * output class HeteroBasisConv(torch.nn.Module): # A wrapper layer that applies the basis-decomposition technique to a # heterogeneous graph. def __init__(self, module: MessagePassing, num_relations: int, num_bases: int): super().__init__() self.num_relations = num_relations self.num_bases = num_bases params = list(module.parameters()) device = params[0].device if len(params) > 0 else 'cpu' self.convs = torch.nn.ModuleList() for _ in range(num_bases): conv = copy.deepcopy(module) conv.fuse = False # Disable `message_and_aggregate` functionality. # We learn a single scalar weight for each individual edge type, # which is used to weight the output message based on edge type: conv.edge_type_weight = Parameter( torch.empty(1, num_relations, device=device)) conv.register_message_forward_hook(hook) self.convs.append(conv) if self.num_bases > 1: self.reset_parameters() def reset_parameters(self): for conv in self.convs: if hasattr(conv, 'reset_parameters'): conv.reset_parameters() elif sum([p.numel() for p in conv.parameters()]) > 0: warnings.warn( f"'{conv}' will be duplicated, but its parameters cannot " f"be reset. To suppress this warning, add a " f"'reset_parameters()' method to '{conv}'", stacklevel=2) torch.nn.init.xavier_uniform_(conv.edge_type_weight) def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor: out = None # Call message passing modules and perform aggregation: for conv in self.convs: conv._edge_type = edge_type res = conv(*args, **kwargs) del conv._edge_type out = res if out is None else out.add_(res) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_relations=' f'{self.num_relations}, num_bases={self.num_bases})') class LinearAlign(torch.nn.Module): # Aligns representations to the same dimensionality. Note that this will # create lazy modules, and as such requires a forward pass in order to # initialize parameters. def __init__(self, keys: List[Union[NodeType, EdgeType]], out_channels: int): super().__init__() self.out_channels = out_channels self.lins = torch.nn.ModuleDict() for key in keys: self.lins[key2str(key)] = Linear(-1, out_channels, bias=False) def forward( self, x_dict: Dict[Union[NodeType, EdgeType], Tensor] ) -> Dict[Union[NodeType, EdgeType], Tensor]: return {key: self.lins[key2str(key)](x) for key, x in x_dict.items()} def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_relations={len(self.lins)}, ' f'out_channels={self.out_channels})') ############################################################################### # These methods are used in order to receive the cumulated sizes of input # dictionaries. We make use of them for creating a unified homogeneous graph # representation, as well as to split the final output data once again. def get_node_offset_dict( input_dict: Dict[NodeType, Union[Tensor, SparseTensor]], type2id: Dict[NodeType, int], ) -> Dict[NodeType, int]: cumsum = 0 out: Dict[NodeType, int] = {} for key in type2id.keys(): out[key] = cumsum cumsum += input_dict[key].size(-2) return out def get_edge_offset_dict( input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]], type2id: Dict[EdgeType, int], ) -> Dict[EdgeType, int]: cumsum = 0 out: Dict[EdgeType, int] = {} for key in type2id.keys(): out[key] = cumsum value = input_dict[key] if isinstance(value, SparseTensor): cumsum += value.nnz() elif value.dtype == torch.long and value.size(0) == 2: cumsum += value.size(-1) else: cumsum += value.size(-2) return out ############################################################################### # This method computes the edge type of the final homogeneous graph # representation. It will be used in the `HeteroBasisConv` wrapper. def get_edge_type( input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]], type2id: Dict[EdgeType, int], ) -> Tensor: inputs = [input_dict[key] for key in type2id.keys()] outs = [] for i, value in enumerate(inputs): if value.size(0) == 2 and value.dtype == torch.long: # edge_index out = value.new_full((value.size(-1), ), i, dtype=torch.long) elif isinstance(value, SparseTensor): out = torch.full((value.nnz(), ), i, dtype=torch.long, device=value.device()) else: out = value.new_full((value.size(-2), ), i, dtype=torch.long) outs.append(out) return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0) ############################################################################### # These methods are used to group the individual type-wise components into a # unified single representation. def group_node_placeholder(input_dict: Dict[NodeType, Tensor], type2id: Dict[NodeType, int]) -> Tensor: inputs = [input_dict[key] for key in type2id.keys()] return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=-2) def group_edge_placeholder( input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]], type2id: Dict[EdgeType, int], offset_dict: Dict[NodeType, int] = None, ) -> Union[Tensor, SparseTensor]: inputs = [input_dict[key] for key in type2id.keys()] if len(inputs) == 1: return inputs[0] # In case of grouping a graph connectivity tensor `edge_index` or `adj_t`, # we need to increment its indices: elif inputs[0].size(0) == 2 and inputs[0].dtype == torch.long: if offset_dict is None: raise AttributeError( "Can not infer node-level offsets. Please ensure that there " "exists a node-level argument before the 'edge_index' " "argument in your forward header.") outputs = [] for value, (src_type, _, dst_type) in zip(inputs, type2id): value = value.clone() value[0, :] += offset_dict[src_type] value[1, :] += offset_dict[dst_type] outputs.append(value) return torch.cat(outputs, dim=-1) elif isinstance(inputs[0], SparseTensor): if offset_dict is None: raise AttributeError( "Can not infer node-level offsets. Please ensure that there " "exists a node-level argument before the 'SparseTensor' " "argument in your forward header.") # For grouping a list of SparseTensors, we convert them into a # unified `edge_index` representation in order to avoid conflicts # induced by re-shuffling the data. rows, cols = [], [] for value, (src_type, _, dst_type) in zip(inputs, type2id): col, row, value = value.coo() assert value is None rows.append(row + offset_dict[src_type]) cols.append(col + offset_dict[dst_type]) row = torch.cat(rows, dim=0) col = torch.cat(cols, dim=0) return torch.stack([row, col], dim=0) else: return torch.cat(inputs, dim=-2) ############################################################################### # This method is used to split the output tensors into individual type-wise # components: def split_output( output: Tensor, offset_dict: Union[Dict[NodeType, int], Dict[EdgeType, int]], ) -> Union[Dict[NodeType, Tensor], Dict[EdgeType, Tensor]]: cumsums = list(offset_dict.values()) + [output.size(-2)] sizes = [cumsums[i + 1] - cumsums[i] for i in range(len(offset_dict))] outputs = output.split(sizes, dim=-2) return {key: output for key, output in zip(offset_dict, outputs)} ############################################################################### def key2str(key: Union[NodeType, EdgeType]) -> str: key = '__'.join(key) if isinstance(key, tuple) else key return key.replace(' ', '_').replace('-', '_').replace(':', '_') ================================================ FILE: torch_geometric/nn/unpool/__init__.py ================================================ r"""Unpooling package.""" from .knn_interpolate import knn_interpolate __all__ = [ 'knn_interpolate', ] classes = __all__ ================================================ FILE: torch_geometric/nn/unpool/knn_interpolate.py ================================================ import torch from torch_geometric.nn import knn from torch_geometric.typing import OptTensor from torch_geometric.utils import scatter def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, batch_x: OptTensor = None, batch_y: OptTensor = None, k: int = 3, num_workers: int = 1): r"""The k-NN interpolation from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" `_ paper. For each point :math:`y` with position :math:`\mathbf{p}(y)`, its interpolated features :math:`\mathbf{f}(y)` are given by .. math:: \mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y), \mathbf{p}(x_i))^2} and :math:`\{ x_1, \ldots, x_k \}` denoting the :math:`k` nearest points to :math:`y`. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. pos_x (torch.Tensor): Node position matrix :math:`\in \mathbb{R}^{N \times d}`. pos_y (torch.Tensor): Upsampled node position matrix :math:`\in \mathbb{R}^{M \times d}`. batch_x (torch.Tensor, optional): Batch vector :math:`\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node from :math:`\mathbf{X}` to a specific example. (default: :obj:`None`) batch_y (torch.Tensor, optional): Batch vector :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node from :math:`\mathbf{Y}` to a specific example. (default: :obj:`None`) k (int, optional): Number of neighbors. (default: :obj:`3`) num_workers (int, optional): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) """ with torch.no_grad(): assign_index = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y, num_workers=num_workers) y_idx, x_idx = assign_index[0], assign_index[1] diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') y = y / scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') return y ================================================ FILE: torch_geometric/profile/__init__.py ================================================ r"""GNN profiling package.""" from .benchmark import benchmark from .profile import ( get_stats_summary, print_time_total, profileit, rename_profile_file, timeit, torch_profile, trace_handler, xpu_profile, ) from .utils import ( count_parameters, get_cpu_memory_from_gc, get_data_size, get_gpu_memory_from_gc, get_gpu_memory_from_ipex, get_gpu_memory_from_nvidia_smi, get_model_size, ) from .nvtx import nvtxit __all__ = [ 'profileit', 'timeit', 'get_stats_summary', 'trace_handler', 'print_time_total', 'rename_profile_file', 'torch_profile', 'xpu_profile', 'count_parameters', 'get_model_size', 'get_data_size', 'get_cpu_memory_from_gc', 'get_gpu_memory_from_gc', 'get_gpu_memory_from_nvidia_smi', 'get_gpu_memory_from_ipex', 'benchmark', 'nvtxit', ] classes = __all__ ================================================ FILE: torch_geometric/profile/benchmark.py ================================================ import time from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.utils import is_torch_sparse_tensor def require_grad(x: Any, requires_grad: bool = True) -> Any: if (isinstance(x, Tensor) and x.is_floating_point() and not is_torch_sparse_tensor(x)): return x.detach().requires_grad_(requires_grad) elif isinstance(x, list): return [require_grad(v, requires_grad) for v in x] elif isinstance(x, tuple): return tuple(require_grad(v, requires_grad) for v in x) elif isinstance(x, dict): return {k: require_grad(v, requires_grad) for k, v in x.items()} return x def benchmark( funcs: List[Callable], args: Union[Tuple[Any], List[Tuple[Any]]], num_steps: int, func_names: Optional[List[str]] = None, num_warmups: int = 10, backward: bool = False, per_step: bool = False, progress_bar: bool = False, ): r"""Benchmark a list of functions :obj:`funcs` that receive the same set of arguments :obj:`args`. Args: funcs ([Callable]): The list of functions to benchmark. args ((Any, ) or [(Any, )]): The arguments to pass to the functions. Can be a list of arguments for each function in :obj:`funcs` in case their headers differ. Alternatively, you can pass in functions that generate arguments on-the-fly (e.g., useful for benchmarking models on various sizes). num_steps (int): The number of steps to run the benchmark. func_names ([str], optional): The names of the functions. If not given, will try to infer the name from the function itself. (default: :obj:`None`) num_warmups (int, optional): The number of warmup steps. (default: :obj:`10`) backward (bool, optional): If set to :obj:`True`, will benchmark both forward and backward passes. (default: :obj:`False`) per_step (bool, optional): If set to :obj:`True`, will report runtimes per step. (default: :obj:`False`) progress_bar (bool, optional): If set to :obj:`True`, will print a progress bar during benchmarking. (default: :obj:`False`) """ from tabulate import tabulate if num_steps <= 0: raise ValueError(f"'num_steps' must be a positive integer " f"(got {num_steps})") if num_warmups <= 0: raise ValueError(f"'num_warmups' must be a positive integer " f"(got {num_warmups})") if func_names is None: func_names = [get_func_name(func) for func in funcs] if len(funcs) != len(func_names): raise ValueError(f"Length of 'funcs' (got {len(funcs)}) and " f"'func_names' (got {len(func_names)}) must be equal") # Zero-copy `args` for each function (if necessary): args_list = [args] * len(funcs) if not isinstance(args, list) else args iterator = zip(funcs, args_list, func_names) if progress_bar: from tqdm import tqdm iterator = tqdm(iterator, total=len(funcs)) ts: List[List[str]] = [] for func, inputs, name in iterator: t_forward = t_backward = 0 for i in range(num_warmups + num_steps): args = inputs() if callable(inputs) else inputs args = require_grad(args, backward) if torch.cuda.is_available(): torch.cuda.synchronize() t_start = time.perf_counter() out = func(*args) if torch.cuda.is_available(): torch.cuda.synchronize() if i >= num_warmups: t_forward += time.perf_counter() - t_start if backward: if isinstance(out, (tuple, list)): out = sum(o.sum() for o in out if isinstance(o, Tensor)) elif isinstance(out, dict): out = out.values() out = sum(o.sum() for o in out if isinstance(o, Tensor)) out_grad = torch.randn_like(out) t_start = time.perf_counter() out.backward(out_grad) if torch.cuda.is_available(): torch.cuda.synchronize() if i >= num_warmups: t_backward += time.perf_counter() - t_start if per_step: ts.append([name, f'{t_forward/num_steps:.6f}s']) else: ts.append([name, f'{t_forward:.4f}s']) if backward: if per_step: ts[-1].append(f'{t_backward/num_steps:.6f}s') ts[-1].append(f'{(t_forward + t_backward)/num_steps:.6f}s') else: ts[-1].append(f'{t_backward:.4f}s') ts[-1].append(f'{t_forward + t_backward:.4f}s') header = ['Name', 'Forward'] if backward: header.extend(['Backward', 'Total']) print(tabulate(ts, headers=header, tablefmt='psql')) def get_func_name(func: Callable) -> str: if hasattr(func, '__name__'): return func.__name__ elif hasattr(func, '__class__'): return func.__class__.__name__ raise ValueError("Could not infer name for function '{func}'") ================================================ FILE: torch_geometric/profile/nvtx.py ================================================ from functools import wraps from typing import Optional import torch CUDA_PROFILE_STARTED = False def begin_cuda_profile(): global CUDA_PROFILE_STARTED prev_state = CUDA_PROFILE_STARTED if prev_state is False: CUDA_PROFILE_STARTED = True torch.cuda.cudart().cudaProfilerStart() return prev_state def end_cuda_profile(prev_state: bool): global CUDA_PROFILE_STARTED CUDA_PROFILE_STARTED = prev_state if prev_state is False: torch.cuda.cudart().cudaProfilerStop() def nvtxit(name: Optional[str] = None, n_warmups: int = 0, n_iters: Optional[int] = None): """Enables NVTX profiling for a function. Args: name (Optional[str], optional): Name to give the reference frame for the function being wrapped. Defaults to the name of the function in code. n_warmups (int, optional): Number of iters to call that function before starting. Defaults to 0. n_iters (Optional[int], optional): Number of iters of that function to record. Defaults to all of them. """ def nvtx(func): nonlocal name iters_so_far = 0 if name is None: name = func.__name__ @wraps(func) def wrapper(*args, **kwargs): nonlocal iters_so_far if not torch.cuda.is_available(): return func(*args, **kwargs) elif iters_so_far < n_warmups: iters_so_far += 1 return func(*args, **kwargs) elif n_iters is None or iters_so_far < n_iters + n_warmups: prev_state = begin_cuda_profile() torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}") result = func(*args, **kwargs) torch.cuda.nvtx.range_pop() end_cuda_profile(prev_state) iters_so_far += 1 return result else: return func(*args, **kwargs) return wrapper return nvtx ================================================ FILE: torch_geometric/profile/profile.py ================================================ import os import pathlib import time from contextlib import ContextDecorator, contextmanager from dataclasses import dataclass from typing import Any, List, Tuple, Union import torch from torch.autograd.profiler import EventList from torch.profiler import ProfilerActivity, profile from torch_geometric.profile.utils import ( byte_to_megabyte, get_gpu_memory_from_ipex, get_gpu_memory_from_nvidia_smi, ) @dataclass class GPUStats: time: float max_allocated_gpu: float max_reserved_gpu: float max_active_gpu: float @dataclass class CUDAStats(GPUStats): nvidia_smi_free_cuda: float nvidia_smi_used_cuda: float @dataclass class GPUStatsSummary: time_mean: float time_std: float max_allocated_gpu: float max_reserved_gpu: float max_active_gpu: float @dataclass class CUDAStatsSummary(GPUStatsSummary): min_nvidia_smi_free_cuda: float max_nvidia_smi_used_cuda: float def profileit(device: str): # pragma: no cover r"""A decorator to facilitate profiling a function, *e.g.*, obtaining training runtime and memory statistics of a specific model on a specific dataset. Returns a :obj:`GPUStats` if :obj:`device` is :obj:`xpu` or extended object :obj:`CUDAStats`, if :obj:`device` is :obj:`cuda`. Args: device (str): Target device for profiling. Options are: :obj:`cuda` and obj:`xpu`. .. code-block:: python @profileit("cuda") def train(model, optimizer, x, edge_index, y): optimizer.zero_grad() out = model(x, edge_index) loss = criterion(out, y) loss.backward() optimizer.step() return float(loss) loss, stats = train(model, x, edge_index, y) """ def decorator(func): def wrapper( *args, **kwargs ) -> Union[Tuple[Any, GPUStats], Tuple[Any, CUDAStats]]: model = args[0] if not isinstance(model, torch.nn.Module): raise AttributeError( 'First argument for profiling needs to be torch.nn.Module') if device not in ['cuda', 'xpu']: raise AttributeError( "The profiling decorator supports only CUDA and " "XPU devices") device_id = None for arg in list(args) + list(kwargs.values()): if isinstance(arg, torch.Tensor): device_id = arg.get_device() break if device_id is None: raise AttributeError( "Could not infer GPU device from the args in the " "function being profiled") if device_id == -1: raise RuntimeError( "The profiling decorator does not support profiling " "on non GPU devices") is_cuda = device == 'cuda' torch_gpu = torch.cuda if is_cuda else torch.xpu # `pytorch_memlab` supports only CUDA devices if is_cuda: from pytorch_memlab import LineProfiler # Init `pytorch_memlab` for analyzing the model forward pass: line_profiler = LineProfiler(target_gpu=device_id) line_profiler.enable() line_profiler.add_function(args[0].forward) start = torch_gpu.Event(enable_timing=True) end = torch_gpu.Event(enable_timing=True) start.record() out = func(*args, **kwargs) end.record() torch_gpu.synchronize() time = start.elapsed_time(end) / 1000 if is_cuda: # Get the global memory statistics collected # by `pytorch_memlab`: memlab = read_from_memlab(line_profiler) max_allocated, max_reserved, max_active = memlab line_profiler.disable() # Get additional information from `nvidia-smi`: free_cuda, used_cuda = get_gpu_memory_from_nvidia_smi( device=device_id) stats = CUDAStats(time, max_allocated, max_reserved, max_active, free_cuda, used_cuda) return out, stats else: stats = GPUStats(time, *get_gpu_memory_from_ipex(device_id)) return out, stats return wrapper return decorator class timeit(ContextDecorator): r"""A context decorator to facilitate timing a function, *e.g.*, obtaining the runtime of a specific model on a specific dataset. .. code-block:: python @torch.no_grad() def test(model, x, edge_index): return model(x, edge_index) with timeit() as t: z = test(model, x, edge_index) time = t.duration Args: log (bool, optional): If set to :obj:`False`, will not log any runtime to the console. (default: :obj:`True`) avg_time_divisor (int, optional): If set to a value greater than :obj:`1`, will divide the total time by this value. Useful for calculating the average of runtimes within a for-loop. (default: :obj:`0`) """ def __init__(self, log: bool = True, avg_time_divisor: int = 0): self.log = log self.avg_time_divisor = avg_time_divisor def __enter__(self): if torch.cuda.is_available(): torch.cuda.synchronize() self.t_start = time.time() return self def __exit__(self, *args): if torch.cuda.is_available(): torch.cuda.synchronize() self.t_end = time.time() self.duration = self.t_end - self.t_start if self.avg_time_divisor > 1: self.duration = self.duration / self.avg_time_divisor if self.log: # pragma: no cover print(f'Time: {self.duration:.8f}s', flush=True) def reset(self): r"""Prints the duration and resets current timer.""" if self.t_start is None: raise RuntimeError("Timer wasn't started.") else: self.__exit__() self.__enter__() def get_stats_summary( stats_list: Union[List[GPUStats], List[CUDAStats]] ) -> Union[GPUStatsSummary, CUDAStatsSummary]: # pragma: no cover r"""Creates a summary of collected runtime and memory statistics. Returns a :obj:`GPUStatsSummary` if list of :obj:`GPUStats` was passed, otherwise (list of :obj:`CUDAStats` was passed), returns a :obj:`CUDAStatsSummary`. Args: stats_list (Union[List[GPUStats], List[CUDAStats]]): A list of :obj:`GPUStats` or :obj:`CUDAStats` objects, as returned by :meth:`~torch_geometric.profile.profileit`. """ # calculate common statistics kwargs = dict( time_mean=float(torch.tensor([s.time for s in stats_list]).mean()), time_std=float(torch.tensor([s.time for s in stats_list]).std()), max_allocated_gpu=max([s.max_allocated_gpu for s in stats_list]), max_reserved_gpu=max([s.max_reserved_gpu for s in stats_list]), max_active_gpu=max([s.max_active_gpu for s in stats_list])) if all(isinstance(s, CUDAStats) for s in stats_list): return CUDAStatsSummary( **kwargs, min_nvidia_smi_free_cuda=min( [s.nvidia_smi_free_cuda for s in stats_list]), max_nvidia_smi_used_cuda=max( [s.nvidia_smi_used_cuda for s in stats_list]), ) else: return GPUStatsSummary(**kwargs) ############################################################################### def read_from_memlab(line_profiler: Any) -> List[float]: # pragma: no cover from pytorch_memlab.line_profiler.line_records import LineRecords # See: https://pytorch.org/docs/stable/cuda.html#torch.cuda.memory_stats track_stats = [ # Different statistic can be collected as needed. 'allocated_bytes.all.peak', 'reserved_bytes.all.peak', 'active_bytes.all.peak', ] records = LineRecords(line_profiler._raw_line_records, line_profiler._code_infos) stats = records.display(None, track_stats)._line_records return [byte_to_megabyte(x) for x in stats.values.max(axis=0).tolist()] def trace_handler(p): print_time_total(p) profile_dir = str(pathlib.Path.cwd()) + '/' timeline_file = profile_dir + 'timeline' + '.json' p.export_chrome_trace(timeline_file) def print_time_total(p): if torch.cuda.is_available(): profile_sort = 'self_cuda_time_total' else: profile_sort = 'self_cpu_time_total' output = p.key_averages().table(sort_by=profile_sort) print(output) def rename_profile_file(*args): profile_dir = str(pathlib.Path.cwd()) + '/' timeline_file = profile_dir + 'profile' for arg in args: timeline_file += '-' + arg timeline_file += '.json' os.rename('timeline.json', timeline_file) @contextmanager def torch_profile(export_chrome_trace=True, csv_data=None, write_csv=None): use_cuda = torch.cuda.is_available() activities = [ProfilerActivity.CPU] if use_cuda: activities.append(ProfilerActivity.CUDA) if export_chrome_trace: p_trace_handler = trace_handler else: p_trace_handler = print_time_total p = profile(activities=activities, on_trace_ready=p_trace_handler) with p: yield p.step() if csv_data is not None and write_csv == 'prof': if use_cuda: profile_sort = 'self_cuda_time_total' else: profile_sort = 'self_cpu_time_total' events = EventList( sorted( p.key_averages(), key=lambda evt: getattr(evt, profile_sort), reverse=True, ), use_cuda=use_cuda) save_profile_data(csv_data, events, use_cuda) @contextmanager def xpu_profile(export_chrome_trace=True): with torch.autograd.profiler_legacy.profile(use_xpu=True) as profile: yield print(profile.key_averages().table(sort_by='self_xpu_time_total')) if export_chrome_trace: profile.export_chrome_trace('timeline.json') def format_prof_time(time): # Profile time is in micro seconds, so format it appropriately: return round(time / 1e6, 3) def save_profile_data(csv_data, events, use_cuda): sum_self_cpu_time_total = sum( [event.self_cpu_time_total for event in events]) sum_cpu_time_total = sum([event.self_cpu_time_total for event in events]) sum_self_cuda_time_total = sum( [event.self_cuda_time_total for event in events]) if use_cuda else 0 for e in events[:5]: # Save top 5 most time consuming operations: csv_data['NAME'].append(e.key) csv_data['SELF CPU %'].append( round(e.self_cpu_time_total * 100.0 / sum_self_cpu_time_total, 3)) csv_data['SELF CPU'].append(format_prof_time(e.self_cpu_time_total)) csv_data['CPU TOTAL %'].append( round(e.cpu_time_total * 100.0 / sum_cpu_time_total, 3)) csv_data['CPU TOTAL'].append(format_prof_time(e.cpu_time_total)) csv_data['CPU TIME AVG'].append(format_prof_time(e.cpu_time_total)) if use_cuda: csv_data['SELF CUDA %'].append(e.self_cuda_time_total * 100.0 / sum_self_cuda_time_total) csv_data['SELF CUDA'].append( format_prof_time(e.self_cuda_time_total)) csv_data['CUDA TOTAL'].append(format_prof_time(e.cpu_time_total)) csv_data['CUDA TIME AVG'].append(format_prof_time( e.cpu_time_total)) csv_data['# OF CALLS'].append(e.count) ================================================ FILE: torch_geometric/profile/profiler.py ================================================ import functools from collections import OrderedDict, defaultdict, namedtuple from typing import Any, List, NamedTuple, Optional, Tuple import torch import torch.profiler as torch_profiler import torch_geometric.typing # predefined namedtuple for variable setting (global template) Trace = namedtuple('Trace', ['path', 'leaf', 'module']) # the metrics returned from the torch profiler Measure = namedtuple('Measure', [ 'self_cpu_total', 'cpu_total', 'self_cuda_total', 'cuda_total', 'self_cpu_memory', 'cpu_memory', 'self_cuda_memory', 'cuda_memory', 'occurrences', ]) class Profiler: r"""Layer by layer profiling of PyTorch models, using the PyTorch profiler for memory profiling. Parts of the code are adapted from :obj:`torchprof` for layer-wise grouping. Args: model (torch.nn.Module): The underlying model to be profiled. enabled (bool, optional): If set to :obj:`True`, turn on the profiler. (default: :obj:`False`) use_cuda (bool, optional): Whether to profile CUDA execution. (default: :obj:`False`) profile_memory (bool, optional): If set to :obj:`True`, also profile memory usage. (default: :obj:`False`) paths ([str], optional): Pre-defined paths for fast loading. (default: :obj:`None`) """ def __init__( self, model: torch.nn.Module, enabled: bool = True, use_cuda: bool = False, profile_memory: bool = False, paths: Optional[List[str]] = None, ): self._model = model self.enabled = enabled self.use_cuda = use_cuda self.profile_memory = profile_memory self.paths = paths self.entered = False self.exited = False self.traces = () self._ids = set() self.trace_profile_events = defaultdict(list) def __enter__(self): if not self.enabled: return self if self.entered: raise RuntimeError("the profiler can be initialized only once") self.entered = True self._forwards = {} # store the original forward functions # generate the trace and conduct profiling self.traces = tuple(map(self._hook_trace, _walk_modules(self._model))) return self def __exit__(self, exc_type, exc_val, exc_tb): if not self.enabled: return tuple(map(self._remove_hook_trace, self.traces)) del self._forwards # remove unnecessary forwards self.exited = True def get_trace(self): return _layer_trace(self.traces, self.trace_profile_events) def __repr__(self) -> str: return self.get_trace()[0] def __call__(self, *args, **kwargs): return self._model(*args, **kwargs) def _hook_trace(self, trace): """Add hooks to torch modules for profiling. The underlying model's forward pass is hooked/decorated here. """ [path, leaf, module] = trace # the id of the model is guaranteed to be unique _id = id(module) if (self.paths is not None and path in self.paths) or (self.paths is None and leaf): if _id in self._ids: # already wrapped return trace self._ids.add(_id) _forward = module.forward self._forwards[path] = _forward @functools.wraps(_forward) def wrap_forward(*args, **kwargs): """The forward pass is decorated and profiled here.""" # only torch 1.8.1+ is supported torch_version = torch.__version__ if torch_version <= '1.8.1': raise NotImplementedError( "Profiler requires at least torch 1.8.1") activities = [torch.profiler.ProfilerActivity.CPU] if self.use_cuda: activities.append(torch.profiler.ProfilerActivity.CUDA) with torch_profiler.profile( activities=activities, profile_memory=self.profile_memory, ) as prof: res = _forward(*args, **kwargs) event_list = prof.events() # each profile call should be contained in its own list self.trace_profile_events[path].append(event_list) return res # decorate the underlying model's forward pass module.forward = wrap_forward return trace def _remove_hook_trace(self, trace): """Clean it up after the profiling is done.""" [path, leaf, module] = trace _id = id(module) if _id in self._ids: self._ids.discard(_id) else: return if (self.paths is not None and path in self.paths) or (self.paths is None and leaf): module.forward = self._forwards[path] def _layer_trace( traces: NamedTuple, trace_events: Any, show_events: bool = True, paths: List[str] = None, use_cuda: bool = False, profile_memory: bool = False, dt: Tuple[str, ...] = ('-', '-', '-', ' '), ) -> object: """Construct human readable output of the profiler traces and events. The information is presented in layers, and each layer contains its underlying operators. Args: traces (trace object): Raw trace to be parsed. trace_events (trace object): Raw events to be parsed. show_events (bool, optional): If True, show detailed event information. (default: :obj:`True`) paths (str, optional): Predefine path for fast loading. By default, it will not be used. (default: :obj:`False`) use_cuda (bool, optional): Enables timing of CUDA events. (default: :obj:`False`) profile_memory (bool, optional): If True, also profile for the memory usage information. (default: :obj:`False`) dt (object, optional): Delimiters for showing the events. """ tree = OrderedDict() for trace in traces: [path, leaf, module] = trace current_tree = tree # unwrap all of the events, in case model is called multiple times events = [te for t_events in trace_events[path] for te in t_events] for depth, name in enumerate(path, 1): if name not in current_tree: current_tree[name] = OrderedDict() if depth == len(path) and ((paths is None and leaf) or (paths is not None and path in paths)): # tree measurements have key None, avoiding name conflict if show_events: for event_name, event_group in _group_by( events, lambda e: e.name): event_group = list(event_group) current_tree[name][event_name] = { None: _build_measure_tuple(event_group, len(event_group)) } else: current_tree[name][None] = _build_measure_tuple( events, len(trace_events[path])) current_tree = current_tree[name] tree_lines = _flatten_tree(tree) format_lines = [] has_self_cuda_total = False has_self_cpu_memory = False has_cpu_memory = False has_self_cuda_memory = False has_cuda_memory = False raw_results = {} for idx, tree_line in enumerate(tree_lines): depth, name, measures = tree_line next_depths = [pl[0] for pl in tree_lines[idx + 1:]] pre = "-" if depth > 0: pre = dt[1] if depth in next_depths and next_depths[ 0] >= depth else dt[2] depth -= 1 while depth > 0: pre = (dt[0] + pre) if depth in next_depths else (dt[3] + pre) depth -= 1 format_lines.append([pre + name, *_format_measure_tuple(measures)]) if measures: has_self_cuda_total = (has_self_cuda_total or measures.self_cuda_total is not None) has_self_cpu_memory = (has_self_cpu_memory or measures.self_cpu_memory is not None) has_cpu_memory = has_cpu_memory or measures.cpu_memory is not None has_self_cuda_memory = (has_self_cuda_memory or measures.self_cuda_memory is not None) has_cuda_memory = (has_cuda_memory or measures.cuda_memory is not None) raw_results[name] = [ measures.self_cpu_total, measures.cpu_total, measures.self_cuda_total, measures.cuda_total, measures.self_cpu_memory, measures.cpu_memory, measures.self_cuda_memory, measures.cuda_memory, measures.occurrences ] # construct the table (this is pretty ugly and can probably be optimized) heading = ( "Module", "Self CPU total", "CPU total", "Self CUDA total", "CUDA total", "Self CPU Mem", "CPU Mem", "Self CUDA Mem", "CUDA Mem", "Number of Calls", ) # get the output aligned max_lens = [max(map(len, col)) for col in zip(*([heading] + format_lines))] # not all columns should be displayed, specify kept indexes keep_indexes = [0, 1, 2, 9] if profile_memory: if has_self_cpu_memory: keep_indexes.append(5) if has_cpu_memory: keep_indexes.append(6) if use_cuda: if has_self_cuda_total: keep_indexes.append(3) keep_indexes.append(4) if profile_memory: if has_self_cuda_memory: keep_indexes.append(7) if has_cuda_memory: keep_indexes.append(8) # the final columns to be shown keep_indexes = tuple(sorted(keep_indexes)) heading_list = list(heading) display = ( # table heading " | ".join([ "{:<{}s}".format(heading[keep_index], max_lens[keep_index]) for keep_index in keep_indexes ]) + "\n") display += ( # separator "-|-".join([ "-" * max_len for val_idx, max_len in enumerate(max_lens) if val_idx in keep_indexes ]) + "\n") for format_line in format_lines: # body display += (" | ".join([ "{:<{}s}".format(value, max_lens[val_idx]) for val_idx, value in enumerate(format_line) if val_idx in keep_indexes ]) + "\n") # layer information readable key_dict = {} layer_names = [] layer_stats = [] for format_line in format_lines: # body if format_line[1] == '': # key line key_dict[format_line[0].count("-")] = format_line[0] else: # must print # get current line's level curr_level = format_line[0].count("-") par_str = "" for i in range(1, curr_level): par_str += key_dict[i] curr_key = par_str + format_line[0] layer_names.append(curr_key) layer_stats.append(format_line[1:]) return display, heading_list, raw_results, layer_names, layer_stats def _flatten_tree(t, depth=0): flat = [] for name, st in t.items(): measures = st.pop(None, None) flat.append([depth, name, measures]) flat.extend(_flatten_tree(st, depth=depth + 1)) return flat def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple: device_str = 'device' if torch_geometric.typing.WITH_PT24 else 'cuda' # memory profiling supported in torch >= 1.6 self_cpu_memory = None has_self_cpu_memory = any( hasattr(e, "self_cpu_memory_usage") for e in events) if has_self_cpu_memory: self_cpu_memory = sum( [getattr(e, "self_cpu_memory_usage", 0) or 0 for e in events]) cpu_memory = None has_cpu_memory = any(hasattr(e, "cpu_memory_usage") for e in events) if has_cpu_memory: cpu_memory = sum( [getattr(e, "cpu_memory_usage", 0) or 0 for e in events]) self_cuda_memory = None has_self_cuda_memory = any( hasattr(e, f"self_{device_str}_memory_usage") for e in events) if has_self_cuda_memory: self_cuda_memory = sum([ getattr(e, f"self_{device_str}_memory_usage", 0) or 0 for e in events ]) cuda_memory = None has_cuda_memory = any( hasattr(e, f"{device_str}_memory_usage") for e in events) if has_cuda_memory: cuda_memory = sum( [getattr(e, f"{device_str}_memory_usage", 0) or 0 for e in events]) # self CUDA time supported in torch >= 1.7 self_cuda_total = None has_self_cuda_time = any( hasattr(e, f"self_{device_str}_time_total") for e in events) if has_self_cuda_time: self_cuda_total = sum([ getattr(e, f"self_{device_str}_time_total", 0) or 0 for e in events ]) return Measure( self_cpu_total=sum([e.self_cpu_time_total or 0 for e in events]), cpu_total=sum([e.cpu_time_total or 0 for e in events]), self_cuda_total=self_cuda_total, cuda_total=sum( [getattr(e, f"{device_str}_time_total") or 0 for e in events]), self_cpu_memory=self_cpu_memory, cpu_memory=cpu_memory, self_cuda_memory=self_cuda_memory, cuda_memory=cuda_memory, occurrences=occurrences, ) def _format_measure_tuple(measure: NamedTuple) -> NamedTuple: self_cpu_total = (format_time(measure.self_cpu_total) if measure else "") cpu_total = format_time(measure.cpu_total) if measure else "" self_cuda_total = (format_time(measure.self_cuda_total) if measure and measure.self_cuda_total is not None else "") cuda_total = format_time(measure.cuda_total) if measure else "" self_cpu_memory = (format_memory(measure.self_cpu_memory) if measure and measure.self_cpu_memory is not None else "") cpu_memory = (format_memory(measure.cpu_memory) if measure and measure.cpu_memory is not None else "") self_cuda_memory = (format_memory(measure.self_cuda_memory) if measure and measure.self_cuda_memory is not None else "") cuda_memory = (format_memory(measure.cuda_memory) if measure and measure.cuda_memory is not None else "") occurrences = str(measure.occurrences) if measure else "" return Measure( self_cpu_total=self_cpu_total, cpu_total=cpu_total, self_cuda_total=self_cuda_total, cuda_total=cuda_total, self_cpu_memory=self_cpu_memory, cpu_memory=cpu_memory, self_cuda_memory=self_cuda_memory, cuda_memory=cuda_memory, occurrences=occurrences, ) def _group_by(events, keyfn): event_groups = OrderedDict() for event in events: key = keyfn(event) key_events = event_groups.get(key, []) key_events.append(event) event_groups[key] = key_events return event_groups.items() def _walk_modules(module, name: str = "", path=()): # Walk through a PyTorch model and output trace tuples (its path, leafe # node, model). if not name: name = module.__class__.__name__ # This will track the children of the module (layers) # for instance, [('conv1', GCNConv(10, 16)), ('conv2', GCNConv(16, 3))] named_children = list(module.named_children()) # it builds the path of the structure # for instance, ('GCN', 'conv1', 'lin') path = path + (name, ) # create namedtuple [path, (whether has) leaf, module] yield Trace(path, len(named_children) == 0, module) # recursively walk into all submodules for name, child_module in named_children: yield from _walk_modules(child_module, name=name, path=path) def format_time(time_us: int) -> str: r"""Returns a formatted time string.""" US_IN_SECOND = 1000.0 * 1000.0 US_IN_MS = 1000.0 if time_us >= US_IN_SECOND: return f'{time_us / US_IN_SECOND:.3f}s' if time_us >= US_IN_MS: return f'{time_us / US_IN_MS:.3f}ms' return f'{time_us:.3f}us' def format_memory(nbytes: int) -> str: """Returns a formatted memory size string.""" KB = 1024 MB = 1024 * KB GB = 1024 * MB if (abs(nbytes) >= GB): return f'{nbytes * 1.0 / GB:.2f} Gb' elif (abs(nbytes) >= MB): return f'{nbytes * 1.0 / MB:.2f} Mb' elif (abs(nbytes) >= KB): return f'{nbytes * 1.0 / KB:.2f} Kb' else: return str(nbytes) + ' b' ================================================ FILE: torch_geometric/profile/utils.py ================================================ import gc import os import os.path as osp import random import subprocess as sp import sys import warnings from collections.abc import Mapping, Sequence from typing import Any, Tuple import torch from torch import Tensor from torch_geometric.data.data import BaseData from torch_geometric.typing import SparseTensor def count_parameters(model: torch.nn.Module) -> int: r"""Given a :class:`torch.nn.Module`, count its trainable parameters. Args: model (torch.nn.Model): The model. """ return sum([p.numel() for p in model.parameters() if p.requires_grad]) def get_model_size(model: torch.nn.Module) -> int: r"""Given a :class:`torch.nn.Module`, get its actual disk size in bytes. Args: model (torch model): The model. """ path = f'{random.randrange(sys.maxsize)}.pt' torch.save(model.state_dict(), path) model_size = osp.getsize(path) os.remove(path) return model_size def get_data_size(data: BaseData) -> int: r"""Given a :class:`torch_geometric.data.Data` object, get its theoretical memory usage in bytes. Args: data (torch_geometric.data.Data or torch_geometric.data.HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. """ data_ptrs = set() def _get_size(obj: Any) -> int: if isinstance(obj, Tensor): if obj.data_ptr() in data_ptrs: return 0 data_ptrs.add(obj.data_ptr()) return obj.numel() * obj.element_size() elif isinstance(obj, SparseTensor): return _get_size(obj.csr()) elif isinstance(obj, Sequence) and not isinstance(obj, str): return sum([_get_size(x) for x in obj]) elif isinstance(obj, Mapping): return sum([_get_size(x) for x in obj.values()]) else: return 0 return sum([_get_size(store) for store in data.stores]) def get_cpu_memory_from_gc() -> int: r"""Returns the used CPU memory in bytes, as reported by the :python:`Python` garbage collector. """ warnings.filterwarnings('ignore', '.*torch.distributed.reduce_op.*') mem = 0 for obj in gc.get_objects(): try: if isinstance(obj, Tensor) and not obj.is_cuda: mem += obj.numel() * obj.element_size() except Exception: pass return mem def get_gpu_memory_from_gc(device: int = 0) -> int: # pragma: no cover r"""Returns the used GPU memory in bytes, as reported by the :python:`Python` garbage collector. Args: device (int, optional): The GPU device identifier. (default: :obj:`1`) """ warnings.filterwarnings('ignore', '.*torch.distributed.reduce_op.*') mem = 0 for obj in gc.get_objects(): try: if isinstance(obj, Tensor) and obj.get_device() == device: mem += obj.numel() * obj.element_size() except Exception: pass return mem def get_gpu_memory_from_nvidia_smi( # pragma: no cover device: int = 0, digits: int = 2, ) -> Tuple[float, float]: r"""Returns the free and used GPU memory in megabytes, as reported by :obj:`nivdia-smi`. .. note:: :obj:`nvidia-smi` will generally overestimate the amount of memory used by the actual program, see `here `__. Args: device (int, optional): The GPU device identifier. (default: :obj:`1`) digits (int): The number of decimals to use for megabytes. (default: :obj:`2`) """ def parse_memory(output: str) -> list: lines = output.decode('utf-8').split('\n')[1:-1] mem_list = [] for line in lines: try: mem_list.append(int(line.split()[0])) except (TypeError, ValueError): mem_list.append(None) return mem_list def get_gpu_memory(out_device, digits): if out_device is None: return 0 return medibyte_to_megabyte(out_device, digits) CMD = 'nvidia-smi --query-gpu=memory.free --format=csv' free_out = parse_memory(sp.check_output(CMD.split())) CMD = 'nvidia-smi --query-gpu=memory.used --format=csv' used_out = parse_memory(sp.check_output(CMD.split())) if device < 0 or device >= len(free_out): raise AttributeError( f'GPU {device} not available (found {len(free_out)} GPUs)') free_mem = get_gpu_memory(free_out[device], digits) used_mem = get_gpu_memory(used_out[device], digits) return free_mem, used_mem def get_gpu_memory_from_ipex( device: int = 0, digits=2) -> Tuple[float, float, float]: # pragma: no cover r"""Returns the XPU memory statistics. Args: device (int, optional): The GPU device identifier. (default: :obj:`0`) digits (int): The number of decimals to use for megabytes. (default: :obj:`2`) """ import intel_extension_for_pytorch as ipex stats = ipex.xpu.memory_stats_as_nested_dict(device) max_allocated = stats['allocated_bytes']['all']['peak'] max_reserved = stats['reserved_bytes']['all']['peak'] max_active = stats['active_bytes']['all']['peak'] max_allocated = byte_to_megabyte(max_allocated, digits) max_reserved = byte_to_megabyte(max_reserved, digits) max_active = byte_to_megabyte(max_active, digits) ipex.xpu.reset_peak_memory_stats(device) return max_allocated, max_reserved, max_active ############################################################################### def byte_to_megabyte(value: int, digits: int = 2) -> float: return round(value / (1024 * 1024), digits) def medibyte_to_megabyte(value: int, digits: int = 2) -> float: return round(1.0485 * value, digits) ================================================ FILE: torch_geometric/resolver.py ================================================ import inspect from typing import Any, Dict, List, Optional, Union def normalize_string(s: str) -> str: return s.lower().replace('-', '').replace('_', '').replace(' ', '') def resolver( classes: List[Any], class_dict: Dict[str, Any], query: Union[Any, str], base_cls: Optional[Any], base_cls_repr: Optional[str], *args: Any, **kwargs: Any, ) -> Any: if not isinstance(query, str): return query query_repr = normalize_string(query) if base_cls_repr is None: base_cls_repr = base_cls.__name__ if base_cls else '' base_cls_repr = normalize_string(base_cls_repr) for key_repr, cls in class_dict.items(): if query_repr == key_repr: if inspect.isclass(cls): obj = cls(*args, **kwargs) return obj return cls for cls in classes: cls_repr = normalize_string(cls.__name__) if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]: if inspect.isclass(cls): obj = cls(*args, **kwargs) return obj return cls choices = {cls.__name__ for cls in classes} | set(class_dict.keys()) raise ValueError(f"Could not resolve '{query}' among choices {choices}") ================================================ FILE: torch_geometric/sampler/__init__.py ================================================ r"""Graph sampler package.""" from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput, SamplerOutput, HeteroSamplerOutput, NegativeSampling, NumNeighbors) from .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler from .hgt_sampler import HGTSampler __all__ = classes = [ 'BaseSampler', 'NodeSamplerInput', 'EdgeSamplerInput', 'SamplerOutput', 'HeteroSamplerOutput', 'NumNeighbors', 'NegativeSampling', 'NeighborSampler', 'BidirectionalNeighborSampler', 'HGTSampler', ] ================================================ FILE: torch_geometric/sampler/base.py ================================================ import copy import math import warnings from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union import torch from torch import Tensor from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData from torch_geometric.sampler.utils import ( global_to_local_node_idx, local_to_global_node_idx, to_bidirectional, unique_unsorted, ) from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor from torch_geometric.utils.mixin import CastMixin class DataType(Enum): r"""The data type a sampler is operating on.""" homogeneous = 'homogeneous' heterogeneous = 'heterogeneous' remote = 'remote' @classmethod def from_data(cls, data: Any): if isinstance(data, Data): return cls.homogeneous elif isinstance(data, HeteroData): return cls.heterogeneous elif (isinstance(data, (list, tuple)) and len(data) == 2 and isinstance(data[0], FeatureStore) and isinstance(data[1], GraphStore)): return cls.remote raise ValueError(f"Expected a 'Data', 'HeteroData', or a tuple of " f"'FeatureStore' and 'GraphStore' " f"(got '{type(data)}')") class SubgraphType(Enum): r"""The type of the returned subgraph.""" directional = 'directional' bidirectional = 'bidirectional' induced = 'induced' @dataclass(init=False) class NodeSamplerInput(CastMixin): r"""The sampling input of :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes`. Args: input_id (torch.Tensor, optional): The indices of the data loader input of the current mini-batch. node (torch.Tensor): The indices of seed nodes to start sampling from. time (torch.Tensor, optional): The timestamp for the seed nodes. (default: :obj:`None`) input_type (str, optional): The input node type (in case of sampling in a heterogeneous graph). (default: :obj:`None`) """ input_id: OptTensor node: Tensor time: OptTensor = None input_type: Optional[NodeType] = None def __init__( self, input_id: OptTensor, node: Tensor, time: OptTensor = None, input_type: Optional[NodeType] = None, ): if input_id is not None: input_id = input_id.cpu() node = node.cpu() if time is not None: time = time.cpu() self.input_id = input_id self.node = node self.time = time self.input_type = input_type def __getitem__(self, index: Union[Tensor, Any]) -> 'NodeSamplerInput': if not isinstance(index, Tensor): index = torch.tensor(index, dtype=torch.long) return NodeSamplerInput( self.input_id[index] if self.input_id is not None else index, self.node[index], self.time[index] if self.time is not None else None, self.input_type, ) @dataclass(init=False) class EdgeSamplerInput(CastMixin): r"""The sampling input of :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`. Args: input_id (torch.Tensor, optional): The indices of the data loader input of the current mini-batch. row (torch.Tensor): The source node indices of seed links to start sampling from. col (torch.Tensor): The destination node indices of seed links to start sampling from. label (torch.Tensor, optional): The label for the seed links. (default: :obj:`None`) time (torch.Tensor, optional): The timestamp for the seed links. (default: :obj:`None`) input_type (Tuple[str, str, str], optional): The input edge type (in case of sampling in a heterogeneous graph). (default: :obj:`None`) """ input_id: OptTensor row: Tensor col: Tensor label: OptTensor = None time: OptTensor = None input_type: Optional[EdgeType] = None def __init__( self, input_id: OptTensor, row: Tensor, col: Tensor, label: OptTensor = None, time: OptTensor = None, input_type: Optional[EdgeType] = None, ): if input_id is not None: input_id = input_id.cpu() row = row.clone().cpu() col = col.clone().cpu() if label is not None: label = label.cpu() if time is not None: time = time.cpu() self.input_id = input_id self.row = row self.col = col self.label = label self.time = time self.input_type = input_type def __getitem__(self, index: Union[Tensor, Any]) -> 'EdgeSamplerInput': if not isinstance(index, Tensor): index = torch.tensor(index, dtype=torch.long) return EdgeSamplerInput( self.input_id[index] if self.input_id is not None else index, self.row[index], self.col[index], self.label[index] if self.label is not None else None, self.time[index] if self.time is not None else None, self.input_type, ) @dataclass class SamplerOutput(CastMixin): r"""The sampling output of a :class:`~torch_geometric.sampler.BaseSampler` on homogeneous graphs. Args: node (torch.Tensor): The sampled nodes in the original graph. row (torch.Tensor): The source node indices of the sampled subgraph. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor. col (torch.Tensor): The destination node indices of the sampled subgraph. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor. edge (torch.Tensor, optional): The sampled edges in the original graph. This tensor is used to obtain edge features from the original graph. If no edge attributes are present, it may be omitted. batch (torch.Tensor, optional): The vector to identify the seed node for each sampled node. Can be present in case of disjoint subgraph sampling per seed node. (default: :obj:`None`) num_sampled_nodes (List[int], optional): The number of sampled nodes per hop. (default: :obj:`None`) num_sampled_edges (List[int], optional): The number of sampled edges per hop. (default: :obj:`None`) orig_row (torch.Tensor, optional): The original source node indices returned by the sampler. Filled in case :meth:`to_bidirectional` is called with the :obj:`keep_orig_edges` option. (default: :obj:`None`) orig_col (torch.Tensor, optional): The original destination node indices indices returned by the sampler. Filled in case :meth:`to_bidirectional` is called with the :obj:`keep_orig_edges` option. (default: :obj:`None`) metadata: (Any, optional): Additional metadata information. (default: :obj:`None`) """ node: Tensor row: Tensor col: Tensor edge: OptTensor batch: OptTensor = None num_sampled_nodes: Optional[List[int]] = None num_sampled_edges: Optional[List[int]] = None orig_row: Tensor = None orig_col: Tensor = None # TODO(manan): refine this further; it does not currently define a proper # API for the expected output of a sampler. metadata: Optional[Any] = None _seed_node: OptTensor = field(repr=False, default=None) @property def global_row(self) -> Tensor: return local_to_global_node_idx(self.node, self.row) @property def global_col(self) -> Tensor: return local_to_global_node_idx(self.node, self.col) @property def seed_node(self) -> Tensor: # can be set manually if the seed nodes are not contained in the # sampled nodes if self._seed_node is None: self._seed_node = local_to_global_node_idx( self.node, self.batch) if self.batch is not None else None return self._seed_node @seed_node.setter def seed_node(self, value: Tensor): assert len(value) == len(self.node) self._seed_node = value @property def global_orig_row(self) -> Tensor: return local_to_global_node_idx( self.node, self.orig_row) if self.orig_row is not None else None @property def global_orig_col(self) -> Tensor: return local_to_global_node_idx( self.node, self.orig_col) if self.orig_col is not None else None def to_bidirectional( self, keep_orig_edges: bool = False, ) -> 'SamplerOutput': r"""Converts the sampled subgraph into a bidirectional variant, in which all sampled edges are guaranteed to be bidirectional. Args: keep_orig_edges (bool, optional): If specified, directional edges are still maintained. (default: :obj:`False`) """ out = copy.copy(self) if keep_orig_edges: out.orig_row = self.row out.orig_col = self.col else: out.num_sampled_nodes = out.num_sampled_edges = None out.row, out.col, out.edge = to_bidirectional( row=self.row, col=self.col, rev_row=self.row, rev_col=self.col, edge_id=self.edge, rev_edge_id=self.edge, ) return out @classmethod def collate(cls, outputs: List['SamplerOutput'], replace: bool = True) -> 'SamplerOutput': r"""Collate a list of :class:`~torch_geometric.sampler.SamplerOutput` objects into a single :class:`~torch_geometric.sampler.SamplerOutput` object. Requires that they all have the same fields. """ if len(outputs) == 0: raise ValueError("Cannot collate an empty list of SamplerOutputs") out = outputs[0] has_edge = out.edge is not None has_orig_row = out.orig_row is not None has_orig_col = out.orig_col is not None has_batch = out.batch is not None has_num_sampled_nodes = out.num_sampled_nodes is not None has_num_sampled_edges = out.num_sampled_edges is not None try: for i, sample_output in enumerate(outputs): # noqa assert not has_edge == (sample_output.edge is None) assert not has_orig_row == (sample_output.orig_row is None) assert not has_orig_col == (sample_output.orig_col is None) assert not has_batch == (sample_output.batch is None) assert not has_num_sampled_nodes == ( sample_output.num_sampled_nodes is None) assert not has_num_sampled_edges == ( sample_output.num_sampled_edges is None) except AssertionError: error_str = f"Output {i+1} has a different field than the first output" # noqa raise ValueError(error_str) # noqa for other in outputs[1:]: out = out.merge_with(other, replace=replace) return out def merge_with(self, other: 'SamplerOutput', replace: bool = True) -> 'SamplerOutput': """Merges two SamplerOutputs. If replace is True, self's nodes and edges take precedence. """ if not replace: return SamplerOutput( node=torch.cat([self.node, other.node], dim=0), row=torch.cat([self.row, len(self.node) + other.row], dim=0), col=torch.cat([self.col, len(self.node) + other.col], dim=0), edge=torch.cat([self.edge, other.edge], dim=0) if self.edge is not None and other.edge is not None else None, batch=torch.cat( [self.batch, len(self.node) + other.batch], dim=0) if self.batch is not None and other.batch is not None else None, num_sampled_nodes=self.num_sampled_nodes + other.num_sampled_nodes if self.num_sampled_nodes is not None and other.num_sampled_nodes is not None else None, num_sampled_edges=self.num_sampled_edges + other.num_sampled_edges if self.num_sampled_edges is not None and other.num_sampled_edges is not None else None, orig_row=torch.cat( [self.orig_row, len(self.node) + other.orig_row], dim=0) if self.orig_row is not None and other.orig_row is not None else None, orig_col=torch.cat( [self.orig_col, len(self.node) + other.orig_col], dim=0) if self.orig_col is not None and other.orig_col is not None else None, metadata=[self.metadata, other.metadata], ) else: # NODES old_nodes, new_nodes = self.node, other.node old_node_uid, new_node_uid = [old_nodes], [new_nodes] # batch tracks disjoint subgraph samplings if self.batch is not None and other.batch is not None: # Transform the batch indices to be global node ids old_batch_nodes = self.seed_node new_batch_nodes = other.seed_node old_node_uid.append(old_batch_nodes) new_node_uid.append(new_batch_nodes) # NOTE: if any new node fields are added, # they need to be merged here old_node_uid = torch.stack(old_node_uid, dim=1) new_node_uid = torch.stack(new_node_uid, dim=1) merged_node_uid = unique_unsorted( torch.cat([old_node_uid, new_node_uid], dim=0)) num_old_nodes = old_node_uid.shape[0] # Recompute num sampled nodes for second output, # subtracting out nodes already seen in first output merged_node_num_sampled_nodes = None if (self.num_sampled_nodes is not None and other.num_sampled_nodes is not None): merged_node_num_sampled_nodes = copy.copy( self.num_sampled_nodes) curr_index = 0 # NOTE: There's an assumption here that no two nodes will be # sampled twice in the same SampleOutput object for minibatch in other.num_sampled_nodes: size_of_intersect = torch.cat([ old_node_uid, new_node_uid[curr_index:curr_index + minibatch] ]).unique(dim=0, sorted=False).shape[0] - num_old_nodes merged_node_num_sampled_nodes.append(size_of_intersect) curr_index += minibatch merged_nodes = merged_node_uid[:, 0] merged_batch = None if self.batch is not None and other.batch is not None: # Restore the batch indices to be relative to the nodes field ref_merged_batch_nodes = merged_node_uid[:, 1].unsqueeze( -1).expand(-1, 2) # num_nodes x 2 merged_batch = global_to_local_node_idx( merged_node_uid, ref_merged_batch_nodes) # EDGES is_bidirectional = self.orig_row is not None \ and self.orig_col is not None \ and other.orig_row is not None \ and other.orig_col is not None if is_bidirectional: old_row, old_col = self.orig_row, self.orig_col new_row, new_col = other.orig_row, other.orig_col else: old_row, old_col = self.row, self.col new_row, new_col = other.row, other.col # Transform the row and col indices to be global node ids # instead of relative indices to nodes field # Edge uids build off of node uids old_row_idx, old_col_idx = local_to_global_node_idx( old_node_uid, old_row), local_to_global_node_idx(old_node_uid, old_col) new_row_idx, new_col_idx = local_to_global_node_idx( new_node_uid, new_row), local_to_global_node_idx(new_node_uid, new_col) old_edge_uid, new_edge_uid = [old_row_idx, old_col_idx ], [new_row_idx, new_col_idx] row_idx = 0 col_idx = old_row_idx.shape[1] edge_idx = old_row_idx.shape[1] + old_col_idx.shape[1] if self.edge is not None and other.edge is not None: if is_bidirectional: # bidirectional duplicates edge ids old_edge_uid_ref = torch.stack([self.row, self.col], dim=1) # num_edges x 2 old_orig_edge_uid_ref = torch.stack( [self.orig_row, self.orig_col], dim=1) # num_orig_edges x 2 old_edge_idx = global_to_local_node_idx( old_edge_uid_ref, old_orig_edge_uid_ref) old_edge = self.edge[old_edge_idx] new_edge_uid_ref = torch.stack([other.row, other.col], dim=1) # num_edges x 2 new_orig_edge_uid_ref = torch.stack( [other.orig_row, other.orig_col], dim=1) # num_orig_edges x 2 new_edge_idx = global_to_local_node_idx( new_edge_uid_ref, new_orig_edge_uid_ref) new_edge = other.edge[new_edge_idx] else: old_edge, new_edge = self.edge, other.edge old_edge_uid.append(old_edge.unsqueeze(-1)) new_edge_uid.append(new_edge.unsqueeze(-1)) old_edge_uid = torch.cat(old_edge_uid, dim=1) new_edge_uid = torch.cat(new_edge_uid, dim=1) merged_edge_uid = unique_unsorted( torch.cat([old_edge_uid, new_edge_uid], dim=0)) num_old_edges = old_edge_uid.shape[0] merged_edge_num_sampled_edges = None if (self.num_sampled_edges is not None and other.num_sampled_edges is not None): merged_edge_num_sampled_edges = copy.copy( self.num_sampled_edges) curr_index = 0 # NOTE: There's an assumption here that no two edges will be # sampled twice in the same SampleOutput object for minibatch in other.num_sampled_edges: size_of_intersect = torch.cat([ old_edge_uid, new_edge_uid[curr_index:curr_index + minibatch] ]).unique(dim=0, sorted=False).shape[0] - num_old_edges merged_edge_num_sampled_edges.append(size_of_intersect) curr_index += minibatch merged_row = merged_edge_uid[:, row_idx:col_idx] merged_col = merged_edge_uid[:, col_idx:edge_idx] merged_edge = merged_edge_uid[:, edge_idx:].squeeze() \ if self.edge is not None and other.edge is not None else None # restore to row and col indices relative to nodes field merged_row = global_to_local_node_idx(merged_node_uid, merged_row) merged_col = global_to_local_node_idx(merged_node_uid, merged_col) out = SamplerOutput( node=merged_nodes, row=merged_row, col=merged_col, edge=merged_edge, batch=merged_batch, num_sampled_nodes=merged_node_num_sampled_nodes, num_sampled_edges=merged_edge_num_sampled_edges, metadata=[self.metadata, other.metadata], ) # Restores orig_row and orig_col if they existed before merging if is_bidirectional: out = out.to_bidirectional(keep_orig_edges=True) return out @dataclass class HeteroSamplerOutput(CastMixin): r"""The sampling output of a :class:`~torch_geometric.sampler.BaseSampler` on heterogeneous graphs. Args: node (Dict[str, torch.Tensor]): The sampled nodes in the original graph for each node type. row (Dict[Tuple[str, str, str], torch.Tensor]): The source node indices of the sampled subgraph for each edge type. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor of the source node type. col (Dict[Tuple[str, str, str], torch.Tensor]): The destination node indices of the sampled subgraph for each edge type. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor of the destination node type. edge (Dict[Tuple[str, str, str], torch.Tensor], optional): The sampled edges in the original graph for each edge type. This tensor is used to obtain edge features from the original graph. If no edge attributes are present, it may be omitted. batch (Dict[str, torch.Tensor], optional): The vector to identify the seed node for each sampled node for each node type. Can be present in case of disjoint subgraph sampling per seed node. (default: :obj:`None`) num_sampled_nodes (Dict[str, List[int]], optional): The number of sampled nodes for each node type and each layer. (default: :obj:`None`) num_sampled_edges (Dict[EdgeType, List[int]], optional): The number of sampled edges for each edge type and each layer. (default: :obj:`None`) orig_row (Dict[EdgeType, torch.Tensor], optional): The original source node indices returned by the sampler. Filled in case :meth:`to_bidirectional` is called with the :obj:`keep_orig_edges` option. (default: :obj:`None`) orig_col (Dict[EdgeType, torch.Tensor], optional): The original destination node indices returned by the sampler. Filled in case :meth:`to_bidirectional` is called with the :obj:`keep_orig_edges` option. (default: :obj:`None`) metadata: (Any, optional): Additional metadata information. (default: :obj:`None`) """ node: Dict[NodeType, Tensor] row: Dict[EdgeType, Tensor] col: Dict[EdgeType, Tensor] edge: Dict[EdgeType, OptTensor] batch: Optional[Dict[NodeType, Tensor]] = None num_sampled_nodes: Optional[Dict[NodeType, List[int]]] = None num_sampled_edges: Optional[Dict[EdgeType, List[int]]] = None orig_row: Optional[Dict[EdgeType, Tensor]] = None orig_col: Optional[Dict[EdgeType, Tensor]] = None # TODO(manan): refine this further; it does not currently define a proper # API for the expected output of a sampler. metadata: Optional[Any] = None @property def global_row(self) -> Dict[EdgeType, Tensor]: return { edge_type: local_to_global_node_idx(self.node[edge_type[0]], row) for edge_type, row in self.row.items() } @property def global_col(self) -> Dict[EdgeType, Tensor]: return { edge_type: local_to_global_node_idx(self.node[edge_type[2]], col) for edge_type, col in self.col.items() } @property def seed_node(self) -> Optional[Dict[NodeType, Tensor]]: return { node_type: local_to_global_node_idx(self.node[node_type], batch) for node_type, batch in self.batch.items() } if self.batch is not None else None @property def global_orig_row(self) -> Optional[Dict[EdgeType, Tensor]]: return { edge_type: local_to_global_node_idx(self.node[edge_type[0]], orig_row) for edge_type, orig_row in self.orig_row.items() } if self.orig_row is not None else None @property def global_orig_col(self) -> Optional[Dict[EdgeType, Tensor]]: return { edge_type: local_to_global_node_idx(self.node[edge_type[2]], orig_col) for edge_type, orig_col in self.orig_col.items() } if self.orig_col is not None else None def to_bidirectional( self, keep_orig_edges: bool = False, ) -> 'SamplerOutput': r"""Converts the sampled subgraph into a bidirectional variant, in which all sampled edges are guaranteed to be bidirectional. Args: keep_orig_edges (bool, optional): If specified, directional edges are still maintained. (default: :obj:`False`) """ out = copy.copy(self) out.row = copy.copy(self.row) out.col = copy.copy(self.col) out.edge = copy.copy(self.edge) if keep_orig_edges: out.orig_row = {} out.orig_col = {} for key in self.row.keys(): out.orig_row[key] = self.row[key] out.orig_col[key] = self.col[key] else: out.num_sampled_nodes = out.num_sampled_edges = None src_dst_dict = defaultdict(list) edge_types = self.row.keys() edge_types = [k for k in edge_types if not k[1].startswith('rev_')] for edge_type in edge_types: src, rel, dst = edge_type rev_edge_type = (dst, f'rev_{rel}', src) if src == dst and rev_edge_type not in self.row: out.row[edge_type], out.col[edge_type], _ = to_bidirectional( row=self.row[edge_type], col=self.col[edge_type], rev_row=self.row[edge_type], rev_col=self.col[edge_type], ) if out.edge is not None: out.edge[edge_type] = None elif rev_edge_type in self.row: out.row[edge_type], out.col[edge_type], _ = to_bidirectional( row=self.row[edge_type], col=self.col[edge_type], rev_row=self.row[rev_edge_type], rev_col=self.col[rev_edge_type], ) out.row[rev_edge_type] = out.col[edge_type] out.col[rev_edge_type] = out.row[edge_type] if out.edge is not None: out.edge[edge_type] = None out.edge[rev_edge_type] = None else: # Find the reverse edge type (if it is unique): if len(src_dst_dict) == 0: # Create mapping lazily. for key in self.row.keys(): v1, _, v2 = key src_dst_dict[(v1, v2)].append(key) if len(src_dst_dict[(dst, src)]) == 1: rev_edge_type = src_dst_dict[(dst, src)][0] row, col, _ = to_bidirectional( row=self.row[edge_type], col=self.col[edge_type], rev_row=self.row[rev_edge_type], rev_col=self.col[rev_edge_type], ) out.row[edge_type] = row out.col[edge_type] = col if out.edge is not None: out.edge[edge_type] = None else: warnings.warn( f"Cannot convert to bidirectional graph " f"since the edge type {edge_type} does not " f"seem to have a reverse edge type", stacklevel=2) return out @classmethod def collate(cls, outputs: List['HeteroSamplerOutput'], replace: bool = True) -> 'HeteroSamplerOutput': r"""Collate a list of :class:`~torch_geometric.sampler.HeteroSamplerOutput`objects into a single :class:`~torch_geometric.sampler.HeteroSamplerOutput` object. Requires that they all have the same fields. """ # TODO(zaristei) raise NotImplementedError def merge_with(self, other: 'HeteroSamplerOutput', replace: bool = True) -> 'HeteroSamplerOutput': """Merges two HeteroSamplerOutputs. If replace is True, self's nodes and edges take precedence. """ # TODO(zaristei) raise NotImplementedError @dataclass(frozen=True) class NumNeighbors: r"""The number of neighbors to sample in a homogeneous or heterogeneous graph. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for individual edge types. Args: values (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample. If an entry is set to :obj:`-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for individual edge types. default (List[int], optional): The default number of neighbors for edge types not specified in :obj:`values`. (default: :obj:`None`) """ values: Union[List[int], Dict[EdgeTypeStr, List[int]]] default: Optional[List[int]] = None def __init__( self, values: Union[List[int], Dict[EdgeType, List[int]]], default: Optional[List[int]] = None, ): if isinstance(values, (tuple, list)) and default is not None: raise ValueError(f"'default' must be set to 'None' in case a " f"single list is given as the number of " f"neighbors (got '{type(default)})'") if isinstance(values, dict): values = {EdgeTypeStr(key): value for key, value in values.items()} # Write to `__dict__` since dataclass is annotated with `frozen=True`: self.__dict__['values'] = values self.__dict__['default'] = default def _get_values( self, edge_types: Optional[List[EdgeType]] = None, mapped: bool = False, ) -> Union[List[int], Dict[Union[EdgeType, EdgeTypeStr], List[int]]]: if edge_types is not None: if isinstance(self.values, (tuple, list)): default = self.values elif isinstance(self.values, dict): default = self.default else: raise AssertionError() # Confirm that `values` only hold valid edge types: if isinstance(self.values, dict): edge_types_str = {EdgeTypeStr(key) for key in edge_types} invalid_edge_types = set(self.values.keys()) - edge_types_str if len(invalid_edge_types) > 0: raise ValueError("Not all edge types specified in " "'num_neighbors' exist in the graph") out = {} for edge_type in edge_types: edge_type_str = EdgeTypeStr(edge_type) if edge_type_str in self.values: out[edge_type_str if mapped else edge_type] = ( self.values[edge_type_str]) else: if default is None: raise ValueError(f"Missing number of neighbors for " f"edge type '{edge_type}'") out[edge_type_str if mapped else edge_type] = default elif isinstance(self.values, dict) and not mapped: out = {key.to_tuple(): value for key, value in self.values.items()} else: out = copy.copy(self.values) if isinstance(out, dict): num_hops = {len(v) for v in out.values()} if len(num_hops) > 1: raise ValueError(f"Number of hops must be the same across all " f"edge types (got {len(num_hops)} different " f"number of hops)") return out def get_values( self, edge_types: Optional[List[EdgeType]] = None, ) -> Union[List[int], Dict[EdgeType, List[int]]]: r"""Returns the number of neighbors. Args: edge_types (List[Tuple[str, str, str]], optional): The edge types to generate the number of neighbors for. (default: :obj:`None`) """ if '_values' in self.__dict__: return self.__dict__['_values'] values = self._get_values(edge_types, mapped=False) self.__dict__['_values'] = values return values def get_mapped_values( self, edge_types: Optional[List[EdgeType]] = None, ) -> Union[List[int], Dict[str, List[int]]]: r"""Returns the number of neighbors. For heterogeneous graphs, a dictionary is returned in which edge type tuples are converted to strings. Args: edge_types (List[Tuple[str, str, str]], optional): The edge types to generate the number of neighbors for. (default: :obj:`None`) """ if '_mapped_values' in self.__dict__: return self.__dict__['_mapped_values'] values = self._get_values(edge_types, mapped=True) self.__dict__['_mapped_values'] = values return values @property def num_hops(self) -> int: r"""Returns the number of hops.""" if '_num_hops' in self.__dict__: return self.__dict__['_num_hops'] if isinstance(self.values, (tuple, list)): num_hops = max(len(self.values), len(self.default or [])) else: # isinstance(self.values, dict): num_hops = max([0] + [len(v) for v in self.values.values()]) num_hops = max(num_hops, len(self.default or [])) self.__dict__['_num_hops'] = num_hops return num_hops def __len__(self) -> int: r"""Returns the number of hops.""" return self.num_hops class NegativeSamplingMode(Enum): # 'binary': Randomly sample negative edges in the graph. binary = 'binary' # 'triplet': Randomly sample negative destination nodes for each positive # source node. triplet = 'triplet' @dataclass class NegativeSampling(CastMixin): r"""The negative sampling configuration of a :class:`~torch_geometric.sampler.BaseSampler` when calling :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`. Args: mode (str): The negative sampling mode (:obj:`"binary"` or :obj:`"triplet"`). If set to :obj:`"binary"`, will randomly sample negative links from the graph. If set to :obj:`"triplet"`, will randomly sample negative destination nodes for each positive source node. amount (int or float, optional): The ratio of sampled negative edges to the number of positive edges. (default: :obj:`1`) src_weight (torch.Tensor, optional): A node-level vector determining the sampling of source nodes. Does not necessarily need to sum up to one. If not given, negative nodes will be sampled uniformly. (default: :obj:`None`) dst_weight (torch.Tensor, optional): A node-level vector determining the sampling of destination nodes. Does not necessarily need to sum up to one. If not given, negative nodes will be sampled uniformly. (default: :obj:`None`) """ mode: NegativeSamplingMode amount: Union[int, float] = 1 src_weight: Optional[Tensor] = None dst_weight: Optional[Tensor] = None def __init__( self, mode: Union[NegativeSamplingMode, str], amount: Union[int, float] = 1, src_weight: Optional[Tensor] = None, dst_weight: Optional[Tensor] = None, ): self.mode = NegativeSamplingMode(mode) self.amount = amount self.src_weight = src_weight self.dst_weight = dst_weight if self.amount <= 0: raise ValueError(f"The attribute 'amount' needs to be positive " f"for '{self.__class__.__name__}' " f"(got {self.amount})") if self.is_triplet(): if self.amount != math.ceil(self.amount): raise ValueError(f"The attribute 'amount' needs to be an " f"integer for '{self.__class__.__name__}' " f"with 'triplet' negative sampling " f"(got {self.amount}).") self.amount = math.ceil(self.amount) def is_binary(self) -> bool: return self.mode == NegativeSamplingMode.binary def is_triplet(self) -> bool: return self.mode == NegativeSamplingMode.triplet def sample( self, num_samples: int, endpoint: Literal['src', 'dst'], num_nodes: Optional[int] = None, ) -> Tensor: r"""Generates :obj:`num_samples` negative samples.""" weight = self.src_weight if endpoint == 'src' else self.dst_weight if weight is None: if num_nodes is None: raise ValueError( f"Cannot sample negatives in '{self.__class__.__name__}' " f"without passing the 'num_nodes' argument") return torch.randint(num_nodes, (num_samples, )) if num_nodes is not None and weight.numel() != num_nodes: raise ValueError( f"The 'weight' attribute in '{self.__class__.__name__}' " f"needs to match the number of nodes {num_nodes} " f"(got {self.weight.numel()})") return torch.multinomial(weight, num_samples, replacement=True) class BaseSampler(ABC): r"""An abstract base class that initializes a graph sampler and provides :meth:`sample_from_nodes` and :meth:`sample_from_edges` routines. .. note :: Any data stored in the sampler will be *replicated* across data loading workers that use the sampler since each data loading worker holds its own instance of a sampler. As such, it is recommended to limit the amount of information stored in the sampler. """ @abstractmethod def sample_from_nodes( self, index: NodeSamplerInput, **kwargs, ) -> Union[HeteroSamplerOutput, SamplerOutput]: r"""Performs sampling from the nodes specified in :obj:`index`, returning a sampled subgraph in the specified output format. The :obj:`index` is a tuple holding the following information: 1. The example indices of the seed nodes 2. The node indices to start sampling from 3. The timestamps of the given seed nodes (optional) Args: index (NodeSamplerInput): The node sampler input object. **kwargs (optional): Additional keyword arguments. """ raise NotImplementedError @abstractmethod def sample_from_edges( self, index: EdgeSamplerInput, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[HeteroSamplerOutput, SamplerOutput]: r"""Performs sampling from the edges specified in :obj:`index`, returning a sampled subgraph in the specified output format. The :obj:`index` is a tuple holding the following information: 1. The example indices of the seed links 2. The source node indices to start sampling from 3. The destination node indices to start sampling from 4. The labels of the seed links (optional) 5. The timestamps of the given seed nodes (optional) Args: index (EdgeSamplerInput): The edge sampler input object. neg_sampling (NegativeSampling, optional): The negative sampling configuration. (default: :obj:`None`) """ raise NotImplementedError @property def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]: r"""If the sampler performs any modification of edge ordering in the original graph, this function is expected to return the permutation tensor that defines the permutation from the edges in the original graph and the edges used in the sampler. If no such permutation was applied, :obj:`None` is returned. For heterogeneous graphs, the expected return type is a permutation tensor for each edge type. """ return None ================================================ FILE: torch_geometric/sampler/hgt_sampler.py ================================================ from typing import Dict, List, Optional, Union import torch from torch_geometric.data import Data, HeteroData from torch_geometric.sampler import ( BaseSampler, EdgeSamplerInput, HeteroSamplerOutput, NegativeSampling, NodeSamplerInput, SamplerOutput, ) from torch_geometric.sampler.utils import remap_keys, to_hetero_csc from torch_geometric.typing import ( WITH_TORCH_SPARSE, EdgeType, NodeType, OptTensor, ) class HGTSampler(BaseSampler): r"""An implementation of an in-memory heterogeneous layer-wise sampler user by :class:`~torch_geometric.loader.HGTLoader`. """ def __init__( self, data: HeteroData, num_samples: Union[List[int], Dict[NodeType, List[int]]], is_sorted: bool = False, share_memory: bool = False, ): if not WITH_TORCH_SPARSE: raise ImportError( f"'{self.__class__.__name__}' requires 'torch-sparse'") if isinstance(data, Data) or isinstance(data, tuple): raise NotImplementedError( f'{self.__class__.__name__} does not support a data object of ' f'type {type(data)}.') if isinstance(num_samples, (list, tuple)): num_samples = {key: num_samples for key in data.node_types} self.node_types, self.edge_types = data.metadata() self.num_samples = num_samples self.num_hops = max([len(v) for v in num_samples.values()]) # Conversion to/from C++ string type (see `NeighborSampler`): self.to_rel_type = {k: '__'.join(k) for k in self.edge_types} self.to_edge_type = {v: k for k, v in self.to_rel_type.items()} # Convert the graph data into a suitable format for sampling: colptr_dict, row_dict, self.perm = to_hetero_csc( data, device='cpu', share_memory=share_memory, is_sorted=is_sorted) self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) def sample_from_nodes( self, inputs: NodeSamplerInput, ) -> HeteroSamplerOutput: node, row, col, edge = torch.ops.torch_sparse.hgt_sample( self.colptr_dict, self.row_dict, {inputs.input_type: inputs.node}, self.num_samples, self.num_hops, ) return HeteroSamplerOutput( node=node, row=remap_keys(row, self.to_edge_type), col=remap_keys(col, self.to_edge_type), edge=remap_keys(edge, self.to_edge_type), batch=None, metadata=(inputs.input_id, inputs.time), ) def sample_from_edges( self, index: EdgeSamplerInput, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[HeteroSamplerOutput, SamplerOutput]: pass @property def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]: return self.perm ================================================ FILE: torch_geometric/sampler/neighbor_sampler.py ================================================ import copy import math import sys import warnings from typing import Callable, Dict, List, Literal, Optional, Tuple, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric.data import ( Data, FeatureStore, GraphStore, HeteroData, remote_backend_utils, ) from torch_geometric.data.graph_store import EdgeLayout from torch_geometric.sampler import ( BaseSampler, EdgeSamplerInput, HeteroSamplerOutput, NegativeSampling, NodeSamplerInput, SamplerOutput, ) from torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType from torch_geometric.sampler.utils import ( global_to_local_node_idx, remap_keys, reverse_edge_type, to_csc, to_hetero_csc, ) from torch_geometric.typing import EdgeType, NodeType, OptTensor NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]] class NeighborSampler(BaseSampler): r"""An implementation of an in-memory (heterogeneous) neighbor sampler used by :class:`~torch_geometric.loader.NeighborLoader`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: NumNeighborsType, subgraph_type: Union[SubgraphType, str] = 'directional', replace: bool = False, disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, weight_attr: Optional[str] = None, is_sorted: bool = False, share_memory: bool = False, directed: bool = True, # Deprecated sample_direction: Literal['forward', 'backward'] = 'forward', ): if not directed: subgraph_type = SubgraphType.induced warnings.warn( f"The usage of the 'directed' argument in " f"'{self.__class__.__name__}' is deprecated. Use " f"`subgraph_type='induced'` instead.", stacklevel=2) if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux' and subgraph_type != SubgraphType.induced): warnings.warn( f"Using '{self.__class__.__name__}' without a " f"'pyg-lib' installation is deprecated and will be " f"removed soon. Please install 'pyg-lib' for " f"accelerated neighborhood sampling", stacklevel=2) self.data_type = DataType.from_data(data) self.sample_direction = sample_direction if self.sample_direction == 'backward': # TODO(zaristei) if time_attr is not None: raise NotImplementedError( "Temporal Sampling not yet supported for backward sampling" ) if self.data_type == DataType.homogeneous: self.num_nodes = data.num_nodes self.node_time: Optional[Tensor] = None self.edge_time: Optional[Tensor] = None if time_attr is not None: if data.is_node_attr(time_attr): self.node_time = data[time_attr] elif data.is_edge_attr(time_attr): self.edge_time = data[time_attr] else: raise ValueError( f"The time attribute '{time_attr}' is neither a " f"node-level or edge-level attribute") # Convert the graph data into CSC format for sampling: self.colptr, self.row, self.perm = to_csc( data, device='cpu', share_memory=share_memory, is_sorted=is_sorted, src_node_time=self.node_time, edge_time=self.edge_time, to_transpose=self.sample_direction == 'backward') if self.edge_time is not None and self.perm is not None: self.edge_time = self.edge_time[self.perm] self.edge_weight: Optional[Tensor] = None if weight_attr is not None: self.edge_weight = data[weight_attr] if self.perm is not None: self.edge_weight = self.edge_weight[self.perm] elif self.data_type == DataType.heterogeneous: self.node_types, self.edge_types = data.metadata() # reverse edge types if sample_direction is backward if self.sample_direction == 'backward': self.edge_types = [ reverse_edge_type(edge_type) for edge_type in self.edge_types ] self.to_restored_edge_type = { k: reverse_edge_type(k) for k in self.edge_types } self.num_nodes = {k: data[k].num_nodes for k in self.node_types} self.node_time: Optional[Dict[NodeType, Tensor]] = None self.edge_time: Optional[Dict[EdgeType, Tensor]] = None if time_attr is not None: is_node_level_time = is_edge_level_time = False for store in data.node_stores: if time_attr in store: is_node_level_time = True for store in data.edge_stores: if time_attr in store: is_edge_level_time = True if is_node_level_time and is_edge_level_time: raise ValueError( f"The time attribute '{time_attr}' holds both " f"node-level and edge-level information") if not is_node_level_time and not is_edge_level_time: raise ValueError( f"The time attribute '{time_attr}' is neither a " f"node-level or edge-level attribute") if is_node_level_time: self.node_time = data.collect(time_attr) else: self.edge_time = data.collect(time_attr) # Conversion to/from C++ string type: Since C++ cannot take # dictionaries with tuples as key as input, edge type triplets need # to be converted into single strings. self.to_rel_type = {k: '__'.join(k) for k in self.edge_types} self.to_edge_type = {v: k for k, v in self.to_rel_type.items()} # Convert the graph data into CSC format for sampling: colptr_dict, row_dict, self.perm = to_hetero_csc( data, device='cpu', share_memory=share_memory, is_sorted=is_sorted, node_time_dict=self.node_time, edge_time_dict=self.edge_time, to_transpose=self.sample_direction == 'backward') self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) if self.edge_time is not None: for edge_type, edge_time in self.edge_time.items(): if self.perm.get(edge_type, None) is not None: edge_time = edge_time[self.perm[edge_type]] self.edge_time[edge_type] = edge_time self.edge_time = remap_keys(self.edge_time, self.to_rel_type) self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None if weight_attr is not None: self.edge_weight = data.collect(weight_attr) for edge_type, edge_weight in self.edge_weight.items(): if self.perm.get(edge_type, None) is not None: edge_weight = edge_weight[self.perm[edge_type]] self.edge_weight[edge_type] = edge_weight self.edge_weight = remap_keys(self.edge_weight, self.to_rel_type) else: # self.data_type == DataType.remote feature_store, graph_store = data # Obtain graph metadata: attrs = [attr for attr in feature_store.get_all_tensor_attrs()] edge_attrs = graph_store.get_all_edge_attrs() self.edge_types = list({attr.edge_type for attr in edge_attrs}) # reverse edge types if sample_direction is backward if self.sample_direction == 'backward': self.edge_types = [ reverse_edge_type(edge_type) for edge_type in self.edge_types ] self.to_restored_edge_type = { k: reverse_edge_type(k) for k in self.edge_types } self.to_backward_edge_type = { v: k for k, v in self.to_restored_edge_type.items() } if weight_attr is not None: raise NotImplementedError( f"'weight_attr' argument not yet supported within " f"'{self.__class__.__name__}' for " f"'(FeatureStore, GraphStore)' inputs") if time_attr is not None: # If the `time_attr` is present, we expect that `GraphStore` # holds all edges sorted by destination, and within local # neighborhoods, node indices should be sorted by time. # TODO (matthias, manan) Find an alternative way to ensure. for edge_attr in edge_attrs: if edge_attr.layout == EdgeLayout.CSR: raise ValueError( "Temporal sampling requires that edges are stored " "in either COO or CSC layout") if not edge_attr.is_sorted: raise ValueError( "Temporal sampling requires that edges are " "sorted by destination, and by source time " "within local neighborhoods") # We obtain all features with `node_attr.name=time_attr`: time_attrs = [ copy.copy(attr) for attr in attrs if attr.attr_name == time_attr ] if not self.is_hetero: self.node_types = [None] self.num_nodes = max(edge_attrs[0].size) self.edge_weight: Optional[Tensor] = None self.node_time: Optional[Tensor] = None self.edge_time: Optional[Tensor] = None if time_attr is not None: if len(time_attrs) != 1: raise ValueError("Temporal sampling specified but did " "not find any temporal data") time_attrs[0].index = None # Reset index for full data. time_tensor = feature_store.get_tensor(time_attrs[0]) # Currently, we determine whether to use node-level or # edge-level temporal sampling based on the attribute name. if time_attr == 'time': self.node_time = time_tensor else: self.edge_time = time_tensor if self.sample_direction == 'forward': self.row, self.colptr, self.perm = graph_store.csc() elif self.sample_direction == 'backward': self.colptr, self.row, self.perm = graph_store.csr() else: node_types = [ attr.group_name for attr in attrs if isinstance(attr.group_name, str) ] self.node_types = list(set(node_types)) self.num_nodes = { node_type: remote_backend_utils.size(*data, node_type) for node_type in self.node_types } self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None self.node_time: Optional[Dict[NodeType, Tensor]] = None self.edge_time: Optional[Dict[EdgeType, Tensor]] = None if time_attr is not None: for attr in time_attrs: # Reset index for full data. attr.index = None time_tensors = feature_store.multi_get_tensor(time_attrs) time = { attr.group_name: time_tensor for attr, time_tensor in zip(time_attrs, time_tensors) } group_names = [attr.group_name for attr in time_attrs] if all([isinstance(g, str) for g in group_names]): self.node_time = time elif all([isinstance(g, tuple) for g in group_names]): self.edge_time = time else: raise ValueError( f"Found time attribute '{time_attr}' for both " f"node-level and edge-level types") # Conversion to/from C++ string type (see above): self.to_rel_type = {k: '__'.join(k) for k in self.edge_types} self.to_edge_type = {v: k for k, v in self.to_rel_type.items()} if self.sample_direction == 'forward': row_dict, colptr_dict, self.perm = graph_store.csc() elif self.sample_direction == 'backward': colptr_dict, row_dict, self.perm = graph_store.csr() colptr_dict = remap_keys(colptr_dict, self.to_backward_edge_type) row_dict = remap_keys(row_dict, self.to_backward_edge_type) self.perm = remap_keys(self.perm, self.to_backward_edge_type) self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) if (self.edge_time is not None and not torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE): raise ImportError("Edge-level temporal sampling requires a " "more recent 'pyg-lib' installation") if (self.edge_weight is not None and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE): raise ImportError("Weighted neighbor sampling requires " "'pyg-lib>=0.3.0'") self.num_neighbors = num_neighbors self.replace = replace self.subgraph_type = SubgraphType(subgraph_type) self.disjoint = disjoint self.temporal_strategy = temporal_strategy self.keep_orig_edges = False @property def num_neighbors(self) -> NumNeighbors: if self.sample_direction == 'backward': return self._input_num_neighbors \ if self._input_num_neighbors is not None \ else self._num_neighbors return self._num_neighbors @num_neighbors.setter def num_neighbors(self, num_neighbors: NumNeighborsType): # only used if sample direction is backward and num_neighbors has edge # keys self._input_num_neighbors = None if isinstance(num_neighbors, NumNeighbors): num_neighbors_values = num_neighbors.values if isinstance(num_neighbors_values, dict) and self.sample_direction == 'backward': # reverse the edge_types if sample_direction is backward self._input_num_neighbors = num_neighbors num_neighbors_values = remap_keys(num_neighbors_values, self.to_backward_edge_type) self._num_neighbors = NumNeighbors(num_neighbors_values) else: self._num_neighbors = num_neighbors else: if isinstance(num_neighbors, dict) and self.sample_direction == 'backward': # intentionally recursing here to make sure num_neighbors is # set as expected for the user self.num_neighbors = NumNeighbors( remap_keys(num_neighbors, self.to_backward_edge_type)) else: self._num_neighbors = NumNeighbors(num_neighbors) @property def is_hetero(self) -> bool: if self.data_type == DataType.homogeneous: return False if self.data_type == DataType.heterogeneous: return True # self.data_type == DataType.remote return self.edge_types != [None] @property def is_temporal(self) -> bool: return self.node_time is not None or self.edge_time is not None @property def disjoint(self) -> bool: return self._disjoint or self.is_temporal @disjoint.setter def disjoint(self, disjoint: bool): self._disjoint = disjoint # Node-based sampling ##################################################### def sample_from_nodes( self, inputs: NodeSamplerInput, ) -> Union[SamplerOutput, HeteroSamplerOutput]: out = node_sample(inputs, self._sample) if self.subgraph_type == SubgraphType.bidirectional: out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges) return out # Edge-based sampling ##################################################### def sample_from_edges( self, inputs: EdgeSamplerInput, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[SamplerOutput, HeteroSamplerOutput]: out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint, self.node_time, neg_sampling) if self.subgraph_type == SubgraphType.bidirectional: out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges) return out # Other Utilities ######################################################### @property def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]: return self.perm # Helper functions ######################################################## def _sample( self, seed: Union[Tensor, Dict[NodeType, Tensor]], seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Implements neighbor sampling by calling either :obj:`pyg-lib` (if installed) or :obj:`torch-sparse` (if installed) sampling routines. """ if isinstance(seed, dict): # Heterogeneous sampling: # TODO Support induced subgraph sampling in `pyg-lib`. if (torch_geometric.typing.WITH_PYG_LIB and self.subgraph_type != SubgraphType.induced): # TODO (matthias) Ideally, `seed` inherits dtype from `colptr` colptrs = list(self.colptr_dict.values()) dtype = colptrs[0].dtype if len(colptrs) > 0 else torch.int64 seed = {k: v.to(dtype) for k, v in seed.items()} args = ( self.node_types, self.edge_types, self.colptr_dict, self.row_dict, seed, self.num_neighbors.get_mapped_values(self.edge_types), self.node_time, ) if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE: args += (self.edge_time, ) args += (seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: args += (self.edge_weight, ) args += ( True, # csc self.replace, self.subgraph_type != SubgraphType.induced, self.disjoint, self.temporal_strategy, # TODO (matthias) `return_edge_id` if edge features present True, # return_edge_id ) out = torch.ops.pyg.hetero_neighbor_sample(*args) row, col, node, edge, batch = out[:4] + (None, ) # `pyg-lib>0.1.0` returns sampled number of nodes/edges: num_sampled_nodes = num_sampled_edges = None if len(out) >= 6: num_sampled_nodes, num_sampled_edges = out[4:6] if self.disjoint: node = {k: v.t().contiguous() for k, v in node.items()} batch = {k: v[0] for k, v in node.items()} node = {k: v[1] for k, v in node.items()} elif torch_geometric.typing.WITH_TORCH_SPARSE: if self.disjoint: if self.subgraph_type == SubgraphType.induced: raise ValueError("'disjoint' sampling not supported " "for neighbor sampling with " "`subgraph_type='induced'`") else: raise ValueError("'disjoint' sampling not supported " "for neighbor sampling via " "'torch-sparse'. Please install " "'pyg-lib' for improved and " "optimized sampling routines.") out = torch.ops.torch_sparse.hetero_neighbor_sample( self.node_types, self.edge_types, self.colptr_dict, self.row_dict, seed, # seed_dict self.num_neighbors.get_mapped_values(self.edge_types), self.num_neighbors.num_hops, self.replace, self.subgraph_type != SubgraphType.induced, ) node, row, col, edge, batch = out + (None, ) num_sampled_nodes = num_sampled_edges = None else: raise ImportError(f"'{self.__class__.__name__}' requires " f"either 'pyg-lib' or 'torch-sparse'") if self.sample_direction == 'backward': row, col = col, row row = remap_keys(row, self.to_edge_type) col = remap_keys(col, self.to_edge_type) edge = remap_keys(edge, self.to_edge_type) # In the case of backward sampling, we need to restore the edges # keys to be forward facing in the HeteroSamplerOutput object. if self.sample_direction == 'backward': row = remap_keys(row, self.to_restored_edge_type) col = remap_keys(col, self.to_restored_edge_type) edge = remap_keys(edge, self.to_restored_edge_type) if num_sampled_edges is not None: num_sampled_edges = remap_keys( num_sampled_edges, self.to_edge_type, ) if self.sample_direction == 'backward': num_sampled_edges = remap_keys(num_sampled_edges, self.to_restored_edge_type) return HeteroSamplerOutput( node=node, row=row, col=col, edge=edge, batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, ) else: # Homogeneous sampling: # TODO Support induced subgraph sampling in `pyg-lib`. if (torch_geometric.typing.WITH_PYG_LIB and self.subgraph_type != SubgraphType.induced): args = ( self.colptr, self.row, # TODO (matthias) `seed` should inherit dtype from `colptr` seed.to(self.colptr.dtype), self.num_neighbors.get_mapped_values(), self.node_time, ) if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE: args += (self.edge_time, ) args += (seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: args += (self.edge_weight, ) args += ( True, # csc self.replace, self.subgraph_type != SubgraphType.induced, self.disjoint, self.temporal_strategy, # TODO (matthias) `return_edge_id` if edge features present True, # return_edge_id ) out = torch.ops.pyg.neighbor_sample(*args) row, col, node, edge, batch = out[:4] + (None, ) # `pyg-lib>0.1.0` returns sampled number of nodes/edges: num_sampled_nodes = num_sampled_edges = None if len(out) >= 6: num_sampled_nodes, num_sampled_edges = out[4:6] if self.disjoint: batch, node = node.t().contiguous() elif torch_geometric.typing.WITH_TORCH_SPARSE: if self.disjoint: raise ValueError("'disjoint' sampling not supported for " "neighbor sampling via 'torch-sparse'. " "Please install 'pyg-lib' for improved " "and optimized sampling routines.") out = torch.ops.torch_sparse.neighbor_sample( self.colptr, self.row, seed, # seed self.num_neighbors.get_mapped_values(), self.replace, self.subgraph_type != SubgraphType.induced, ) node, row, col, edge, batch = out + (None, ) num_sampled_nodes = num_sampled_edges = None else: raise ImportError(f"'{self.__class__.__name__}' requires " f"either 'pyg-lib' or 'torch-sparse'") if self.sample_direction == 'backward': row, col = col, row return SamplerOutput( node=node, row=row, col=col, edge=edge, batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, ) class BidirectionalNeighborSampler(NeighborSampler): """A sampler that allows for both upstream and downstream sampling.""" def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: NumNeighborsType, subgraph_type: Union[SubgraphType, str] = 'directional', replace: bool = False, disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, weight_attr: Optional[str] = None, is_sorted: bool = False, share_memory: bool = False, # Deprecated: directed: bool = True, ): # TODO(zaristei) if isinstance(num_neighbors, NumNeighbors) and isinstance( num_neighbors.values, dict) or isinstance(num_neighbors, dict): raise RuntimeError( "BidirectionalNeighborSampler does not yet support edge " "delimited sampling.") self.forward_sampler = NeighborSampler( data, num_neighbors, subgraph_type, replace, disjoint, temporal_strategy, time_attr, weight_attr, is_sorted, share_memory, sample_direction='forward', directed=directed) self.backward_sampler = NeighborSampler( data, num_neighbors, subgraph_type, replace, disjoint, temporal_strategy, time_attr, weight_attr, is_sorted, share_memory, sample_direction='backward', directed=directed) # Trigger warnings on init if number of hops is greater than 1 self.num_neighbors = num_neighbors self.subgraph_type = subgraph_type @property def num_neighbors(self) -> NumNeighbors: return self._num_neighbors @num_neighbors.setter def num_neighbors(self, num_neighbors: NumNeighborsType): if not isinstance(num_neighbors, NumNeighbors): num_neighbors = NumNeighbors(num_neighbors) if num_neighbors.num_hops > 1: print("Warning: Number of hops is greater than 1, resulting in " "memory-expensive recursive calls.") self._num_neighbors = num_neighbors @property def is_hetero(self) -> bool: return self.forward_sampler.is_hetero @property def is_temporal(self) -> bool: return self.forward_sampler.is_temporal @property def disjoint(self) -> bool: return self.forward_sampler.disjoint @disjoint.setter def disjoint(self, disjoint: bool): self.forward_sampler.disjoint = disjoint self.backward_sampler.disjoint = disjoint def sample_from_nodes( self, inputs: NodeSamplerInput, ) -> Union[SamplerOutput, HeteroSamplerOutput]: return super().sample_from_nodes(inputs) def sample_from_edges( self, inputs: EdgeSamplerInput, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[SamplerOutput, HeteroSamplerOutput]: # TODO(zaristei) Figure out what exactly regular and negative sampling # imply for bidirectional sampling case if neg_sampling is not None: raise RuntimeError( "BidirectionalNeighborSampler does not yet support " "negative sampling.") # Not thoroughly tested yet! return super().sample_from_edges(inputs) @property def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]: return self.forward_sampler.edge_permutation def _sample( self, seed: Union[Tensor, Dict[NodeType, Tensor]], seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: if seed_time is not None: raise NotImplementedError( "BidirectionalNeighborSampler does not yet support " "temporal sampling.") if self.is_hetero: raise NotImplementedError( "BidirectionalNeighborSampler does not yet support " "heterogeneous sampling.") else: current_seed = seed current_seed_batch = None current_seed_time = seed_time seen_seed_set = {int(node) for node in current_seed} if self.disjoint: current_seed_batch = torch.arange(len(current_seed)) seen_seed_set = { (int(node), int(batch)) for node, batch in zip(current_seed, current_seed_batch) } iter_results = [] for n_neighbors in self.num_neighbors.values: current_n_neighbors = [n_neighbors] self.forward_sampler.num_neighbors = current_n_neighbors self.backward_sampler.num_neighbors = current_n_neighbors fwd_result = self.forward_sampler._sample( current_seed, current_seed_time, **kwargs) bwd_result = self.backward_sampler._sample( current_seed, current_seed_time, **kwargs) # The seeds for the next iteration will be the new nodes in # this iteration iter_result = fwd_result.merge_with(bwd_result) iter_results.append(iter_result) # Find the nodes not yet seen to set a seed for next iteration if self.disjoint: iter_seed_global_batch = global_to_local_node_idx( current_seed_batch, iter_result.batch) iter_result.seed_node = seed[iter_seed_global_batch] keep_mask = torch.tensor([ (int(node), int(batch)) not in seen_seed_set for node, batch in zip(iter_result.node, iter_seed_global_batch) ]) next_seed = [(int(node), int(batch)) for node, batch in zip( iter_result.node[keep_mask], iter_seed_global_batch[keep_mask]) ] if keep_mask.any() else [] current_seed, current_seed_batch = torch.tensor( next_seed).reshape(-1, 2).transpose(0, 1).contiguous() else: keep_mask = torch.tensor([ int(node) not in seen_seed_set for node in iter_result.node ]) next_seed = [ int(node) for node in iter_result.node[keep_mask] ] if keep_mask.any() else [] current_seed = torch.tensor(next_seed) seen_seed_set |= set(next_seed) # TODO(zaristei) figure out how to update seed times for # temporal sampling return SamplerOutput.collate(iter_results) # Sampling Utilities ########################################################## def node_sample( inputs: NodeSamplerInput, sample_fn: Callable, ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Performs sampling from a :class:`NodeSamplerInput`, leveraging a sampling function that accepts a seed and (optionally) a seed time as input. Returns the output of this sampling procedure. """ if inputs.input_type is not None: # Heterogeneous sampling: seed = {inputs.input_type: inputs.node} seed_time = None if inputs.time is not None: seed_time = {inputs.input_type: inputs.time} else: # Homogeneous sampling: seed = inputs.node seed_time = inputs.time out = sample_fn(seed, seed_time) out.metadata = (inputs.input_id, inputs.time) return out def edge_sample( inputs: EdgeSamplerInput, sample_fn: Callable, num_nodes: Union[int, Dict[NodeType, int]], disjoint: bool, node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Performs sampling from an edge sampler input, leveraging a sampling function of the same signature as `node_sample`. """ input_id = inputs.input_id src = inputs.row dst = inputs.col edge_label = inputs.label edge_label_time = inputs.time input_type = inputs.input_type src_time = dst_time = edge_label_time assert edge_label_time is None or disjoint assert isinstance(num_nodes, (dict, int)) if not isinstance(num_nodes, dict): num_src_nodes = num_dst_nodes = num_nodes else: num_src_nodes = num_nodes[input_type[0]] num_dst_nodes = num_nodes[input_type[-1]] num_pos = src.numel() num_neg = 0 # Negative Sampling ####################################################### if neg_sampling is not None: # When we are doing negative sampling, we append negative information # of nodes/edges to `src`, `dst`, `src_time`, `dst_time`. # Later on, we can easily reconstruct what belongs to positive and # negative examples by slicing via `num_pos`. num_neg = math.ceil(num_pos * neg_sampling.amount) if neg_sampling.is_binary(): # In the "binary" case, we randomly sample negative pairs of nodes. if isinstance(node_time, dict): src_node_time = node_time.get(input_type[0]) else: src_node_time = node_time src_neg = neg_sample(src, neg_sampling, num_src_nodes, src_time, src_node_time, endpoint='src') src = torch.cat([src, src_neg], dim=0) if isinstance(node_time, dict): dst_node_time = node_time.get(input_type[-1]) else: dst_node_time = node_time dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time, dst_node_time, endpoint='dst') dst = torch.cat([dst, dst_neg], dim=0) if edge_label is None: edge_label = torch.ones(num_pos) size = (num_neg, ) + edge_label.size()[1:] edge_neg_label = edge_label.new_zeros(size) edge_label = torch.cat([edge_label, edge_neg_label]) if edge_label_time is not None: src_time = dst_time = edge_label_time.repeat( 1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg] elif neg_sampling.is_triplet(): # In the "triplet" case, we randomly sample negative destinations. if isinstance(node_time, dict): dst_node_time = node_time.get(input_type[-1]) else: dst_node_time = node_time dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time, dst_node_time, endpoint='dst') dst = torch.cat([dst, dst_neg], dim=0) assert edge_label is None if edge_label_time is not None: dst_time = edge_label_time.repeat(1 + neg_sampling.amount) # Heterogeneous Neighborhood Sampling ##################################### if input_type is not None: seed_time_dict = None if input_type[0] != input_type[-1]: # Two distinct node types: if not disjoint: src, inverse_src = src.unique(return_inverse=True) dst, inverse_dst = dst.unique(return_inverse=True) seed_dict = {input_type[0]: src, input_type[-1]: dst} if edge_label_time is not None: # Always disjoint. seed_time_dict = { input_type[0]: src_time, input_type[-1]: dst_time, } else: # Only a single node type: Merge both source and destination. seed = torch.cat([src, dst], dim=0) if not disjoint: seed, inverse_seed = seed.unique(return_inverse=True) seed_dict = {input_type[0]: seed} if edge_label_time is not None: # Always disjoint. seed_time_dict = { input_type[0]: torch.cat([src_time, dst_time], dim=0), } out = sample_fn(seed_dict, seed_time_dict) # Enhance `out` by label information ################################## if disjoint: for key, batch in out.batch.items(): out.batch[key] = batch % num_pos if neg_sampling is None or neg_sampling.is_binary(): if disjoint: if input_type[0] != input_type[-1]: edge_label_index = torch.arange(num_pos + num_neg) edge_label_index = edge_label_index.repeat(2).view(2, -1) else: edge_label_index = torch.arange(2 * (num_pos + num_neg)) edge_label_index = edge_label_index.view(2, -1) else: if input_type[0] != input_type[-1]: edge_label_index = torch.stack([ inverse_src, inverse_dst, ], dim=0) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = (input_id, edge_label_index, edge_label, src_time) elif neg_sampling.is_triplet(): if disjoint: src_index = torch.arange(num_pos) if input_type[0] != input_type[-1]: dst_pos_index = torch.arange(num_pos) # `dst_neg_index` needs to be offset such that indices with # offset `num_pos` belong to the same triplet: dst_neg_index = torch.arange( num_pos, seed_dict[input_type[-1]].numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: dst_pos_index = torch.arange(num_pos, 2 * num_pos) dst_neg_index = torch.arange( 2 * num_pos, seed_dict[input_type[-1]].numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: if input_type[0] != input_type[-1]: src_index = inverse_src dst_pos_index = inverse_dst[:num_pos] dst_neg_index = inverse_dst[num_pos:] else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = ( input_id, src_index, dst_pos_index, dst_neg_index, src_time, ) # Homogeneous Neighborhood Sampling ####################################### else: seed = torch.cat([src, dst], dim=0) seed_time = None if not disjoint: seed, inverse_seed = seed.unique(return_inverse=True) if edge_label_time is not None: # Always disjoint. seed_time = torch.cat([src_time, dst_time]) out = sample_fn(seed, seed_time) # Enhance `out` by label information ################################## if neg_sampling is None or neg_sampling.is_binary(): if disjoint: out.batch = out.batch % num_pos edge_label_index = torch.arange(seed.numel()).view(2, -1) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = (input_id, edge_label_index, edge_label, src_time) elif neg_sampling.is_triplet(): if disjoint: out.batch = out.batch % num_pos src_index = torch.arange(num_pos) dst_pos_index = torch.arange(num_pos, 2 * num_pos) # `dst_neg_index` needs to be offset such that indices with # offset `num_pos` belong to the same triplet: dst_neg_index = torch.arange(2 * num_pos, seed.numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = ( input_id, src_index, dst_pos_index, dst_neg_index, src_time, ) return out def neg_sample( seed: Tensor, neg_sampling: NegativeSampling, num_nodes: int, seed_time: Optional[Tensor], node_time: Optional[Tensor], endpoint: Literal['str', 'dst'], ) -> Tensor: num_neg = math.ceil(seed.numel() * neg_sampling.amount) # TODO: Do not sample false negatives. if node_time is None: return neg_sampling.sample(num_neg, endpoint, num_nodes) # If we are in a temporal-sampling scenario, we need to respect the # timestamp of the given nodes we can use as negative examples. # That is, we can only sample nodes for which `node_time <= seed_time`. # For now, we use a greedy algorithm which randomly samples negative # nodes and discard any which do not respect the temporal constraint. # We iteratively repeat this process until we have sampled a valid node for # each seed. # TODO See if this greedy algorithm here can be improved. assert seed_time is not None num_samples = math.ceil(neg_sampling.amount) seed_time = seed_time.view(1, -1).expand(num_samples, -1) out = neg_sampling.sample(num_samples * seed.numel(), endpoint, num_nodes) out = out.view(num_samples, seed.numel()) mask = node_time[out] > seed_time # holds all invalid samples. neg_sampling_complete = False for _ in range(5): # pragma: no cover num_invalid = int(mask.sum()) if num_invalid == 0: neg_sampling_complete = True break # Greedily search for alternative negatives. out[mask] = tmp = neg_sampling.sample(num_invalid, endpoint, num_nodes) mask[mask.clone()] = node_time[tmp] >= seed_time[mask] if not neg_sampling_complete: # pragma: no cover # Not much options left. In that case, we set remaining negatives # to the node with minimum timestamp. out[mask] = node_time.argmin() return out.view(-1)[:num_neg] ================================================ FILE: torch_geometric/sampler/utils.py ================================================ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.storage import EdgeStorage from torch_geometric.index import index2ptr from torch_geometric.typing import EdgeType, NodeType, OptTensor from torch_geometric.utils import coalesce, index_sort, lexsort def reverse_edge_type(edge_type: EdgeType) -> EdgeType: """Reverses edge types for heterogeneous graphs. Useful in cases of backward sampling. """ return (edge_type[2], edge_type[1], edge_type[0]) if edge_type is not None else None # Edge Layout Conversion ###################################################### def sort_csc( row: Tensor, col: Tensor, src_node_time: OptTensor = None, edge_time: OptTensor = None, ) -> Tuple[Tensor, Tensor, Tensor]: if src_node_time is None and edge_time is None: col, perm = index_sort(col) return row[perm], col, perm elif edge_time is not None: assert src_node_time is None perm = lexsort([edge_time, col]) return row[perm], col[perm], perm else: # src_node_time is not None perm = lexsort([src_node_time[row], col]) return row[perm], col[perm], perm # TODO(manan) deprecate when FeatureStore / GraphStore unification is complete def to_csc( data: Union[Data, EdgeStorage], device: Optional[torch.device] = None, share_memory: bool = False, is_sorted: bool = False, src_node_time: Optional[Tensor] = None, edge_time: Optional[Tensor] = None, to_transpose: bool = False, ) -> Tuple[Tensor, Tensor, OptTensor]: # Convert the graph data into a suitable format for sampling (CSC format). # Returns the `colptr` and `row` indices of the graph, as well as an # `perm` vector that denotes the permutation of edges. # Since no permutation of edges is applied when using `SparseTensor`, # `perm` can be of type `None`. perm: Optional[Tensor] = None if hasattr(data, 'adj'): if src_node_time is not None: raise NotImplementedError("Temporal sampling via 'SparseTensor' " "format not yet supported") if to_transpose: row, colptr, _ = data.adj.csr() else: colptr, row, _ = data.adj.csc() elif hasattr(data, 'adj_t'): if src_node_time is not None: # TODO (matthias) This only works when instantiating a # `SparseTensor` with `is_sorted=True`. Otherwise, the # `SparseTensor` will by default re-sort the neighbors according to # column index. # As such, we probably want to consider re-adding error: # raise NotImplementedError("Temporal sampling via 'SparseTensor' " # "format not yet supported") pass if to_transpose: row, colptr, _ = data.adj_t.csc() else: colptr, row, _ = data.adj_t.csr() elif data.edge_index is not None: if to_transpose: col, row = data.edge_index else: row, col = data.edge_index if not is_sorted: row, col, perm = sort_csc(row, col, src_node_time, edge_time) colptr = index2ptr(col, data.size(1) if not to_transpose else data.size(0)) else: row = torch.empty(0, dtype=torch.long, device=device) colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long, device=device) colptr = colptr.to(device) row = row.to(device) perm = perm.to(device) if perm is not None else None if not colptr.is_cuda and share_memory: colptr.share_memory_() row.share_memory_() if perm is not None: perm.share_memory_() return colptr, row, perm def to_hetero_csc( data: HeteroData, device: Optional[torch.device] = None, share_memory: bool = False, is_sorted: bool = False, node_time_dict: Optional[Dict[NodeType, Tensor]] = None, edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None, to_transpose: bool = False, ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]: # Convert the heterogeneous graph data into a suitable format for sampling # (CSC format). # Returns dictionaries holding `colptr` and `row` indices as well as edge # permutations for each edge type, respectively. colptr_dict, row_dict, perm_dict = {}, {}, {} for edge_type, store in data.edge_items(): src_node_time = (node_time_dict or {}).get(edge_type[0], None) edge_time = (edge_time_dict or {}).get(edge_type, None) out = to_csc(store, device, share_memory, is_sorted, src_node_time, edge_time, to_transpose) # Edge types need to be reversed for backward sampling: if to_transpose: edge_type = reverse_edge_type(edge_type) colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out return colptr_dict, row_dict, perm_dict def to_bidirectional( row: Tensor, col: Tensor, rev_row: Tensor, rev_col: Tensor, edge_id: OptTensor = None, rev_edge_id: OptTensor = None, ) -> Tuple[Tensor, Tensor, OptTensor]: assert row.numel() == col.numel() assert rev_row.numel() == rev_col.numel() edge_index = row.new_empty(2, row.numel() + rev_row.numel()) edge_index[0, :row.numel()] = row edge_index[1, :row.numel()] = col edge_index[0, row.numel():] = rev_col edge_index[1, row.numel():] = rev_row if edge_id is not None: edge_id = torch.cat([edge_id, rev_edge_id], dim=0) (row, col), edge_id = coalesce( edge_index, edge_id, sort_by_row=False, reduce='any', ) return row, col, edge_id ############################################################################### X, Y = TypeVar('X'), TypeVar('Y') def remap_keys( inputs: Dict[X, Any], mapping: Dict[X, Y], exclude: Optional[List[X]] = None, ) -> Dict[Union[X, Y], Any]: exclude = exclude or [] return { k if k in exclude else mapping.get(k, k): v for k, v in inputs.items() } def local_to_global_node_idx(node_values: Tensor, local_indices: Tensor) -> Tensor: """Convert a tensor of indices referring to elements in the node_values tensor to their values. Args: node_values (Tensor): The node values. (num_nodes, feature_dim) local_indices (Tensor): The local indices. (num_indices) Returns: Tensor: The values of the node_values tensor at the local indices. (num_indices, feature_dim) """ return torch.index_select(node_values, dim=0, index=local_indices) def global_to_local_node_idx(node_values: Tensor, local_values: Tensor) -> Tensor: """Converts a tensor of values that are contained in the node_values tensor to their indices in that tensor. Args: node_values (Tensor): The node values. (num_nodes, feature_dim) local_values (Tensor): The local values. (num_indices, feature_dim) Returns: Tensor: The indices of the local values in the node_values tensor. (num_indices) """ if node_values.dim() == 1: node_values = node_values.unsqueeze(1) if local_values.dim() == 1: local_values = local_values.unsqueeze(1) node_values_expand = node_values.unsqueeze(-1).expand( *node_values.shape, local_values.shape[0]) # (num_nodes, feature_dim, num_indices) local_values_expand = local_values.transpose(0, 1).unsqueeze(0).expand( *node_values_expand.shape) # (num_nodes, feature_dim, num_indices) idx_match = torch.all(node_values_expand == local_values_expand, dim=1).nonzero() # (num_indices, 2) sort_idx = torch.argsort(idx_match[:, 1]) return idx_match[:, 0][sort_idx] def unique_unsorted(tensor: Tensor) -> Tensor: """Returns the unique elements of a tensor while preserving the original order. Necessary because torch.unique() ignores sort parameter. """ seen = set() output = [] for val in tensor: val = tuple(val.tolist()) if val not in seen: seen.add(val) output.append(val) return torch.tensor(output, dtype=tensor.dtype, device=tensor.device).reshape((-1, *tensor.shape[1:])) ================================================ FILE: torch_geometric/seed.py ================================================ import random import numpy as np import torch def seed_everything(seed: int) -> None: r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, :obj:`numpy` and :python:`Python`. Args: seed (int): The desired seed. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) ================================================ FILE: torch_geometric/template.py ================================================ import importlib import os.path as osp import sys import tempfile from typing import Any from jinja2 import Environment, FileSystemLoader def module_from_template( module_name: str, template_path: str, tmp_dirname: str, **kwargs: Any, ) -> Any: if module_name in sys.modules: # If module is already loaded, return it: return sys.modules[module_name] env = Environment(loader=FileSystemLoader(osp.dirname(template_path))) template = env.get_template(osp.basename(template_path)) module_repr = template.render(**kwargs) with tempfile.NamedTemporaryFile( mode='w', prefix=f'{module_name}_', suffix='.py', delete=False, ) as tmp: tmp.write(module_repr) tmp.flush() spec = importlib.util.spec_from_file_location(module_name, tmp.name) assert spec is not None module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module assert spec.loader is not None spec.loader.exec_module(module) return module ================================================ FILE: torch_geometric/testing/__init__.py ================================================ r"""Testing package. This package provides helper methods and decorators to ease testing. """ from .decorators import ( is_full_test, onlyFullTest, is_distributed_test, onlyDistributedTest, onlyLinux, noWindows, noMac, minPython, onlyCUDA, onlyXPU, onlyOnline, onlyGraphviz, onlyNeighborSampler, has_package, withPackage, withDevice, withCUDA, withMETIS, withHashTensor, disableExtensions, withoutExtensions, ) from .asserts import assert_module from .feature_store import MyFeatureStore from .graph_store import MyGraphStore from .data import ( get_random_edge_index, get_random_tensor_frame, FakeHeteroDataset, ) __all__ = [ 'is_full_test', 'onlyFullTest', 'is_distributed_test', 'onlyDistributedTest', 'onlyLinux', 'noWindows', 'noMac', 'minPython', 'onlyCUDA', 'onlyXPU', 'onlyOnline', 'onlyGraphviz', 'onlyNeighborSampler', 'has_package', 'withPackage', 'withDevice', 'withCUDA', 'withMETIS', 'withHashTensor', 'disableExtensions', 'withoutExtensions', 'assert_module', 'MyFeatureStore', 'MyGraphStore', 'get_random_edge_index', 'get_random_tensor_frame', 'FakeHeteroDataset', ] ================================================ FILE: torch_geometric/testing/asserts.py ================================================ import copy import warnings from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor from torch_geometric.utils import to_torch_coo_tensor, to_torch_csc_tensor SPARSE_LAYOUTS: List[Union[str, torch.layout]] = [ 'torch_sparse', torch.sparse_csc, torch.sparse_coo ] def assert_module( module: torch.nn.Module, x: Any, edge_index: Tensor, *, expected_size: Tuple[int, ...], test_edge_permutation: bool = True, test_node_permutation: bool = False, test_sparse_layouts: Optional[List[Union[str, torch.layout]]] = None, sparse_size: Optional[Tuple[int, int]] = None, atol: float = 1e-08, rtol: float = 1e-05, equal_nan: bool = False, **kwargs: Any, ) -> Any: r"""Asserts that the output of a :obj:`module` is correct. Specifically, this method tests that: 1. The module output has the correct shape. 2. The module is invariant to the permutation of edges. 3. The module is invariant to the permutation of nodes. 4. The module is invariant to the layout of :obj:`edge_index`. Args: module (torch.nn.Module): The module to test. x (Any): The input features to the module. edge_index (torch.Tensor): The input edge indices. expected_size (Tuple[int, ...]): The expected output size. test_edge_permutation (bool, optional): If set to :obj:`False`, will not test the module for edge permutation invariance. test_node_permutation (bool, optional): If set to :obj:`False`, will not test the module for node permutation invariance. test_sparse_layouts (List[str or int], optional): The sparse layouts to test for module invariance. (default: :obj:`["torch_sparse", torch.sparse_csc, torch.sparse_coo]`) sparse_size (Tuple[int, int], optional): The size of the sparse adjacency matrix. If not given, will try to automatically infer it. (default: :obj:`None`) atol (float, optional): Absolute tolerance. (default: :obj:`1e-08`) rtol (float, optional): Relative tolerance. (default: :obj:`1e-05`) equal_nan (bool, optional): If set to :obj:`True`, then two :obj:`NaN`s will be considered equal. (default: :obj:`False`) **kwargs (optional): Additional arguments passed to :meth:`module.forward`. """ if test_sparse_layouts is None: test_sparse_layouts = SPARSE_LAYOUTS if sparse_size is None: if 'size' in kwargs: sparse_size = kwargs['size'] elif isinstance(x, Tensor): sparse_size = (x.size(0), x.size(0)) elif (isinstance(x, (tuple, list)) and isinstance(x[0], Tensor) and isinstance(x[1], Tensor)): sparse_size = (x[0].size(0), x[1].size(0)) if len(test_sparse_layouts) > 0 and sparse_size is None: raise ValueError(f"Got sparse layouts {test_sparse_layouts}, but no " f"'sparse_size' were specified") expected = module(x, edge_index=edge_index, **kwargs) assert expected.size() == expected_size if test_edge_permutation: perm = torch.randperm(edge_index.size(1)) perm_kwargs = copy.copy(kwargs) for key, value in kwargs.items(): if isinstance(value, Tensor) and value.size(0) == perm.numel(): perm_kwargs[key] = value[perm] out = module(x, edge_index[:, perm], **perm_kwargs) assert torch.allclose(out, expected, rtol, atol, equal_nan) if test_node_permutation: raise NotImplementedError for layout in (test_sparse_layouts or []): # TODO Add support for values. if layout == 'torch_sparse': if not WITH_TORCH_SPARSE: continue adj = SparseTensor.from_edge_index( edge_index, sparse_sizes=sparse_size, ) adj_t = adj.t() elif layout == torch.sparse_csc: adj = to_torch_csc_tensor(edge_index, size=sparse_size) adj_t = adj.t() elif layout == torch.sparse_coo: warnings.filterwarnings('ignore', ".*to CSR format.*") adj = to_torch_coo_tensor(edge_index, size=sparse_size) adj_t = adj.t().coalesce() else: raise ValueError(f"Got invalid sparse layout '{layout}'") out = module(x, adj_t, **kwargs) assert torch.allclose(out, expected, rtol, atol, equal_nan) return expected ================================================ FILE: torch_geometric/testing/data.py ================================================ from typing import Callable, Optional import torch from torch import Tensor from torch_geometric.data import HeteroData, InMemoryDataset from torch_geometric.typing import TensorFrame, torch_frame from torch_geometric.utils import coalesce as coalesce_fn def get_random_edge_index( num_src_nodes: int, num_dst_nodes: int, num_edges: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, coalesce: bool = False, ) -> Tensor: row = torch.randint(num_src_nodes, (num_edges, ), dtype=dtype, device=device) col = torch.randint(num_dst_nodes, (num_edges, ), dtype=dtype, device=device) edge_index = torch.stack([row, col], dim=0) if coalesce: edge_index = coalesce_fn(edge_index) return edge_index def get_random_tensor_frame( num_rows: int, device: Optional[torch.device] = None, ) -> TensorFrame: feat_dict = { torch_frame.categorical: torch.randint(0, 3, size=(num_rows, 3), device=device), torch_frame.numerical: torch.randn(size=(num_rows, 2), device=device), } col_names_dict = { torch_frame.categorical: ['a', 'b', 'c'], torch_frame.numerical: ['x', 'y'], } y = torch.randn(num_rows, device=device) return torch_frame.TensorFrame( feat_dict=feat_dict, col_names_dict=col_names_dict, y=y, ) class FakeHeteroDataset(InMemoryDataset): def __init__(self, transform: Optional[Callable] = None): super().__init__(transform=transform) data = HeteroData() num_papers = 100 num_authors = 10 data['paper'].x = torch.randn(num_papers, 16) data['author'].x = torch.randn(num_authors, 8) edge_index = get_random_edge_index( num_src_nodes=num_papers, num_dst_nodes=num_authors, num_edges=300, ) data['paper', 'author'].edge_index = edge_index data['author', 'paper'].edge_index = edge_index.flip([0]) data['paper'].y = torch.randint(0, 4, (num_papers, )) perm = torch.randperm(num_papers) data['paper'].train_mask = torch.zeros(num_papers, dtype=torch.bool) data['paper'].train_mask[perm[0:60]] = True data['paper'].val_mask = torch.zeros(num_papers, dtype=torch.bool) data['paper'].val_mask[perm[60:80]] = True data['paper'].test_mask = torch.zeros(num_papers, dtype=torch.bool) data['paper'].test_mask[perm[80:100]] = True self.data, self.slices = self.collate([data]) ================================================ FILE: torch_geometric/testing/decorators.py ================================================ import os import sys import warnings from importlib import import_module from importlib.util import find_spec from typing import Callable import torch from packaging.requirements import Requirement from packaging.version import Version import torch_geometric import torch_geometric.typing from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE from torch_geometric.visualization.graph import has_graphviz def is_full_test() -> bool: r"""Whether to run the full but time-consuming test suite.""" return os.getenv('FULL_TEST', '0') == '1' def onlyFullTest(func: Callable) -> Callable: r"""A decorator to specify that this function belongs to the full test suite. """ import pytest return pytest.mark.skipif( not is_full_test(), reason="Fast test run", )(func) def is_distributed_test() -> bool: r"""Whether to run the distributed test suite.""" return (os.getenv('DIST_TEST', '0') == '1' and sys.platform == 'linux' and has_package('pyg_lib')) def onlyDistributedTest(func: Callable) -> Callable: r"""A decorator to specify that this function belongs to the distributed test suite. """ import pytest return pytest.mark.skipif( not is_distributed_test(), reason="Fast test run", )(func) def onlyLinux(func: Callable) -> Callable: r"""A decorator to specify that this function should only execute on Linux systems. """ import pytest return pytest.mark.skipif( sys.platform != 'linux', reason="No Linux system", )(func) def noWindows(func: Callable) -> Callable: r"""A decorator to specify that this function should not execute on Windows systems. """ import pytest return pytest.mark.skipif( os.name == 'nt', reason="Windows system", )(func) def noMac(func: Callable) -> Callable: r"""A decorator to specify that this function should not execute on macOS systems. """ import pytest return pytest.mark.skipif( sys.platform == 'darwin', reason="macOS system", )(func) def minPython(version: str) -> Callable: r"""A decorator to run tests on specific :python:`Python` versions only.""" def decorator(func: Callable) -> Callable: import pytest major, minor = version.split('.') skip = False if sys.version_info.major < int(major): skip = True if (sys.version_info.major == int(major) and sys.version_info.minor < int(minor)): skip = True return pytest.mark.skipif( skip, reason=f"Python {version} required", )(func) return decorator def onlyCUDA(func: Callable) -> Callable: r"""A decorator to skip tests if CUDA is not found.""" import pytest return pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", )(func) def onlyXPU(func: Callable) -> Callable: r"""A decorator to skip tests if XPU is not found.""" import pytest return pytest.mark.skipif( not torch_geometric.is_xpu_available(), reason="XPU not available", )(func) def onlyOnline(func: Callable) -> Callable: r"""A decorator to skip tests if there exists no connection to the internet. """ import http.client as httplib import pytest has_connection = True connection = httplib.HTTPSConnection('8.8.8.8', timeout=5) try: connection.request('HEAD', '/') except Exception: has_connection = False finally: connection.close() return pytest.mark.skipif( not has_connection, reason="No internet connection", )(func) def onlyGraphviz(func: Callable) -> Callable: r"""A decorator to specify that this function should only execute in case :obj:`graphviz` is installed. """ import pytest return pytest.mark.skipif( not has_graphviz(), reason="Graphviz not installed", )(func) def onlyNeighborSampler(func: Callable) -> Callable: r"""A decorator to skip tests if no neighborhood sampler package is installed. """ import pytest return pytest.mark.skipif( not WITH_PYG_LIB and not WITH_TORCH_SPARSE, reason="No neighbor sampler installed", )(func) def has_package(package: str) -> bool: r"""Returns :obj:`True` in case :obj:`package` is installed.""" if '|' in package: return any(has_package(p) for p in package.split('|')) req = Requirement(package) if find_spec(req.name) is None: return False try: module = import_module(req.name) if not hasattr(module, '__version__'): return True version = Version(module.__version__).base_version return version in req.specifier except Exception: return False def withPackage(*args: str) -> Callable: r"""A decorator to skip tests if certain packages are not installed. Also supports version specification. """ na_packages = {package for package in args if not has_package(package)} if len(na_packages) == 1: reason = f"Package {list(na_packages)[0]} not found" else: reason = f"Packages {na_packages} not found" def decorator(func: Callable) -> Callable: import pytest return pytest.mark.skipif(len(na_packages) > 0, reason=reason)(func) return decorator def withCUDA(func: Callable) -> Callable: r"""A decorator to test both on CPU and CUDA (if available).""" import pytest devices = [pytest.param(torch.device('cpu'), id='cpu')] if torch.cuda.is_available(): devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0')) return pytest.mark.parametrize('device', devices)(func) def withDevice(func: Callable) -> Callable: r"""A decorator to test on all available tensor processing devices.""" import pytest devices = [pytest.param(torch.device('cpu'), id='cpu')] if torch.cuda.is_available(): devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0')) if torch_geometric.is_mps_available(): devices.append(pytest.param(torch.device('mps:0'), id='mps')) if torch_geometric.is_xpu_available(): devices.append(pytest.param(torch.device('xpu:0'), id='xpu')) # Additional devices can be registered through environment variables: device = os.getenv('TORCH_DEVICE') if device: backend = os.getenv('TORCH_BACKEND') if backend is None: warnings.warn( f"Please specify the backend via 'TORCH_BACKEND' in" f"order to test against '{device}'", stacklevel=2) else: import_module(backend) devices.append(pytest.param(torch.device(device), id=device)) return pytest.mark.parametrize('device', devices)(func) def withMETIS(func: Callable) -> Callable: r"""A decorator to only test in case a valid METIS method is available.""" import pytest with_metis = WITH_METIS if with_metis: try: # Test that METIS can successfully execute: # TODO Using `pyg-lib` metis partitioning leads to some weird bugs # in the # CI. As such, we require `torch-sparse` for now. rowptr = torch.tensor([0, 2, 4, 6]) col = torch.tensor([1, 2, 0, 2, 1, 0]) torch.ops.torch_sparse.partition(rowptr, col, None, 2, True) except Exception: with_metis = False return pytest.mark.skipif( not with_metis, reason="METIS not enabled", )(func) def withHashTensor(func: Callable) -> Callable: r"""A decorator to only test in case :class:`HashTensor` is available.""" import pytest return pytest.mark.skipif( not torch_geometric.typing.WITH_CPU_HASH_MAP and not has_package('pandas'), reason="HashTensor dependencies not available", )(func) def disableExtensions(func: Callable) -> Callable: r"""A decorator to temporarily disable the usage of the :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension packages. """ import pytest return pytest.mark.usefixtures('disable_extensions')(func) def withoutExtensions(func: Callable) -> Callable: r"""A decorator to test both with and without the usage of extension packages such as :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib`. """ import pytest return pytest.mark.parametrize( 'without_extensions', ['enable_extensions', 'disable_extensions'], indirect=True, )(func) ================================================ FILE: torch_geometric/testing/distributed.py ================================================ import sys import traceback from dataclasses import dataclass from io import StringIO from typing import Any, Callable, List, Tuple import pytest from torch.multiprocessing import Manager, Queue from typing_extensions import Self @dataclass class ProcArgs: target: Callable args: Tuple[Any, ...] class MPCaptOutput: def __enter__(self) -> Self: self.stdout = StringIO() self.stderr = StringIO() self.old_stdout = sys.stdout self.old_stderr = sys.stderr sys.stdout = self.stdout sys.stderr = self.stderr return self def __exit__(self, *args: Any) -> None: sys.stdout = self.old_stdout sys.stderr = self.old_stderr @property def stdout_str(self) -> str: return self.stdout.getvalue() @property def stderr_str(self) -> str: return self.stderr.getvalue() def ps_std_capture( func: Callable, queue: Queue, *args: Any, **kwargs: Any, ) -> None: with MPCaptOutput() as capt: try: func(*args, **kwargs) except Exception as e: traceback.print_exc(file=sys.stderr) raise e finally: queue.put((capt.stdout_str, capt.stderr_str)) def assert_run_mproc( mp_context: Any, pargs: List[ProcArgs], full_trace: bool = False, timeout: int = 5, ) -> None: manager = Manager() world_size = len(pargs) queues = [manager.Queue() for _ in pargs] procs = [ mp_context.Process( target=ps_std_capture, args=[p.target, q, world_size] + list(p.args), ) for p, q in zip(pargs, queues) ] results = [] for p, _ in zip(procs, queues): p.start() for p, q in zip(procs, queues): p.join() stdout, stderr = q.get(timeout=timeout) results.append((p, stdout, stderr)) for p, stdout, stderr in results: if stdout: print(stdout) if stderr: # can be a warning as well => exitcode == 0 print(stderr) if p.exitcode != 0: pytest.fail( pytrace=full_trace, reason=stderr.splitlines()[-1] if stderr else f"exitcode {p.exitcode}") ================================================ FILE: torch_geometric/testing/feature_store.py ================================================ from typing import Dict, List, Optional, Tuple import torch from torch import Tensor from torch_geometric.data import FeatureStore, TensorAttr from torch_geometric.typing import FeatureTensorType KeyType = Tuple[Optional[str], Optional[str]] class MyFeatureStore(FeatureStore): def __init__(self) -> None: super().__init__() self.store: Dict[KeyType, Tuple[Tensor, Tensor]] = {} @staticmethod def key(attr: TensorAttr) -> KeyType: return (attr.group_name, attr.attr_name) def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: index = attr.index # None indices define the obvious index: if index is None: index = torch.arange(0, tensor.shape[0]) # Store the index: assert isinstance(index, Tensor) assert isinstance(tensor, Tensor) self.store[self.key(attr)] = (index, tensor) return True def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]: index, tensor = self.store.get(self.key(attr), (None, None)) if tensor is None: raise KeyError(f"Could not find tensor for '{attr}'") assert isinstance(tensor, Tensor) # None indices return the whole tensor: if attr.index is None: return tensor # Empty slices return the whole tensor: if (isinstance(attr.index, slice) and attr.index == slice(None, None, None)): return tensor assert isinstance(attr.index, Tensor) if attr.index.numel() == 0: return tensor[attr.index] idx = torch.cat([(index == v).nonzero() for v in attr.index]).view(-1) return tensor[idx] def _remove_tensor(self, attr: TensorAttr) -> bool: return self.store.pop(self.key(attr), None) is not None def _get_tensor_size(self, attr: TensorAttr) -> Optional[Tuple[int, ...]]: tensor = self._get_tensor(attr) return tensor.size() if tensor is not None else None def get_all_tensor_attrs(self) -> List[TensorAttr]: return [self._tensor_attr_cls.cast(*key) for key in self.store.keys()] ================================================ FILE: torch_geometric/testing/graph_store.py ================================================ from typing import Dict, List, Optional, Tuple from torch import Tensor from torch_geometric.data import EdgeAttr, GraphStore from torch_geometric.typing import EdgeTensorType class MyGraphStore(GraphStore): def __init__(self) -> None: super().__init__() self.store: Dict[Tuple, Tuple[Tensor, Tensor]] = {} @staticmethod def key(attr: EdgeAttr) -> Tuple: return (attr.edge_type, attr.layout.value, attr.is_sorted, attr.size) def _put_edge_index( self, edge_index: EdgeTensorType, edge_attr: EdgeAttr, ) -> bool: self.store[self.key(edge_attr)] = edge_index return True def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: return self.store.get(self.key(edge_attr), None) def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: return self.store.pop(self.key(edge_attr), None) is not None def get_all_edge_attrs(self) -> List[EdgeAttr]: return [EdgeAttr(*key) for key in self.store.keys()] ================================================ FILE: torch_geometric/transforms/__init__.py ================================================ # flake8: noqa from .base_transform import BaseTransform from .compose import Compose, ComposeFilters from .to_device import ToDevice from .to_sparse_tensor import ToSparseTensor from .constant import Constant from .normalize_features import NormalizeFeatures from .svd_feature_reduction import SVDFeatureReduction from .remove_training_classes import RemoveTrainingClasses from .random_node_split import RandomNodeSplit from .random_link_split import RandomLinkSplit from .node_property_split import NodePropertySplit from .mask import IndexToMask, MaskToIndex from .pad import Pad from .to_undirected import ToUndirected from .one_hot_degree import OneHotDegree from .target_indegree import TargetIndegree from .local_degree_profile import LocalDegreeProfile from .add_self_loops import AddSelfLoops from .add_remaining_self_loops import AddRemainingSelfLoops from .remove_self_loops import RemoveSelfLoops from .remove_isolated_nodes import RemoveIsolatedNodes from .remove_duplicated_edges import RemoveDuplicatedEdges from .knn_graph import KNNGraph from .radius_graph import RadiusGraph from .to_dense import ToDense from .two_hop import TwoHop from .line_graph import LineGraph from .laplacian_lambda_max import LaplacianLambdaMax from .gdc import GDC from .sign import SIGN from .gcn_norm import GCNNorm from .add_metapaths import AddMetaPaths, AddRandomMetaPaths from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph from .largest_connected_components import LargestConnectedComponents from .virtual_node import VirtualNode from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE from .add_gpse import AddGPSE from .feature_propagation import FeaturePropagation from .half_hop import HalfHop from .distance import Distance from .cartesian import Cartesian from .local_cartesian import LocalCartesian from .polar import Polar from .spherical import Spherical from .point_pair_features import PointPairFeatures from .center import Center from .normalize_rotation import NormalizeRotation from .normalize_scale import NormalizeScale from .random_jitter import RandomJitter from .random_flip import RandomFlip from .linear_transformation import LinearTransformation from .random_scale import RandomScale from .random_rotate import RandomRotate from .random_shear import RandomShear from .face_to_edge import FaceToEdge from .sample_points import SamplePoints from .fixed_points import FixedPoints from .generate_mesh_normals import GenerateMeshNormals from .delaunay import Delaunay from .to_superpixels import ToSLIC from .grid_sampling import GridSampling general_transforms = [ 'BaseTransform', 'Compose', 'ComposeFilters', 'ToDevice', 'ToSparseTensor', 'Constant', 'NormalizeFeatures', 'SVDFeatureReduction', 'RemoveTrainingClasses', 'RandomNodeSplit', 'RandomLinkSplit', 'NodePropertySplit', 'IndexToMask', 'MaskToIndex', 'Pad', ] graph_transforms = [ 'ToUndirected', 'OneHotDegree', 'TargetIndegree', 'LocalDegreeProfile', 'AddSelfLoops', 'AddRemainingSelfLoops', 'RemoveSelfLoops', 'RemoveIsolatedNodes', 'RemoveDuplicatedEdges', 'KNNGraph', 'RadiusGraph', 'ToDense', 'TwoHop', 'LineGraph', 'LaplacianLambdaMax', 'GDC', 'SIGN', 'GCNNorm', 'AddMetaPaths', 'AddRandomMetaPaths', 'RootedEgoNets', 'RootedRWSubgraph', 'LargestConnectedComponents', 'VirtualNode', 'AddLaplacianEigenvectorPE', 'AddRandomWalkPE', 'AddGPSE', 'FeaturePropagation', 'HalfHop', ] vision_transforms = [ 'Distance', 'Cartesian', 'LocalCartesian', 'Polar', 'Spherical', 'PointPairFeatures', 'Center', 'NormalizeRotation', 'NormalizeScale', 'RandomJitter', 'RandomFlip', 'LinearTransformation', 'RandomScale', 'RandomRotate', 'RandomShear', 'FaceToEdge', 'SamplePoints', 'FixedPoints', 'GenerateMeshNormals', 'Delaunay', 'ToSLIC', 'GridSampling', ] __all__ = general_transforms + graph_transforms + vision_transforms from torch_geometric.deprecation import deprecated RandomTranslate = deprecated("use 'transforms.RandomJitter' instead", 'transforms.RandomTranslate')(RandomJitter) ================================================ FILE: torch_geometric/transforms/add_gpse.py ================================================ from typing import Any from torch.nn import Module from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform, VirtualNode @functional_transform('add_gpse') class AddGPSE(BaseTransform): r"""Adds the GPSE encoding from the `"Graph Positional and Structural Encoder" `_ paper to the given graph (functional name: :obj:`add_gpse`). To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates the actual encodings. Args: model (Module): The pre-trained GPSE model. use_vn (bool, optional): Whether to use virtual nodes. (default: :obj:`True`) rand_type (str, optional): Type of random features to use. Options are :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`. (default: :obj:`NormalSE`) """ def __init__( self, model: Module, use_vn: bool = True, rand_type: str = 'NormalSE', ): self.model = model self.use_vn = use_vn self.vn = VirtualNode() self.rand_type = rand_type def forward(self, data: Data) -> Any: pass def __call__(self, data: Data) -> Data: from torch_geometric.nn.models.gpse import gpse_process data_vn = self.vn(data.clone()) if self.use_vn else data.clone() batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn) batch_out = batch_out.to('cpu', non_blocking=True) data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out return data ================================================ FILE: torch_geometric/transforms/add_metapaths.py ================================================ import warnings from typing import List, Optional, Tuple, Union, cast import torch from torch import Tensor from torch_geometric import EdgeIndex from torch_geometric.data import HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.typing import EdgeType from torch_geometric.utils import coalesce, degree @functional_transform('add_metapaths') class AddMetaPaths(BaseTransform): r"""Adds additional edge types to a :class:`~torch_geometric.data.HeteroData` object between the source node type and the destination node type of a given :obj:`metapath`, as described in the `"Heterogenous Graph Attention Networks" `_ paper (functional name: :obj:`add_metapaths`). Meta-path based neighbors can exploit different aspects of structure information in heterogeneous graphs. Formally, a metapath is a path of the form .. math:: \mathcal{V}_1 \xrightarrow{R_1} \mathcal{V}_2 \xrightarrow{R_2} \ldots \xrightarrow{R_{\ell-1}} \mathcal{V}_{\ell} in which :math:`\mathcal{V}_i` represents node types, and :math:`R_j` represents the edge type connecting two node types. The added edge type is given by the sequential multiplication of adjacency matrices along the metapath, and is added to the :class:`~torch_geometric.data.HeteroData` object as edge type :obj:`(src_node_type, "metapath_*", dst_node_type)`, where :obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\mathcal{V}_1` and :math:`\mathcal{V}_{\ell}`, respectively. In addition, a :obj:`metapath_dict` object is added to the :class:`~torch_geometric.data.HeteroData` object which maps the metapath-based edge type to its original metapath. .. code-block:: python from torch_geometric.datasets import DBLP from torch_geometric.data import HeteroData from torch_geometric.transforms import AddMetaPaths data = DBLP(root)[0] # 4 node types: "paper", "author", "conference", and "term" # 6 edge types: ("paper","author"), ("author", "paper"), # ("paper, "term"), ("paper", "conference"), # ("term, "paper"), ("conference", "paper") # Add two metapaths: # 1. From "paper" to "paper" through "conference" # 2. From "author" to "conference" through "paper" metapaths = [[("paper", "conference"), ("conference", "paper")], [("author", "paper"), ("paper", "conference")]] data = AddMetaPaths(metapaths)(data) print(data.edge_types) >>> [("author", "to", "paper"), ("paper", "to", "author"), ("paper", "to", "term"), ("paper", "to", "conference"), ("term", "to", "paper"), ("conference", "to", "paper"), ("paper", "metapath_0", "paper"), ("author", "metapath_1", "conference")] print(data.metapath_dict) >>> {("paper", "metapath_0", "paper"): [("paper", "conference"), ("conference", "paper")], ("author", "metapath_1", "conference"): [("author", "paper"), ("paper", "conference")]} Args: metapaths (List[List[Tuple[str, str, str]]]): The metapaths described by a list of lists of :obj:`(src_node_type, rel_type, dst_node_type)` tuples. drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing edge types will be dropped. (default: :obj:`False`) keep_same_node_type (bool, optional): If set to :obj:`True`, existing edge types between the same node type are not dropped even in case :obj:`drop_orig_edge_types` is set to :obj:`True`. (default: :obj:`False`) drop_unconnected_node_types (bool, optional): If set to :obj:`True`, will drop node types not connected by any edge type. (default: :obj:`False`) max_sample (int, optional): If set, will sample at maximum :obj:`max_sample` neighbors within metapaths. Useful in order to tackle very dense metapath edges. (default: :obj:`None`) weighted (bool, optional): If set to :obj:`True`, computes weights for each metapath edge and stores them in :obj:`edge_weight`. The weight of each metapath edge is computed as the number of metapaths from the start to the end of the metapath edge. (default :obj:`False`) """ def __init__( self, metapaths: List[List[EdgeType]], drop_orig_edge_types: bool = False, keep_same_node_type: bool = False, drop_unconnected_node_types: bool = False, max_sample: Optional[int] = None, weighted: bool = False, **kwargs: bool, ) -> None: if 'drop_orig_edges' in kwargs: warnings.warn( "'drop_orig_edges' is deprecated. Use " "'drop_orig_edge_types' instead", stacklevel=2) drop_orig_edge_types = kwargs['drop_orig_edges'] if 'drop_unconnected_nodes' in kwargs: warnings.warn( "'drop_unconnected_nodes' is deprecated. Use " "'drop_unconnected_node_types' instead", stacklevel=2) drop_unconnected_node_types = kwargs['drop_unconnected_nodes'] for path in metapaths: assert len(path) >= 2, f"Invalid metapath '{path}'" assert all([ j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1]) ]), f"Invalid sequence of node types in '{path}'" self.metapaths = metapaths self.drop_orig_edge_types = drop_orig_edge_types self.keep_same_node_type = keep_same_node_type self.drop_unconnected_node_types = drop_unconnected_node_types self.max_sample = max_sample self.weighted = weighted def forward(self, data: HeteroData) -> HeteroData: edge_types = data.edge_types # Save original edge types. data.metapath_dict = {} for j, metapath in enumerate(self.metapaths): for edge_type in metapath: assert data._to_canonical(edge_type) in edge_types edge_type = metapath[0] edge_index, edge_weight = self._edge_index(data, edge_type) if self.max_sample is not None: edge_index, edge_weight = self._sample(edge_index, edge_weight) for edge_type in metapath[1:]: edge_index2, edge_weight2 = self._edge_index(data, edge_type) edge_index, edge_weight = edge_index.matmul( edge_index2, edge_weight, edge_weight2) if not self.weighted: edge_weight = None if self.max_sample is not None: edge_index, edge_weight = self._sample( edge_index, edge_weight) new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1]) data[new_edge_type].edge_index = edge_index.as_tensor() if self.weighted: data[new_edge_type].edge_weight = edge_weight data.metapath_dict[new_edge_type] = metapath postprocess(data, edge_types, self.drop_orig_edge_types, self.keep_same_node_type, self.drop_unconnected_node_types) return data def _edge_index( self, data: HeteroData, edge_type: EdgeType, ) -> Tuple[EdgeIndex, Optional[Tensor]]: edge_index = EdgeIndex( data[edge_type].edge_index, sparse_size=data[edge_type].size(), ) edge_index, perm = edge_index.sort_by('row') if not self.weighted: return edge_index, None edge_weight = data[edge_type].get('edge_weight') if edge_weight is not None: assert edge_weight.dim() == 1 edge_weight = edge_weight[perm] return edge_index, edge_weight def _sample( self, edge_index: EdgeIndex, edge_weight: Optional[Tensor], ) -> Tuple[EdgeIndex, Optional[Tensor]]: assert self.max_sample is not None deg = degree(edge_index[0], num_nodes=edge_index.get_sparse_size(0)) prob = (self.max_sample * (1. / deg))[edge_index[0]] mask = torch.rand_like(prob) < prob edge_index = cast(EdgeIndex, edge_index[:, mask]) assert isinstance(edge_index, EdgeIndex) if edge_weight is not None: edge_weight = edge_weight[mask] return edge_index, edge_weight @functional_transform('add_random_metapaths') class AddRandomMetaPaths(BaseTransform): r"""Adds additional edge types similar to :class:`AddMetaPaths`. The key difference is that the added edge type is given by multiple random walks along the metapath. One might want to increase the number of random walks via :obj:`walks_per_node` to achieve competitive performance with :class:`AddMetaPaths`. Args: metapaths (List[List[Tuple[str, str, str]]]): The metapaths described by a list of lists of :obj:`(src_node_type, rel_type, dst_node_type)` tuples. drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing edge types will be dropped. (default: :obj:`False`) keep_same_node_type (bool, optional): If set to :obj:`True`, existing edge types between the same node type are not dropped even in case :obj:`drop_orig_edge_types` is set to :obj:`True`. (default: :obj:`False`) drop_unconnected_node_types (bool, optional): If set to :obj:`True`, will drop node types not connected by any edge type. (default: :obj:`False`) walks_per_node (int, List[int], optional): The number of random walks for each starting node in a metapath. (default: :obj:`1`) sample_ratio (float, optional): The ratio of source nodes to start random walks from. (default: :obj:`1.0`) """ def __init__( self, metapaths: List[List[EdgeType]], drop_orig_edge_types: bool = False, keep_same_node_type: bool = False, drop_unconnected_node_types: bool = False, walks_per_node: Union[int, List[int]] = 1, sample_ratio: float = 1.0, ): for path in metapaths: assert len(path) >= 2, f"Invalid metapath '{path}'" assert all([ j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1]) ]), f"Invalid sequence of node types in '{path}'" self.metapaths = metapaths self.drop_orig_edge_types = drop_orig_edge_types self.keep_same_node_type = keep_same_node_type self.drop_unconnected_node_types = drop_unconnected_node_types self.sample_ratio = sample_ratio if isinstance(walks_per_node, int): walks_per_node = [walks_per_node] * len(metapaths) assert len(walks_per_node) == len(metapaths) self.walks_per_node = walks_per_node def forward(self, data: HeteroData) -> HeteroData: edge_types = data.edge_types # save original edge types data.metapath_dict = {} for j, metapath in enumerate(self.metapaths): for edge_type in metapath: assert data._to_canonical( edge_type) in edge_types, f"'{edge_type}' not present" src_node = metapath[0][0] num_nodes = data[src_node].num_nodes num_starts = round(num_nodes * self.sample_ratio) row = start = torch.randperm(num_nodes)[:num_starts].repeat( self.walks_per_node[j]) for edge_type in metapath: edge_index = EdgeIndex( data[edge_type].edge_index, sparse_size=data[edge_type].size(), ) col, mask = self.sample(edge_index, start) row, col = row[mask], col[mask] start = col new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1]) data[new_edge_type].edge_index = coalesce(torch.vstack([row, col])) data.metapath_dict[new_edge_type] = metapath postprocess(data, edge_types, self.drop_orig_edge_types, self.keep_same_node_type, self.drop_unconnected_node_types) return data @staticmethod def sample(edge_index: EdgeIndex, subset: Tensor) -> Tuple[Tensor, Tensor]: """Sample neighbors from :obj:`edge_index` for each node in :obj:`subset`. """ edge_index, _ = edge_index.sort_by('row') rowptr = edge_index.get_indptr() rowcount = rowptr.diff()[subset] mask = rowcount > 0 offset = torch.zeros_like(subset) offset[mask] = rowptr[subset[mask]] rand = torch.rand((rowcount.size(0), 1), device=subset.device) rand.mul_(rowcount.to(rand.dtype).view(-1, 1)) rand = rand.to(torch.long) rand.add_(offset.view(-1, 1)) col = edge_index[1][rand].squeeze() return col, mask def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'sample_ratio={self.sample_ratio}, ' f'walks_per_node={self.walks_per_node})') def postprocess( data: HeteroData, edge_types: List[EdgeType], drop_orig_edge_types: bool, keep_same_node_type: bool, drop_unconnected_node_types: bool, ) -> None: if drop_orig_edge_types: for i in edge_types: if keep_same_node_type and i[0] == i[-1]: continue else: del data[i] # Remove nodes not connected by any edge type: if drop_unconnected_node_types: new_edge_types = data.edge_types node_types = data.node_types connected_nodes = set() for i in new_edge_types: connected_nodes.add(i[0]) connected_nodes.add(i[-1]) for node in node_types: if node not in connected_nodes: del data[node] ================================================ FILE: torch_geometric/transforms/add_positional_encoding.py ================================================ from typing import Any, Optional import numpy as np import torch from torch import Tensor import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import ( get_laplacian, get_self_loop_attr, is_torch_sparse_tensor, scatter, to_edge_index, to_scipy_sparse_matrix, to_torch_coo_tensor, to_torch_csr_tensor, ) def add_node_attr( data: Data, value: Any, attr_name: Optional[str] = None, ) -> Data: # TODO Move to `BaseTransform`. if attr_name is None: if data.x is not None: x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) else: data.x = value else: data[attr_name] = value return data @functional_transform('add_laplacian_eigenvector_pe') class AddLaplacianEigenvectorPE(BaseTransform): r"""Adds the Laplacian eigenvector positional encoding from the `"Benchmarking Graph Neural Networks" `_ paper to the given graph (functional name: :obj:`add_laplacian_eigenvector_pe`). Args: k (int): The number of non-trivial eigenvectors to consider. attr_name (str, optional): The attribute name of the data object to add positional encodings to. If set to :obj:`None`, will be concatenated to :obj:`data.x`. (default: :obj:`"laplacian_eigenvector_pe"`) is_undirected (bool, optional): If set to :obj:`True`, this transform expects undirected graphs as input, and can hence speed up the computation of eigenvectors. (default: :obj:`False`) **kwargs (optional): Additional arguments of :meth:`scipy.sparse.linalg.eigs` (when :attr:`is_undirected` is :obj:`False`) or :meth:`scipy.sparse.linalg.eigsh` (when :attr:`is_undirected` is :obj:`True`). """ # Number of nodes from which to use sparse eigenvector computation: SPARSE_THRESHOLD: int = 100 def __init__( self, k: int, attr_name: Optional[str] = 'laplacian_eigenvector_pe', is_undirected: bool = False, **kwargs: Any, ) -> None: self.k = k self.attr_name = attr_name self.is_undirected = is_undirected self.kwargs = kwargs def forward(self, data: Data) -> Data: assert data.edge_index is not None num_nodes = data.num_nodes assert num_nodes is not None edge_index, edge_weight = get_laplacian( data.edge_index, data.edge_weight, normalization='sym', num_nodes=num_nodes, ) L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes) if num_nodes < self.SPARSE_THRESHOLD: from numpy.linalg import eig, eigh eig_fn = eig if not self.is_undirected else eigh eig_vals, eig_vecs = eig_fn(L.todense()) else: from scipy.sparse.linalg import eigs, eigsh eig_fn = eigs if not self.is_undirected else eigsh eig_vals, eig_vecs = eig_fn( L, k=self.k + 1, which='SR' if not self.is_undirected else 'SA', return_eigenvectors=True, **self.kwargs, ) eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()]) pe = torch.from_numpy(eig_vecs[:, 1:self.k + 1]) sign = -1 + 2 * torch.randint(0, 2, (self.k, )) pe *= sign data = add_node_attr(data, pe, attr_name=self.attr_name) return data @functional_transform('add_random_walk_pe') class AddRandomWalkPE(BaseTransform): r"""Adds the random walk positional encoding from the `"Graph Neural Networks with Learnable Structural and Positional Representations" `_ paper to the given graph (functional name: :obj:`add_random_walk_pe`). Args: walk_length (int): The number of random walk steps. attr_name (str, optional): The attribute name of the data object to add positional encodings to. If set to :obj:`None`, will be concatenated to :obj:`data.x`. (default: :obj:`"random_walk_pe"`) """ def __init__( self, walk_length: int, attr_name: Optional[str] = 'random_walk_pe', ) -> None: self.walk_length = walk_length self.attr_name = attr_name def forward(self, data: Data) -> Data: assert data.edge_index is not None row, col = data.edge_index N = data.num_nodes assert N is not None if data.edge_weight is None: value = torch.ones(data.num_edges, device=row.device) else: value = data.edge_weight value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row] value = 1.0 / value if N <= 2_000: # Dense code path for faster computation: adj = torch.zeros((N, N), device=row.device) adj[row, col] = value loop_index = torch.arange(N, device=row.device) elif torch_geometric.typing.NO_MKL: # pragma: no cover adj = to_torch_coo_tensor(data.edge_index, value, size=data.size()) else: adj = to_torch_csr_tensor(data.edge_index, value, size=data.size()) def get_pe(out: Tensor) -> Tensor: if is_torch_sparse_tensor(out): return get_self_loop_attr(*to_edge_index(out), num_nodes=N) return out[loop_index, loop_index] out = adj pe_list = [get_pe(out)] for _ in range(self.walk_length - 1): out = out @ adj pe_list.append(get_pe(out)) pe = torch.stack(pe_list, dim=-1) data = add_node_attr(data, pe, attr_name=self.attr_name) return data ================================================ FILE: torch_geometric/transforms/add_remaining_self_loops.py ================================================ from typing import Union from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import add_remaining_self_loops @functional_transform('add_remaining_self_loops') class AddRemainingSelfLoops(BaseTransform): r"""Adds remaining self-loops to the given homogeneous or heterogeneous graph (functional name: :obj:`add_remaining_self_loops`). Args: attr (str, optional): The name of the attribute of edge weights or multi-dimensional edge features to pass to :meth:`torch_geometric.utils.add_remaining_self_loops`. (default: :obj:`"edge_weight"`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`attr != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) """ def __init__( self, attr: str = 'edge_weight', fill_value: Union[float, Tensor, str] = 1.0, ): self.attr = attr self.fill_value = fill_value def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.edge_stores: if store.is_bipartite() or 'edge_index' not in store: continue store.edge_index, store[self.attr] = add_remaining_self_loops( store.edge_index, edge_attr=store.get(self.attr, None), fill_value=self.fill_value, num_nodes=store.size(0), ) return data ================================================ FILE: torch_geometric/transforms/add_self_loops.py ================================================ from typing import Union from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import add_self_loops @functional_transform('add_self_loops') class AddSelfLoops(BaseTransform): r"""Adds self-loops to the given homogeneous or heterogeneous graph (functional name: :obj:`add_self_loops`). Args: attr (str, optional): The name of the attribute of edge weights or multi-dimensional edge features to pass to :meth:`torch_geometric.utils.add_self_loops`. (default: :obj:`"edge_weight"`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`attr != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) """ def __init__( self, attr: str = 'edge_weight', fill_value: Union[float, Tensor, str] = 1.0, ) -> None: self.attr = attr self.fill_value = fill_value def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.edge_stores: if store.is_bipartite() or 'edge_index' not in store: continue store.edge_index, store[self.attr] = add_self_loops( store.edge_index, edge_attr=store.get(self.attr, None), fill_value=self.fill_value, num_nodes=store.size(0), ) return data ================================================ FILE: torch_geometric/transforms/base_transform.py ================================================ import copy from abc import ABC, abstractmethod from typing import Any class BaseTransform(ABC): r"""An abstract base class for writing transforms. Transforms are a general way to modify and customize :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects, either by implicitly passing them as an argument to a :class:`~torch_geometric.data.Dataset`, or by applying them explicitly to individual :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects: .. code-block:: python import torch_geometric.transforms as T from torch_geometric.datasets import TUDataset transform = T.Compose([T.ToUndirected(), T.AddSelfLoops()]) dataset = TUDataset(path, name='MUTAG', transform=transform) data = dataset[0] # Implicitly transform data on every access. data = TUDataset(path, name='MUTAG')[0] data = transform(data) # Explicitly transform data. """ def __call__(self, data: Any) -> Any: # Shallow-copy the data so that we prevent in-place data modification. return self.forward(copy.copy(data)) @abstractmethod def forward(self, data: Any) -> Any: pass def __repr__(self) -> str: return f'{self.__class__.__name__}()' ================================================ FILE: torch_geometric/transforms/cartesian.py ================================================ from typing import Optional, Tuple import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('cartesian') class Cartesian(BaseTransform): r"""Saves the relative Cartesian coordinates of linked nodes in its edge attributes (functional name: :obj:`cartesian`). Each coordinate gets globally normalized to a specified interval (:math:`[0, 1]` by default). Args: norm (bool, optional): If set to :obj:`False`, the output will not be normalized. (default: :obj:`True`) max_value (float, optional): If set and :obj:`norm=True`, normalization will be performed based on this value instead of the maximum value found in the data. (default: :obj:`None`) cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) interval ((float, float), optional): A tuple specifying the lower and upper bound for normalization. (default: :obj:`(0.0, 1.0)`) """ def __init__( self, norm: bool = True, max_value: Optional[float] = None, cat: bool = True, interval: Tuple[float, float] = (0.0, 1.0), ): self.norm = norm self.max = max_value self.cat = cat self.interval = interval def forward(self, data: Data) -> Data: assert data.pos is not None assert data.edge_index is not None (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr cart = pos[row] - pos[col] cart = cart.view(-1, 1) if cart.dim() == 1 else cart if self.norm and cart.numel() > 0: max_val = float(cart.abs().max()) if self.max is None else self.max length = self.interval[1] - self.interval[0] center = (self.interval[0] + self.interval[1]) / 2 cart = length * cart / (2 * max_val) + center if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1) else: data.edge_attr = cart return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(norm={self.norm}, ' f'max_value={self.max})') ================================================ FILE: torch_geometric/transforms/center.py ================================================ from typing import Union from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('center') class Center(BaseTransform): r"""Centers node positions :obj:`data.pos` around the origin (functional name: :obj:`center`). """ def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.node_stores: if hasattr(store, 'pos'): store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True) return data ================================================ FILE: torch_geometric/transforms/compose.py ================================================ from typing import Callable, List, Union from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import BaseTransform class Compose(BaseTransform): r"""Composes several transforms together. Args: transforms (List[Callable]): List of transforms to compose. """ def __init__(self, transforms: List[Callable]): self.transforms = transforms def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for transform in self.transforms: if isinstance(data, (list, tuple)): data = [transform(d) for d in data] else: data = transform(data) return data def __repr__(self) -> str: args = [f' {transform}' for transform in self.transforms] return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args)) class ComposeFilters: r"""Composes several filters together. Args: filters (List[Callable]): List of filters to compose. """ def __init__(self, filters: List[Callable]): self.filters = filters def __call__( self, data: Union[Data, HeteroData], ) -> bool: for filter_fn in self.filters: if isinstance(data, (list, tuple)): if not all([filter_fn(d) for d in data]): return False elif not filter_fn(data): return False return True def __repr__(self) -> str: args = [f' {filter_fn}' for filter_fn in self.filters] return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args)) ================================================ FILE: torch_geometric/transforms/constant.py ================================================ from typing import List, Optional, Union import torch from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('constant') class Constant(BaseTransform): r"""Appends a constant value to each node feature :obj:`x` (functional name: :obj:`constant`). Args: value (float, optional): The value to add. (default: :obj:`1.0`) cat (bool, optional): If set to :obj:`False`, existing node features will be replaced. (default: :obj:`True`) node_types (str or List[str], optional): The specified node type(s) to append constant values for if used on heterogeneous graphs. If set to :obj:`None`, constants will be added to each node feature :obj:`x` for all existing node types. (default: :obj:`None`) """ def __init__( self, value: float = 1.0, cat: bool = True, node_types: Optional[Union[str, List[str]]] = None, ): if isinstance(node_types, str): node_types = [node_types] self.value = value self.cat = cat self.node_types = node_types def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.node_stores: if self.node_types is None or store._key in self.node_types: num_nodes = store.num_nodes assert num_nodes is not None c = torch.full((num_nodes, 1), self.value, dtype=torch.float) if hasattr(store, 'x') and self.cat: x = store.x.view(-1, 1) if store.x.dim() == 1 else store.x store.x = torch.cat([x, c.to(x.device, x.dtype)], dim=-1) else: store.x = c return data def __repr__(self) -> str: return f'{self.__class__.__name__}(value={self.value})' ================================================ FILE: torch_geometric/transforms/delaunay.py ================================================ from typing import List import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform class _QhullTransform(BaseTransform): r"""Q-hull implementation of delaunay triangulation.""" def forward(self, data: Data) -> Data: assert data.pos is not None import scipy.spatial pos = data.pos.cpu().numpy() tri = scipy.spatial.Delaunay(pos, qhull_options='QJ') face = torch.from_numpy(tri.simplices) data.face = face.t().contiguous().to(data.pos.device, torch.long) return data class _ShullTransform(BaseTransform): r"""Sweep-hull implementation of delaunay triangulation.""" def forward(self, data: Data) -> Data: assert data.pos is not None from torch_delaunay.functional import shull2d face = shull2d(data.pos.cpu()) data.face = face.t().contiguous().to(data.pos.device) return data class _SequentialTransform(BaseTransform): r"""Runs the first successful transformation. All intermediate exceptions are suppressed except the last. """ def __init__(self, transforms: List[BaseTransform]) -> None: assert len(transforms) > 0 self.transforms = transforms def forward(self, data: Data) -> Data: for i, transform in enumerate(self.transforms): try: return transform.forward(data) except ImportError as e: if i == len(self.transforms) - 1: raise e return data @functional_transform('delaunay') class Delaunay(BaseTransform): r"""Computes the delaunay triangulation of a set of points (functional name: :obj:`delaunay`). .. hint:: Consider installing the `torch_delaunay `_ package to speed up computation. """ def __init__(self) -> None: self._transform = _SequentialTransform([ _ShullTransform(), _QhullTransform(), ]) def forward(self, data: Data) -> Data: assert data.pos is not None device = data.pos.device if data.pos.size(0) < 2: data.edge_index = torch.empty(2, 0, dtype=torch.long, device=device) elif data.pos.size(0) == 2: data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device) elif data.pos.size(0) == 3: data.face = torch.tensor([[0], [1], [2]], device=device) else: data = self._transform.forward(data) return data ================================================ FILE: torch_geometric/transforms/distance.py ================================================ from typing import Optional, Tuple import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('distance') class Distance(BaseTransform): r"""Saves the Euclidean distance of linked nodes in its edge attributes (functional name: :obj:`distance`). Each distance gets globally normalized to a specified interval (:math:`[0, 1]` by default). Args: norm (bool, optional): If set to :obj:`False`, the output will not be normalized. (default: :obj:`True`) max_value (float, optional): If set and :obj:`norm=True`, normalization will be performed based on this value instead of the maximum value found in the data. (default: :obj:`None`) cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) interval ((float, float), optional): A tuple specifying the lower and upper bound for normalization. (default: :obj:`(0.0, 1.0)`) """ def __init__( self, norm: bool = True, max_value: Optional[float] = None, cat: bool = True, interval: Tuple[float, float] = (0.0, 1.0), ): self.norm = norm self.max = max_value self.cat = cat self.interval = interval def forward(self, data: Data) -> Data: assert data.pos is not None assert data.edge_index is not None (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1) if self.norm and dist.numel() > 0: max_val = float(dist.max()) if self.max is None else self.max length = self.interval[1] - self.interval[0] dist = length * (dist / max_val) + self.interval[0] if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo data.edge_attr = torch.cat([pseudo, dist.type_as(pseudo)], dim=-1) else: data.edge_attr = dist return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(norm={self.norm}, ' f'max_value={self.max})') ================================================ FILE: torch_geometric/transforms/face_to_edge.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_undirected @functional_transform('face_to_edge') class FaceToEdge(BaseTransform): r"""Converts mesh faces of shape :obj:`[3, num_faces]` or :obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]` (functional name: :obj:`face_to_edge`). This transform supports both 2D triangular faces, represented by a tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces, represented by a tensor of shape :obj:`[4, num_faces]`. It will convert these faces into edge indices, where each edge is defined by the indices of its two endpoints. Args: remove_faces (bool, optional): If set to :obj:`False`, the face tensor will not be removed. """ def __init__(self, remove_faces: bool = True) -> None: self.remove_faces = remove_faces def forward(self, data: Data) -> Data: if hasattr(data, 'face'): assert data.face is not None face = data.face if face.size(0) not in [3, 4]: raise RuntimeError(f"Expected 'face' tensor with shape " f"[3, num_faces] or [4, num_faces] " f"(got {list(face.size())})") if face.size()[0] == 3: edge_index = torch.cat([ face[:2], face[1:], face[::2], ], dim=1) else: assert face.size()[0] == 4 edge_index = torch.cat([ face[:2], face[1:3], face[2:4], face[::2], face[1::2], face[::3], ], dim=1) edge_index = to_undirected(edge_index, num_nodes=data.num_nodes) data.edge_index = edge_index if self.remove_faces: data.face = None return data ================================================ FILE: torch_geometric/transforms/feature_propagation.py ================================================ from torch import Tensor import torch_geometric from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import is_torch_sparse_tensor, to_torch_csc_tensor @functional_transform('feature_propagation') class FeaturePropagation(BaseTransform): r"""The feature propagation operator from the `"On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing Node Features" `_ paper (functional name: :obj:`feature_propagation`). .. math:: \mathbf{X}^{(0)} &= (1 - \mathbf{M}) \cdot \mathbf{X} \mathbf{X}^{(\ell + 1)} &= \mathbf{X}^{(0)} + \mathbf{M} \cdot (\mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{X}^{(\ell)}) where missing node features are inferred by known features via propagation. .. code-block:: python from torch_geometric.transforms import FeaturePropagation transform = FeaturePropagation(missing_mask=torch.isnan(data.x)) data = transform(data) Args: missing_mask (torch.Tensor): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{N\times F}` indicating missing node features. num_iterations (int, optional): The number of propagations. (default: :obj:`40`) """ def __init__(self, missing_mask: Tensor, num_iterations: int = 40) -> None: self.missing_mask = missing_mask self.num_iterations = num_iterations def forward(self, data: Data) -> Data: assert data.x is not None assert data.edge_index is not None or data.adj_t is not None assert data.x.size() == self.missing_mask.size() gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm missing_mask = self.missing_mask.to(data.x.device) known_mask = ~missing_mask if data.edge_index is not None: edge_weight = data.edge_attr if 'edge_weight' in data: edge_weight = data.edge_weight adj_t = to_torch_csc_tensor( edge_index=data.edge_index, edge_attr=edge_weight, size=data.size(0), ).t() adj_t, _ = gcn_norm(adj_t, add_self_loops=False) elif is_torch_sparse_tensor(data.adj_t): adj_t, _ = gcn_norm(data.adj_t, add_self_loops=False) else: adj_t = gcn_norm(data.adj_t, add_self_loops=False) x = data.x.clone() x[missing_mask] = 0. out = x for _ in range(self.num_iterations): out = adj_t @ out out[known_mask] = x[known_mask] # Reset. data.x = out return data def __repr__(self) -> str: na_values = int(self.missing_mask.sum()) / self.missing_mask.numel() return (f'{self.__class__.__name__}(' f'missing_features={100 * na_values:.1f}%, ' f'num_iterations={self.num_iterations})') ================================================ FILE: torch_geometric/transforms/fixed_points.py ================================================ import math import re import numpy as np import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('fixed_points') class FixedPoints(BaseTransform): r"""Samples a fixed number of points and features from a point cloud (functional name: :obj:`fixed_points`). Args: num (int): The number of points to sample. replace (bool, optional): If set to :obj:`False`, samples points without replacement. (default: :obj:`True`) allow_duplicates (bool, optional): In case :obj:`replace` is :obj`False` and :obj:`num` is greater than the number of points, this option determines whether to add duplicated nodes to the output points or not. In case :obj:`allow_duplicates` is :obj:`False`, the number of output points might be smaller than :obj:`num`. In case :obj:`allow_duplicates` is :obj:`True`, the number of duplicated points are kept to a minimum. (default: :obj:`False`) """ def __init__( self, num: int, replace: bool = True, allow_duplicates: bool = False, ): self.num = num self.replace = replace self.allow_duplicates = allow_duplicates def forward(self, data: Data) -> Data: num_nodes = data.num_nodes assert num_nodes is not None if self.replace: choice = torch.from_numpy( np.random.choice(num_nodes, self.num, replace=True)).long() elif not self.allow_duplicates: choice = torch.randperm(num_nodes)[:self.num] else: choice = torch.cat([ torch.randperm(num_nodes) for _ in range(math.ceil(self.num / num_nodes)) ], dim=0)[:self.num] for key, value in data.items(): if key == 'num_nodes': data.num_nodes = choice.size(0) elif bool(re.search('edge', key)): continue elif (isinstance(value, Tensor) and value.size(0) == num_nodes and value.size(0) != 1): data[key] = value[choice] return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num}, replace={self.replace})' ================================================ FILE: torch_geometric/transforms/gcn_norm.py ================================================ import torch_geometric from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('gcn_norm') class GCNNorm(BaseTransform): r"""Applies the GCN normalization from the `"Semi-supervised Classification with Graph Convolutional Networks" `_ paper (functional name: :obj:`gcn_norm`). .. math:: \mathbf{\hat{A}} = \mathbf{\hat{D}}^{-1/2} (\mathbf{A} + \mathbf{I}) \mathbf{\hat{D}}^{-1/2} where :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij} + 1`. """ def __init__(self, add_self_loops: bool = True): self.add_self_loops = add_self_loops def forward(self, data: Data) -> Data: gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm assert 'edge_index' in data or 'adj_t' in data if 'edge_index' in data: data.edge_index, data.edge_weight = gcn_norm( data.edge_index, data.edge_weight, data.num_nodes, add_self_loops=self.add_self_loops) else: data.adj_t = gcn_norm(data.adj_t, add_self_loops=self.add_self_loops) return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'add_self_loops={self.add_self_loops})') ================================================ FILE: torch_geometric/transforms/gdc.py ================================================ from typing import Any, Dict, Optional, Tuple import numpy as np import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import ( add_self_loops, coalesce, get_ppr, is_undirected, scatter, sort_edge_index, to_dense_adj, ) @functional_transform('gdc') class GDC(BaseTransform): r"""Processes the graph via Graph Diffusion Convolution (GDC) from the `"Diffusion Improves Graph Learning" `_ paper (functional name: :obj:`gdc`). .. note:: The paper offers additional advice on how to choose the hyperparameters. For an example of using GCN with GDC, see `examples/gcn.py `_. Args: self_loop_weight (float, optional): Weight of the added self-loop. Set to :obj:`None` to add no self-loops. (default: :obj:`1`) normalization_in (str, optional): Normalization of the transition matrix on the original (input) graph. Possible values: :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`. See :func:`GDC.transition_matrix` for details. (default: :obj:`"sym"`) normalization_out (str, optional): Normalization of the transition matrix on the transformed GDC (output) graph. Possible values: :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`. See :func:`GDC.transition_matrix` for details. (default: :obj:`"col"`) diffusion_kwargs (dict, optional): Dictionary containing the parameters for diffusion. `method` specifies the diffusion method (:obj:`"ppr"`, :obj:`"heat"` or :obj:`"coeff"`). Each diffusion method requires different additional parameters. See :func:`GDC.diffusion_matrix_exact` or :func:`GDC.diffusion_matrix_approx` for details. (default: :obj:`dict(method='ppr', alpha=0.15)`) sparsification_kwargs (dict, optional): Dictionary containing the parameters for sparsification. `method` specifies the sparsification method (:obj:`"threshold"` or :obj:`"topk"`). Each sparsification method requires different additional parameters. See :func:`GDC.sparsify_dense` for details. (default: :obj:`dict(method='threshold', avg_degree=64)`) exact (bool, optional): Whether to exactly calculate the diffusion matrix. Note that the exact variants are not scalable. They densify the adjacency matrix and calculate either its inverse or its matrix exponential. However, the approximate variants do not support edge weights and currently only personalized PageRank and sparsification by threshold are implemented as fast, approximate versions. (default: :obj:`True`) :rtype: :class:`torch_geometric.data.Data` """ def __init__( self, self_loop_weight: float = 1., normalization_in: str = 'sym', normalization_out: str = 'col', diffusion_kwargs: Optional[Dict[str, Any]] = None, sparsification_kwargs: Optional[Dict[str, Any]] = None, exact: bool = True, ) -> None: self.self_loop_weight = self_loop_weight self.normalization_in = normalization_in self.normalization_out = normalization_out self.diffusion_kwargs = diffusion_kwargs or dict( method='ppr', alpha=0.15) self.sparsification_kwargs = sparsification_kwargs or dict( method='threshold', avg_degree=64) self.exact = exact if self_loop_weight: assert exact or self_loop_weight == 1 @torch.no_grad() def forward(self, data: Data) -> Data: assert data.edge_index is not None edge_index = data.edge_index N = data.num_nodes assert N is not None if data.edge_attr is None: edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) else: edge_weight = data.edge_attr assert self.exact assert edge_weight.dim() == 1 if self.self_loop_weight: edge_index, edge_weight = add_self_loops( edge_index, edge_weight, fill_value=self.self_loop_weight, num_nodes=N) edge_index, edge_weight = coalesce(edge_index, edge_weight, N) if self.exact: edge_index, edge_weight = self.transition_matrix( edge_index, edge_weight, N, self.normalization_in) diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N, **self.diffusion_kwargs) edge_index, edge_weight = self.sparsify_dense( diff_mat, **self.sparsification_kwargs) else: edge_index, edge_weight = self.diffusion_matrix_approx( edge_index, edge_weight, N, self.normalization_in, **self.diffusion_kwargs) edge_index, edge_weight = self.sparsify_sparse( edge_index, edge_weight, N, **self.sparsification_kwargs) edge_index, edge_weight = coalesce(edge_index, edge_weight, N) edge_index, edge_weight = self.transition_matrix( edge_index, edge_weight, N, self.normalization_out) data.edge_index = edge_index data.edge_attr = edge_weight return data def transition_matrix( self, edge_index: Tensor, edge_weight: Tensor, num_nodes: int, normalization: str, ) -> Tuple[Tensor, Tensor]: r"""Calculate the approximate, sparse diffusion on a given sparse matrix. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. normalization (str): Normalization scheme: 1. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}`. 2. :obj:`"col"`: Column-wise normalization :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`. 3. :obj:`"row"`: Row-wise normalization :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`. 4. :obj:`None`: No normalization. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ if normalization == 'sym': row, col = edge_index deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum') deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] elif normalization == 'col': _, col = edge_index deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum') deg_inv = 1. / deg deg_inv[deg_inv == float('inf')] = 0 edge_weight = edge_weight * deg_inv[col] elif normalization == 'row': row, _ = edge_index deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum') deg_inv = 1. / deg deg_inv[deg_inv == float('inf')] = 0 edge_weight = edge_weight * deg_inv[row] elif normalization is None: pass else: raise ValueError( f"Transition matrix normalization '{normalization}' unknown") return edge_index, edge_weight def diffusion_matrix_exact( # noqa: D417 self, edge_index: Tensor, edge_weight: Tensor, num_nodes: int, method: str, **kwargs: Any, ) -> Tensor: r"""Calculate the (dense) diffusion on a given sparse graph. Note that these exact variants are not scalable. They densify the adjacency matrix and calculate either its inverse or its matrix exponential. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. method (str): Diffusion method: 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. Additionally expects the parameter: - **alpha** (*float*) - Return probability in PPR. Commonly lies in :obj:`[0.05, 0.2]`. 2. :obj:`"heat"`: Use heat kernel diffusion. Additionally expects the parameter: - **t** (*float*) - Time of diffusion. Commonly lies in :obj:`[2, 10]`. 3. :obj:`"coeff"`: Freely choose diffusion coefficients. Additionally expects the parameter: - **coeffs** (*List[float]*) - List of coefficients :obj:`theta_k` for each power of the transition matrix (starting at :obj:`0`). :rtype: (:class:`Tensor`) """ if method == 'ppr': # α (I_n + (α - 1) A)^-1 edge_weight = (kwargs['alpha'] - 1) * edge_weight edge_index, edge_weight = add_self_loops(edge_index, edge_weight, fill_value=1, num_nodes=num_nodes) mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() diff_matrix = kwargs['alpha'] * torch.inverse(mat) elif method == 'heat': # exp(t (A - I_n)) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, fill_value=-1, num_nodes=num_nodes) edge_weight = kwargs['t'] * edge_weight mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() undirected = is_undirected(edge_index, edge_weight, num_nodes) diff_matrix = self.__expm__(mat, undirected) elif method == 'coeff': adj_matrix = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() mat = torch.eye(num_nodes, device=edge_index.device) diff_matrix = kwargs['coeffs'][0] * mat for coeff in kwargs['coeffs'][1:]: mat = mat @ adj_matrix diff_matrix += coeff * mat else: raise ValueError(f"Exact GDC diffusion '{method}' unknown") return diff_matrix def diffusion_matrix_approx( # noqa: D417 self, edge_index: Tensor, edge_weight: Tensor, num_nodes: int, normalization: str, method: str, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Calculate the approximate, sparse diffusion on a given sparse graph. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. normalization (str): Transition matrix normalization scheme (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`). See :func:`GDC.transition_matrix` for details. method (str): Diffusion method: 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. Additionally expects the parameters: - **alpha** (*float*) - Return probability in PPR. Commonly lies in :obj:`[0.05, 0.2]`. - **eps** (*float*) - Threshold for PPR calculation stopping criterion (:obj:`edge_weight >= eps * out_degree`). Recommended default: :obj:`1e-4`. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ if method == 'ppr': if normalization == 'sym': # Calculate original degrees. _, col = edge_index deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum') edge_index, edge_weight = get_ppr( edge_index, alpha=kwargs['alpha'], eps=kwargs['eps'], num_nodes=num_nodes, ) if normalization == 'col': edge_index, edge_weight = sort_edge_index( edge_index.flip([0]), edge_weight, num_nodes) if normalization == 'sym': # We can change the normalization from row-normalized to # symmetric by multiplying the resulting matrix with D^{1/2} # from the left and D^{-1/2} from the right. # Since we use the original degrees for this it will be like # we had used symmetric normalization from the beginning # (except for errors due to approximation). row, col = edge_index deg_inv = deg.sqrt() deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col] elif normalization in ['col', 'row']: pass else: raise ValueError( f"Transition matrix normalization '{normalization}' not " f"implemented for non-exact GDC computation") elif method == 'heat': raise NotImplementedError( 'Currently no fast heat kernel is implemented. You are ' 'welcome to create one yourself, e.g., based on ' '"Kloster and Gleich: Heat kernel based community detection ' '(KDD 2014)."') else: raise ValueError(f"Approximate GDC diffusion '{method}' unknown") return edge_index, edge_weight def sparsify_dense( # noqa: D417 self, matrix: Tensor, method: str, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Sparsifies the given dense matrix. Args: matrix (Tensor): Matrix to sparsify. method (str): Method of sparsification. Options: 1. :obj:`"threshold"`: Remove all edges with weights smaller than :obj:`eps`. Additionally expects one of these parameters: - **eps** (*float*) - Threshold to bound edges at. - **avg_degree** (*int*) - If :obj:`eps` is not given, it can optionally be calculated by calculating the :obj:`eps` required to achieve a given :obj:`avg_degree`. 2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per node (column). Additionally expects the following parameters: - **k** (*int*) - Specifies the number of edges to keep. - **dim** (*int*) - The axis along which to take the top :obj:`k`. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ assert matrix.shape[0] == matrix.shape[1] N = matrix.shape[1] if method == 'threshold': if 'eps' not in kwargs.keys(): kwargs['eps'] = self.__calculate_eps__(matrix, N, kwargs['avg_degree']) edge_index = (matrix >= kwargs['eps']).nonzero(as_tuple=False).t() edge_index_flat = edge_index[0] * N + edge_index[1] edge_weight = matrix.flatten()[edge_index_flat] elif method == 'topk': k, dim = min(N, kwargs['k']), kwargs['dim'] assert dim in [0, 1] sort_idx = torch.argsort(matrix, dim=dim, descending=True) if dim == 0: top_idx = sort_idx[:k] edge_weight = torch.gather(matrix, dim=dim, index=top_idx).flatten() row_idx = torch.arange(0, N, device=matrix.device).repeat(k) edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0) else: top_idx = sort_idx[:, :k] edge_weight = torch.gather(matrix, dim=dim, index=top_idx).flatten() col_idx = torch.arange( 0, N, device=matrix.device).repeat_interleave(k) edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0) else: raise ValueError(f"GDC sparsification '{method}' unknown") return edge_index, edge_weight def sparsify_sparse( # noqa: D417 self, edge_index: Tensor, edge_weight: Tensor, num_nodes: int, method: str, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Sparsifies a given sparse graph further. Args: edge_index (torch.Tensor): The edge indices. edge_weight (torch.Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. method (str): Method of sparsification: 1. :obj:`"threshold"`: Remove all edges with weights smaller than :obj:`eps`. Additionally expects one of these parameters: - **eps** (*float*) - Threshold to bound edges at. - **avg_degree** (*int*) - If :obj:`eps` is not given, it can optionally be calculated by calculating the :obj:`eps` required to achieve a given :obj:`avg_degree`. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ if method == 'threshold': if 'eps' not in kwargs.keys(): kwargs['eps'] = self.__calculate_eps__( edge_weight, num_nodes, kwargs['avg_degree'], ) remaining_edge_idx = (edge_weight >= kwargs['eps']).nonzero( as_tuple=False).flatten() edge_index = edge_index[:, remaining_edge_idx] edge_weight = edge_weight[remaining_edge_idx] elif method == 'topk': raise NotImplementedError( 'Sparse topk sparsification not implemented') else: raise ValueError(f"GDC sparsification '{method}' unknown") return edge_index, edge_weight def __expm__(self, matrix: Tensor, symmetric: bool) -> Tensor: r"""Calculates matrix exponential. Args: matrix (Tensor): Matrix to take exponential of. symmetric (bool): Specifies whether the matrix is symmetric. :rtype: (:class:`Tensor`) """ from scipy.linalg import expm if symmetric: e, V = torch.linalg.eigh(matrix, UPLO='U') diff_mat = V @ torch.diag(e.exp()) @ V.t() else: diff_mat = torch.from_numpy(expm(matrix.cpu().numpy())) diff_mat = diff_mat.to(matrix.device, matrix.dtype) return diff_mat def __calculate_eps__( self, matrix: Tensor, num_nodes: int, avg_degree: int, ) -> float: r"""Calculates threshold necessary to achieve a given average degree. Args: matrix (Tensor): Adjacency matrix or edge weights. num_nodes (int): Number of nodes. avg_degree (int): Target average degree. :rtype: (:class:`float`) """ sorted_edges = torch.sort(matrix.flatten(), descending=True).values if avg_degree * num_nodes > len(sorted_edges): return -np.inf left = sorted_edges[avg_degree * num_nodes - 1] right = sorted_edges[avg_degree * num_nodes] return float(left + right) / 2.0 ================================================ FILE: torch_geometric/transforms/generate_mesh_normals.py ================================================ import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import scatter @functional_transform('generate_mesh_normals') class GenerateMeshNormals(BaseTransform): r"""Generate normal vectors for each mesh node based on neighboring faces (functional name: :obj:`generate_mesh_normals`). """ def forward(self, data: Data) -> Data: assert data.pos is not None assert data.face is not None pos, face = data.pos, data.face vec1 = pos[face[1]] - pos[face[0]] vec2 = pos[face[2]] - pos[face[0]] face_norm = F.normalize(vec1.cross(vec2, dim=1), p=2, dim=-1) # [F, 3] face_norm = face_norm.repeat(3, 1) idx = face.view(-1) norm = scatter(face_norm, idx, 0, pos.size(0), reduce='sum') norm = F.normalize(norm, p=2, dim=-1) # [N, 3] data.norm = norm return data ================================================ FILE: torch_geometric/transforms/grid_sampling.py ================================================ import re from typing import List, Optional, Union import torch from torch import Tensor import torch_geometric from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import one_hot, scatter @functional_transform('grid_sampling') class GridSampling(BaseTransform): r"""Clusters points into fixed-sized voxels (functional name: :obj:`grid_sampling`). Each cluster returned is a new point based on the mean of all points inside the given cluster. Args: size (float or [float] or Tensor): Size of a voxel (in each dimension). start (float or [float] or Tensor, optional): Start coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the minimum coordinates found in :obj:`data.pos`. (default: :obj:`None`) end (float or [float] or Tensor, optional): End coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the maximum coordinates found in :obj:`data.pos`. (default: :obj:`None`) """ def __init__( self, size: Union[float, List[float], Tensor], start: Optional[Union[float, List[float], Tensor]] = None, end: Optional[Union[float, List[float], Tensor]] = None, ) -> None: self.size = size self.start = start self.end = end def forward(self, data: Data) -> Data: num_nodes = data.num_nodes assert data.pos is not None c = torch_geometric.nn.voxel_grid(data.pos, self.size, data.batch, self.start, self.end) c, perm = torch_geometric.nn.pool.consecutive.consecutive_cluster(c) for key, item in data.items(): if bool(re.search('edge', key)): raise ValueError(f"'{self.__class__.__name__}' does not " f"support coarsening of edges") if torch.is_tensor(item) and item.size(0) == num_nodes: if key == 'y': item = scatter(one_hot(item), c, dim=0, reduce='sum') data[key] = item.argmax(dim=-1) elif key == 'batch': data[key] = item[perm] else: data[key] = scatter(item, c, dim=0, reduce='mean') return data def __repr__(self) -> str: return f'{self.__class__.__name__}(size={self.size})' ================================================ FILE: torch_geometric/transforms/half_hop.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('half_hop') class HalfHop(BaseTransform): r"""The graph upsampling augmentation from the `"Half-Hop: A Graph Upsampling Approach for Slowing Down Message Passing" `_ paper. The graph is augmented by adding artificial slow nodes between neighbors to slow down message propagation. (functional name: :obj:`half_hop`). .. note:: :class:`HalfHop` augmentation is not supported if :obj:`data` has :attr:`edge_weight` or :attr:`edge_attr`. Args: alpha (float, optional): The interpolation factor used to compute slow node features :math:`x = \alpha*x_src + (1-\alpha)*x_dst` (default: :obj:`0.5`) p (float, optional): The probability of half-hopping an edge. (default: :obj:`1.0`) .. code-block:: python import torch_geometric.transforms as T transform = T.HalfHop(alpha=0.5) data = transform(data) # Apply transformation. out = model(data.x, data.edge_index) # Feed-forward. out = out[~data.slow_node_mask] # Get rid of slow nodes. """ def __init__(self, alpha: float = 0.5, p: float = 1.0) -> None: if alpha < 0. or alpha > 1.: raise ValueError(f"Interpolation factor has to be between 0 and 1 " f"(got '{alpha}'") if p < 0. or p > 1.: raise ValueError(f"Ratio of half-hopped edges has to be between " f"0 and 1 (got '{p}'") self.p = p self.alpha = alpha def forward(self, data: Data) -> Data: if data.edge_weight is not None or data.edge_attr is not None: raise ValueError("'HalfHop' augmentation is not supported if " "'data' contains 'edge_weight' or 'edge_attr'") assert data.x is not None assert data.edge_index is not None x, edge_index = data.x, data.edge_index num_nodes = data.num_nodes assert num_nodes is not None # isolate self loops which are not half-hopped self_loop_mask = edge_index[0] == edge_index[1] edge_index_self_loop = edge_index[:, self_loop_mask] edge_index = edge_index[:, ~self_loop_mask] # randomly sample nodes and half-hop their edges node_mask = torch.rand(num_nodes, device=x.device) < self.p edge_mask = node_mask[edge_index[1]] edge_index_to_halfhop = edge_index[:, edge_mask] edge_index_to_keep = edge_index[:, ~edge_mask] # add new slow nodes of which features are initialized # by linear interpolation num_halfhop_edges = edge_index_to_halfhop.size(1) slow_node_ids = torch.arange(num_halfhop_edges, device=x.device) + num_nodes x_src = x[edge_index_to_halfhop[0]] x_dst = x[edge_index_to_halfhop[1]] x_slow_node = self.alpha * x_src + (1 - self.alpha) * x_dst new_x = torch.cat([x, x_slow_node], dim=0) # add new edges between slow nodes and the original nodes edge_index_slow = [ torch.stack([edge_index_to_halfhop[0], slow_node_ids]), torch.stack([slow_node_ids, edge_index_to_halfhop[1]]), torch.stack([edge_index_to_halfhop[1], slow_node_ids]) ] new_edge_index = torch.cat( [edge_index_to_keep, edge_index_self_loop, *edge_index_slow], dim=1) # prepare a mask that distinguishes between original nodes & slow nodes slow_node_mask = torch.cat( [x.new_zeros(x.size(0)), x.new_ones(slow_node_ids.size(0))], dim=0).bool() data.x, data.edge_index = new_x, new_edge_index data.slow_node_mask = slow_node_mask return data def __repr__(self) -> str: return f'{self.__class__.__name__}(alpha={self.alpha}, p={self.p})' ================================================ FILE: torch_geometric/transforms/knn_graph.py ================================================ import torch_geometric from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_undirected @functional_transform('knn_graph') class KNNGraph(BaseTransform): r"""Creates a k-NN graph based on node positions :obj:`data.pos` (functional name: :obj:`knn_graph`). Args: k (int, optional): The number of neighbors. (default: :obj:`6`) loop (bool, optional): If :obj:`True`, the graph will contain self-loops. (default: :obj:`False`) force_undirected (bool, optional): If set to :obj:`True`, new edges will be undirected. (default: :obj:`False`) flow (str, optional): The flow direction when used in combination with message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). If set to :obj:`"source_to_target"`, every target node will have exactly :math:`k` source nodes pointing to it. (default: :obj:`"source_to_target"`) cosine (bool, optional): If :obj:`True`, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: :obj:`False`) num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) """ def __init__( self, k: int = 6, loop: bool = False, force_undirected: bool = False, flow: str = 'source_to_target', cosine: bool = False, num_workers: int = 1, ) -> None: self.k = k self.loop = loop self.force_undirected = force_undirected self.flow = flow self.cosine = cosine self.num_workers = num_workers def forward(self, data: Data) -> Data: assert data.pos is not None edge_index = torch_geometric.nn.knn_graph( data.pos, self.k, data.batch, loop=self.loop, flow=self.flow, cosine=self.cosine, num_workers=self.num_workers, ) if self.force_undirected: edge_index = to_undirected(edge_index, num_nodes=data.num_nodes) data.edge_index = edge_index data.edge_attr = None return data def __repr__(self) -> str: return f'{self.__class__.__name__}(k={self.k})' ================================================ FILE: torch_geometric/transforms/laplacian_lambda_max.py ================================================ from typing import Optional from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import get_laplacian, to_scipy_sparse_matrix @functional_transform('laplacian_lambda_max') class LaplacianLambdaMax(BaseTransform): r"""Computes the highest eigenvalue of the graph Laplacian given by :meth:`torch_geometric.utils.get_laplacian` (functional name: :obj:`laplacian_lambda_max`). Args: normalization (str, optional): The normalization scheme for the graph Laplacian (default: :obj:`None`): 1. :obj:`None`: No normalization :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 2. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}` 3. :obj:`"rw"`: Random-walk normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` is_undirected (bool, optional): If set to :obj:`True`, this transform expects undirected graphs as input, and can hence speed up the computation of the largest eigenvalue. (default: :obj:`False`) """ def __init__( self, normalization: Optional[str] = None, is_undirected: bool = False, ): assert normalization in [None, 'sym', 'rw'], 'Invalid normalization' self.normalization = normalization self.is_undirected = is_undirected def forward(self, data: Data) -> Data: from scipy.sparse.linalg import eigs, eigsh assert data.edge_index is not None num_nodes = data.num_nodes edge_weight = data.edge_attr if edge_weight is not None and edge_weight.numel() != data.num_edges: edge_weight = None edge_index, edge_weight = get_laplacian( data.edge_index, edge_weight, self.normalization, num_nodes=num_nodes, ) L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes) eig_fn = eigs if self.is_undirected and self.normalization != 'rw': eig_fn = eigsh lambda_max = eig_fn(L, k=1, which='LM', return_eigenvectors=False) data.lambda_max = lambda_max.real.item() return data def __repr__(self) -> str: return f'{self.__class__.__name__}(normalization={self.normalization})' ================================================ FILE: torch_geometric/transforms/largest_connected_components.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_scipy_sparse_matrix @functional_transform('largest_connected_components') class LargestConnectedComponents(BaseTransform): r"""Selects the subgraph that corresponds to the largest connected components in the graph (functional name: :obj:`largest_connected_components`). Args: num_components (int, optional): Number of largest components to keep (default: :obj:`1`) connection (str, optional): Type of connection to use for directed graphs, can be either :obj:`'strong'` or :obj:`'weak'`. Nodes `i` and `j` are strongly connected if a path exists both from `i` to `j` and from `j` to `i`. A directed graph is weakly connected if replacing all of its directed edges with undirected edges produces a connected (undirected) graph. (default: :obj:`'weak'`) """ def __init__( self, num_components: int = 1, connection: str = 'weak', ) -> None: assert connection in ['strong', 'weak'], 'Unknown connection type' self.num_components = num_components self.connection = connection def forward(self, data: Data) -> Data: import numpy as np import scipy.sparse as sp assert data.edge_index is not None adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes) num_components, component = sp.csgraph.connected_components( adj, connection=self.connection) if num_components <= self.num_components: return data _, count = np.unique(component, return_counts=True) subset_np = np.isin(component, count.argsort()[-self.num_components:]) subset = torch.from_numpy(subset_np) subset = subset.to(data.edge_index.device, torch.bool) return data.subgraph(subset) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num_components})' ================================================ FILE: torch_geometric/transforms/line_graph.py ================================================ import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import coalesce, cumsum, remove_self_loops, scatter @functional_transform('line_graph') class LineGraph(BaseTransform): r"""Converts a graph to its corresponding line-graph (functional name: :obj:`line_graph`). .. math:: L(\mathcal{G}) &= (\mathcal{V}^{\prime}, \mathcal{E}^{\prime}) \mathcal{V}^{\prime} &= \mathcal{E} \mathcal{E}^{\prime} &= \{ (e_1, e_2) : e_1 \cap e_2 \neq \emptyset \} Line-graph node indices are equal to indices in the original graph's coalesced :obj:`edge_index`. For undirected graphs, the maximum line-graph node index is :obj:`(data.edge_index.size(1) // 2) - 1`. New node features are given by old edge attributes. For undirected graphs, edge attributes for reciprocal edges :obj:`(row, col)` and :obj:`(col, row)` get summed together. Args: force_directed (bool, optional): If set to :obj:`True`, the graph will be always treated as a directed graph. (default: :obj:`False`) """ def __init__(self, force_directed: bool = False) -> None: self.force_directed = force_directed def forward(self, data: Data) -> Data: assert data.edge_index is not None edge_index, edge_attr = data.edge_index, data.edge_attr N = data.num_nodes edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=N) row, col = edge_index if self.force_directed or data.is_directed(): i = torch.arange(row.size(0), dtype=torch.long, device=row.device) count = scatter(torch.ones_like(row), row, dim=0, dim_size=data.num_nodes, reduce='sum') ptr = cumsum(count) cols = [i[ptr[col[j]]:ptr[col[j] + 1]] for j in range(col.size(0))] rows = [row.new_full((c.numel(), ), j) for j, c in enumerate(cols)] row, col = torch.cat(rows, dim=0), torch.cat(cols, dim=0) data.edge_index = torch.stack([row, col], dim=0) data.x = data.edge_attr data.num_nodes = edge_index.size(1) else: # Compute node indices. mask = row < col row, col = row[mask], col[mask] i = torch.arange(row.size(0), dtype=torch.long, device=row.device) (row, col), i = coalesce( torch.stack([ torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) ], dim=0), torch.cat([i, i], dim=0), N, ) # Compute new edge indices according to `i`. count = scatter(torch.ones_like(row), row, dim=0, dim_size=data.num_nodes, reduce='sum') joints = list(torch.split(i, count.tolist())) def generate_grid(x: Tensor) -> Tensor: row = x.view(-1, 1).repeat(1, x.numel()).view(-1) col = x.repeat(x.numel()) return torch.stack([row, col], dim=0) joints = [generate_grid(joint) for joint in joints] joint = torch.cat(joints, dim=1) joint, _ = remove_self_loops(joint) N = row.size(0) // 2 joint = coalesce(joint, num_nodes=N) if edge_attr is not None: data.x = scatter(edge_attr, i, dim=0, dim_size=N, reduce='sum') data.edge_index = joint data.num_nodes = edge_index.size(1) // 2 data.edge_attr = None return data ================================================ FILE: torch_geometric/transforms/linear_transformation.py ================================================ from typing import Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('linear_transformation') class LinearTransformation(BaseTransform): r"""Transforms node positions :obj:`data.pos` with a square transformation matrix computed offline (functional name: :obj:`linear_transformation`). Args: matrix (Tensor): Tensor with shape :obj:`[D, D]` where :obj:`D` corresponds to the dimensionality of node positions. """ def __init__(self, matrix: Tensor): if not isinstance(matrix, Tensor): matrix = torch.tensor(matrix) assert matrix.dim() == 2, ( 'Transformation matrix should be two-dimensional.') assert matrix.size(0) == matrix.size(1), ( f'Transformation matrix should be square (got {matrix.size()})') # Store the matrix as its transpose. # We do this to enable post-multiplication in `forward`. self.matrix = matrix.t() def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.node_stores: if not hasattr(store, 'pos'): continue pos = store.pos.view(-1, 1) if store.pos.dim() == 1 else store.pos assert pos.size(-1) == self.matrix.size(-2), ( 'Node position matrix and transformation matrix have ' 'incompatible shape') # We post-multiply the points by the transformation matrix instead # of pre-multiplying, because `pos` attribute has shape `[N, D]`, # and we want to preserve this shape. store.pos = pos @ self.matrix.to(pos.device, pos.dtype) return data def __repr__(self) -> str: return f'{self.__class__.__name__}(\n{self.matrix.cpu().numpy()}\n)' ================================================ FILE: torch_geometric/transforms/local_cartesian.py ================================================ from typing import Tuple import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import scatter @functional_transform('local_cartesian') class LocalCartesian(BaseTransform): r"""Saves the relative Cartesian coordinates of linked nodes in its edge attributes (functional name: :obj:`local_cartesian`). Each coordinate gets *neighborhood-normalized* to a specified interval (:math:`[0, 1]` by default). Args: norm (bool, optional): If set to :obj:`False`, the output will not be normalized. (default: :obj:`True`) cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) interval ((float, float), optional): A tuple specifying the lower and upper bound for normalization. (default: :obj:`(0.0, 1.0)`) """ def __init__( self, norm: bool = True, cat: bool = True, interval: Tuple[float, float] = (0.0, 1.0), ): self.norm = norm self.cat = cat self.interval = interval def forward(self, data: Data) -> Data: assert data.pos is not None assert data.edge_index is not None (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr cart = pos[row] - pos[col] cart = cart.view(-1, 1) if cart.dim() == 1 else cart if self.norm: max_value = scatter(cart.abs(), col, 0, pos.size(0), reduce='max') max_value = max_value.max(dim=-1, keepdim=True)[0] length = self.interval[1] - self.interval[0] center = (self.interval[0] + self.interval[1]) / 2 cart = length * cart / (2 * max_value[col]) + center if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1) else: data.edge_attr = cart return data ================================================ FILE: torch_geometric/transforms/local_degree_profile.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import degree @functional_transform('local_degree_profile') class LocalDegreeProfile(BaseTransform): r"""Appends the Local Degree Profile (LDP) from the `"A Simple yet Effective Baseline for Non-attribute Graph Classification" `_ paper (functional name: :obj:`local_degree_profile`). .. math:: \mathbf{x}_i = \mathbf{x}_i \, \Vert \, (\deg(i), \min(DN(i)), \max(DN(i)), \textrm{mean}(DN(i)), \textrm{std}(DN(i))) to the node features, where :math:`DN(i) = \{ \deg(j) \mid j \in \mathcal{N}(i) \}`. """ def __init__(self) -> None: from torch_geometric.nn.aggr.fused import FusedAggregation self.aggr = FusedAggregation(['min', 'max', 'mean', 'std']) def forward(self, data: Data) -> Data: assert data.edge_index is not None row, col = data.edge_index num_nodes = data.num_nodes deg = degree(row, num_nodes, dtype=torch.float).view(-1, 1) xs = [deg] + self.aggr(deg[col], row, dim_size=num_nodes) if data.x is not None: data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x data.x = torch.cat([data.x] + xs, dim=-1) else: data.x = torch.cat(xs, dim=-1) return data ================================================ FILE: torch_geometric/transforms/mask.py ================================================ from typing import List, Optional, Sequence, Union from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.data.storage import BaseStorage from torch_geometric.transforms import BaseTransform from torch_geometric.utils import index_to_mask, mask_to_index AnyData = Union[Data, HeteroData] def get_attrs_with_suffix( attrs: Optional[List[str]], store: BaseStorage, suffix: str, ) -> List[str]: if attrs is not None: return attrs return [key for key in store.keys() if key.endswith(suffix)] def get_mask_size( attr: str, store: BaseStorage, size: Optional[int], ) -> Optional[int]: if size is not None: return size return store.num_edges if store.is_edge_attr(attr) else store.num_nodes @functional_transform('index_to_mask') class IndexToMask(BaseTransform): r"""Converts indices to a mask representation (functional name: :obj:`index_to_mask`). Args: attrs (str, [str], optional): If given, will only perform index to mask conversion for the given attributes. If omitted, will infer the attributes from the suffix :obj:`_index`. (default: :obj:`None`) sizes (int, [int], optional): The size of the mask. If set to :obj:`None`, an automatically sized tensor is returned. The number of nodes will be used by default, except for edge attributes which will use the number of edges as the mask size. (default: :obj:`None`) replace (bool, optional): if set to :obj:`True` replaces the index attributes with mask tensors. (default: :obj:`False`) """ def __init__( self, attrs: Optional[Union[str, List[str]]] = None, sizes: Optional[Union[int, List[int]]] = None, replace: bool = False, ) -> None: self.attrs = [attrs] if isinstance(attrs, str) else attrs self.sizes = sizes self.replace = replace def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.stores: attrs = get_attrs_with_suffix(self.attrs, store, '_index') sizes: Sequence[Optional[int]] if isinstance(self.sizes, int): sizes = [self.sizes] * len(attrs) elif isinstance(self.sizes, (list, tuple)): if len(attrs) != len(self.sizes): raise ValueError( f"The number of attributes (got {len(attrs)}) must " f"match the number of sizes provided " f"(got {len(self.sizes)})") sizes = self.sizes else: sizes = [None] * len(attrs) for attr, size in zip(attrs, sizes): if 'edge_index' in attr: continue if attr not in store: continue size = get_mask_size(attr, store, size) mask = index_to_mask(store[attr], size=size) store[f'{attr[:-6]}_mask'] = mask if self.replace: del store[attr] return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(attrs={self.attrs}, ' f'sizes={self.sizes}, replace={self.replace})') @functional_transform('mask_to_index') class MaskToIndex(BaseTransform): r"""Converts a mask to an index representation (functional name: :obj:`mask_to_index`). Args: attrs (str, [str], optional): If given, will only perform mask to index conversion for the given attributes. If omitted, will infer the attributes from the suffix :obj:`_mask` (default: :obj:`None`) replace (bool, optional): if set to :obj:`True` replaces the mask attributes with index tensors. (default: :obj:`False`) """ def __init__( self, attrs: Optional[Union[str, List[str]]] = None, replace: bool = False, ): self.attrs = [attrs] if isinstance(attrs, str) else attrs self.replace = replace def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.stores: attrs = get_attrs_with_suffix(self.attrs, store, '_mask') for attr in attrs: if attr not in store: continue index = mask_to_index(store[attr]) store[f'{attr[:-5]}_index'] = index if self.replace: del store[attr] return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(attrs={self.attrs}, ' f'replace={self.replace})') ================================================ FILE: torch_geometric/transforms/node_property_split.py ================================================ from typing import Any, Dict, List import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_networkx @functional_transform('node_property_split') class NodePropertySplit(BaseTransform): r"""Creates a node-level split with distributional shift based on a given node property, as proposed in the `"Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts" `__ paper (functional name: :obj:`node_property_split`). It splits the nodes in a given graph into five non-intersecting parts based on their structural properties. This can be used for transductive node prediction tasks with distributional shifts. It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes. The ID subset includes training, validation and testing parts, while the OOD subset includes validation and testing parts. As a result, it creates five associated node mask vectors for each graph, three which are for the ID nodes (:obj:`id_train_mask`, :obj:`id_val_mask`, :obj:`id_test_mask`), and two which are for the OOD nodes (:obj:`ood_val_mask`, :obj:`ood_test_mask`). This class implements three particular strategies for inducing distributional shifts in a graph — based on **popularity**, **locality** or **density**. Args: property_name (str): The name of the node property to be used (:obj:`"popularity"`, :obj:`"locality"`, :obj:`"density"`). ratios ([float]): A list of five ratio values for ID training, ID validation, ID test, OOD validation and OOD test parts. The values must sum to :obj:`1.0`. ascending (bool, optional): Whether to sort nodes in ascending order of the node property, so that nodes with greater values of the property are considered to be OOD (default: :obj:`True`) .. code-block:: python from torch_geometric.transforms import NodePropertySplit from torch_geometric.datasets.graph_generator import ERGraph data = ERGraph(num_nodes=1000, edge_prob=0.4)() property_name = 'popularity' ratios = [0.3, 0.1, 0.1, 0.3, 0.2] transform = NodePropertySplit(property_name, ratios) data = transform(data) """ def __init__( self, property_name: str, ratios: List[float], ascending: bool = True, ): if property_name not in {'popularity', 'locality', 'density'}: raise ValueError(f"Unexpected 'property_name' " f"(got '{property_name}')") if len(ratios) != 5: raise ValueError(f"'ratios' must contain 5 values " f"(got {len(ratios)})") if sum(ratios) != 1.0: raise ValueError(f"'ratios' must sum to 1.0 (got {sum(ratios)})") self.property_name = property_name self.compute_fn = _property_name_to_compute_fn[property_name] self.ratios = ratios self.ascending = ascending def forward(self, data: Data) -> Data: G = to_networkx(data, to_undirected=True, remove_self_loops=True) property_values = self.compute_fn(G, self.ascending) mask_dict = self._mask_nodes_by_property(property_values, self.ratios) for key, mask in mask_dict.items(): data[key] = mask return data @staticmethod def _compute_popularity_property(G: Any, ascending: bool = True) -> Tensor: import networkx.algorithms as A property_values = torch.tensor(list(A.pagerank(G).values())) property_values *= -1 if ascending else 1 return property_values @staticmethod def _compute_locality_property(G: Any, ascending: bool = True) -> Tensor: import networkx.algorithms as A pagerank_values = torch.tensor(list(A.pagerank(G).values())) num_nodes = G.number_of_nodes() personalization = dict(zip(range(num_nodes), [0.0] * num_nodes)) personalization[int(pagerank_values.argmax())] = 1.0 property_values = torch.tensor( list(A.pagerank(G, personalization=personalization).values())) property_values *= -1 if ascending else 1 return property_values @staticmethod def _compute_density_property(G: Any, ascending: bool = True) -> Tensor: import networkx.algorithms as A property_values = torch.tensor(list(A.clustering(G).values())) property_values *= -1 if ascending else 1 return property_values @staticmethod def _mask_nodes_by_property( property_values: Tensor, ratios: List[float], ) -> Dict[str, Tensor]: num_nodes = property_values.size(0) sizes = (num_nodes * torch.tensor(ratios)).round().long() sizes[-1] -= sizes.sum() - num_nodes perm = torch.randperm(num_nodes) id_size = int(sizes[:3].sum()) perm = perm[property_values[perm].argsort()] perm[:id_size] = perm[:id_size][torch.randperm(id_size)] node_splits = perm.split(sizes.tolist()) names = [ 'id_train_mask', 'id_val_mask', 'id_test_mask', 'ood_val_mask', 'ood_test_mask', ] split_masks = {} for name, node_split in zip(names, node_splits): split_mask = torch.zeros(num_nodes, dtype=torch.bool) split_mask[node_split] = True split_masks[name] = split_mask return split_masks def __repr__(self) -> str: return f'{self.__class__.__name__}({self.property_name})' _property_name_to_compute_fn = { 'popularity': NodePropertySplit._compute_popularity_property, 'locality': NodePropertySplit._compute_locality_property, 'density': NodePropertySplit._compute_density_property, } ================================================ FILE: torch_geometric/transforms/normalize_features.py ================================================ from typing import List, Optional, Union from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('normalize_features') class NormalizeFeatures(BaseTransform): r"""Row-normalizes the attributes given in :obj:`attrs` to sum-up to one (functional name: :obj:`normalize_features`). Args: attrs (List[str]): The names of attributes to normalize. (default: :obj:`["x"]`) """ def __init__(self, attrs: Optional[List[str]] = None) -> None: self.attrs = attrs or ["x"] def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.stores: for key, value in store.items(*self.attrs): if value.numel() > 0: value = value - value.min() value.div_(value.sum(dim=-1, keepdim=True).clamp_(min=1.)) store[key] = value return data ================================================ FILE: torch_geometric/transforms/normalize_rotation.py ================================================ import torch import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('normalize_rotation') class NormalizeRotation(BaseTransform): r"""Rotates all points according to the eigenvectors of the point cloud (functional name: :obj:`normalize_rotation`). If the data additionally holds normals saved in :obj:`data.normal`, these will be rotated accordingly. Args: max_points (int, optional): If set to a value greater than :obj:`0`, only a random number of :obj:`max_points` points are sampled and used to compute eigenvectors. (default: :obj:`-1`) sort (bool, optional): If set to :obj:`True`, will sort eigenvectors according to their eigenvalues. (default: :obj:`False`) """ def __init__(self, max_points: int = -1, sort: bool = False) -> None: self.max_points = max_points self.sort = sort def forward(self, data: Data) -> Data: assert data.pos is not None pos = data.pos if self.max_points > 0 and pos.size(0) > self.max_points: perm = torch.randperm(pos.size(0)) pos = pos[perm[:self.max_points]] pos = pos - pos.mean(dim=0, keepdim=True) C = torch.matmul(pos.t(), pos) e, v = torch.linalg.eig(C) # v[:,j] is j-th eigenvector e, v = torch.view_as_real(e), v.real if self.sort: indices = e[:, 0].argsort(descending=True) v = v.t()[indices].t() data.pos = torch.matmul(data.pos, v) if 'normal' in data: data.normal = F.normalize(torch.matmul(data.normal, v)) return data ================================================ FILE: torch_geometric/transforms/normalize_scale.py ================================================ from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform, Center @functional_transform('normalize_scale') class NormalizeScale(BaseTransform): r"""Centers and normalizes node positions to the interval :math:`(-1, 1)` (functional name: :obj:`normalize_scale`). """ def __init__(self) -> None: self.center = Center() def forward(self, data: Data) -> Data: data = self.center(data) assert data.pos is not None scale = (1.0 / data.pos.abs().max()) * 0.999999 data.pos = data.pos * scale return data ================================================ FILE: torch_geometric/transforms/one_hot_degree.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import degree, one_hot @functional_transform('one_hot_degree') class OneHotDegree(BaseTransform): r"""Adds the node degree as one hot encodings to the node features (functional name: :obj:`one_hot_degree`). Args: max_degree (int): Maximum degree. in_degree (bool, optional): If set to :obj:`True`, will compute the in-degree of nodes instead of the out-degree. (default: :obj:`False`) cat (bool, optional): Concat node degrees to node features instead of replacing them. (default: :obj:`True`) """ def __init__( self, max_degree: int, in_degree: bool = False, cat: bool = True, ) -> None: self.max_degree = max_degree self.in_degree = in_degree self.cat = cat def forward(self, data: Data) -> Data: assert data.edge_index is not None idx, x = data.edge_index[1 if self.in_degree else 0], data.x deg = degree(idx, data.num_nodes, dtype=torch.long) deg = one_hot(deg, num_classes=self.max_degree + 1) if x is not None and self.cat: x = x.view(-1, 1) if x.dim() == 1 else x data.x = torch.cat([x, deg.to(x.dtype)], dim=-1) else: data.x = deg return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.max_degree})' ================================================ FILE: torch_geometric/transforms/pad.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.data.storage import EdgeStorage, NodeStorage from torch_geometric.transforms import BaseTransform from torch_geometric.typing import EdgeType, NodeType class Padding(ABC): r"""An abstract class for specifying padding values.""" @abstractmethod def get_value( self, store_type: Optional[Union[NodeType, EdgeType]] = None, attr_name: Optional[str] = None, ) -> Union[int, float]: pass @dataclass(init=False) class UniformPadding(Padding): r"""Uniform padding independent of attribute name or node/edge type. Args: value (int or float, optional): The value to be used for padding. (default: :obj:`0.0`) """ value: Union[int, float] = 0.0 def __init__(self, value: Union[int, float] = 0.0): self.value = value if not isinstance(self.value, (int, float)): raise ValueError(f"Expected 'value' to be an integer or float " f"(got '{type(value)}'") def get_value( self, store_type: Optional[Union[NodeType, EdgeType]] = None, attr_name: Optional[str] = None, ) -> Union[int, float]: return self.value @dataclass(init=False) class MappingPadding(Padding): r"""An abstract class for specifying different padding values.""" values: Dict[Any, Padding] default: UniformPadding def __init__( self, values: Dict[Any, Union[int, float, Padding]], default: Union[int, float] = 0.0, ): if not isinstance(values, dict): raise ValueError(f"Expected 'values' to be a dictionary " f"(got '{type(values)}'") self.values = { key: UniformPadding(val) if isinstance(val, (int, float)) else val for key, val in values.items() } self.default = UniformPadding(default) for key, value in self.values.items(): self.validate_key_value(key, value) def validate_key_value(self, key: Any, value: Any) -> None: pass class AttrNamePadding(MappingPadding): r"""Padding dependent on attribute names. Args: values (dict): The mapping from attribute names to padding values. default (int or float, optional): The padding value to use for attribute names not specified in :obj:`values`. (default: :obj:`0.0`) """ def validate_key_value(self, key: Any, value: Any) -> None: if not isinstance(key, str): raise ValueError(f"Expected the attribute name '{key}' to be a " f"string (got '{type(key)}')") if not isinstance(value, UniformPadding): raise ValueError(f"Expected the value of '{key}' to be of " f"type 'UniformPadding' (got '{type(value)}')") def get_value( self, store_type: Optional[Union[NodeType, EdgeType]] = None, attr_name: Optional[str] = None, ) -> Union[int, float]: padding = self.values.get(attr_name, self.default) return padding.get_value() class NodeTypePadding(MappingPadding): r"""Padding dependent on node types. Args: values (dict): The mapping from node types to padding values. default (int or float, optional): The padding value to use for node types not specified in :obj:`values`. (default: :obj:`0.0`) """ def validate_key_value(self, key: Any, value: Any) -> None: if not isinstance(key, str): raise ValueError(f"Expected the node type '{key}' to be a string " f"(got '{type(key)}')") if not isinstance(value, (UniformPadding, AttrNamePadding)): raise ValueError(f"Expected the value of '{key}' to be of " f"type 'UniformPadding' or 'AttrNamePadding' " f"(got '{type(value)}')") def get_value( self, store_type: Optional[Union[NodeType, EdgeType]] = None, attr_name: Optional[str] = None, ) -> Union[int, float]: padding = self.values.get(store_type, self.default) return padding.get_value(attr_name=attr_name) class EdgeTypePadding(MappingPadding): r"""Padding dependent on node types. Args: values (dict): The mapping from edge types to padding values. default (int or float, optional): The padding value to use for edge types not specified in :obj:`values`. (default: :obj:`0.0`) """ def validate_key_value(self, key: Any, value: Any) -> None: if not isinstance(key, tuple): raise ValueError(f"Expected the edge type '{key}' to be a tuple " f"(got '{type(key)}')") if len(key) != 3: raise ValueError(f"Expected the edge type '{key}' to hold exactly " f"three elements (got {len(key)})") if not isinstance(value, (UniformPadding, AttrNamePadding)): raise ValueError(f"Expected the value of '{key}' to be of " f"type 'UniformPadding' or 'AttrNamePadding' " f"(got '{type(value)}')") def get_value( self, store_type: Optional[Union[NodeType, EdgeType]] = None, attr_name: Optional[str] = None, ) -> Union[int, float]: padding = self.values.get(store_type, self.default) return padding.get_value(attr_name=attr_name) class _NumNodes: def __init__( self, value: Union[int, Dict[NodeType, int], None], ) -> None: self.value = value def get_value(self, key: Optional[NodeType] = None) -> Optional[int]: if self.value is None or isinstance(self.value, int): return self.value assert isinstance(key, str) return self.value[key] class _NumEdges: def __init__( self, value: Union[int, Dict[EdgeType, int], None], num_nodes: _NumNodes, ) -> None: if value is None: if isinstance(num_nodes.value, int): value = num_nodes.value * num_nodes.value else: value = {} self.value = value self.num_nodes = num_nodes def get_value(self, key: Optional[EdgeType] = None) -> Optional[int]: if self.value is None or isinstance(self.value, int): return self.value assert isinstance(key, tuple) and len(key) == 3 if key not in self.value: num_src_nodes = self.num_nodes.get_value(key[0]) num_dst_nodes = self.num_nodes.get_value(key[-1]) assert num_src_nodes is not None and num_dst_nodes is not None self.value[key] = num_src_nodes * num_dst_nodes return self.value[key] @functional_transform('pad') class Pad(BaseTransform): r"""Applies padding to enforce consistent tensor shapes (functional name: :obj:`pad`). This transform will pad node and edge features up to a maximum allowed size in the node or edge feature dimension. By default :obj:`0.0` is used as the padding value and can be configured by setting :obj:`node_pad_value` and :obj:`edge_pad_value`. In case of applying :class:`Pad` to a :class:`~torch_geometric.data.Data` object, the :obj:`node_pad_value` value (or :obj:`edge_pad_value`) can be either: * an int, float or object of :class:`UniformPadding` class for cases when all attributes are going to be padded with the same value; * an object of :class:`AttrNamePadding` class for cases when padding is going to differ based on attribute names. In case of applying :class:`Pad` to a :class:`~torch_geometric.data.HeteroData` object, the :obj:`node_pad_value` value (or :obj:`edge_pad_value`) can be either: * an int, float or object of :class:`UniformPadding` class for cases when all attributes of all node (or edge) stores are going to be padded with the same value; * an object of :class:`AttrNamePadding` class for cases when padding is going to differ based on attribute names (but not based on node or edge types); * an object of class :class:`NodeTypePadding` or :class:`EdgeTypePadding` for cases when padding values are going to differ based on node or edge types. Padding values can also differ based on attribute names for a given node or edge type by using :class:`AttrNamePadding` objects as values of its `values` argument. Note that in order to allow for consistent padding across all graphs in a dataset, below conditions must be met: * if :obj:`max_num_nodes` is a single value, it must be greater than or equal to the maximum number of nodes of any graph in the dataset; * if :obj:`max_num_nodes` is a dictionary, value for every node type must be greater than or equal to the maximum number of this type nodes of any graph in the dataset. Example below shows how to create a :class:`Pad` transform for an :class:`~torch_geometric.data.HeteroData` object. The object is padded to have :obj:`10` nodes of type :obj:`v0`, :obj:`20` nodes of type :obj:`v1` and :obj:`30` nodes of type :obj:`v2`. It is padded to have :obj:`80` edges of type :obj:`('v0', 'e0', 'v1')`. All the attributes of the :obj:`v0` nodes are padded using a value of :obj:`3.0`. The :obj:`x` attribute of the :obj:`v1` node type is padded using a value of :obj:`-1.0`, and the other attributes of this node type are padded using a value of :obj:`0.5`. All the attributes of node types other than :obj:`v0` and :obj:`v1` are padded using a value of :obj:`1.0`. All the attributes of the :obj:`('v0', 'e0', 'v1')` edge type are padded using a value of :obj:`3.5`. The :obj:`edge_attr` attributes of the :obj:`('v1', 'e0', 'v0')` edge type are padded using a value of :obj:`-1.5`, and any other attributes of this edge type are padded using a value of :obj:`5.5`. All the attributes of edge types other than these two are padded using a value of :obj:`1.5`. .. code-block:: python num_nodes = {'v0': 10, 'v1': 20, 'v2':30} num_edges = {('v0', 'e0', 'v1'): 80} node_padding = NodeTypePadding({ 'v0': 3.0, 'v1': AttrNamePadding({'x': -1.0}, default=0.5), }, default=1.0) edge_padding = EdgeTypePadding({ ('v0', 'e0', 'v1'): 3.5, ('v1', 'e0', 'v0'): AttrNamePadding({'edge_attr': -1.5}, default=5.5), }, default=1.5) transform = Pad(num_nodes, num_edges, node_padding, edge_padding) Args: max_num_nodes (int or dict): The number of nodes after padding. In heterogeneous graphs, may also take in a dictionary denoting the number of nodes for specific node types. max_num_edges (int or dict, optional): The number of edges after padding. In heterogeneous graphs, may also take in a dictionary denoting the number of edges for specific edge types. (default: :obj:`None`) node_pad_value (int or float or Padding, optional): The fill value to use for node features. (default: :obj:`0.0`) edge_pad_value (int or float or Padding, optional): The fill value to use for edge features. (default: :obj:`0.0`) The :obj:`edge_index` tensor is padded with with the index of the first padded node (which represents a set of self-loops on the padded node). (default: :obj:`0.0`) mask_pad_value (bool, optional): The fill value to use for :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask` attributes (default: :obj:`False`). add_pad_mask (bool, optional): If set to :obj:`True`, will attach node-level :obj:`pad_node_mask` and edge-level :obj:`pad_edge_mask` attributes to the output which indicates which elements in the data are real (represented by :obj:`True`) and which were added as a result of padding (represented by :obj:`False`). (default: :obj:`False`) exclude_keys ([str], optional): Keys to be removed from the input data object. (default: :obj:`None`) """ def __init__( self, max_num_nodes: Union[int, Dict[NodeType, int]], max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None, node_pad_value: Union[int, float, Padding] = 0.0, edge_pad_value: Union[int, float, Padding] = 0.0, mask_pad_value: bool = False, add_pad_mask: bool = False, exclude_keys: Optional[List[str]] = None, ): self.max_num_nodes = _NumNodes(max_num_nodes) self.max_num_edges = _NumEdges(max_num_edges, self.max_num_nodes) self.node_pad: Padding if not isinstance(node_pad_value, Padding): self.node_pad = UniformPadding(node_pad_value) else: self.node_pad = node_pad_value self.edge_pad: Padding if not isinstance(edge_pad_value, Padding): self.edge_pad = UniformPadding(edge_pad_value) else: self.edge_pad = edge_pad_value self.node_additional_attrs_pad = { key: mask_pad_value for key in ['train_mask', 'val_mask', 'test_mask'] } self.add_pad_mask = add_pad_mask self.exclude_keys = set(exclude_keys or []) def __should_pad_node_attr(self, attr_name: str) -> bool: if attr_name in self.node_additional_attrs_pad: return True if self.exclude_keys is None or attr_name not in self.exclude_keys: return True return False def __should_pad_edge_attr(self, attr_name: str) -> bool: if self.max_num_edges.value is None: return False if attr_name == 'edge_index': return True if self.exclude_keys is None or attr_name not in self.exclude_keys: return True return False def __get_node_padding( self, attr_name: str, node_type: Optional[NodeType] = None, ) -> Union[int, float]: if attr_name in self.node_additional_attrs_pad: return self.node_additional_attrs_pad[attr_name] return self.node_pad.get_value(node_type, attr_name) def __get_edge_padding( self, attr_name: str, edge_type: Optional[EdgeType] = None, ) -> Union[int, float]: return self.edge_pad.get_value(edge_type, attr_name) def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: if isinstance(data, Data): assert isinstance(self.node_pad, (UniformPadding, AttrNamePadding)) assert isinstance(self.edge_pad, (UniformPadding, AttrNamePadding)) for key in self.exclude_keys: del data[key] num_nodes = data.num_nodes assert num_nodes is not None self.__pad_edge_store(data._store, data.__cat_dim__, num_nodes) self.__pad_node_store(data._store, data.__cat_dim__) data.num_nodes = self.max_num_nodes.get_value() else: assert isinstance( self.node_pad, (UniformPadding, AttrNamePadding, NodeTypePadding)) assert isinstance( self.edge_pad, (UniformPadding, AttrNamePadding, EdgeTypePadding)) for edge_type, edge_store in data.edge_items(): for key in self.exclude_keys: del edge_store[key] src_node_type, _, dst_node_type = edge_type num_src_nodes = data[src_node_type].num_nodes num_dst_nodes = data[dst_node_type].num_nodes assert num_src_nodes is not None and num_dst_nodes is not None self.__pad_edge_store(edge_store, data.__cat_dim__, (num_src_nodes, num_dst_nodes), edge_type) for node_type, node_store in data.node_items(): for key in self.exclude_keys: del node_store[key] self.__pad_node_store(node_store, data.__cat_dim__, node_type) data[node_type].num_nodes = self.max_num_nodes.get_value( node_type) return data def __pad_node_store( self, store: NodeStorage, get_dim_fn: Callable, node_type: Optional[NodeType] = None, ) -> None: attrs_to_pad = [key for key in store.keys() if store.is_node_attr(key)] if len(attrs_to_pad) == 0: return num_target_nodes = self.max_num_nodes.get_value(node_type) assert num_target_nodes is not None assert store.num_nodes is not None assert num_target_nodes >= store.num_nodes, \ f'The number of nodes after padding ({num_target_nodes}) cannot ' \ f'be lower than the number of nodes in the data object ' \ f'({store.num_nodes}).' num_pad_nodes = num_target_nodes - store.num_nodes if self.add_pad_mask: pad_node_mask = torch.ones(num_target_nodes, dtype=torch.bool) pad_node_mask[store.num_nodes:] = False store.pad_node_mask = pad_node_mask for attr_name in attrs_to_pad: attr = store[attr_name] pad_value = self.__get_node_padding(attr_name, node_type) dim = get_dim_fn(attr_name, attr) store[attr_name] = self._pad_tensor_dim(attr, dim, num_pad_nodes, pad_value) def __pad_edge_store( self, store: EdgeStorage, get_dim_fn: Callable, num_nodes: Union[int, Tuple[int, int]], edge_type: Optional[EdgeType] = None, ) -> None: attrs_to_pad = { attr for attr in store.keys() if store.is_edge_attr(attr) and self.__should_pad_edge_attr(attr) } if not attrs_to_pad: return num_target_edges = self.max_num_edges.get_value(edge_type) assert num_target_edges is not None assert num_target_edges >= store.num_edges, \ f'The number of edges after padding ({num_target_edges}) cannot ' \ f'be lower than the number of edges in the data object ' \ f'({store.num_edges}).' num_pad_edges = num_target_edges - store.num_edges if self.add_pad_mask: pad_edge_mask = torch.ones(num_target_edges, dtype=torch.bool) pad_edge_mask[store.num_edges:] = False store.pad_edge_mask = pad_edge_mask if isinstance(num_nodes, tuple): src_pad_value, dst_pad_value = num_nodes else: src_pad_value = dst_pad_value = num_nodes for attr_name in attrs_to_pad: attr = store[attr_name] dim = get_dim_fn(attr_name, attr) if attr_name == 'edge_index': store[attr_name] = self._pad_edge_index( attr, num_pad_edges, src_pad_value, dst_pad_value) else: pad_value = self.__get_edge_padding(attr_name, edge_type) store[attr_name] = self._pad_tensor_dim( attr, dim, num_pad_edges, pad_value) @staticmethod def _pad_tensor_dim(input: torch.Tensor, dim: int, length: int, pad_value: float) -> torch.Tensor: r"""Pads the input tensor in the specified dim with a constant value of the given length. """ pads = [0] * (2 * input.ndim) pads[-2 * dim - 1] = length return F.pad(input, pads, 'constant', pad_value) @staticmethod def _pad_edge_index(input: torch.Tensor, length: int, src_pad_value: float, dst_pad_value: float) -> torch.Tensor: r"""Pads the edges :obj:`edge_index` feature with values specified separately for src and dst nodes. """ pads = [0, length, 0, 0] padded = F.pad(input, pads, 'constant', src_pad_value) if src_pad_value != dst_pad_value: padded[1, input.shape[1]:] = dst_pad_value return padded def __repr__(self) -> str: s = f'{self.__class__.__name__}(' s += f'max_num_nodes={self.max_num_nodes.value}, ' s += f'max_num_edges={self.max_num_edges.value}, ' s += f'node_pad_value={self.node_pad}, ' s += f'edge_pad_value={self.edge_pad})' return s ================================================ FILE: torch_geometric/transforms/point_pair_features.py ================================================ import torch import torch_geometric from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('point_pair_features') class PointPairFeatures(BaseTransform): r"""Computes the rotation-invariant Point Pair Features (functional name: :obj:`point_pair_features`). .. math:: \left( \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right) of linked nodes in its edge attributes, where :math:`\mathbf{d}_{j,i}` denotes the difference vector between, and :math:`\mathbf{n}_i` and :math:`\mathbf{n}_j` denote the surface normals of node :math:`i` and :math:`j` respectively. Args: cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) """ def __init__(self, cat: bool = True): self.cat = cat def forward(self, data: Data) -> Data: ppf_func = torch_geometric.nn.conv.ppf_conv.point_pair_features assert data.edge_index is not None assert data.pos is not None and data.norm is not None assert data.pos.size(-1) == 3 assert data.pos.size() == data.norm.size() row, col = data.edge_index pos, norm, pseudo = data.pos, data.norm, data.edge_attr ppf = ppf_func(pos[row], pos[col], norm[row], norm[col]) if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo data.edge_attr = torch.cat([pseudo, ppf.type_as(pseudo)], dim=-1) else: data.edge_attr = ppf return data ================================================ FILE: torch_geometric/transforms/polar.py ================================================ from math import pi as PI from typing import Optional import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('polar') class Polar(BaseTransform): r"""Saves the polar coordinates of linked nodes in its edge attributes (functional name: :obj:`polar`). Args: norm (bool, optional): If set to :obj:`False`, the output will not be normalized to the interval :math:`{[0, 1]}^2`. (default: :obj:`True`) max_value (float, optional): If set and :obj:`norm=True`, normalization will be performed based on this value instead of the maximum value found in the data. (default: :obj:`None`) cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) """ def __init__( self, norm: bool = True, max_value: Optional[float] = None, cat: bool = True, ) -> None: self.norm = norm self.max = max_value self.cat = cat def forward(self, data: Data) -> Data: assert data.pos is not None assert data.edge_index is not None (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr assert pos.dim() == 2 and pos.size(1) == 2 cart = pos[col] - pos[row] rho = torch.norm(cart, p=2, dim=-1).view(-1, 1) theta = torch.atan2(cart[..., 1], cart[..., 0]).view(-1, 1) theta = theta + (theta < 0).type_as(theta) * (2 * PI) if self.norm: rho = rho / (rho.max() if self.max is None else self.max) theta = theta / (2 * PI) polar = torch.cat([rho, theta], dim=-1) if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo data.edge_attr = torch.cat([pseudo, polar.type_as(pos)], dim=-1) else: data.edge_attr = polar return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(norm={self.norm}, ' f'max_value={self.max})') ================================================ FILE: torch_geometric/transforms/radius_graph.py ================================================ import torch_geometric from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('radius_graph') class RadiusGraph(BaseTransform): r"""Creates edges based on node positions :obj:`data.pos` to all points within a given distance (functional name: :obj:`radius_graph`). Args: r (float): The distance. loop (bool, optional): If :obj:`True`, the graph will contain self-loops. (default: :obj:`False`) max_num_neighbors (int, optional): The maximum number of neighbors to return for each element in :obj:`y`. This flag is only needed for CUDA tensors. (default: :obj:`32`) flow (str, optional): The flow direction when using in combination with message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) """ def __init__( self, r: float, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', num_workers: int = 1, ) -> None: self.r = r self.loop = loop self.max_num_neighbors = max_num_neighbors self.flow = flow self.num_workers = num_workers def forward(self, data: Data) -> Data: assert data.pos is not None data.edge_index = torch_geometric.nn.radius_graph( data.pos, self.r, data.batch, self.loop, max_num_neighbors=self.max_num_neighbors, flow=self.flow, num_workers=self.num_workers, ) data.edge_attr = None return data def __repr__(self) -> str: return f'{self.__class__.__name__}(r={self.r})' ================================================ FILE: torch_geometric/transforms/random_flip.py ================================================ import random from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('random_flip') class RandomFlip(BaseTransform): """Flips node positions along a given axis randomly with a given probability (functional name: :obj:`random_flip`). Args: axis (int): The axis along the position of nodes being flipped. p (float, optional): Probability that node positions will be flipped. (default: :obj:`0.5`) """ def __init__(self, axis: int, p: float = 0.5) -> None: self.axis = axis self.p = p def forward(self, data: Data) -> Data: assert data.pos is not None if random.random() < self.p: pos = data.pos.clone() pos[..., self.axis] = -pos[..., self.axis] data.pos = pos return data def __repr__(self) -> str: return f'{self.__class__.__name__}(axis={self.axis}, p={self.p})' ================================================ FILE: torch_geometric/transforms/random_jitter.py ================================================ from itertools import repeat from typing import Sequence, Union from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('random_jitter') class RandomJitter(BaseTransform): r"""Translates node positions by randomly sampled translation values within a given interval (functional name: :obj:`random_jitter`). In contrast to other random transformations, translation is applied separately at each position. Args: translate (sequence or float or int): Maximum translation in each dimension, defining the range :math:`(-\mathrm{translate}, +\mathrm{translate})` to sample from. If :obj:`translate` is a number instead of a sequence, the same range is used for each dimension. """ def __init__( self, translate: Union[float, int, Sequence[Union[float, int]]], ) -> None: self.translate = translate def forward(self, data: Data) -> Data: assert data.pos is not None num_nodes, dim = data.pos.size() translate: Sequence[Union[float, int]] if isinstance(self.translate, (int, float)): translate = list(repeat(self.translate, times=dim)) else: assert len(self.translate) == dim translate = self.translate jitter = data.pos.new_empty(num_nodes, dim) for d in range(dim): jitter[:, d].uniform_(-abs(translate[d]), abs(translate[d])) data.pos = data.pos + jitter return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.translate})' ================================================ FILE: torch_geometric/transforms/random_link_split.py ================================================ import copy import warnings from typing import List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.data.storage import EdgeStorage from torch_geometric.transforms import BaseTransform from torch_geometric.typing import EdgeType from torch_geometric.utils import negative_sampling @functional_transform('random_link_split') class RandomLinkSplit(BaseTransform): r"""Performs an edge-level random split into training, validation and test sets of a :class:`~torch_geometric.data.Data` or a :class:`~torch_geometric.data.HeteroData` object (functional name: :obj:`random_link_split`). The split is performed such that the training split does not include edges in validation and test splits; and the validation split does not include edges in the test split. .. code-block:: python from torch_geometric.transforms import RandomLinkSplit transform = RandomLinkSplit(is_undirected=True) train_data, val_data, test_data = transform(data) Args: num_val (int or float, optional): The number of validation edges. If set to a floating-point value in :math:`[0, 1]`, it represents the ratio of edges to include in the validation set. (default: :obj:`0.1`) num_test (int or float, optional): The number of test edges. If set to a floating-point value in :math:`[0, 1]`, it represents the ratio of edges to include in the test set. (default: :obj:`0.2`) is_undirected (bool): If set to :obj:`True`, the graph is assumed to be undirected, and positive and negative samples will not leak (reverse) edge connectivity across different splits. This only affects the graph split, label data will not be returned undirected. This option is ignored for bipartite edge types or whenever :obj:`edge_type != rev_edge_type`. (default: :obj:`False`) key (str, optional): The name of the attribute holding ground-truth labels. If :obj:`data[key]` does not exist, it will be automatically created and represents a binary classification task (:obj:`1` = edge, :obj:`0` = no edge). If :obj:`data[key]` exists, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. (default: :obj:`"edge_label"`) split_labels (bool, optional): If set to :obj:`True`, will split positive and negative labels and save them in distinct attributes :obj:`"pos_edge_label"` and :obj:`"neg_edge_label"`, respectively. (default: :obj:`False`) add_negative_train_samples (bool, optional): Whether to add negative training samples for link prediction. If the model already performs negative sampling, then the option should be set to :obj:`False`. Otherwise, the added negative samples will be the same across training iterations unless negative sampling is performed again. (default: :obj:`True`) neg_sampling_ratio (float, optional): The ratio of sampled negative edges to the number of positive edges. (default: :obj:`1.0`) disjoint_train_ratio (int or float, optional): If set to a value greater than :obj:`0.0`, training edges will not be shared for message passing and supervision. Instead, :obj:`disjoint_train_ratio` edges are used as ground-truth labels for supervision during training. (default: :obj:`0.0`) edge_types (Tuple[EdgeType] or List[EdgeType], optional): The edge types used for performing edge-level splitting in case of operating on :class:`~torch_geometric.data.HeteroData` objects. (default: :obj:`None`) rev_edge_types (Tuple[EdgeType] or List[Tuple[EdgeType]], optional): The reverse edge types of :obj:`edge_types` in case of operating on :class:`~torch_geometric.data.HeteroData` objects. This will ensure that edges of the reverse direction will be split accordingly to prevent any data leakage. Can be :obj:`None` in case no reverse connection exists. (default: :obj:`None`) """ def __init__( self, num_val: Union[int, float] = 0.1, num_test: Union[int, float] = 0.2, is_undirected: bool = False, key: str = 'edge_label', split_labels: bool = False, add_negative_train_samples: bool = True, neg_sampling_ratio: float = 1.0, disjoint_train_ratio: Union[int, float] = 0.0, edge_types: Optional[Union[EdgeType, List[EdgeType]]] = None, rev_edge_types: Optional[Union[ EdgeType, List[Optional[EdgeType]], ]] = None, ) -> None: if isinstance(edge_types, list): if rev_edge_types is None: rev_edge_types = [None] * len(edge_types) assert isinstance(rev_edge_types, list) assert len(edge_types) == len(rev_edge_types) self.num_val = num_val self.num_test = num_test self.is_undirected = is_undirected self.key = key self.split_labels = split_labels self.add_negative_train_samples = add_negative_train_samples self.neg_sampling_ratio = neg_sampling_ratio self.disjoint_train_ratio = disjoint_train_ratio self.edge_types = edge_types self.rev_edge_types = rev_edge_types def forward( self, data: Union[Data, HeteroData], ) -> Tuple[ Union[Data, HeteroData], Union[Data, HeteroData], Union[Data, HeteroData], ]: edge_types = self.edge_types rev_edge_types = self.rev_edge_types train_data = copy.copy(data) val_data = copy.copy(data) test_data = copy.copy(data) if isinstance(data, HeteroData): assert isinstance(train_data, HeteroData) assert isinstance(val_data, HeteroData) assert isinstance(test_data, HeteroData) if edge_types is None: raise ValueError( "The 'RandomLinkSplit' transform expects 'edge_types' to " "be specified when operating on 'HeteroData' objects") if not isinstance(edge_types, list): assert not isinstance(rev_edge_types, list) edge_types = [edge_types] rev_edge_types = [rev_edge_types] stores = [data[edge_type] for edge_type in edge_types] train_stores = [train_data[edge_type] for edge_type in edge_types] val_stores = [val_data[edge_type] for edge_type in edge_types] test_stores = [test_data[edge_type] for edge_type in edge_types] else: assert isinstance(train_data, Data) assert isinstance(val_data, Data) assert isinstance(test_data, Data) rev_edge_types = [None] train_data = copy.copy(data) val_data = copy.copy(data) test_data = copy.copy(data) stores = [data._store] train_stores = [train_data._store] val_stores = [val_data._store] test_stores = [test_data._store] assert isinstance(rev_edge_types, list) for item in zip(stores, train_stores, val_stores, test_stores, rev_edge_types): store, train_store, val_store, test_store, rev_edge_type = item is_undirected = self.is_undirected is_undirected &= not store.is_bipartite() is_undirected &= (rev_edge_type is None or (isinstance(data, HeteroData) and store._key == data[rev_edge_type]._key)) edge_index = store.edge_index if is_undirected: mask = edge_index[0] <= edge_index[1] perm = mask.nonzero(as_tuple=False).view(-1) perm = perm[torch.randperm(perm.size(0), device=perm.device)] else: device = edge_index.device perm = torch.randperm(edge_index.size(1), device=device) num_val = self.num_val if isinstance(num_val, float): num_val = int(num_val * perm.numel()) num_test = self.num_test if isinstance(num_test, float): num_test = int(num_test * perm.numel()) num_train = perm.numel() - num_val - num_test if num_train <= 0: raise ValueError("Insufficient number of edges for training") train_edges = perm[:num_train] val_edges = perm[num_train:num_train + num_val] test_edges = perm[num_train + num_val:] train_val_edges = perm[:num_train + num_val] num_disjoint = self.disjoint_train_ratio if isinstance(num_disjoint, float): num_disjoint = int(num_disjoint * train_edges.numel()) if num_train - num_disjoint <= 0: raise ValueError("Insufficient number of edges for training") # Create data splits: self._split(train_store, train_edges[num_disjoint:], is_undirected, rev_edge_type) self._split(val_store, train_edges, is_undirected, rev_edge_type) self._split(test_store, train_val_edges, is_undirected, rev_edge_type) # Create negative samples: num_neg_train = 0 if self.add_negative_train_samples: if num_disjoint > 0: num_neg_train = int(num_disjoint * self.neg_sampling_ratio) else: num_neg_train = int(num_train * self.neg_sampling_ratio) num_neg_val = int(num_val * self.neg_sampling_ratio) num_neg_test = int(num_test * self.neg_sampling_ratio) num_neg = num_neg_train + num_neg_val + num_neg_test size = store.size() if store._key is None or store._key[0] == store._key[-1]: size = size[0] neg_edge_index = negative_sampling(edge_index, size, num_neg_samples=num_neg, method='sparse') # Adjust ratio if not enough negative edges exist if neg_edge_index.size(1) < num_neg: num_neg_found = neg_edge_index.size(1) ratio = num_neg_found / num_neg warnings.warn( f"There are not enough negative edges to satisfy " "the provided sampling ratio. The ratio will be " f"adjusted to {ratio:.2f}.", stacklevel=2) num_neg_train = int((num_neg_train / num_neg) * num_neg_found) num_neg_val = int((num_neg_val / num_neg) * num_neg_found) num_neg_test = num_neg_found - num_neg_train - num_neg_val # Create labels: if num_disjoint > 0: train_edges = train_edges[:num_disjoint] self._create_label( store, train_edges, neg_edge_index[:, num_neg_val + num_neg_test:], out=train_store, ) self._create_label( store, val_edges, neg_edge_index[:, :num_neg_val], out=val_store, ) self._create_label( store, test_edges, neg_edge_index[:, num_neg_val:num_neg_val + num_neg_test], out=test_store, ) return train_data, val_data, test_data def _split( self, store: EdgeStorage, index: Tensor, is_undirected: bool, rev_edge_type: Optional[EdgeType], ) -> EdgeStorage: edge_attrs = {key for key in store.keys() if store.is_edge_attr(key)} for key, value in store.items(): if key == 'edge_index': continue if key in edge_attrs: value = value[index] if is_undirected: value = torch.cat([value, value], dim=0) store[key] = value edge_index = store.edge_index[:, index] if is_undirected: edge_index = torch.cat([edge_index, edge_index.flip([0])], dim=-1) store.edge_index = edge_index if rev_edge_type is not None: rev_store = store._parent()[rev_edge_type] for key in rev_store.keys(): if key not in store: del rev_store[key] # We delete all outdated attributes. elif key == 'edge_index': rev_store.edge_index = store.edge_index.flip([0]) else: rev_store[key] = store[key] return store def _create_label( self, store: EdgeStorage, index: Tensor, neg_edge_index: Tensor, out: EdgeStorage, ) -> EdgeStorage: edge_index = store.edge_index[:, index] if hasattr(store, self.key): edge_label = store[self.key] edge_label = edge_label[index] # Increment labels by one. Note that there is no need to increment # in case no negative edges are added. if neg_edge_index.numel() > 0: assert edge_label.dtype == torch.long assert edge_label.size(0) == edge_index.size(1) edge_label.add_(1) if hasattr(out, self.key): delattr(out, self.key) else: edge_label = torch.ones(index.numel(), device=index.device) if neg_edge_index.numel() > 0: neg_edge_label = edge_label.new_zeros((neg_edge_index.size(1), ) + edge_label.size()[1:]) if self.split_labels: out[f'pos_{self.key}'] = edge_label out[f'pos_{self.key}_index'] = edge_index if neg_edge_index.numel() > 0: out[f'neg_{self.key}'] = neg_edge_label out[f'neg_{self.key}_index'] = neg_edge_index else: if neg_edge_index.numel() > 0: edge_label = torch.cat([edge_label, neg_edge_label], dim=0) edge_index = torch.cat([edge_index, neg_edge_index], dim=-1) out[self.key] = edge_label out[f'{self.key}_index'] = edge_index return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_val={self.num_val}, ' f'num_test={self.num_test})') ================================================ FILE: torch_geometric/transforms/random_node_split.py ================================================ from typing import Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.data.storage import NodeStorage from torch_geometric.transforms import BaseTransform @functional_transform('random_node_split') class RandomNodeSplit(BaseTransform): r"""Performs a node-level random split by adding :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask` attributes to the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object (functional name: :obj:`random_node_split`). Args: split (str, optional): The type of dataset split (:obj:`"train_rest"`, :obj:`"test_rest"`, :obj:`"random"`). If set to :obj:`"train_rest"`, all nodes except those in the validation and test sets will be used for training (as in the `"FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling" `_ paper). If set to :obj:`"test_rest"`, all nodes except those in the training and validation sets will be used for test (as in the `"Pitfalls of Graph Neural Network Evaluation" `_ paper). If set to :obj:`"random"`, train, validation, and test sets will be randomly generated, according to :obj:`num_train_per_class`, :obj:`num_val` and :obj:`num_test` (as in the `"Semi-supervised Classification with Graph Convolutional Networks" `_ paper). (default: :obj:`"train_rest"`) num_splits (int, optional): The number of splits to add. If bigger than :obj:`1`, the shape of masks will be :obj:`[num_nodes, num_splits]`, and :obj:`[num_nodes]` otherwise. (default: :obj:`1`) num_train_per_class (int, optional): The number of training samples per class in case of :obj:`"test_rest"` and :obj:`"random"` split. (default: :obj:`20`) num_val (int or float, optional): The number of validation samples. If float, it represents the ratio of samples to include in the validation set. (default: :obj:`500`) num_test (int or float, optional): The number of test samples in case of :obj:`"train_rest"` and :obj:`"random"` split. If float, it represents the ratio of samples to include in the test set. (default: :obj:`1000`) key (str, optional): The name of the attribute holding ground-truth labels. By default, will only add node-level splits for node-level storages in which :obj:`key` is present. (default: :obj:`"y"`). """ def __init__( self, split: str = "train_rest", num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, key: Optional[str] = "y", ) -> None: assert split in ['train_rest', 'test_rest', 'random'] self.split = split self.num_splits = num_splits self.num_train_per_class = num_train_per_class self.num_val = num_val self.num_test = num_test self.key = key def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.node_stores: if self.key is not None and not hasattr(store, self.key): continue train_masks, val_masks, test_masks = zip( *[self._split(store) for _ in range(self.num_splits)]) store.train_mask = torch.stack(train_masks, dim=-1).squeeze(-1) store.val_mask = torch.stack(val_masks, dim=-1).squeeze(-1) store.test_mask = torch.stack(test_masks, dim=-1).squeeze(-1) return data def _split(self, store: NodeStorage) -> Tuple[Tensor, Tensor, Tensor]: num_nodes = store.num_nodes assert num_nodes is not None train_mask = torch.zeros(num_nodes, dtype=torch.bool) val_mask = torch.zeros(num_nodes, dtype=torch.bool) test_mask = torch.zeros(num_nodes, dtype=torch.bool) if isinstance(self.num_val, float): num_val = round(num_nodes * self.num_val) else: num_val = self.num_val if isinstance(self.num_test, float): num_test = round(num_nodes * self.num_test) else: num_test = self.num_test if self.split == 'train_rest': perm = torch.randperm(num_nodes) val_mask[perm[:num_val]] = True test_mask[perm[num_val:num_val + num_test]] = True train_mask[perm[num_val + num_test:]] = True else: assert self.key is not None y = getattr(store, self.key) num_classes = int(y.max().item()) + 1 for c in range(num_classes): idx = (y == c).nonzero(as_tuple=False).view(-1) idx = idx[torch.randperm(idx.size(0))] idx = idx[:self.num_train_per_class] train_mask[idx] = True remaining = (~train_mask).nonzero(as_tuple=False).view(-1) remaining = remaining[torch.randperm(remaining.size(0))] val_mask[remaining[:num_val]] = True if self.split == 'test_rest': test_mask[remaining[num_val:]] = True elif self.split == 'random': test_mask[remaining[num_val:num_val + num_test]] = True return train_mask, val_mask, test_mask def __repr__(self) -> str: return f'{self.__class__.__name__}(split={self.split})' ================================================ FILE: torch_geometric/transforms/random_rotate.py ================================================ import math import random from typing import Tuple, Union import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform, LinearTransformation @functional_transform('random_rotate') class RandomRotate(BaseTransform): r"""Rotates node positions around a specific axis by a randomly sampled factor within a given interval (functional name: :obj:`random_rotate`). Args: degrees (tuple or float): Rotation interval from which the rotation angle is sampled. If :obj:`degrees` is a number instead of a tuple, the interval is given by :math:`[-\mathrm{degrees}, \mathrm{degrees}]`. axis (int, optional): The rotation axis. (default: :obj:`0`) """ def __init__( self, degrees: Union[Tuple[float, float], float], axis: int = 0, ) -> None: if isinstance(degrees, (int, float)): degrees = (-abs(degrees), abs(degrees)) assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 self.degrees = degrees self.axis = axis def forward(self, data: Data) -> Data: assert data.pos is not None degree = math.pi * random.uniform(*self.degrees) / 180.0 sin, cos = math.sin(degree), math.cos(degree) if data.pos.size(-1) == 2: matrix = [[cos, sin], [-sin, cos]] else: if self.axis == 0: matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] elif self.axis == 1: matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] else: matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] return LinearTransformation(torch.tensor(matrix))(data) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.degrees}, ' f'axis={self.axis})') ================================================ FILE: torch_geometric/transforms/random_scale.py ================================================ import random from typing import Tuple from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('random_scale') class RandomScale(BaseTransform): r"""Scales node positions by a randomly sampled factor :math:`s` within a given interval, *e.g.*, resulting in the transformation matrix (functional name: :obj:`random_scale`). .. math:: \begin{bmatrix} s & 0 & 0 \\ 0 & s & 0 \\ 0 & 0 & s \\ \end{bmatrix} for three-dimensional positions. Args: scales (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale is randomly sampled from the range :math:`a \leq \mathrm{scale} \leq b`. """ def __init__(self, scales: Tuple[float, float]) -> None: assert isinstance(scales, (tuple, list)) and len(scales) == 2 self.scales = scales def forward(self, data: Data) -> Data: assert data.pos is not None scale = random.uniform(*self.scales) data.pos = data.pos * scale return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.scales})' ================================================ FILE: torch_geometric/transforms/random_shear.py ================================================ from typing import Union import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform, LinearTransformation @functional_transform('random_shear') class RandomShear(BaseTransform): r"""Shears node positions by randomly sampled factors :math:`s` within a given interval, *e.g.*, resulting in the transformation matrix (functional name: :obj:`random_shear`). .. math:: \begin{bmatrix} 1 & s_{xy} & s_{xz} \\ s_{yx} & 1 & s_{yz} \\ s_{zx} & z_{zy} & 1 \\ \end{bmatrix} for three-dimensional positions. Args: shear (float or int): maximum shearing factor defining the range :math:`(-\mathrm{shear}, +\mathrm{shear})` to sample from. """ def __init__(self, shear: Union[float, int]) -> None: self.shear = abs(shear) def forward(self, data: Data) -> Data: assert data.pos is not None dim = data.pos.size(-1) matrix = data.pos.new_empty(dim, dim).uniform_(-self.shear, self.shear) eye = torch.arange(dim, dtype=torch.long) matrix[eye, eye] = 1 return LinearTransformation(matrix)(data) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.shear})' ================================================ FILE: torch_geometric/transforms/remove_duplicated_edges.py ================================================ from typing import List, Optional, Union from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import coalesce @functional_transform('remove_duplicated_edges') class RemoveDuplicatedEdges(BaseTransform): r"""Removes duplicated edges from a given homogeneous or heterogeneous graph. Useful to clean-up known repeated edges/self-loops in common benchmark datasets, *e.g.*, in :obj:`ogbn-products`. (functional name: :obj:`remove_duplicated_edges`). Args: key (str or [str], optional): The name of edge attribute(s) to merge in case of duplication. (default: :obj:`["edge_weight", "edge_attr"]`) reduce (str, optional): The reduce operation to use for merging edge attributes (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"add"`) """ def __init__( self, key: Optional[Union[str, List[str]]] = None, reduce: str = "add", ) -> None: key = key or ['edge_attr', 'edge_weight'] if isinstance(key, str): key = [key] self.keys = key self.reduce = reduce def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.edge_stores: keys = [key for key in self.keys if key in store] size = [s for s in store.size() if s is not None] num_nodes = max(size) if len(size) > 0 else None store.edge_index, edge_attrs = coalesce( edge_index=store.edge_index, edge_attr=[store[key] for key in keys], num_nodes=num_nodes, reduce=self.reduce, ) for key, edge_attr in zip(keys, edge_attrs): store[key] = edge_attr return data ================================================ FILE: torch_geometric/transforms/remove_isolated_nodes.py ================================================ import copy from collections import defaultdict from typing import Union import torch from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('remove_isolated_nodes') class RemoveIsolatedNodes(BaseTransform): r"""Removes isolated nodes from the graph (functional name: :obj:`remove_isolated_nodes`). """ def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: # Gather all nodes that occur in at least one edge (across all types): n_ids_dict = defaultdict(list) for edge_store in data.edge_stores: if 'edge_index' not in edge_store: continue if edge_store._key is None: src = dst = None else: src, _, dst = edge_store._key n_ids_dict[src].append(edge_store.edge_index[0]) n_ids_dict[dst].append(edge_store.edge_index[1]) n_id_dict = {k: torch.cat(v).unique() for k, v in n_ids_dict.items()} n_map_dict = {} for node_store in data.node_stores: if node_store._key not in n_id_dict: n_id_dict[node_store._key] = torch.empty(0, dtype=torch.long) idx = n_id_dict[node_store._key] assert data.num_nodes is not None mapping = idx.new_zeros(data.num_nodes) mapping[idx] = torch.arange(idx.numel(), device=mapping.device) n_map_dict[node_store._key] = mapping for edge_store in data.edge_stores: if 'edge_index' not in edge_store: continue if edge_store._key is None: src = dst = None else: src, _, dst = edge_store._key row = n_map_dict[src][edge_store.edge_index[0]] col = n_map_dict[dst][edge_store.edge_index[1]] edge_store.edge_index = torch.stack([row, col], dim=0) old_data = copy.copy(data) for out, node_store in zip(data.node_stores, old_data.node_stores): for key, value in node_store.items(): if key == 'num_nodes': out.num_nodes = n_id_dict[node_store._key].numel() elif node_store.is_node_attr(key): out[key] = value[n_id_dict[node_store._key]] return data ================================================ FILE: torch_geometric/transforms/remove_self_loops.py ================================================ from typing import Union from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import remove_self_loops @functional_transform('remove_self_loops') class RemoveSelfLoops(BaseTransform): r"""Removes all self-loops in the given homogeneous or heterogeneous graph (functional name: :obj:`remove_self_loops`). Args: attr (str, optional): The name of the attribute of edge weights or multi-dimensional edge features to pass to :meth:`torch_geometric.utils.remove_self_loops`. (default: :obj:`"edge_weight"`) """ def __init__(self, attr: str = 'edge_weight') -> None: self.attr = attr def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.edge_stores: if store.is_bipartite() or 'edge_index' not in store: continue store.edge_index, store[self.attr] = remove_self_loops( store.edge_index, edge_attr=store.get(self.attr, None), ) return data ================================================ FILE: torch_geometric/transforms/remove_training_classes.py ================================================ from typing import List from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('remove_training_classes') class RemoveTrainingClasses(BaseTransform): r"""Removes classes from the node-level training set as given by :obj:`data.train_mask`, *e.g.*, in order to get a zero-shot label scenario (functional name: :obj:`remove_training_classes`). Args: classes (List[int]): The classes to remove from the training set. """ def __init__(self, classes: List[int]): self.classes = classes def forward(self, data: Data) -> Data: data.train_mask = data.train_mask.clone() for i in self.classes: data.train_mask[data.y == i] = False return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.classes})' ================================================ FILE: torch_geometric/transforms/rooted_subgraph.py ================================================ import copy from abc import ABC, abstractmethod from typing import Any, Tuple import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_torch_csc_tensor class RootedSubgraphData(Data): r"""A data object describing a homogeneous graph together with each node's rooted subgraph. It contains several additional properties that hold the information to map to batch of every node's rooted subgraph: * :obj:`sub_edge_index` (Tensor): The edge indices of all combined rooted subgraphs. * :obj:`n_id` (Tensor): The indices of nodes in all combined rooted subgraphs. * :obj:`e_id` (Tensor): The indices of edges in all combined rooted subgraphs. * :obj:`n_sub_batch` (Tensor): The batch vector to distinguish nodes across different subgraphs. * :obj:`e_sub_batch` (Tensor): The batch vector to distinguish edges across different subgraphs. """ def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any: if key == 'sub_edge_index': return self.n_id.size(0) if key in ['n_sub_batch', 'e_sub_batch']: return 1 + int(self.n_sub_batch[-1]) elif key == 'n_id': return self.num_nodes elif key == 'e_id': assert self.edge_index is not None return self.edge_index.size(1) return super().__inc__(key, value, *args, **kwargs) def map_data(self) -> Data: # Maps all feature information of the :class:`Data` object to each # rooted subgraph. data = copy.copy(self) for key, value in self.items(): if key in ['sub_edge_index', 'n_id', 'e_id', 'e_sub_batch']: del data[key] elif key == 'n_sub_batch': continue elif key == 'num_nodes': data.num_nodes = self.n_id.size(0) elif key == 'edge_index': data.edge_index = self.sub_edge_index elif self.is_node_attr(key): dim = self.__cat_dim__(key, value) data[key] = value.index_select(dim, self.n_id) elif self.is_edge_attr(key): dim = self.__cat_dim__(key, value) data[key] = value.index_select(dim, self.e_id) return data class RootedSubgraph(BaseTransform, ABC): r"""Base class for implementing rooted subgraph transformations.""" @abstractmethod def extract( self, data: Data, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: # Returns the tuple: # :obj:`(sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch)` # of the :class:`RootedSubgraphData` object. pass def map( self, data: Data, n_mask: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: assert data.edge_index is not None num_nodes = data.num_nodes assert num_nodes is not None n_sub_batch, n_id = n_mask.nonzero().t() e_mask = n_mask[:, data.edge_index[0]] & n_mask[:, data.edge_index[1]] e_sub_batch, e_id = e_mask.nonzero().t() sub_edge_index = data.edge_index[:, e_id] arange = torch.arange(n_id.size(0), device=data.edge_index.device) node_map = data.edge_index.new_ones(num_nodes, num_nodes) node_map[n_sub_batch, n_id] = arange sub_edge_index += (arange * num_nodes)[e_sub_batch] sub_edge_index = node_map.view(-1)[sub_edge_index] return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch def forward(self, data: Data) -> RootedSubgraphData: out = self.extract(data) d = RootedSubgraphData.from_dict(data.to_dict()) d.sub_edge_index, d.n_id, d.e_id, d.n_sub_batch, d.e_sub_batch = out return d class RootedEgoNets(RootedSubgraph): r"""Collects rooted :math:`k`-hop EgoNets for each node in the graph, as described in the `"From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness" `_ paper. Args: num_hops (int): the number of hops :math:`k`. """ def __init__(self, num_hops: int) -> None: super().__init__() self.num_hops = num_hops def extract( self, data: Data, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: assert data.edge_index is not None num_nodes = data.num_nodes assert num_nodes is not None adj_t = to_torch_csc_tensor(data.edge_index, size=data.size()).t() n_mask = torch.eye(num_nodes, device=data.edge_index.device) for _ in range(self.num_hops): n_mask += adj_t @ n_mask return self.map(data, n_mask > 0) def __repr__(self) -> str: return f'{self.__class__.__name__}(num_hops={self.num_hops})' class RootedRWSubgraph(RootedSubgraph): """Collects rooted random-walk based subgraphs for each node in the graph, as described in the `"From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness" `_ paper. Args: walk_length (int): the length of the random walk. repeat (int, optional): The number of times of repeating the random walk to reduce randomness. (default: :obj:`1`) """ def __init__(self, walk_length: int, repeat: int = 1): super().__init__() self.walk_length = walk_length self.repeat = repeat def extract( self, data: Data, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: from torch_cluster import random_walk assert data.edge_index is not None num_nodes = data.num_nodes assert num_nodes is not None start = torch.arange(num_nodes, device=data.edge_index.device) start = start.view(-1, 1).repeat(1, self.repeat).view(-1) walk = random_walk(data.edge_index[0], data.edge_index[1], start, self.walk_length, num_nodes=data.num_nodes) n_mask = torch.zeros((num_nodes, num_nodes), dtype=torch.bool, device=walk.device) start = start.view(-1, 1).repeat(1, (self.walk_length + 1)).view(-1) n_mask[start, walk.view(-1)] = True return self.map(data, n_mask) def __repr__(self) -> str: return f'{self.__class__.__name__}(walk_length={self.walk_length})' ================================================ FILE: torch_geometric/transforms/sample_points.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('sample_points') class SamplePoints(BaseTransform): r"""Uniformly samples a fixed number of points on the mesh faces according to their face area (functional name: :obj:`sample_points`). Args: num (int): The number of points to sample. remove_faces (bool, optional): If set to :obj:`False`, the face tensor will not be removed. (default: :obj:`True`) include_normals (bool, optional): If set to :obj:`True`, then compute normals for each sampled point. (default: :obj:`False`) """ def __init__( self, num: int, remove_faces: bool = True, include_normals: bool = False, ): self.num = num self.remove_faces = remove_faces self.include_normals = include_normals def forward(self, data: Data) -> Data: assert data.pos is not None assert data.face is not None pos, face = data.pos, data.face assert pos.size(1) == 3 and face.size(0) == 3 pos_max = pos.abs().max() pos = pos / pos_max area = (pos[face[1]] - pos[face[0]]).cross( pos[face[2]] - pos[face[0]], dim=1, ) area = area.norm(p=2, dim=1).abs() / 2 prob = area / area.sum() sample = torch.multinomial(prob, self.num, replacement=True) face = face[:, sample] frac = torch.rand(self.num, 2, device=pos.device) mask = frac.sum(dim=-1) > 1 frac[mask] = 1 - frac[mask] vec1 = pos[face[1]] - pos[face[0]] vec2 = pos[face[2]] - pos[face[0]] if self.include_normals: data.normal = torch.nn.functional.normalize( vec1.cross(vec2, dim=1), p=2) pos_sampled = pos[face[0]] pos_sampled += frac[:, :1] * vec1 pos_sampled += frac[:, 1:] * vec2 pos_sampled = pos_sampled * pos_max data.pos = pos_sampled if self.remove_faces: data.face = None return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num})' ================================================ FILE: torch_geometric/transforms/sign.py ================================================ import torch from torch_geometric import EdgeIndex from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import scatter @functional_transform('sign') class SIGN(BaseTransform): r"""The Scalable Inception Graph Neural Network module (SIGN) from the `"SIGN: Scalable Inception Graph Neural Networks" `_ paper (functional name: :obj:`sign`), which precomputes the fixed representations. .. math:: \mathbf{X}^{(i)} = {\left( \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \right)}^i \mathbf{X} for :math:`i \in \{ 1, \ldots, K \}` and saves them in :obj:`data.x1`, :obj:`data.x2`, ... .. note:: Since intermediate node representations are pre-computed, this operator is able to scale well to large graphs via classic mini-batching. For an example of using SIGN, see `examples/sign.py `_. Args: K (int): The number of hops/layer. """ def __init__(self, K: int) -> None: self.K = K def forward(self, data: Data) -> Data: assert data.edge_index is not None edge_index = data.edge_index row, col = data.edge_index num_nodes = data.num_nodes edge_weight = data.edge_weight if edge_weight is None: edge_weight = torch.ones(data.num_edges, device=edge_index.device) deg = scatter(edge_weight, col, dim_size=num_nodes, reduce='sum') deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] edge_index = EdgeIndex(edge_index, sparse_size=(num_nodes, num_nodes)) edge_index, perm = edge_index.sort_by('col') edge_weight = edge_weight[perm] assert data.x is not None xs = [data.x] for i in range(1, self.K + 1): xs.append(edge_index.matmul(xs[-1], edge_weight, transpose=True)) data[f'x{i}'] = xs[-1] return data def __repr__(self) -> str: return f'{self.__class__.__name__}(K={self.K})' ================================================ FILE: torch_geometric/transforms/spherical.py ================================================ from math import pi as PI from typing import Optional import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('spherical') class Spherical(BaseTransform): r"""Saves the spherical coordinates of linked nodes in its edge attributes (functional name: :obj:`spherical`). Args: norm (bool, optional): If set to :obj:`False`, the output will not be normalized to the interval :math:`{[0, 1]}^3`. (default: :obj:`True`) max_value (float, optional): If set and :obj:`norm=True`, normalization will be performed based on this value instead of the maximum value found in the data. (default: :obj:`None`) cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) """ def __init__( self, norm: bool = True, max_value: Optional[float] = None, cat: bool = True, ): self.norm = norm self.max = max_value self.cat = cat def forward(self, data: Data) -> Data: assert data.pos is not None assert data.edge_index is not None (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr assert pos.dim() == 2 and pos.size(1) == 3 cart = pos[col] - pos[row] rho = torch.norm(cart, p=2, dim=-1).view(-1, 1) theta = torch.atan2(cart[..., 1], cart[..., 0]).view(-1, 1) theta = theta + (theta < 0).type_as(theta) * (2 * PI) phi = torch.acos(cart[..., 2] / rho.view(-1)).view(-1, 1) if self.norm: rho = rho / (rho.max() if self.max is None else self.max) theta = theta / (2 * PI) phi = phi / PI spher = torch.cat([rho, theta, phi], dim=-1) if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo data.edge_attr = torch.cat([pseudo, spher.type_as(pos)], dim=-1) else: data.edge_attr = spher return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(norm={self.norm}, ' f'max_value={self.max})') ================================================ FILE: torch_geometric/transforms/svd_feature_reduction.py ================================================ import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('svd_feature_reduction') class SVDFeatureReduction(BaseTransform): r"""Dimensionality reduction of node features via Singular Value Decomposition (SVD) (functional name: :obj:`svd_feature_reduction`). Args: out_channels (int): The dimensionality of node features after reduction. """ def __init__(self, out_channels: int): self.out_channels = out_channels def forward(self, data: Data) -> Data: assert data.x is not None if data.x.size(-1) > self.out_channels: U, S, _ = torch.linalg.svd(data.x) data.x = torch.mm(U[:, :self.out_channels], torch.diag(S[:self.out_channels])) return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.out_channels})' ================================================ FILE: torch_geometric/transforms/target_indegree.py ================================================ from typing import Optional import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import degree @functional_transform('target_indegree') class TargetIndegree(BaseTransform): r"""Saves the globally normalized degree of target nodes (functional name: :obj:`target_indegree`). .. math:: \mathbf{u}(i,j) = \frac{\deg(j)}{\max_{v \in \mathcal{V}} \deg(v)} in its edge attributes. Args: cat (bool, optional): Concat pseudo-coordinates to edge attributes instead of replacing them. (default: :obj:`True`) """ def __init__( self, norm: bool = True, max_value: Optional[float] = None, cat: bool = True, ) -> None: self.norm = norm self.max = max_value self.cat = cat def forward(self, data: Data) -> Data: assert data.edge_index is not None col, pseudo = data.edge_index[1], data.edge_attr deg = degree(col, data.num_nodes) if self.norm: deg = deg / (deg.max() if self.max is None else self.max) deg = deg[col] deg = deg.view(-1, 1) if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo data.edge_attr = torch.cat([pseudo, deg.type_as(pseudo)], dim=-1) else: data.edge_attr = deg return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(norm={self.norm}, ' f'max_value={self.max})') ================================================ FILE: torch_geometric/transforms/to_dense.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('to_dense') class ToDense(BaseTransform): r"""Converts a sparse adjacency matrix to a dense adjacency matrix with shape :obj:`[num_nodes, num_nodes, *]` (functional name: :obj:`to_dense`). Args: num_nodes (int, optional): The number of nodes. If set to :obj:`None`, the number of nodes will get automatically inferred. (default: :obj:`None`) """ def __init__(self, num_nodes: Optional[int] = None) -> None: self.num_nodes = num_nodes def forward(self, data: Data) -> Data: assert data.edge_index is not None orig_num_nodes = data.num_nodes assert orig_num_nodes is not None if self.num_nodes is None: num_nodes = orig_num_nodes else: assert orig_num_nodes <= self.num_nodes num_nodes = self.num_nodes if data.edge_attr is None: edge_attr = torch.ones(data.edge_index.size(1), dtype=torch.float) else: edge_attr = data.edge_attr size = torch.Size([num_nodes, num_nodes] + list(edge_attr.size())[1:]) adj = torch.sparse_coo_tensor(data.edge_index, edge_attr, size) data.adj = adj.to_dense() data.edge_index = None data.edge_attr = None data.mask = torch.zeros(num_nodes, dtype=torch.bool) data.mask[:orig_num_nodes] = 1 if data.x is not None: _size = [num_nodes - data.x.size(0)] + list(data.x.size())[1:] data.x = torch.cat([data.x, data.x.new_zeros(_size)], dim=0) if data.pos is not None: _size = [num_nodes - data.pos.size(0)] + list(data.pos.size())[1:] data.pos = torch.cat([data.pos, data.pos.new_zeros(_size)], dim=0) if (data.y is not None and isinstance(data.y, Tensor) and data.y.size(0) == orig_num_nodes): _size = [num_nodes - data.y.size(0)] + list(data.y.size())[1:] data.y = torch.cat([data.y, data.y.new_zeros(_size)], dim=0) return data def __repr__(self) -> str: if self.num_nodes is None: return f'{self.__class__.__name__}()' return f'{self.__class__.__name__}(num_nodes={self.num_nodes})' ================================================ FILE: torch_geometric/transforms/to_device.py ================================================ from typing import List, Optional, Union from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('to_device') class ToDevice(BaseTransform): r"""Performs tensor device conversion, either for all attributes of the :obj:`~torch_geometric.data.Data` object or only the ones given by :obj:`attrs` (functional name: :obj:`to_device`). Args: device (torch.device): The destination device. attrs (List[str], optional): If given, will only perform tensor device conversion for the given attributes. (default: :obj:`None`) non_blocking (bool, optional): If set to :obj:`True` and tensor values are in pinned memory, the copy will be asynchronous with respect to the host. (default: :obj:`False`) """ def __init__( self, device: Union[int, str], attrs: Optional[List[str]] = None, non_blocking: bool = False, ) -> None: self.device = device self.attrs = attrs or [] self.non_blocking = non_blocking def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: return data.to(self.device, *self.attrs, non_blocking=self.non_blocking) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.device})' ================================================ FILE: torch_geometric/transforms/to_sparse_tensor.py ================================================ from typing import Optional, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.typing import SparseTensor from torch_geometric.utils import ( sort_edge_index, to_torch_coo_tensor, to_torch_csr_tensor, ) @functional_transform('to_sparse_tensor') class ToSparseTensor(BaseTransform): r"""Converts the :obj:`edge_index` attributes of a homogeneous or heterogeneous data object into a **transposed** :class:`torch_sparse.SparseTensor` or :pytorch:`PyTorch` :class:`torch.sparse.Tensor` object with key :obj:`adj_t` (functional name: :obj:`to_sparse_tensor`). .. note:: In case of composing multiple transforms, it is best to convert the :obj:`data` object via :class:`ToSparseTensor` as late as possible, since there exist some transforms that are only able to operate on :obj:`data.edge_index` for now. Args: attr (str, optional): The name of the attribute to add as a value to the :class:`~torch_sparse.SparseTensor` or :class:`torch.sparse.Tensor` object (if present). (default: :obj:`edge_weight`) remove_edge_index (bool, optional): If set to :obj:`False`, the :obj:`edge_index` tensor will not be removed. (default: :obj:`True`) fill_cache (bool, optional): If set to :obj:`True`, will fill the underlying :class:`torch_sparse.SparseTensor` cache (if used). (default: :obj:`True`) layout (torch.layout, optional): Specifies the layout of the returned sparse tensor (:obj:`None`, :obj:`torch.sparse_coo` or :obj:`torch.sparse_csr`). If set to :obj:`None` and the :obj:`torch_sparse` dependency is installed, will convert :obj:`edge_index` into a :class:`torch_sparse.SparseTensor` object. If set to :obj:`None` and the :obj:`torch_sparse` dependency is not installed, will convert :obj:`edge_index` into a :class:`torch.sparse.Tensor` object with layout :obj:`torch.sparse_csr`. (default: :obj:`None`) """ def __init__( self, attr: Optional[str] = 'edge_weight', remove_edge_index: bool = True, fill_cache: bool = True, layout: Optional[int] = None, ) -> None: if layout not in {None, torch.sparse_coo, torch.sparse_csr}: raise ValueError(f"Unexpected sparse tensor layout " f"(got '{layout}')") self.attr = attr self.remove_edge_index = remove_edge_index self.fill_cache = fill_cache self.layout = layout def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.edge_stores: if 'edge_index' not in store: continue keys, values = [], [] for key, value in store.items(): if key in {'edge_index', 'edge_label', 'edge_label_index'}: continue if store.is_edge_attr(key): keys.append(key) values.append(value) store.edge_index, values = sort_edge_index( store.edge_index, values, sort_by_row=False, ) for key, value in zip(keys, values): store[key] = value layout = self.layout size = store.size()[::-1] edge_weight: Optional[Tensor] = None if self.attr is not None and self.attr in store: edge_weight = store[self.attr] if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE: store.adj_t = SparseTensor( row=store.edge_index[1], col=store.edge_index[0], value=edge_weight, sparse_sizes=size, is_sorted=True, trust_data=True, ) # TODO Multi-dimensional edge attributes only supported for COO. elif ((edge_weight is not None and edge_weight.dim() > 1) or layout == torch.sparse_coo): assert size[0] is not None and size[1] is not None store.adj_t = to_torch_coo_tensor( store.edge_index.flip([0]), edge_attr=edge_weight, size=size, ) elif layout is None or layout == torch.sparse_csr: assert size[0] is not None and size[1] is not None store.adj_t = to_torch_csr_tensor( store.edge_index.flip([0]), edge_attr=edge_weight, size=size, ) if self.remove_edge_index: del store['edge_index'] if self.attr is not None and self.attr in store: del store[self.attr] if self.fill_cache and isinstance(store.adj_t, SparseTensor): # Pre-process some important attributes. store.adj_t.storage.rowptr() store.adj_t.storage.csr2csc() return data def __repr__(self) -> str: return (f'{self.__class__.__name__}(attr={self.attr}, ' f'layout={self.layout})') ================================================ FILE: torch_geometric/transforms/to_superpixels.py ================================================ from typing import Any import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import scatter @functional_transform('to_slic') class ToSLIC(BaseTransform): r"""Converts an image to a superpixel representation using the :meth:`skimage.segmentation.slic` algorithm, resulting in a :obj:`torch_geometric.data.Data` object holding the centroids of superpixels in :obj:`data.pos` and their mean color in :obj:`data.x` (functional name: :obj:`to_slic`). This transform can be used with any :obj:`torchvision` dataset. .. code-block:: python from torchvision.datasets import MNIST import torchvision.transforms as T from torch_geometric.transforms import ToSLIC transform = T.Compose([T.ToTensor(), ToSLIC(n_segments=75)]) dataset = MNIST('/tmp/MNIST', download=True, transform=transform) Args: add_seg (bool, optional): If set to `True`, will add the segmentation result to the data object. (default: :obj:`False`) add_img (bool, optional): If set to `True`, will add the input image to the data object. (default: :obj:`False`) **kwargs (optional): Arguments to adjust the output of the SLIC algorithm. See the `SLIC documentation `_ for an overview. """ def __init__( self, add_seg: bool = False, add_img: bool = False, **kwargs: Any, ) -> None: self.add_seg = add_seg self.add_img = add_img self.kwargs = kwargs def forward(self, img: Tensor) -> Data: from skimage.segmentation import slic img = img.permute(1, 2, 0) h, w, c = img.size() seg = slic(img.to(torch.double).numpy(), start_label=0, **self.kwargs) seg = torch.from_numpy(seg) x = scatter(img.view(h * w, c), seg.view(h * w), dim=0, reduce='mean') pos_y = torch.arange(h, dtype=torch.float) pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w) pos_x = torch.arange(w, dtype=torch.float) pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w) pos = torch.stack([pos_x, pos_y], dim=-1) pos = scatter(pos, seg.view(h * w), dim=0, reduce='mean') data = Data(x=x, pos=pos) if self.add_seg: data.seg = seg.view(1, h, w) if self.add_img: data.img = img.permute(2, 0, 1).view(1, c, h, w) return data ================================================ FILE: torch_geometric/transforms/to_undirected.py ================================================ from typing import Union import torch from torch import Tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_undirected @functional_transform('to_undirected') class ToUndirected(BaseTransform): r"""Converts a homogeneous or heterogeneous graph to an undirected graph such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in \mathcal{E}` (functional name: :obj:`to_undirected`). In heterogeneous graphs, will add "reverse" connections for *all* existing edge types. Args: reduce (str, optional): The reduce operation to use for merging edge features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"add"`) merge (bool, optional): If set to :obj:`False`, will create reverse edge types for connections pointing to the same source and target node type. If set to :obj:`True`, reverse edges will be merged into the original relation. This option only has effects in :class:`~torch_geometric.data.HeteroData` graph data. (default: :obj:`True`) """ def __init__(self, reduce: str = "add", merge: bool = True): self.reduce = reduce self.merge = merge def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.edge_stores: if 'edge_index' not in store: continue nnz = store.edge_index.size(1) if isinstance(data, HeteroData) and (store.is_bipartite() or not self.merge): src, rel, dst = store._key # Just reverse the connectivity and add edge attributes: row, col = store.edge_index rev_edge_index = torch.stack([col, row], dim=0) inv_store = data[dst, f'rev_{rel}', src] inv_store.edge_index = rev_edge_index for key, value in store.items(): if key == 'edge_index': continue if isinstance(value, Tensor) and value.size(0) == nnz: inv_store[key] = value else: keys, values = [], [] for key, value in store.items(): if key == 'edge_index': continue if store.is_edge_attr(key): keys.append(key) values.append(value) store.edge_index, values = to_undirected( store.edge_index, values, reduce=self.reduce) for key, value in zip(keys, values): store[key] = value return data ================================================ FILE: torch_geometric/transforms/two_hop.py ================================================ import torch from torch_geometric import EdgeIndex from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import coalesce, remove_self_loops @functional_transform('two_hop') class TwoHop(BaseTransform): r"""Adds the two hop edges to the edge indices (functional name: :obj:`two_hop`). """ def forward(self, data: Data) -> Data: assert data.edge_index is not None edge_index, edge_attr = data.edge_index, data.edge_attr N = data.num_nodes edge_index = EdgeIndex(edge_index, sparse_size=(N, N)) edge_index = edge_index.sort_by('row')[0] edge_index2 = edge_index.matmul(edge_index)[0].as_tensor() edge_index2, _ = remove_self_loops(edge_index2) edge_index = torch.cat([edge_index, edge_index2], dim=1) if edge_attr is not None: # We treat newly added edge features as "zero-features": edge_attr2 = edge_attr.new_zeros(edge_index2.size(1), *edge_attr.size()[1:]) edge_attr = torch.cat([edge_attr, edge_attr2], dim=0) data.edge_index, data.edge_attr = coalesce(edge_index, edge_attr, N) return data ================================================ FILE: torch_geometric/transforms/virtual_node.py ================================================ import copy import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @functional_transform('virtual_node') class VirtualNode(BaseTransform): r"""Appends a virtual node to the given homogeneous graph that is connected to all other nodes, as described in the `"Neural Message Passing for Quantum Chemistry" `_ paper (functional name: :obj:`virtual_node`). The virtual node serves as a global scratch space that each node both reads from and writes to in every step of message passing. This allows information to travel long distances during the propagation phase. Node and edge features of the virtual node are added as zero-filled input features. Furthermore, special edge types will be added both for in-coming and out-going information to and from the virtual node. """ def forward(self, data: Data) -> Data: assert data.edge_index is not None row, col = data.edge_index edge_type = data.get('edge_type', torch.zeros_like(row)) num_nodes = data.num_nodes assert num_nodes is not None arange = torch.arange(num_nodes, device=row.device) full = row.new_full((num_nodes, ), num_nodes) row = torch.cat([row, arange, full], dim=0) col = torch.cat([col, full, arange], dim=0) edge_index = torch.stack([row, col], dim=0) num_edge_types = int(edge_type.max()) if edge_type.numel() > 0 else 0 new_type = edge_type.new_full((num_nodes, ), num_edge_types + 1) edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0) old_data = copy.copy(data) for key, value in old_data.items(): if key == 'edge_index' or key == 'edge_type': continue if isinstance(value, Tensor): dim = old_data.__cat_dim__(key, value) size = list(value.size()) fill_value = None if key == 'edge_weight': size[dim] = 2 * num_nodes fill_value = 1. elif key == 'batch': size[dim] = 1 fill_value = int(value[0]) elif old_data.is_edge_attr(key): size[dim] = 2 * num_nodes fill_value = 0. elif old_data.is_node_attr(key): size[dim] = 1 fill_value = 0. if fill_value is not None: new_value = value.new_full(size, fill_value) data[key] = torch.cat([value, new_value], dim=dim) data.edge_index = edge_index data.edge_type = edge_type if 'num_nodes' in data: data.num_nodes = num_nodes + 1 return data ================================================ FILE: torch_geometric/typing.py ================================================ import importlib.util import inspect import os import typing import warnings from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union import numpy as np import torch from torch import Tensor WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2 WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1 WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2 WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3 WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4 WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5 WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6 WITH_PT27 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 7 WITH_PT28 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 8 WITH_PT29 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 9 WITH_PT210 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 10 WITH_PT211 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 11 WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13 WITH_WINDOWS = os.name == 'nt' NO_MKL = 'USE_MKL=OFF' in torch.__config__.show() or WITH_WINDOWS MAX_INT64 = torch.iinfo(torch.int64).max if WITH_PT20: INDEX_DTYPES: Set[torch.dtype] = { torch.int32, torch.int64, } elif not typing.TYPE_CHECKING: # pragma: no cover INDEX_DTYPES: Set[torch.dtype] = { torch.int64, } if not hasattr(torch, 'sparse_csc'): torch.sparse_csc = torch.sparse_coo try: import pyg_lib # noqa WITH_PYG_LIB = True WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, 'grouped_matmul') WITH_SEGMM = hasattr(pyg_lib.ops, 'segment_matmul') if (WITH_SEGMM and 'PYTEST_CURRENT_TEST' in os.environ and torch.cuda.is_available()): # NOTE `segment_matmul` is currently bugged on older NVIDIA cards which # let our GPU tests on CI crash. Try if this error is present on the # current GPU and disable `WITH_SEGMM`/`WITH_GMM` if necessary. # TODO Drop this code block once `segment_matmul` is fixed. try: x = torch.randn(3, 4, device='cuda') ptr = torch.tensor([0, 2, 3], device='cuda') weight = torch.randn(2, 4, 4, device='cuda') out = pyg_lib.ops.segment_matmul(x, ptr, weight) except RuntimeError: WITH_GMM = False WITH_SEGMM = False WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add') WITH_SPLINE = hasattr(pyg_lib.ops, 'spline_basis') WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr') WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort') WITH_METIS = hasattr(pyg_lib, 'partition') WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature( pyg_lib.sampler.neighbor_sample).parameters) WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature( pyg_lib.sampler.neighbor_sample).parameters) try: torch.classes.pyg.CPUHashMap # noqa: B018 WITH_CPU_HASH_MAP = True except Exception: WITH_CPU_HASH_MAP = False try: torch.classes.pyg.CUDAHashMap # noqa: B018 WITH_CUDA_HASH_MAP = True except Exception: WITH_CUDA_HASH_MAP = False except Exception as e: if not isinstance(e, ImportError): # pragma: no cover warnings.warn( f"An issue occurred while importing 'pyg-lib'. " f"Disabling its usage. Stacktrace: {e}", stacklevel=2) pyg_lib = object WITH_PYG_LIB = False WITH_GMM = False WITH_SEGMM = False WITH_SAMPLED_OP = False WITH_SPLINE = False WITH_SOFTMAX = False WITH_INDEX_SORT = False WITH_METIS = False WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False WITH_WEIGHTED_NEIGHBOR_SAMPLE = False WITH_CPU_HASH_MAP = False WITH_CUDA_HASH_MAP = False if WITH_CPU_HASH_MAP: CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap # type: ignore[name-defined] # noqa: E501 else: class CPUHashMap: # type: ignore def __init__(self, key: Tensor) -> None: raise ImportError("'CPUHashMap' requires 'pyg-lib'") def get(self, query: Tensor) -> Tensor: raise ImportError("'CPUHashMap' requires 'pyg-lib'") if WITH_CUDA_HASH_MAP: CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap # type: ignore[name-defined] # noqa: E501 else: class CUDAHashMap: # type: ignore def __init__(self, key: Tensor) -> None: raise ImportError("'CUDAHashMap' requires 'pyg-lib'") def get(self, query: Tensor) -> Tensor: raise ImportError("'CUDAHashMap' requires 'pyg-lib'") try: import torch_scatter # noqa WITH_TORCH_SCATTER = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover warnings.warn( f"An issue occurred while importing 'torch-scatter'. " f"Disabling its usage. Stacktrace: {e}", stacklevel=2) torch_scatter = object WITH_TORCH_SCATTER = False try: import torch_cluster # noqa WITH_TORCH_CLUSTER = True WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__ except Exception as e: if not isinstance(e, ImportError): # pragma: no cover warnings.warn( f"An issue occurred while importing 'torch-cluster'. " f"Disabling its usage. Stacktrace: {e}", stacklevel=2) WITH_TORCH_CLUSTER = False WITH_TORCH_CLUSTER_BATCH_SIZE = False class TorchCluster: def __getattr__(self, key: str) -> Any: raise ImportError(f"'{key}' requires 'torch-cluster'") torch_cluster = TorchCluster() if importlib.util.find_spec('torch_spline_conv') is not None: warnings.warn( "'torch-spline-conv' is no longer necessary and is being ignored. " "Its functionality has been migrated to 'pyg-lib>=0.6.0'.", DeprecationWarning, stacklevel=2, ) try: import torch_sparse # noqa from torch_sparse import SparseStorage, SparseTensor WITH_TORCH_SPARSE = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover warnings.warn( f"An issue occurred while importing 'torch-sparse'. " f"Disabling its usage. Stacktrace: {e}", stacklevel=2) WITH_TORCH_SPARSE = False class SparseStorage: # type: ignore def __init__( self, row: Optional[Tensor] = None, rowptr: Optional[Tensor] = None, col: Optional[Tensor] = None, value: Optional[Tensor] = None, sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, rowcount: Optional[Tensor] = None, colptr: Optional[Tensor] = None, colcount: Optional[Tensor] = None, csr2csc: Optional[Tensor] = None, csc2csr: Optional[Tensor] = None, is_sorted: bool = False, trust_data: bool = False, ): raise ImportError("'SparseStorage' requires 'torch-sparse'") def value(self) -> Optional[Tensor]: raise ImportError("'SparseStorage' requires 'torch-sparse'") def rowcount(self) -> Tensor: raise ImportError("'SparseStorage' requires 'torch-sparse'") class SparseTensor: # type: ignore def __init__( self, row: Optional[Tensor] = None, rowptr: Optional[Tensor] = None, col: Optional[Tensor] = None, value: Optional[Tensor] = None, sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, is_sorted: bool = False, trust_data: bool = False, ): raise ImportError("'SparseTensor' requires 'torch-sparse'") @classmethod def from_edge_index( self, edge_index: Tensor, edge_attr: Optional[Tensor] = None, sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, is_sorted: bool = False, trust_data: bool = False, ) -> 'SparseTensor': raise ImportError("'SparseTensor' requires 'torch-sparse'") @property def storage(self) -> SparseStorage: raise ImportError("'SparseTensor' requires 'torch-sparse'") @classmethod def from_dense(self, mat: Tensor, has_value: bool = True) -> 'SparseTensor': raise ImportError("'SparseTensor' requires 'torch-sparse'") def size(self, dim: int) -> int: raise ImportError("'SparseTensor' requires 'torch-sparse'") def nnz(self) -> int: raise ImportError("'SparseTensor' requires 'torch-sparse'") def is_cuda(self) -> bool: raise ImportError("'SparseTensor' requires 'torch-sparse'") def has_value(self) -> bool: raise ImportError("'SparseTensor' requires 'torch-sparse'") def set_value(self, value: Optional[Tensor], layout: Optional[str] = None) -> 'SparseTensor': raise ImportError("'SparseTensor' requires 'torch-sparse'") def fill_value(self, fill_value: float, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': raise ImportError("'SparseTensor' requires 'torch-sparse'") def coo(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]: raise ImportError("'SparseTensor' requires 'torch-sparse'") def csr(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]: raise ImportError("'SparseTensor' requires 'torch-sparse'") def requires_grad(self) -> bool: raise ImportError("'SparseTensor' requires 'torch-sparse'") def to_torch_sparse_csr_tensor( self, dtype: Optional[torch.dtype] = None, ) -> Tensor: raise ImportError("'SparseTensor' requires 'torch-sparse'") class torch_sparse: # type: ignore @staticmethod def matmul(src: SparseTensor, other: Tensor, reduce: str = "sum") -> Tensor: raise ImportError("'matmul' requires 'torch-sparse'") @staticmethod def sum(src: SparseTensor, dim: Optional[int] = None) -> Tensor: raise ImportError("'sum' requires 'torch-sparse'") @staticmethod def mul(src: SparseTensor, other: Tensor) -> SparseTensor: raise ImportError("'mul' requires 'torch-sparse'") @staticmethod def set_diag(src: SparseTensor, values: Optional[Tensor] = None, k: int = 0) -> SparseTensor: raise ImportError("'set_diag' requires 'torch-sparse'") @staticmethod def fill_diag(src: SparseTensor, fill_value: float, k: int = 0) -> SparseTensor: raise ImportError("'fill_diag' requires 'torch-sparse'") @staticmethod def masked_select_nnz(src: SparseTensor, mask: Tensor, layout: Optional[str] = None) -> SparseTensor: raise ImportError("'masked_select_nnz' requires 'torch-sparse'") try: import torch_frame # noqa WITH_TORCH_FRAME = True from torch_frame import TensorFrame except Exception: torch_frame = object WITH_TORCH_FRAME = False class TensorFrame: # type: ignore pass try: import intel_extension_for_pytorch # noqa WITH_IPEX = True except Exception: WITH_IPEX = False class MockTorchCSCTensor: def __init__( self, edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, ): self.edge_index = edge_index self.edge_attr = edge_attr self.size = size def t(self) -> Tensor: # Only support accessing its transpose: from torch_geometric.utils import to_torch_csr_tensor size = self.size return to_torch_csr_tensor( self.edge_index.flip([0]), self.edge_attr, size[::-1] if isinstance(size, (tuple, list)) else size, ) # Types for accessing data #################################################### # Node-types are denoted by a single string, e.g.: `data['paper']`: NodeType = str # Edge-types are denotes by a triplet of strings, e.g.: # `data[('author', 'writes', 'paper')] EdgeType = Tuple[str, str, str] NodeOrEdgeType = Union[NodeType, EdgeType] DEFAULT_REL = 'to' EDGE_TYPE_STR_SPLIT = '__' class EdgeTypeStr(str): r"""A helper class to construct serializable edge types by merging an edge type tuple into a single string. """ edge_type: tuple[str, str, str] def __new__(cls, *args: Any) -> 'EdgeTypeStr': if isinstance(args[0], (list, tuple)): # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`: args = tuple(args[0]) if len(args) == 1 and isinstance(args[0], str): arg = args[0] # An edge type string was passed. edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT)) if len(edge_type) != 3: raise ValueError(f"Cannot convert the edge type '{arg}' to a " f"tuple since it holds invalid characters") elif len(args) == 2 and all(isinstance(arg, str) for arg in args): # A `(src, dst)` edge type was passed - add `DEFAULT_REL`: edge_type = (args[0], DEFAULT_REL, args[1]) arg = EDGE_TYPE_STR_SPLIT.join(edge_type) elif len(args) == 3 and all(isinstance(arg, str) for arg in args): # A `(src, rel, dst)` edge type was passed: edge_type = tuple(args) arg = EDGE_TYPE_STR_SPLIT.join(args) else: raise ValueError(f"Encountered invalid edge type '{args}'") out = str.__new__(cls, arg) out.edge_type = edge_type # type: ignore return out def to_tuple(self) -> EdgeType: r"""Returns the original edge type.""" if len(self.edge_type) != 3: raise ValueError(f"Cannot convert the edge type '{self}' to a " f"tuple since it holds invalid characters") return self.edge_type def __reduce__(self) -> tuple[Any, Any]: return (self.__class__, (self.edge_type, )) # There exist some short-cuts to query edge-types (given that the full triplet # can be uniquely reconstructed, e.g.: # * via str: `data['writes']` # * via Tuple[str, str]: `data[('author', 'paper')]` QueryType = Union[NodeType, EdgeType, str, Tuple[str, str]] Metadata = Tuple[List[NodeType], List[EdgeType]] # A representation of a feature tensor FeatureTensorType = Union[Tensor, np.ndarray] # A representation of an edge index, following the possible formats: # * COO: (row, col) # * CSC: (row, colptr) # * CSR: (rowptr, col) EdgeTensorType = Tuple[Tensor, Tensor] # Types for message passing ################################################### Adj = Union[Tensor, SparseTensor] OptTensor = Optional[Tensor] PairTensor = Tuple[Tensor, Tensor] OptPairTensor = Tuple[Tensor, Optional[Tensor]] PairOptTensor = Tuple[Optional[Tensor], Optional[Tensor]] Size = Optional[Tuple[int, int]] NoneType = Optional[Tensor] MaybeHeteroNodeTensor = Union[Tensor, Dict[NodeType, Tensor]] MaybeHeteroAdjTensor = Union[Tensor, Dict[EdgeType, Adj]] MaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]] # Types for sampling ########################################################## InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]] InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]] # Serialization ############################################################### if WITH_PT24: torch.serialization.add_safe_globals([ SparseTensor, SparseStorage, TensorFrame, MockTorchCSCTensor, EdgeTypeStr, ]) ================================================ FILE: torch_geometric/utils/__init__.py ================================================ r"""Utility package.""" import copy from ._scatter import scatter, group_argsort, group_cat from ._segment import segment, segment_logsumexp from ._index_sort import index_sort from .functions import cumsum from ._degree import degree from ._softmax import softmax from ._lexsort import lexsort from ._sort_edge_index import sort_edge_index from ._coalesce import coalesce from .undirected import is_undirected, to_undirected from .loop import (contains_self_loops, remove_self_loops, segregate_self_loops, add_self_loops, add_remaining_self_loops, get_self_loop_attr) from .isolated import contains_isolated_nodes, remove_isolated_nodes from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph, bipartite_subgraph) from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path from ._homophily import homophily from ._assortativity import assortativity from ._normalize_edge_index import normalize_edge_index from .laplacian import get_laplacian from .mesh_laplacian import get_mesh_laplacian from .mask import mask_select, index_to_mask, mask_to_index from ._select import select, narrow from ._to_dense_batch import to_dense_batch from ._to_dense_adj import to_dense_adj from .nested import to_nested_tensor, from_nested_tensor from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor, to_torch_coo_tensor, to_torch_csr_tensor, to_torch_csc_tensor, to_torch_sparse_tensor, to_edge_index) from ._spmm import spmm from ._unbatch import unbatch, unbatch_edge_index from ._one_hot import one_hot from ._normalized_cut import normalized_cut from ._grid import grid from .geodesic import geodesic_distance from .convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix from .convert import to_networkx, from_networkx from .convert import to_networkit, from_networkit from .convert import to_trimesh, from_trimesh from .convert import to_cugraph, from_cugraph from .convert import to_dgl, from_dgl from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles from .random import (erdos_renyi_graph, stochastic_blockmodel_graph, barabasi_albert_graph) from ._negative_sampling import (negative_sampling, batched_negative_sampling, structured_negative_sampling, structured_negative_sampling_feasible) from .augmentation import shuffle_node, mask_feature, add_random_edge from ._tree_decomposition import tree_decomposition from .embedding import get_embeddings, get_embeddings_hetero from ._trim_to_layer import trim_to_layer from .ppr import get_ppr from ._train_test_split_edges import train_test_split_edges from .influence import total_influence __all__ = [ 'scatter', 'group_argsort', 'group_cat', 'segment', 'segment_logsumexp', 'index_sort', 'cumsum', 'degree', 'softmax', 'lexsort', 'sort_edge_index', 'coalesce', 'is_undirected', 'to_undirected', 'contains_self_loops', 'remove_self_loops', 'segregate_self_loops', 'add_self_loops', 'add_remaining_self_loops', 'get_self_loop_attr', 'contains_isolated_nodes', 'remove_isolated_nodes', 'get_num_hops', 'subgraph', 'bipartite_subgraph', 'k_hop_subgraph', 'dropout_node', 'dropout_edge', 'dropout_path', 'dropout_adj', 'homophily', 'assortativity', 'normalize_edge_index', 'get_laplacian', 'get_mesh_laplacian', 'mask_select', 'index_to_mask', 'mask_to_index', 'select', 'narrow', 'to_dense_batch', 'to_dense_adj', 'to_nested_tensor', 'from_nested_tensor', 'dense_to_sparse', 'is_torch_sparse_tensor', 'is_sparse', 'to_torch_coo_tensor', 'to_torch_csr_tensor', 'to_torch_csc_tensor', 'to_torch_sparse_tensor', 'to_edge_index', 'spmm', 'unbatch', 'unbatch_edge_index', 'one_hot', 'normalized_cut', 'grid', 'geodesic_distance', 'to_scipy_sparse_matrix', 'from_scipy_sparse_matrix', 'to_networkx', 'from_networkx', 'to_networkit', 'from_networkit', 'to_trimesh', 'from_trimesh', 'to_cugraph', 'from_cugraph', 'to_dgl', 'from_dgl', 'from_rdmol', 'to_rdmol', 'from_smiles', 'to_smiles', 'erdos_renyi_graph', 'stochastic_blockmodel_graph', 'barabasi_albert_graph', 'negative_sampling', 'batched_negative_sampling', 'structured_negative_sampling', 'structured_negative_sampling_feasible', 'shuffle_node', 'mask_feature', 'add_random_edge', 'tree_decomposition', 'get_embeddings', 'get_embeddings_hetero', 'trim_to_layer', 'get_ppr', 'train_test_split_edges', 'total_influence', ] # `structured_negative_sampling_feasible` is a long name and thus destroys the # documentation rendering. We remove it for now from the documentation: classes = copy.copy(__all__) classes.remove('structured_negative_sampling_feasible') ================================================ FILE: torch_geometric/utils/_assortativity.py ================================================ import torch from torch import Tensor from torch_geometric.typing import Adj, SparseTensor from torch_geometric.utils import coalesce, degree from torch_geometric.utils._to_dense_adj import to_dense_adj def assortativity(edge_index: Adj) -> float: r"""The degree assortativity coefficient from the `"Mixing patterns in networks" `_ paper. Assortativity in a network refers to the tendency of nodes to connect with other similar nodes over dissimilar nodes. It is computed from Pearson correlation coefficient of the node degrees. Args: edge_index (Tensor or SparseTensor): The graph connectivity. Returns: The value of the degree assortativity coefficient for the input graph :math:`\in [-1, 1]` Example: >>> edge_index = torch.tensor([[0, 1, 2, 3, 2], ... [1, 2, 0, 1, 3]]) >>> assortativity(edge_index) -0.666667640209198 """ if isinstance(edge_index, SparseTensor): adj: SparseTensor = edge_index row, col, _ = adj.coo() else: assert isinstance(edge_index, Tensor) row, col = edge_index device = row.device out_deg = degree(row, dtype=torch.long) in_deg = degree(col, dtype=torch.long) degrees = torch.unique(torch.cat([out_deg, in_deg])) mapping = row.new_zeros(degrees.max().item() + 1) mapping[degrees] = torch.arange(degrees.size(0), device=device) # Compute degree mixing matrix (joint probability distribution) `M` num_degrees = degrees.size(0) src_deg = mapping[out_deg[row]] dst_deg = mapping[in_deg[col]] pairs = torch.stack([src_deg, dst_deg], dim=0) occurrence = torch.ones(pairs.size(1), device=device) pairs, occurrence = coalesce(pairs, occurrence) M = to_dense_adj(pairs, edge_attr=occurrence, max_num_nodes=num_degrees)[0] # normalization M /= M.sum() # numeric assortativity coefficient, computed by # Pearson correlation coefficient of the node degrees x = y = degrees.float() a, b = M.sum(0), M.sum(1) vara = (a * x**2).sum() - ((a * x).sum())**2 varb = (b * x**2).sum() - ((b * x).sum())**2 xy = torch.outer(x, y) ab = torch.outer(a, b) out = (xy * (M - ab)).sum() / (vara * varb).sqrt() return out.item() ================================================ FILE: torch_geometric/utils/_coalesce.py ================================================ import typing from typing import List, Optional, Tuple, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric import EdgeIndex from torch_geometric.edge_index import SortOrder from torch_geometric.typing import OptTensor from torch_geometric.utils import index_sort, scatter from torch_geometric.utils.num_nodes import maybe_num_nodes if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload as overload MISSING = '???' @overload def coalesce( edge_index: Tensor, edge_attr: str = MISSING, num_nodes: Optional[int] = None, reduce: str = 'sum', is_sorted: bool = False, sort_by_row: bool = True, ) -> Tensor: pass @overload def coalesce( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, num_nodes: Optional[int] = None, reduce: str = 'sum', is_sorted: bool = False, sort_by_row: bool = True, ) -> Tuple[Tensor, Tensor]: pass @overload def coalesce( # noqa: F811 edge_index: Tensor, edge_attr: OptTensor, num_nodes: Optional[int] = None, reduce: str = 'sum', is_sorted: bool = False, sort_by_row: bool = True, ) -> Tuple[Tensor, OptTensor]: pass @overload def coalesce( # noqa: F811 edge_index: Tensor, edge_attr: List[Tensor], num_nodes: Optional[int] = None, reduce: str = 'sum', is_sorted: bool = False, sort_by_row: bool = True, ) -> Tuple[Tensor, List[Tensor]]: pass def coalesce( # noqa: F811 edge_index: Tensor, edge_attr: Union[OptTensor, List[Tensor], str] = MISSING, num_nodes: Optional[int] = None, reduce: str = 'sum', is_sorted: bool = False, sort_by_row: bool = True, ) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]: """Row-wise sorts :obj:`edge_index` and removes its duplicated entries. Duplicate entries in :obj:`edge_attr` are merged by scattering them together according to the given :obj:`reduce` option. Args: edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor or List[torch.Tensor], optional): Edge weights or multi-dimensional edge features. If given as a list, will re-shuffle and remove duplicates for all its entries. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) reduce (str, optional): The reduce operation to use for merging edge features (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`, :obj:`"any"`). (default: :obj:`"sum"`) is_sorted (bool, optional): If set to :obj:`True`, will expect :obj:`edge_index` to be already sorted row-wise. sort_by_row (bool, optional): If set to :obj:`False`, will sort :obj:`edge_index` column-wise. :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`) .. warning:: From :pyg:`PyG >= 2.3.0` onwards, this function will always return a tuple whenever :obj:`edge_attr` is passed as an argument (even in case it is set to :obj:`None`). Example: >>> edge_index = torch.tensor([[1, 1, 2, 3], ... [3, 3, 1, 2]]) >>> edge_attr = torch.tensor([1., 1., 1., 1.]) >>> coalesce(edge_index) tensor([[1, 2, 3], [3, 1, 2]]) >>> # Sort `edge_index` column-wise >>> coalesce(edge_index, sort_by_row=False) tensor([[2, 3, 1], [1, 2, 3]]) >>> coalesce(edge_index, edge_attr) (tensor([[1, 2, 3], [3, 1, 2]]), tensor([2., 1., 1.])) >>> # Use 'mean' operation to merge edge features >>> coalesce(edge_index, edge_attr, reduce='mean') (tensor([[1, 2, 3], [3, 1, 2]]), tensor([1., 1., 1.])) """ num_edges = edge_index[0].size(0) num_nodes = maybe_num_nodes(edge_index, num_nodes) if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64: raise ValueError("'coalesce' will result in an overflow") idx = edge_index[0].new_empty(num_edges + 1) idx[0] = -1 idx[1:] = edge_index[1 - int(sort_by_row)] idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)]) is_undirected = False if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): is_undirected = edge_index.is_undirected if not is_sorted: idx[1:], perm = index_sort(idx[1:], max_value=num_nodes * num_nodes) if isinstance(edge_index, Tensor): edge_index = edge_index[:, perm] elif isinstance(edge_index, tuple): edge_index = (edge_index[0][perm], edge_index[1][perm]) else: raise NotImplementedError if isinstance(edge_attr, Tensor): edge_attr = edge_attr[perm] elif isinstance(edge_attr, (list, tuple)): edge_attr = [e[perm] for e in edge_attr] if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): edge_index._sort_order = SortOrder('row' if sort_by_row else 'col') edge_index._is_undirected = is_undirected mask = idx[1:] > idx[:-1] # Only perform expensive merging in case there exists duplicates: if mask.all(): if edge_attr is None or isinstance(edge_attr, Tensor): return edge_index, edge_attr if isinstance(edge_attr, (list, tuple)): return edge_index, edge_attr return edge_index if isinstance(edge_index, Tensor): edge_index = edge_index[:, mask] if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): edge_index._is_undirected = is_undirected elif isinstance(edge_index, tuple): edge_index = (edge_index[0][mask], edge_index[1][mask]) else: raise NotImplementedError dim_size: Optional[int] = None if isinstance(edge_attr, (Tensor, list, tuple)) and len(edge_attr) > 0: dim_size = edge_index.size(1) idx = torch.arange(0, num_edges, device=edge_index.device) idx.sub_(mask.logical_not_().cumsum(dim=0)) if edge_attr is None: return edge_index, None if isinstance(edge_attr, Tensor): edge_attr = scatter(edge_attr, idx, 0, dim_size, reduce) return edge_index, edge_attr if isinstance(edge_attr, (list, tuple)): if len(edge_attr) == 0: return edge_index, edge_attr edge_attr = [scatter(e, idx, 0, dim_size, reduce) for e in edge_attr] return edge_index, edge_attr return edge_index ================================================ FILE: torch_geometric/utils/_degree.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.utils.num_nodes import maybe_num_nodes def degree(index: Tensor, num_nodes: Optional[int] = None, dtype: Optional[torch.dtype] = None) -> Tensor: r"""Computes the (unweighted) degree of a given one-dimensional index tensor. Args: index (LongTensor): Index tensor. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) dtype (:obj:`torch.dtype`, optional): The desired data type of the returned tensor. :rtype: :class:`Tensor` Example: >>> row = torch.tensor([0, 1, 0, 2, 0]) >>> degree(row, dtype=torch.long) tensor([3, 1, 1]) """ N = maybe_num_nodes(index, num_nodes) out = torch.zeros((N, ), dtype=dtype, device=index.device) one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device) return out.scatter_add_(0, index, one) ================================================ FILE: torch_geometric/utils/_grid.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.utils import coalesce def grid( height: int, width: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tuple[Tensor, Tensor]: r"""Returns the edge indices of a two-dimensional grid graph with height :attr:`height` and width :attr:`width` and its node positions. Args: height (int): The height of the grid. width (int): The width of the grid. dtype (torch.dtype, optional): The desired data type of the returned position tensor. (default: :obj:`None`) device (torch.device, optional): The desired device of the returned tensors. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Example: >>> (row, col), pos = grid(height=2, width=2) >>> row tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) >>> col tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]) >>> pos tensor([[0., 1.], [1., 1.], [0., 0.], [1., 0.]]) """ edge_index = grid_index(height, width, device) pos = grid_pos(height, width, dtype, device) return edge_index, pos def grid_index( height: int, width: int, device: Optional[torch.device] = None, ) -> Tensor: w = width kernel = torch.tensor( [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1], device=device, ) row = torch.arange(height * width, dtype=torch.long, device=device) row = row.view(-1, 1).repeat(1, kernel.size(0)) col = row + kernel.view(1, -1) row, col = row.view(height, -1), col.view(height, -1) index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device) row, col = row[:, index].view(-1), col[:, index].view(-1) mask = (col >= 0) & (col < height * width) row, col = row[mask], col[mask] edge_index = torch.stack([row, col], dim=0) edge_index = coalesce(edge_index, num_nodes=height * width) return edge_index def grid_pos( height: int, width: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tensor: dtype = torch.float if dtype is None else dtype x = torch.arange(width, dtype=dtype, device=device) y = (height - 1) - torch.arange(height, dtype=dtype, device=device) x = x.repeat(height) y = y.unsqueeze(-1).repeat(1, width).view(-1) return torch.stack([x, y], dim=-1) ================================================ FILE: torch_geometric/utils/_homophily.py ================================================ from typing import Union, overload import torch from torch import Tensor from torch_geometric.typing import Adj, OptTensor, SparseTensor from torch_geometric.utils import degree, scatter @overload def homophily( edge_index: Adj, y: Tensor, batch: None = ..., method: str = ..., ) -> float: pass @overload def homophily( edge_index: Adj, y: Tensor, batch: Tensor, method: str = ..., ) -> Tensor: pass def homophily( edge_index: Adj, y: Tensor, batch: OptTensor = None, method: str = 'edge', ) -> Union[float, Tensor]: r"""The homophily of a graph characterizes how likely nodes with the same label are near each other in a graph. There are many measures of homophily that fits this definition. In particular: - In the `"Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs" `_ paper, the homophily is the fraction of edges in a graph which connects nodes that have the same class label: .. math:: \frac{| \{ (v,w) : (v,w) \in \mathcal{E} \wedge y_v = y_w \} | } {|\mathcal{E}|} That measure is called the *edge homophily ratio*. - In the `"Geom-GCN: Geometric Graph Convolutional Networks" `_ paper, edge homophily is normalized across neighborhoods: .. math:: \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (w,v) : w \in \mathcal{N}(v) \wedge y_v = y_w \} | } { |\mathcal{N}(v)| } That measure is called the *node homophily ratio*. - In the `"Large-Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods" `_ paper, edge homophily is modified to be insensitive to the number of classes and size of each class: .. math:: \frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, h_k - \frac{|\mathcal{C}_k|} {|\mathcal{V}|} \right), where :math:`C` denotes the number of classes, :math:`|\mathcal{C}_k|` denotes the number of nodes of class :math:`k`, and :math:`h_k` denotes the edge homophily ratio of nodes of class :math:`k`. Thus, that measure is called the *class insensitive edge homophily ratio*. Args: edge_index (Tensor or SparseTensor): The graph connectivity. y (Tensor): The labels. batch (LongTensor, optional): Batch vector\ :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) method (str, optional): The method used to calculate the homophily, either :obj:`"edge"` (first formula), :obj:`"node"` (second formula) or :obj:`"edge_insensitive"` (third formula). (default: :obj:`"edge"`) Examples: >>> edge_index = torch.tensor([[0, 1, 2, 3], ... [1, 2, 0, 4]]) >>> y = torch.tensor([0, 0, 0, 0, 1]) >>> # Edge homophily ratio >>> homophily(edge_index, y, method='edge') 0.75 >>> # Node homophily ratio >>> homophily(edge_index, y, method='node') 0.6000000238418579 >>> # Class insensitive edge homophily ratio >>> homophily(edge_index, y, method='edge_insensitive') 0.19999998807907104 """ assert method in {'edge', 'node', 'edge_insensitive'} y = y.squeeze(-1) if y.dim() > 1 else y if isinstance(edge_index, SparseTensor): row, col, _ = edge_index.coo() else: row, col = edge_index if method == 'edge': out = torch.zeros(row.size(0), device=row.device) out[y[row] == y[col]] = 1. if batch is None: return float(out.mean()) else: dim_size = int(batch.max()) + 1 return scatter(out, batch[col], 0, dim_size, reduce='mean') elif method == 'node': out = torch.zeros(row.size(0), device=row.device) out[y[row] == y[col]] = 1. out = scatter(out, col, 0, dim_size=y.size(0), reduce='mean') if batch is None: return float(out.mean()) else: return scatter(out, batch, dim=0, reduce='mean') elif method == 'edge_insensitive': assert y.dim() == 1 num_classes = int(y.max()) + 1 assert num_classes >= 2 batch = torch.zeros_like(y) if batch is None else batch num_nodes = degree(batch, dtype=torch.int64) num_graphs = num_nodes.numel() batch = num_classes * batch + y h = homophily(edge_index, y, batch, method='edge') h = h.view(num_graphs, num_classes) counts = batch.bincount(minlength=num_classes * num_graphs) counts = counts.view(num_graphs, num_classes) proportions = counts / num_nodes.view(-1, 1) out = (h - proportions).clamp_(min=0).sum(dim=-1) out /= num_classes - 1 return out if out.numel() > 1 else float(out) else: raise NotImplementedError ================================================ FILE: torch_geometric/utils/_index_sort.py ================================================ from typing import Optional, Tuple from torch import Tensor import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.typing import pyg_lib def index_sort( inputs: Tensor, max_value: Optional[int] = None, stable: bool = False, ) -> Tuple[Tensor, Tensor]: r"""Sorts the elements of the :obj:`inputs` tensor in ascending order. It is expected that :obj:`inputs` is one-dimensional and that it only contains positive integer values. If :obj:`max_value` is given, it can be used by the underlying algorithm for better performance. Args: inputs (torch.Tensor): A vector with positive integer values. max_value (int, optional): The maximum value stored inside :obj:`inputs`. This value can be an estimation, but needs to be greater than or equal to the real maximum. (default: :obj:`None`) stable (bool, optional): Makes the sorting routine stable, which guarantees that the order of equivalent elements is preserved. (default: :obj:`False`) """ if stable or not torch_geometric.typing.WITH_INDEX_SORT or is_compiling(): return inputs.sort(stable=stable) return pyg_lib.ops.index_sort(inputs, max_value=max_value) ================================================ FILE: torch_geometric/utils/_lexsort.py ================================================ from typing import List from torch import Tensor def lexsort( keys: List[Tensor], dim: int = -1, descending: bool = False, ) -> Tensor: r"""Performs an indirect stable sort using a sequence of keys. Given multiple sorting keys, returns an array of integer indices that describe their sort order. The last key in the sequence is used for the primary sort order, the second-to-last key for the secondary sort order, and so on. Args: keys ([torch.Tensor]): The :math:`k` different columns to be sorted. The last key is the primary sort key. dim (int, optional): The dimension to sort along. (default: :obj:`-1`) descending (bool, optional): Controls the sorting order (ascending or descending). (default: :obj:`False`) """ assert len(keys) >= 1 out = keys[0].argsort(dim=dim, descending=descending, stable=True) for k in keys[1:]: index = k.gather(dim, out) index = index.argsort(dim=dim, descending=descending, stable=True) out = out.gather(dim, index) return out ================================================ FILE: torch_geometric/utils/_negative_sampling.py ================================================ import random from typing import Optional, Tuple, Union import numpy as np import torch from torch import Tensor from torch_geometric.utils import coalesce, cumsum, degree, remove_self_loops from torch_geometric.utils.num_nodes import maybe_num_nodes def negative_sampling( edge_index: Tensor, num_nodes: Optional[Union[int, Tuple[int, int]]] = None, num_neg_samples: Optional[Union[int, float]] = None, method: str = "sparse", force_undirected: bool = False, ) -> Tensor: r"""Samples random negative edges of a graph given by :attr:`edge_index`. Args: edge_index (LongTensor): The edge indices. num_nodes (int or Tuple[int, int], optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. If given as a tuple, then :obj:`edge_index` is interpreted as a bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`. (default: :obj:`None`) num_neg_samples (int or float, optional): The (approximate) number of negative samples to return. If set to a floating-point value, it represents the ratio of negative samples to generate based on the number of positive edges. If set to :obj:`None`, will try to return a negative edge for every positive edge. (default: :obj:`None`) method (str, optional): The method to use for negative sampling, *i.e.* :obj:`"sparse"` or :obj:`"dense"`. This is a memory/runtime trade-off. :obj:`"sparse"` will work on any graph of any size, while :obj:`"dense"` can perform faster true-negative checks. (default: :obj:`"sparse"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) :rtype: LongTensor Examples: >>> # Standard usage >>> edge_index = torch.as_tensor([[0, 0, 1, 2], ... [0, 1, 2, 3]]) >>> negative_sampling(edge_index) tensor([[3, 0, 0, 3], [2, 3, 2, 1]]) >>> negative_sampling(edge_index, num_nodes=(3, 4), ... num_neg_samples=0.5) # 50% of positive edges tensor([[0, 3], [3, 0]]) >>> # For bipartite graph >>> negative_sampling(edge_index, num_nodes=(3, 4)) tensor([[0, 2, 2, 1], [2, 2, 1, 3]]) """ assert method in ['sparse', 'dense'] if num_nodes is None: num_nodes = maybe_num_nodes(edge_index, num_nodes) if isinstance(num_nodes, int): size = (num_nodes, num_nodes) bipartite = False else: size = num_nodes bipartite = True force_undirected = False idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected) if idx.numel() >= population: return edge_index.new_empty((2, 0)) if num_neg_samples is None: num_neg_samples = edge_index.size(1) elif isinstance(num_neg_samples, float): num_neg_samples = int(num_neg_samples * edge_index.size(1)) if force_undirected: num_neg_samples = num_neg_samples // 2 prob = 1. - idx.numel() / population # Probability to sample a negative. sample_size = int(1.1 * num_neg_samples / prob) # (Over)-sample size. neg_idx: Optional[Tensor] = None if method == 'dense': # The dense version creates a mask of shape `population` to check for # invalid samples. mask = idx.new_ones(population, dtype=torch.bool) mask[idx] = False for _ in range(3): # Number of tries to sample negative indices. rnd = sample(population, sample_size, idx.device) rnd = rnd[mask[rnd]] # Filter true negatives. neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd]) if neg_idx.numel() >= num_neg_samples: neg_idx = neg_idx[:num_neg_samples] break mask[neg_idx] = False else: # 'sparse' # The sparse version checks for invalid samples via `np.isin`. idx = idx.to('cpu') for _ in range(3): # Number of tries to sample negative indices. rnd = sample(population, sample_size, device='cpu') mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool() if neg_idx is not None: mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool() rnd = rnd[~mask].to(edge_index.device) neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd]) if neg_idx.numel() >= num_neg_samples: neg_idx = neg_idx[:num_neg_samples] break assert neg_idx is not None return vector_to_edge_index(neg_idx, size, bipartite, force_undirected) def batched_negative_sampling( edge_index: Tensor, batch: Union[Tensor, Tuple[Tensor, Tensor]], num_neg_samples: Optional[Union[int, float]] = None, method: str = "sparse", force_undirected: bool = False, ) -> Tensor: r"""Samples random negative edges of multiple graphs given by :attr:`edge_index` and :attr:`batch`. Args: edge_index (LongTensor): The edge indices. batch (LongTensor or Tuple[LongTensor, LongTensor]): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. If given as a tuple, then :obj:`edge_index` is interpreted as a bipartite graph connecting two different node types. num_neg_samples (int or float, optional): The number of negative samples to return. If set to :obj:`None`, will try to return a negative edge for every positive edge. If float, it will generate :obj:`num_neg_samples * num_edges` negative samples. (default: :obj:`None`) method (str, optional): The method to use for negative sampling, *i.e.* :obj:`"sparse"` or :obj:`"dense"`. This is a memory/runtime trade-off. :obj:`"sparse"` will work on any graph of any size, while :obj:`"dense"` can perform faster true-negative checks. (default: :obj:`"sparse"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) :rtype: LongTensor Examples: >>> # Standard usage >>> edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) >>> edge_index = torch.cat([edge_index, edge_index + 4], dim=1) >>> edge_index tensor([[0, 0, 1, 2, 4, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6, 7]]) >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) >>> batched_negative_sampling(edge_index, batch) tensor([[3, 1, 3, 2, 7, 7, 6, 5], [2, 0, 1, 1, 5, 6, 4, 4]]) >>> # Using float multiplier for negative samples >>> batched_negative_sampling(edge_index, batch, num_neg_samples=1.5) tensor([[3, 1, 3, 2, 7, 7, 6, 5, 2, 0, 1, 1], [2, 0, 1, 1, 5, 6, 4, 4, 3, 2, 3, 0]]) >>> # For bipartite graph >>> edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]]) >>> edge_index2 = edge_index1 + torch.tensor([[2], [4]]) >>> edge_index3 = edge_index2 + torch.tensor([[2], [4]]) >>> edge_index = torch.cat([edge_index1, edge_index2, ... edge_index3], dim=1) >>> edge_index tensor([[ 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]) >>> src_batch = torch.tensor([0, 0, 1, 1, 2, 2]) >>> dst_batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) >>> batched_negative_sampling(edge_index, ... (src_batch, dst_batch)) tensor([[ 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], [ 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9]]) """ if isinstance(batch, Tensor): src_batch, dst_batch = batch, batch else: src_batch, dst_batch = batch[0], batch[1] split = degree(src_batch[edge_index[0]], dtype=torch.long).tolist() edge_indices = torch.split(edge_index, split, dim=1) num_src = degree(src_batch, dtype=torch.long) cum_src = cumsum(num_src)[:-1] if isinstance(batch, Tensor): num_nodes = num_src.tolist() ptr = cum_src else: num_dst = degree(dst_batch, dtype=torch.long) cum_dst = cumsum(num_dst)[:-1] num_nodes = torch.stack([num_src, num_dst], dim=1).tolist() ptr = torch.stack([cum_src, cum_dst], dim=1).unsqueeze(-1) neg_edge_indices = [] for i, edge_index in enumerate(edge_indices): edge_index = edge_index - ptr[i] neg_edge_index = negative_sampling(edge_index, num_nodes[i], num_neg_samples, method, force_undirected) neg_edge_index += ptr[i] neg_edge_indices.append(neg_edge_index) return torch.cat(neg_edge_indices, dim=1) def structured_negative_sampling( edge_index: Tensor, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Samples a negative edge :obj:`(i,k)` for every positive edge :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a tuple of the form :obj:`(i,j,k)`. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) contains_neg_self_loops (bool, optional): If set to :obj:`False`, sampled negative edges will not contain self loops. (default: :obj:`True`) :rtype: (LongTensor, LongTensor, LongTensor) Example: >>> edge_index = torch.as_tensor([[0, 0, 1, 2], ... [0, 1, 2, 3]]) >>> structured_negative_sampling(edge_index) (tensor([0, 0, 1, 2]), tensor([0, 1, 2, 3]), tensor([2, 3, 0, 2])) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index.cpu() pos_idx = row * num_nodes + col if not contains_neg_self_loops: loop_idx = torch.arange(num_nodes) * (num_nodes + 1) pos_idx = torch.cat([pos_idx, loop_idx], dim=0) rand = torch.randint(num_nodes, (row.size(0), ), dtype=torch.long) neg_idx = row * num_nodes + rand mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool) rest = mask.nonzero(as_tuple=False).view(-1) while rest.numel() > 0: # pragma: no cover tmp = torch.randint(num_nodes, (rest.size(0), ), dtype=torch.long) rand[rest] = tmp neg_idx = row[rest] * num_nodes + tmp mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool) rest = rest[mask] return edge_index[0], edge_index[1], rand.to(edge_index.device) def structured_negative_sampling_feasible( edge_index: Tensor, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True, ) -> bool: r"""Returns :obj:`True` if :meth:`~torch_geometric.utils.structured_negative_sampling` is feasible on the graph given by :obj:`edge_index`. :meth:`~torch_geometric.utils.structured_negative_sampling` is infeasible if at least one node is connected to all other nodes. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) contains_neg_self_loops (bool, optional): If set to :obj:`False`, sampled negative edges will not contain self loops. (default: :obj:`True`) :rtype: bool Examples: >>> edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2], ... [1, 2, 0, 2, 0, 1, 1]]) >>> structured_negative_sampling_feasible(edge_index, 3, False) False >>> structured_negative_sampling_feasible(edge_index, 3, True) True """ num_nodes = maybe_num_nodes(edge_index, num_nodes) max_num_neighbors = num_nodes edge_index = coalesce(edge_index, num_nodes=num_nodes) if not contains_neg_self_loops: edge_index, _ = remove_self_loops(edge_index) max_num_neighbors -= 1 # Reduce number of valid neighbors deg = degree(edge_index[0], num_nodes) # True if there exists no node that is connected to all other nodes. return bool(torch.all(deg < max_num_neighbors)) ############################################################################### def sample( population: int, k: int, device: Optional[Union[torch.device, str]] = None, ) -> Tensor: if population <= k: return torch.arange(population, device=device) else: return torch.tensor(random.sample(range(population), k), device=device) def edge_index_to_vector( edge_index: Tensor, size: Tuple[int, int], bipartite: bool, force_undirected: bool = False, ) -> Tuple[Tensor, int]: row, col = edge_index if bipartite: # No need to account for self-loops. idx = (row * size[1]).add_(col) population = size[0] * size[1] return idx, population elif force_undirected: assert size[0] == size[1] num_nodes = size[0] # We only operate on the upper triangular matrix: mask = row < col row, col = row[mask], col[mask] offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row] idx = row.mul_(num_nodes).add_(col).sub_(offset) population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes return idx, population else: assert size[0] == size[1] num_nodes = size[0] # We remove self-loops as we do not want to take them into account # when sampling negative values. mask = row != col row, col = row[mask], col[mask] col[row < col] -= 1 idx = row.mul_(num_nodes - 1).add_(col) population = num_nodes * num_nodes - num_nodes return idx, population def vector_to_edge_index( idx: Tensor, size: Tuple[int, int], bipartite: bool, force_undirected: bool = False, ) -> Tensor: if bipartite: # No need to account for self-loops. row = idx.div(size[1], rounding_mode='floor') col = idx % size[1] return torch.stack([row, col], dim=0) elif force_undirected: assert size[0] == size[1] num_nodes = size[0] offset = torch.arange(1, num_nodes, device=idx.device).cumsum(0) end = torch.arange(num_nodes, num_nodes * num_nodes, num_nodes, device=idx.device) row = torch.bucketize(idx, end.sub_(offset), right=True) col = offset[row].add_(idx) % num_nodes return torch.stack([torch.cat([row, col]), torch.cat([col, row])], 0) else: assert size[0] == size[1] num_nodes = size[0] row = idx.div(num_nodes - 1, rounding_mode='floor') col = idx % (num_nodes - 1) col[row <= col] += 1 return torch.stack([row, col], dim=0) ================================================ FILE: torch_geometric/utils/_normalize_edge_index.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.utils import add_self_loops as add_self_loops_fn from torch_geometric.utils import degree def normalize_edge_index( edge_index: Tensor, num_nodes: Optional[int] = None, add_self_loops: bool = True, symmetric: bool = True, ) -> Tuple[Tensor, Tensor]: """Applies normalization to the edges of a graph. This function can add self-loops to the graph and apply either symmetric or asymmetric normalization based on the node degrees. Args: edge_index (LongTensor): The edge indices. num_nodes (int, int], optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) symmetric (bool, optional): If set to :obj:`True`, symmetric normalization (:math:`D^{-1/2} A D^{-1/2}`) is used, otherwise asymmetric normalization (:math:`D^{-1} A`). """ if add_self_loops: edge_index, _ = add_self_loops_fn(edge_index, num_nodes=num_nodes) row, col = edge_index[0], edge_index[1] deg = degree(row, num_nodes, dtype=torch.get_default_dtype()) if symmetric: # D^-1/2 * A * D^-1/2 deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0 edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col] else: # D^-1 * A deg_inv = deg.pow(-1) deg_inv[torch.isinf(deg_inv)] = 0 edge_weight = deg_inv[row] return edge_index, edge_weight ================================================ FILE: torch_geometric/utils/_normalized_cut.py ================================================ from typing import Optional from torch import Tensor from torch_geometric.utils import degree def normalized_cut( edge_index: Tensor, edge_attr: Tensor, num_nodes: Optional[int] = None, ) -> Tensor: r"""Computes the normalized cut :math:`\mathbf{e}_{i,j} \cdot \left( \frac{1}{\deg(i)} + \frac{1}{\deg(j)} \right)` of a weighted graph given by edge indices and edge attributes. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor): Edge weights or multi-dimensional edge features. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: :class:`Tensor` Example: >>> edge_index = torch.tensor([[1, 1, 2, 3], ... [3, 3, 1, 2]]) >>> edge_attr = torch.tensor([1., 1., 1., 1.]) >>> normalized_cut(edge_index, edge_attr) tensor([1.5000, 1.5000, 2.0000, 1.5000]) """ row, col = edge_index[0], edge_index[1] deg = 1. / degree(col, num_nodes, edge_attr.dtype) deg = deg[row] + deg[col] cut = edge_attr * deg return cut ================================================ FILE: torch_geometric/utils/_one_hot.py ================================================ from typing import Optional import torch from torch import Tensor def one_hot( index: Tensor, num_classes: Optional[int] = None, dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Taskes a one-dimensional :obj:`index` tensor and returns a one-hot encoded representation of it with shape :obj:`[*, num_classes]` that has zeros everywhere except where the index of last dimension matches the corresponding value of the input tensor, in which case it will be :obj:`1`. .. note:: This is a more memory-efficient version of :meth:`torch.nn.functional.one_hot` as you can customize the output :obj:`dtype`. Args: index (torch.Tensor): The one-dimensional input tensor. num_classes (int, optional): The total number of classes. If set to :obj:`None`, the number of classes will be inferred as one greater than the largest class value in the input tensor. (default: :obj:`None`) dtype (torch.dtype, optional): The :obj:`dtype` of the output tensor. """ if index.dim() != 1: raise ValueError("'index' tensor needs to be one-dimensional") if num_classes is None: num_classes = int(index.max()) + 1 out = torch.zeros((index.size(0), num_classes), dtype=dtype, device=index.device) return out.scatter_(1, index.unsqueeze(1), 1) ================================================ FILE: torch_geometric/utils/_scatter.py ================================================ from typing import List, Optional, Tuple, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric import is_compiling, is_in_onnx_export, warnings from torch_geometric.typing import torch_scatter from torch_geometric.utils.functions import cumsum warnings.filterwarnings('ignore', '.*is in beta and the API may change.*') def scatter( src: Tensor, index: Tensor, dim: int = 0, dim_size: Optional[int] = None, reduce: str = 'sum', ) -> Tensor: r"""Reduces all values from the :obj:`src` tensor at the indices specified in the :obj:`index` tensor along a given dimension ``dim``. See the `documentation `__ # noqa: E501 of the ``torch_scatter`` package for more information. Args: src (torch.Tensor): The source tensor. index (torch.Tensor): The index tensor. dim (int, optional): The dimension along which to index. (default: ``0``) dim_size (int, optional): The size of the output tensor at dimension ``dim``. If set to :obj:`None`, will create a minimal-sized output tensor according to ``index.max() + 1``. (default: :obj:`None`) reduce (str, optional): The reduce operation (``"sum"``, ``"mean"``, ``"mul"``, ``"min"``, ``"max"`` or ``"any"``). (default: ``"sum"``) """ if isinstance(index, Tensor) and index.dim() != 1: raise ValueError(f"The `index` argument must be one-dimensional " f"(got {index.dim()} dimensions)") dim = src.dim() + dim if dim < 0 else dim if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()): raise ValueError(f"The `dim` argument must lay between 0 and " f"{src.dim() - 1} (got {dim})") if dim_size is None: dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 # For now, we maintain various different code paths, based on whether # the input requires gradients and whether it lays on the CPU/GPU. # For example, `torch_scatter` is usually faster than # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster # on CPU. # `torch.scatter_reduce` has a faster forward implementation for # "min"/"max" reductions since it does not compute additional arg # indices, but is therefore way slower in its backward implementation. # More insights can be found in `test/utils/test_scatter.py`. size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:] # For "any" reduction, we use regular `scatter_`: if reduce == 'any': index = broadcast(index, src, dim) return src.new_zeros(size).scatter_(dim, index, src) # For "sum" and "mean" reduction, we make use of `scatter_add_`: if reduce == 'sum' or reduce == 'add': index = broadcast(index, src, dim) return src.new_zeros(size).scatter_add_(dim, index, src) if reduce == 'mean': count = src.new_zeros(dim_size) count.scatter_add_(0, index, src.new_ones(src.size(dim))) count = count.clamp(min=1) index = broadcast(index, src, dim) out = src.new_zeros(size).scatter_add_(dim, index, src) return out / broadcast(count, out, dim) # For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or # in case the input does not require gradients: if reduce in ['min', 'max', 'amin', 'amax']: if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling() or is_in_onnx_export() or not src.is_cuda or not src.requires_grad): if (src.is_cuda and src.requires_grad and not is_compiling() and not is_in_onnx_export()): warnings.warn( f"The usage of `scatter(reduce='{reduce}')` " f"can be accelerated via the 'torch-scatter'" f" package, but it was not found", stacklevel=2) index = broadcast(index, src, dim) if not is_in_onnx_export(): return src.new_zeros(size).scatter_reduce_( dim, index, src, reduce=f'a{reduce[-3:]}', include_self=False) fill = torch.full( # type: ignore size=(1, ), fill_value=src.min() if 'max' in reduce else src.max(), dtype=src.dtype, device=src.device, ).expand_as(src) out = src.new_zeros(size).scatter_reduce_(dim, index, fill, reduce=f'a{reduce[-3:]}', include_self=True) return out.scatter_reduce_(dim, index, src, reduce=f'a{reduce[-3:]}', include_self=True) return torch_scatter.scatter(src, index, dim, dim_size=dim_size, reduce=reduce[-3:]) # For "mul" reduction, we prefer `scatter_reduce_` on CPU: if reduce == 'mul': if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling() or not src.is_cuda): if src.is_cuda and not is_compiling(): warnings.warn( f"The usage of `scatter(reduce='{reduce}')` " f"can be accelerated via the 'torch-scatter'" f" package, but it was not found", stacklevel=2) index = broadcast(index, src, dim) # We initialize with `one` here to match `scatter_mul` output: return src.new_ones(size).scatter_reduce_(dim, index, src, reduce='prod', include_self=True) return torch_scatter.scatter(src, index, dim, dim_size=dim_size, reduce='mul') raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'") def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor: dim = ref.dim() + dim if dim < 0 else dim size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1)) return src.view(size).expand_as(ref) def scatter_argmax( src: Tensor, index: Tensor, dim: int = 0, dim_size: Optional[int] = None, ) -> Tensor: if (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling() and not is_in_onnx_export()): out = torch_scatter.scatter_max(src, index, dim=dim, dim_size=dim_size) return out[1] # Only implemented under certain conditions for now :( assert src.dim() == 1 and index.dim() == 1 assert dim == 0 or dim == -1 assert src.numel() == index.numel() if dim_size is None: dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 if not is_in_onnx_export(): res = src.new_empty(dim_size) res.scatter_reduce_(0, index, src.detach(), reduce='amax', include_self=False) else: # `include_self=False` is currently not supported by ONNX: res = src.new_full( size=(dim_size, ), fill_value=src.min(), # type: ignore ) res.scatter_reduce_(0, index, src.detach(), reduce="amax", include_self=True) out = index.new_full((dim_size, ), fill_value=dim_size - 1) nonzero = (src == res[index]).nonzero().view(-1) out[index[nonzero]] = nonzero return out def group_argsort( src: Tensor, index: Tensor, dim: int = 0, num_groups: Optional[int] = None, descending: bool = False, return_consecutive: bool = False, stable: bool = False, ) -> Tensor: r"""Returns the indices that sort the tensor :obj:`src` along a given dimension in ascending order by value. In contrast to :meth:`torch.argsort`, sorting is performed in groups according to the values in :obj:`index`. Args: src (torch.Tensor): The source tensor. index (torch.Tensor): The index tensor. dim (int, optional): The dimension along which to index. (default: :obj:`0`) num_groups (int, optional): The number of groups. (default: :obj:`None`) descending (bool, optional): Controls the sorting order (ascending or descending). (default: :obj:`False`) return_consecutive (bool, optional): If set to :obj:`True`, will not offset the output to start from :obj:`0` for each group. (default: :obj:`False`) stable (bool, optional): Controls the relative order of equivalent elements. (default: :obj:`False`) Example: >>> src = torch.tensor([0, 1, 5, 4, 3, 2, 6, 7, 8]) >>> index = torch.tensor([0, 0, 1, 1, 1, 1, 2, 2, 2]) >>> group_argsort(src, index) tensor([0, 1, 3, 2, 1, 0, 0, 1, 2]) """ # Only implemented under certain conditions for now :( assert src.dim() == 1 and index.dim() == 1 assert dim == 0 or dim == -1 assert src.numel() == index.numel() if src.numel() == 0: return torch.zeros_like(src) # Normalize `src` to range [0, 1]: src = src - src.min() src = src / src.max() # Compute `grouped_argsort`: src = src - 2 * index if descending else src + 2 * index perm = src.argsort(descending=descending, stable=stable) out = torch.empty_like(index) out[perm] = torch.arange(index.numel(), device=index.device) if return_consecutive: return out # Compute cumulative sum of number of entries with the same index: count = scatter(torch.ones_like(index), index, dim=dim, dim_size=num_groups, reduce='sum') ptr = cumsum(count) return out - ptr[index] def group_cat( tensors: Union[List[Tensor], Tuple[Tensor, ...]], indices: Union[List[Tensor], Tuple[Tensor, ...]], dim: int = 0, return_index: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Concatenates the given sequence of tensors :obj:`tensors` in the given dimension :obj:`dim`. Different from :meth:`torch.cat`, values along the concatenating dimension are grouped according to the indices defined in the :obj:`index` tensors. All tensors must have the same shape (except in the concatenating dimension). Args: tensors ([Tensor]): Sequence of tensors. indices ([Tensor]): Sequence of index tensors. dim (int, optional): The dimension along which the tensors are concatenated. (default: :obj:`0`) return_index (bool, optional): If set to :obj:`True`, will return the new index tensor. (default: :obj:`False`) Example: >>> x1 = torch.tensor([[0.2716, 0.4233], ... [0.3166, 0.0142], ... [0.2371, 0.3839], ... [0.4100, 0.0012]]) >>> x2 = torch.tensor([[0.3752, 0.5782], ... [0.7757, 0.5999]]) >>> index1 = torch.tensor([0, 0, 1, 2]) >>> index2 = torch.tensor([0, 2]) >>> scatter_concat([x1,x2], [index1, index2], dim=0) tensor([[0.2716, 0.4233], [0.3166, 0.0142], [0.3752, 0.5782], [0.2371, 0.3839], [0.4100, 0.0012], [0.7757, 0.5999]]) """ assert len(tensors) == len(indices) index, perm = torch.cat(indices).sort(stable=True) out = torch.cat(tensors, dim=dim).index_select(dim, perm) return (out, index) if return_index else out ================================================ FILE: torch_geometric/utils/_segment.py ================================================ import torch from torch import Tensor import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.index import ptr2index from torch_geometric.typing import torch_scatter from torch_geometric.utils import scatter def segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor: r"""Reduces all values in the first dimension of the :obj:`src` tensor within the ranges specified in the :obj:`ptr`. See the `documentation `__ of the :obj:`torch_scatter` package for more information. Args: src (torch.Tensor): The source tensor. ptr (torch.Tensor): A monotonically increasing pointer tensor that refers to the boundaries of segments such that :obj:`ptr[0] = 0` and :obj:`ptr[-1] = src.size(0)`. reduce (str, optional): The reduce operation (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) """ if not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling(): return _torch_segment(src, ptr, reduce) if (ptr.dim() == 1 and torch_geometric.typing.WITH_PT20 and src.is_cuda and reduce == 'mean'): return _torch_segment(src, ptr, reduce) return torch_scatter.segment_csr(src, ptr, reduce=reduce) def _torch_segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor: if not torch_geometric.typing.WITH_PT20: raise ImportError("'segment' requires the 'torch-scatter' package") if ptr.dim() > 1: raise ImportError("'segment' in an arbitrary dimension " "requires the 'torch-scatter' package") if reduce == 'min' or reduce == 'max': reduce = f'a{reduce}' # `amin` or `amax` initial = 0 if reduce == 'mean' else None out = torch._segment_reduce(src, reduce, offsets=ptr, initial=initial) if reduce == 'amin' or reduce == 'amax': out = torch.where(out.isinf(), 0, out) return out def segment_logsumexp( src: Tensor, ptr: Tensor, dim: int, ) -> Tensor: r"""Returns the log summed exponentials of each row of the :obj:`src` tensor within the ranges specified in the :obj:`ptr`. Args: src: The source tensor. ptr (torch.Tensor): A monotonically increasing pointer tensor that refers to the boundaries of segments such that :obj:`ptr[0] = 0` and :obj:`ptr[-1] = src.size(0)`. dim: The dimension to reduce. """ src = src.transpose(0, dim) # Move reduction dimension to first dimension. index = ptr2index(ptr, output_size=src.size(0)) max_src = scatter(src, index, dim_size=ptr.numel() - 1, reduce='max') src = src - max_src[index] out = src.exp() out = segment(out, ptr, reduce='sum') out = out.log().nan_to_num(neginf=0.0) + max_src out = out.transpose(0, dim) return out ================================================ FILE: torch_geometric/utils/_select.py ================================================ from typing import Any, List, Union import torch from torch import Tensor from torch_geometric.typing import TensorFrame from torch_geometric.utils.mask import mask_select from torch_geometric.utils.sparse import is_torch_sparse_tensor def select( src: Union[Tensor, List[Any], TensorFrame], index_or_mask: Tensor, dim: int, ) -> Union[Tensor, List[Any]]: r"""Selects the input tensor or input list according to a given index or mask vector. Args: src (torch.Tensor or list): The input tensor or list. index_or_mask (torch.Tensor): The index or mask vector. dim (int): The dimension along which to select. """ if isinstance(src, Tensor): if index_or_mask.dtype == torch.bool: return mask_select(src, dim, index_or_mask) return src.index_select(dim, index_or_mask) if isinstance(src, (tuple, list)): if dim != 0: raise ValueError("Cannot select along dimension other than 0") if index_or_mask.dtype == torch.bool: return [src[i] for i, m in enumerate(index_or_mask) if m] return [src[i] for i in index_or_mask] if isinstance(src, TensorFrame): assert dim == 0 if index_or_mask.dtype == torch.bool: return mask_select(src, dim, index_or_mask) return src[index_or_mask] raise ValueError(f"Encountered invalid input type (got '{type(src)}')") def narrow(src: Union[Tensor, List[Any]], dim: int, start: int, length: int) -> Union[Tensor, List[Any]]: r"""Narrows the input tensor or input list to the specified range. Args: src (torch.Tensor or list): The input tensor or list. dim (int): The dimension along which to narrow. start (int): The starting dimension. length (int): The distance to the ending dimension. """ if isinstance(src, Tensor) and is_torch_sparse_tensor(src): # TODO Sparse tensors in `torch.sparse` do not yet support `narrow`. index = torch.arange(start, start + length, device=src.device) return src.index_select(dim, index) if isinstance(src, Tensor): return src.narrow(dim, start, length) if isinstance(src, list): if dim != 0: raise ValueError("Cannot narrow along dimension other than 0") return src[start:start + length] raise ValueError(f"Encountered invalid input type (got '{type(src)}')") ================================================ FILE: torch_geometric/utils/_softmax.py ================================================ from typing import Optional from torch import Tensor import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.typing import pyg_lib from torch_geometric.utils import scatter, segment from torch_geometric.utils.num_nodes import maybe_num_nodes def softmax( src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = 0, ) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor, optional): The indices of elements for applying the softmax. (default: :obj:`None`) ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) dim (int, optional): The dimension in which to normalize. (default: :obj:`0`) :rtype: :class:`Tensor` Examples: >>> src = torch.tensor([1., 1., 1., 1.]) >>> index = torch.tensor([0, 0, 1, 2]) >>> ptr = torch.tensor([0, 2, 3, 4]) >>> softmax(src, index) tensor([0.5000, 0.5000, 1.0000, 1.0000]) >>> softmax(src, None, ptr) tensor([0.5000, 0.5000, 1.0000, 1.0000]) >>> src = torch.randn(4, 4) >>> ptr = torch.tensor([0, 4]) >>> softmax(src, index, dim=-1) tensor([[0.7404, 0.2596, 1.0000, 1.0000], [0.1702, 0.8298, 1.0000, 1.0000], [0.7607, 0.2393, 1.0000, 1.0000], [0.8062, 0.1938, 1.0000, 1.0000]]) """ if (ptr is not None and src.device.type == 'cpu' and torch_geometric.typing.WITH_SOFTMAX and not is_compiling()): # pragma: no cover return pyg_lib.ops.softmax_csr(src, ptr, dim) if (ptr is not None and (ptr.dim() == 1 or (ptr.dim() > 1 and index is None) or (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()))): dim = dim + src.dim() if dim < 0 else dim size = ([1] * dim) + [-1] count = ptr[1:] - ptr[:-1] ptr = ptr.view(size) src_max = segment(src.detach(), ptr, reduce='max') src_max = src_max.repeat_interleave(count, dim=dim) out = (src - src_max).exp() out_sum = segment(out, ptr, reduce='sum') + 1e-16 out_sum = out_sum.repeat_interleave(count, dim=dim) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max') out = src - src_max.index_select(dim, index) out = out.exp() out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + 1e-16 out_sum = out_sum.index_select(dim, index) else: raise NotImplementedError("'softmax' requires 'index' to be specified") return out / out_sum ================================================ FILE: torch_geometric/utils/_sort_edge_index.py ================================================ import typing from typing import List, Optional, Tuple, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric import EdgeIndex from torch_geometric.edge_index import SortOrder from torch_geometric.typing import OptTensor from torch_geometric.utils import index_sort, lexsort from torch_geometric.utils.num_nodes import maybe_num_nodes if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload as overload MISSING = '???' @overload def sort_edge_index( edge_index: Tensor, edge_attr: str = MISSING, num_nodes: Optional[int] = None, sort_by_row: bool = True, ) -> Tensor: pass @overload def sort_edge_index( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, num_nodes: Optional[int] = None, sort_by_row: bool = True, ) -> Tuple[Tensor, Tensor]: pass @overload def sort_edge_index( # noqa: F811 edge_index: Tensor, edge_attr: OptTensor, num_nodes: Optional[int] = None, sort_by_row: bool = True, ) -> Tuple[Tensor, OptTensor]: pass @overload def sort_edge_index( # noqa: F811 edge_index: Tensor, edge_attr: List[Tensor], num_nodes: Optional[int] = None, sort_by_row: bool = True, ) -> Tuple[Tensor, List[Tensor]]: pass def sort_edge_index( # noqa: F811 edge_index: Tensor, edge_attr: Union[OptTensor, List[Tensor], str] = MISSING, num_nodes: Optional[int] = None, sort_by_row: bool = True, ) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]: """Row-wise sorts :obj:`edge_index`. Args: edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor or List[torch.Tensor], optional): Edge weights or multi-dimensional edge features. If given as a list, will re-shuffle and remove duplicates for all its entries. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) sort_by_row (bool, optional): If set to :obj:`False`, will sort :obj:`edge_index` column-wise/by destination node. (default: :obj:`True`) :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`) .. warning:: From :pyg:`PyG >= 2.3.0` onwards, this function will always return a tuple whenever :obj:`edge_attr` is passed as an argument (even in case it is set to :obj:`None`). Examples: >>> edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) >>> edge_attr = torch.tensor([[1], [2], [3], [4]]) >>> sort_edge_index(edge_index) tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) >>> sort_edge_index(edge_index, edge_attr) (tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), tensor([[4], [3], [2], [1]])) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64: perm = lexsort(keys=[ edge_index[int(sort_by_row)], edge_index[1 - int(sort_by_row)], ]) else: idx = edge_index[1 - int(sort_by_row)] * num_nodes idx += edge_index[int(sort_by_row)] _, perm = index_sort(idx, max_value=num_nodes * num_nodes) if isinstance(edge_index, Tensor): is_undirected = False if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): is_undirected = edge_index.is_undirected edge_index = edge_index[:, perm] if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): edge_index._sort_order = SortOrder('row' if sort_by_row else 'col') edge_index._is_undirected = is_undirected elif isinstance(edge_index, tuple): edge_index = (edge_index[0][perm], edge_index[1][perm]) else: raise NotImplementedError if edge_attr is None: return edge_index, None if isinstance(edge_attr, Tensor): return edge_index, edge_attr[perm] if isinstance(edge_attr, (list, tuple)): return edge_index, [e[perm] for e in edge_attr] return edge_index ================================================ FILE: torch_geometric/utils/_spmm.py ================================================ import warnings import torch from torch import Tensor import torch_geometric.typing from torch_geometric import EdgeIndex from torch_geometric.typing import Adj, SparseTensor, torch_sparse from torch_geometric.utils import is_torch_sparse_tensor, scatter def spmm( src: Adj, other: Tensor, reduce: str = 'sum', ) -> Tensor: r"""Matrix product of sparse matrix with dense matrix. Args: src (torch.Tensor or torch_sparse.SparseTensor or EdgeIndex): The input sparse matrix which can be a :pyg:`PyG` :class:`torch_sparse.SparseTensor`, a :pytorch:`PyTorch` :class:`torch.sparse.Tensor` or a :pyg:`PyG` :class:`EdgeIndex`. other (torch.Tensor): The input dense matrix. reduce (str, optional): The reduce operation to use (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`). (default: :obj:`"sum"`) :rtype: :class:`Tensor` """ reduce = 'sum' if reduce == 'add' else reduce if reduce not in ['sum', 'mean', 'min', 'max']: raise ValueError(f"`reduce` argument '{reduce}' not supported") if not torch.jit.is_scripting() and isinstance(src, EdgeIndex): return src.matmul(other=other, reduce=reduce) # type: ignore if isinstance(src, SparseTensor): if src.nnz() == 0: return other.new_zeros(src.size(0), other.size(1)) if (torch_geometric.typing.WITH_PT20 and other.dim() == 2 and not src.is_cuda() and not src.requires_grad()): # Use optimized PyTorch `torch.sparse.mm` path: csr = src.to_torch_sparse_csr_tensor().to(other.dtype) return torch.sparse.mm(csr, other, reduce) return torch_sparse.matmul(src, other, reduce) if not is_torch_sparse_tensor(src): raise ValueError("'src' must be a 'torch_sparse.SparseTensor' or a " "'torch.sparse.Tensor'") # `torch.sparse.mm` only supports reductions on CPU for PyTorch>=2.0. # This will currently throw on error for CUDA tensors. if torch_geometric.typing.WITH_PT20: if src.is_cuda and (reduce == 'min' or reduce == 'max'): raise NotImplementedError(f"`{reduce}` reduction is not yet " f"supported for 'torch.sparse.Tensor' " f"on device '{src.device}'") # Always convert COO to CSR for more efficient processing: if src.layout == torch.sparse_coo: warnings.warn( f"Converting sparse tensor to CSR format for more " f"efficient processing. Consider converting your " f"sparse tensor to CSR format beforehand to avoid " f"repeated conversion (got '{src.layout}')", stacklevel=2) src = src.to_sparse_csr() # Warn in case of CSC format without gradient computation: if src.layout == torch.sparse_csc and not other.requires_grad: warnings.warn( f"Converting sparse tensor to CSR format for more " f"efficient processing. Consider converting your " f"sparse tensor to CSR format beforehand to avoid " f"repeated conversion (got '{src.layout}')", stacklevel=2) # Use the default code path for `sum` reduction (works on CPU/GPU): if reduce == 'sum': return torch.sparse.mm(src, other) # Use the default code path with custom reduction (works on CPU): if src.layout == torch.sparse_csr and not src.is_cuda: return torch.sparse.mm(src, other, reduce) # Simulate `mean` reduction by dividing by degree: if reduce == 'mean': if src.layout == torch.sparse_csr: ptr = src.crow_indices() deg = ptr[1:] - ptr[:-1] else: assert src.layout == torch.sparse_csc deg = scatter(torch.ones_like(src.values()), src.row_indices(), dim=0, dim_size=src.size(0), reduce='sum') return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1) # TODO The `torch.sparse.mm` code path with the `reduce` argument does # not yet support CSC :( if src.layout == torch.sparse_csc: warnings.warn( f"Converting sparse tensor to CSR format for more " f"efficient processing. Consider converting your " f"sparse tensor to CSR format beforehand to avoid " f"repeated conversion (got '{src.layout}')", stacklevel=2) src = src.to_sparse_csr() return torch.sparse.mm(src, other, reduce) # pragma: no cover # PyTorch < 2.0 only supports sparse COO format: if reduce == 'sum': return torch.sparse.mm(src, other) elif reduce == 'mean': if src.layout == torch.sparse_csr: ptr = src.crow_indices() deg = ptr[1:] - ptr[:-1] elif src.layout == torch.sparse_csc: assert src.layout == torch.sparse_csc ones = torch.ones_like(src.values()) index = src.row_indices() deg = scatter(ones, index, 0, dim_size=src.size(0), reduce='sum') else: assert src.layout == torch.sparse_coo src = src.coalesce() ones = torch.ones_like(src.values()) index = src.indices()[0] deg = scatter(ones, index, 0, dim_size=src.size(0), reduce='sum') return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1) raise ValueError(f"`{reduce}` reduction is not supported for " f"'torch.sparse.Tensor' on device '{src.device}'") ================================================ FILE: torch_geometric/utils/_subgraph.py ================================================ from typing import List, Literal, Optional, Tuple, Union, overload import torch from torch import Tensor from torch_geometric.typing import OptTensor, PairTensor from torch_geometric.utils import scatter from torch_geometric.utils.map import map_index from torch_geometric.utils.mask import index_to_mask from torch_geometric.utils.num_nodes import maybe_num_nodes def get_num_hops(model: torch.nn.Module) -> int: r"""Returns the number of hops the model is aggregating information from. .. note:: This function counts the number of message passing layers as an approximation of the total number of hops covered by the model. Its output may not necessarily be correct in case message passing layers perform multi-hop aggregation, *e.g.*, as in :class:`~torch_geometric.nn.conv.ChebConv`. Example: >>> class GNN(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.conv1 = GCNConv(3, 16) ... self.conv2 = GCNConv(16, 16) ... self.lin = Linear(16, 2) ... ... def forward(self, x, edge_index): ... x = self.conv1(x, edge_index).relu() ... x = self.conv2(x, edge_index).relu() ... return self.lin(x) >>> get_num_hops(GNN()) 2 """ from torch_geometric.nn.conv import MessagePassing num_hops = 0 for module in model.modules(): if isinstance(module, MessagePassing): num_hops += 1 return num_hops @overload def subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., ) -> Tuple[Tensor, OptTensor]: pass @overload def subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., *, return_edge_mask: Literal[False], ) -> Tuple[Tensor, OptTensor]: pass @overload def subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., *, return_edge_mask: Literal[True], ) -> Tuple[Tensor, OptTensor, Tensor]: pass def subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = None, relabel_nodes: bool = False, num_nodes: Optional[int] = None, *, return_edge_mask: bool = False, ) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]: r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`. Args: subset (LongTensor, BoolTensor or [int]): The nodes to keep. edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) relabel_nodes (bool, optional): If set to :obj:`True`, the resulting :obj:`edge_index` will be relabeled to hold consecutive indices starting from zero. (default: :obj:`False`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max(edge_index) + 1`. (default: :obj:`None`) return_edge_mask (bool, optional): If set to :obj:`True`, will return the edge mask to filter out additional edge features. (default: :obj:`False`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Examples: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6], ... [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5]]) >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) >>> subset = torch.tensor([3, 4, 5]) >>> subgraph(subset, edge_index, edge_attr) (tensor([[3, 4, 4, 5], [4, 3, 5, 4]]), tensor([ 7., 8., 9., 10.])) >>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True) (tensor([[3, 4, 4, 5], [4, 3, 5, 4]]), tensor([ 7., 8., 9., 10.]), tensor([False, False, False, False, False, False, True, True, True, True, False, False])) """ device = edge_index.device if isinstance(subset, (list, tuple)): subset = torch.tensor(subset, dtype=torch.long, device=device) if subset.dtype != torch.bool: num_nodes = maybe_num_nodes(edge_index, num_nodes) node_mask = index_to_mask(subset, size=num_nodes) else: num_nodes = subset.size(0) node_mask = subset subset = node_mask.nonzero().view(-1) edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None if relabel_nodes: edge_index, _ = map_index( edge_index.view(-1), subset, max_index=num_nodes, inclusive=True, ) edge_index = edge_index.view(2, -1) if return_edge_mask: return edge_index, edge_attr, edge_mask else: return edge_index, edge_attr def bipartite_subgraph( subset: Union[PairTensor, Tuple[List[int], List[int]]], edge_index: Tensor, edge_attr: OptTensor = None, relabel_nodes: bool = False, size: Optional[Tuple[int, int]] = None, return_edge_mask: bool = False, ) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, OptTensor]]: r"""Returns the induced subgraph of the bipartite graph :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`. Args: subset (Tuple[Tensor, Tensor] or tuple([int],[int])): The nodes to keep. edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) relabel_nodes (bool, optional): If set to :obj:`True`, the resulting :obj:`edge_index` will be relabeled to hold consecutive indices starting from zero. (default: :obj:`False`) size (tuple, optional): The number of nodes. (default: :obj:`None`) return_edge_mask (bool, optional): If set to :obj:`True`, will return the edge mask to filter out additional edge features. (default: :obj:`False`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Examples: >>> edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6], ... [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]]) >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) >>> subset = (torch.tensor([2, 3, 5]), torch.tensor([2, 3])) >>> bipartite_subgraph(subset, edge_index, edge_attr) (tensor([[2, 3, 5, 5], [3, 2, 2, 3]]), tensor([ 3, 4, 9, 10])) >>> bipartite_subgraph(subset, edge_index, edge_attr, ... return_edge_mask=True) (tensor([[2, 3, 5, 5], [3, 2, 2, 3]]), tensor([ 3, 4, 9, 10]), tensor([False, False, True, True, False, False, False, False, True, True, False])) """ device = edge_index.device src_subset, dst_subset = subset if not isinstance(src_subset, Tensor): src_subset = torch.tensor(src_subset, dtype=torch.long, device=device) if not isinstance(dst_subset, Tensor): dst_subset = torch.tensor(dst_subset, dtype=torch.long, device=device) if src_subset.dtype != torch.bool: src_size = int(edge_index[0].max()) + 1 if size is None else size[0] src_node_mask = index_to_mask(src_subset, size=src_size) else: src_size = src_subset.size(0) src_node_mask = src_subset src_subset = src_subset.nonzero().view(-1) if dst_subset.dtype != torch.bool: dst_size = int(edge_index[1].max()) + 1 if size is None else size[1] dst_node_mask = index_to_mask(dst_subset, size=dst_size) else: dst_size = dst_subset.size(0) dst_node_mask = dst_subset dst_subset = dst_subset.nonzero().view(-1) edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None if relabel_nodes: src_index, _ = map_index(edge_index[0], src_subset, max_index=src_size, inclusive=True) dst_index, _ = map_index(edge_index[1], dst_subset, max_index=dst_size, inclusive=True) edge_index = torch.stack([src_index, dst_index], dim=0) if return_edge_mask: return edge_index, edge_attr, edge_mask else: return edge_index, edge_attr def k_hop_subgraph( node_idx: Union[int, List[int], Tensor], num_hops: int, edge_index: Tensor, relabel_nodes: bool = False, num_nodes: Optional[int] = None, flow: str = 'source_to_target', directed: bool = False, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""Computes the induced subgraph of :obj:`edge_index` around all nodes in :attr:`node_idx` reachable within :math:`k` hops. The :attr:`flow` argument denotes the direction of edges for finding :math:`k`-hop neighbors. If set to :obj:`"source_to_target"`, then the method will find all neighbors that point to the initial set of seed nodes in :attr:`node_idx.` This mimics the natural flow of message passing in Graph Neural Networks. The method returns (1) the nodes involved in the subgraph, (2) the filtered :obj:`edge_index` connectivity, (3) the mapping from node indices in :obj:`node_idx` to their new location, and (4) the edge mask indicating which edges were preserved. Args: node_idx (int, list, tuple or :obj:`torch.Tensor`): The central seed node(s). num_hops (int): The number of hops :math:`k`. edge_index (LongTensor): The edge indices. relabel_nodes (bool, optional): If set to :obj:`True`, the resulting :obj:`edge_index` will be relabeled to hold consecutive indices starting from zero. (default: :obj:`False`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) flow (str, optional): The flow direction of :math:`k`-hop aggregation (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) directed (bool, optional): If set to :obj:`True`, will only include directed edges to the seed nodes :obj:`node_idx`. (default: :obj:`False`) :rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`, :class:`BoolTensor`) Examples: >>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], ... [2, 2, 4, 4, 6, 6]]) >>> # Center node 6, 2-hops >>> subset, edge_index, mapping, edge_mask = k_hop_subgraph( ... 6, 2, edge_index, relabel_nodes=True) >>> subset tensor([2, 3, 4, 5, 6]) >>> edge_index tensor([[0, 1, 2, 3], [2, 2, 4, 4]]) >>> mapping tensor([4]) >>> edge_mask tensor([False, False, True, True, True, True]) >>> subset[mapping] tensor([6]) >>> edge_index = torch.tensor([[1, 2, 4, 5], ... [0, 1, 5, 6]]) >>> (subset, edge_index, ... mapping, edge_mask) = k_hop_subgraph([0, 6], 2, ... edge_index, ... relabel_nodes=True) >>> subset tensor([0, 1, 2, 4, 5, 6]) >>> edge_index tensor([[1, 2, 3, 4], [0, 1, 4, 5]]) >>> mapping tensor([0, 5]) >>> edge_mask tensor([True, True, True, True]) >>> subset[mapping] tensor([0, 6]) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) assert flow in ['source_to_target', 'target_to_source'] if flow == 'target_to_source': row, col = edge_index else: col, row = edge_index node_mask = row.new_empty(num_nodes, dtype=torch.bool) edge_mask = row.new_empty(row.size(0), dtype=torch.bool) if isinstance(node_idx, int): node_idx = torch.tensor([node_idx], device=row.device) elif isinstance(node_idx, (list, tuple)): node_idx = torch.tensor(node_idx, device=row.device) else: node_idx = node_idx.to(row.device) subsets = [node_idx] preserved_edge_mask = torch.zeros_like(edge_mask) for _ in range(num_hops): node_mask.fill_(False) node_mask[subsets[-1]] = True torch.index_select(node_mask, 0, row, out=edge_mask) preserved_edge_mask |= edge_mask subsets.append(col[edge_mask]) subset, inv = torch.cat(subsets).unique(return_inverse=True) inv = inv[:node_idx.numel()] node_mask.fill_(False) node_mask[subset] = True if not directed: edge_mask = node_mask[row] & node_mask[col] else: edge_mask = preserved_edge_mask edge_index = edge_index[:, edge_mask] if relabel_nodes: mapping = row.new_full((num_nodes, ), -1) mapping[subset] = torch.arange(subset.size(0), device=row.device) edge_index = mapping[edge_index] return subset, edge_index, inv, edge_mask @overload def hyper_subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., ) -> Tuple[Tensor, OptTensor]: pass @overload def hyper_subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., *, return_edge_mask: Literal[False], ) -> Tuple[Tensor, OptTensor]: pass @overload def hyper_subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., *, return_edge_mask: Literal[True], ) -> Tuple[Tensor, OptTensor, Tensor]: pass def hyper_subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: OptTensor = None, relabel_nodes: bool = False, num_nodes: Optional[int] = None, return_edge_mask: bool = False, ) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]: r"""Returns the induced subgraph of the hyper graph of :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`. Args: subset (torch.Tensor or [int]): The nodes to keep. edge_index (LongTensor): Hyperedge tensor with shape :obj:`[2, num_edges*num_nodes_per_edge]`, where :obj:`edge_index[1]` denotes the hyperedge index and :obj:`edge_index[0]` denotes the node indices that are connected by the hyperedge. edge_attr (torch.Tensor, optional): Edge weights or multi-dimensional edge features of shape :obj:`[num_edges, *]`. (default: :obj:`None`) relabel_nodes (bool, optional): If set to :obj:`True`, the resulting :obj:`edge_index` will be relabeled to hold consecutive indices starting from zero. (default: :obj:`False`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max(edge_index[0]) + 1`. (default: :obj:`None`) return_edge_mask (bool, optional): If set to :obj:`True`, will return the edge mask to filter out additional edge features. (default: :obj:`False`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Examples: >>> edge_index = torch.tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3], ... [0, 0, 0, 1, 1, 1, 2, 2, 2]]) >>> edge_attr = torch.tensor([3, 2, 6]) >>> subset = torch.tensor([0, 3]) >>> subgraph(subset, edge_index, edge_attr) (tensor([[0, 3], [0, 0]]), tensor([ 6.])) >>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True) (tensor([[0, 3], [0, 0]]), tensor([ 6.])) tensor([False, False, True]) """ device = edge_index.device if isinstance(subset, (list, tuple)): subset = torch.tensor(subset, dtype=torch.long, device=device) if subset.dtype != torch.bool: num_nodes = maybe_num_nodes(edge_index, num_nodes) node_mask = index_to_mask(subset, size=num_nodes) else: num_nodes = subset.size(0) node_mask = subset # Mask all connections that contain a node not in the subset hyper_edge_connection_mask = node_mask[ edge_index[0]] # num_edges*num_nodes_per_edge # Mask hyperedges that contain one or less nodes from the subset edge_mask = scatter(hyper_edge_connection_mask.to(torch.long), edge_index[1], reduce='sum') > 1 # Mask connections if hyperedge contains one or less nodes from the subset # or is connected to a node not in the subset hyper_edge_connection_mask = hyper_edge_connection_mask & edge_mask[ edge_index[1]] edge_index = edge_index[:, hyper_edge_connection_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None # Relabel edges edge_idx = torch.zeros(edge_mask.size(0), dtype=torch.long, device=device) edge_idx[edge_mask] = torch.arange(edge_mask.sum().item(), device=device) edge_index = torch.cat( [edge_index[0].unsqueeze(0), edge_idx[edge_index[1]].unsqueeze(0)], 0) if relabel_nodes: node_idx = torch.zeros(node_mask.size(0), dtype=torch.long, device=device) node_idx[subset] = torch.arange(node_mask.sum().item(), device=device) edge_index = torch.cat( [node_idx[edge_index[0]].unsqueeze(0), edge_index[1].unsqueeze(0)], 0) if return_edge_mask: return edge_index, edge_attr, edge_mask else: return edge_index, edge_attr ================================================ FILE: torch_geometric/utils/_to_dense_adj.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.typing import OptTensor from torch_geometric.utils import cumsum, scatter def to_dense_adj( edge_index: Tensor, batch: OptTensor = None, edge_attr: OptTensor = None, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None, ) -> Tensor: r"""Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix. Args: edge_index (LongTensor): The edge indices. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. If :obj:`edge_index` contains duplicated edges, the dense adjacency matrix output holds the summed up entries of :obj:`edge_attr` for duplicated edges. (default: :obj:`None`) max_num_nodes (int, optional): The size of the output node dimension. (default: :obj:`None`) batch_size (int, optional): The batch size. (default: :obj:`None`) :rtype: :class:`Tensor` Examples: >>> edge_index = torch.tensor([[0, 0, 1, 2, 3], ... [0, 1, 0, 3, 0]]) >>> batch = torch.tensor([0, 0, 1, 1]) >>> to_dense_adj(edge_index, batch) tensor([[[1., 1.], [1., 0.]], [[0., 1.], [1., 0.]]]) >>> to_dense_adj(edge_index, batch, max_num_nodes=4) tensor([[[1., 1., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 1., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]) >>> edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) >>> to_dense_adj(edge_index, batch, edge_attr) tensor([[[1., 2.], [3., 0.]], [[0., 4.], [5., 0.]]]) """ if batch is None: max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 batch = edge_index.new_zeros(max_index) if batch_size is None: batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1 one = batch.new_ones(batch.size(0)) num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum') cum_nodes = cumsum(num_nodes) idx0 = batch[edge_index[0]] idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]] idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]] if max_num_nodes is None: max_num_nodes = int(num_nodes.max()) elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes) or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)): mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes) idx0 = idx0[mask] idx1 = idx1[mask] idx2 = idx2[mask] edge_attr = None if edge_attr is None else edge_attr[mask] if edge_attr is None: edge_attr = torch.ones(idx0.numel(), device=edge_index.device) size = [batch_size, max_num_nodes, max_num_nodes] size += list(edge_attr.size())[1:] flattened_size = batch_size * max_num_nodes * max_num_nodes idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2 adj = scatter(edge_attr, idx, dim=0, dim_size=flattened_size, reduce='sum') adj = adj.view(size) return adj ================================================ FILE: torch_geometric/utils/_to_dense_batch.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.experimental import ( disable_dynamic_shapes, is_experimental_mode_enabled, ) from torch_geometric.utils import cumsum, scatter @disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes']) def to_dense_batch( x: Tensor, batch: Optional[Tensor] = None, fill_value: float = 0.0, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: r"""Given a sparse batch of node features :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a dense node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with :math:`N_{\max} = \max_i^B N_i`). In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times N_{\max}}` is returned, holding information about the existence of fake-nodes in the dense representation. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. (default: :obj:`None`) fill_value (float, optional): The value for invalid entries in the resulting dense output tensor. (default: :obj:`0`) max_num_nodes (int, optional): The size of the output node dimension. (default: :obj:`None`) batch_size (int, optional): The batch size. (default: :obj:`None`) :rtype: (:class:`Tensor`, :class:`BoolTensor`) Examples: >>> x = torch.arange(12).view(6, 2) >>> x tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]]) >>> out, mask = to_dense_batch(x) >>> mask tensor([[True, True, True, True, True, True]]) >>> batch = torch.tensor([0, 0, 1, 2, 2, 2]) >>> out, mask = to_dense_batch(x, batch) >>> out tensor([[[ 0, 1], [ 2, 3], [ 0, 0]], [[ 4, 5], [ 0, 0], [ 0, 0]], [[ 6, 7], [ 8, 9], [10, 11]]]) >>> mask tensor([[ True, True, False], [ True, False, False], [ True, True, True]]) >>> out, mask = to_dense_batch(x, batch, max_num_nodes=4) >>> out tensor([[[ 0, 1], [ 2, 3], [ 0, 0], [ 0, 0]], [[ 4, 5], [ 0, 0], [ 0, 0], [ 0, 0]], [[ 6, 7], [ 8, 9], [10, 11], [ 0, 0]]]) >>> mask tensor([[ True, True, False, False], [ True, False, False, False], [ True, True, True, False]]) """ if batch is None and max_num_nodes is None: mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device) return x.unsqueeze(0), mask if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) if batch_size is None: batch_size = int(batch.max()) + 1 num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, dim_size=batch_size, reduce='sum') cum_nodes = cumsum(num_nodes) filter_nodes = False dynamic_shapes_disabled = is_experimental_mode_enabled( 'disable_dynamic_shapes') if max_num_nodes is None: max_num_nodes = int(num_nodes.max()) elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes: filter_nodes = True tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch] idx = tmp + (batch * max_num_nodes) if filter_nodes: mask = tmp < max_num_nodes x, idx = x[mask], idx[mask] size = [batch_size * max_num_nodes] + list(x.size())[1:] out = torch.as_tensor(fill_value, device=x.device, dtype=x.dtype) out = out.repeat(size) out[idx] = x out = out.view([batch_size, max_num_nodes] + list(x.size())[1:]) mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool, device=x.device) mask[idx] = 1 mask = mask.view(batch_size, max_num_nodes) return out, mask ================================================ FILE: torch_geometric/utils/_train_test_split_edges.py ================================================ import math import torch import torch_geometric from torch_geometric.deprecation import deprecated from torch_geometric.utils import to_undirected @deprecated("use 'transforms.RandomLinkSplit' instead") def train_test_split_edges( data: 'torch_geometric.data.Data', val_ratio: float = 0.05, test_ratio: float = 0.1, ) -> 'torch_geometric.data.Data': r"""Splits the edges of a :class:`torch_geometric.data.Data` object into positive and negative train/val/test edges. As such, it will replace the :obj:`edge_index` attribute with :obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`, :obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and :obj:`test_pos_edge_index` attributes. If :obj:`data` has edge features named :obj:`edge_attr`, then :obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and :obj:`test_pos_edge_attr` will be added as well. .. warning:: :meth:`~torch_geometric.utils.train_test_split_edges` is deprecated and will be removed in a future release. Use :class:`torch_geometric.transforms.RandomLinkSplit` instead. Args: data (Data): The data object. val_ratio (float, optional): The ratio of positive validation edges. (default: :obj:`0.05`) test_ratio (float, optional): The ratio of positive test edges. (default: :obj:`0.1`) :rtype: :class:`torch_geometric.data.Data` """ assert 'batch' not in data # No batch-mode. assert data.num_nodes is not None assert data.edge_index is not None num_nodes = data.num_nodes row, col = data.edge_index edge_attr = data.edge_attr del data.edge_index del data.edge_attr # Return upper triangular portion. mask = row < col row, col = row[mask], col[mask] if edge_attr is not None: edge_attr = edge_attr[mask] n_v = int(math.floor(val_ratio * row.size(0))) n_t = int(math.floor(test_ratio * row.size(0))) # Positive edges. perm = torch.randperm(row.size(0)) row, col = row[perm], col[perm] if edge_attr is not None: edge_attr = edge_attr[perm] r, c = row[:n_v], col[:n_v] data.val_pos_edge_index = torch.stack([r, c], dim=0) if edge_attr is not None: data.val_pos_edge_attr = edge_attr[:n_v] r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t] data.test_pos_edge_index = torch.stack([r, c], dim=0) if edge_attr is not None: data.test_pos_edge_attr = edge_attr[n_v:n_v + n_t] r, c = row[n_v + n_t:], col[n_v + n_t:] data.train_pos_edge_index = torch.stack([r, c], dim=0) if edge_attr is not None: out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:]) data.train_pos_edge_index, data.train_pos_edge_attr = out else: data.train_pos_edge_index = to_undirected(data.train_pos_edge_index) # Negative edges. neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8) neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool) neg_adj_mask[row, col] = 0 neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t() perm = torch.randperm(neg_row.size(0))[:n_v + n_t] neg_row, neg_col = neg_row[perm], neg_col[perm] neg_adj_mask[neg_row, neg_col] = 0 data.train_neg_adj_mask = neg_adj_mask row, col = neg_row[:n_v], neg_col[:n_v] data.val_neg_edge_index = torch.stack([row, col], dim=0) row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t] data.test_neg_edge_index = torch.stack([row, col], dim=0) return data ================================================ FILE: torch_geometric/utils/_tree_decomposition.py ================================================ from itertools import chain from typing import Any, List, Literal, Tuple, Union, overload import torch from torch import Tensor from torch_geometric.utils import ( from_scipy_sparse_matrix, to_scipy_sparse_matrix, to_undirected, ) @overload def tree_decomposition(mol: Any) -> Tuple[Tensor, Tensor, int]: pass @overload def tree_decomposition( mol: Any, return_vocab: Literal[False], ) -> Tuple[Tensor, Tensor, int]: pass @overload def tree_decomposition( mol: Any, return_vocab: Literal[True], ) -> Tuple[Tensor, Tensor, int, Tensor]: pass def tree_decomposition( mol: Any, return_vocab: bool = False, ) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, int, Tensor]]: r"""The tree decomposition algorithm of molecules from the `"Junction Tree Variational Autoencoder for Molecular Graph Generation" `_ paper. Returns the graph connectivity of the junction tree, the assignment mapping of each atom to the clique in the junction tree, and the number of cliques. Args: mol (rdkit.Chem.Mol): An :obj:`rdkit` molecule. return_vocab (bool, optional): If set to :obj:`True`, will return an identifier for each clique (ring, bond, bridged compounds, single). (default: :obj:`False`) :rtype: :obj:`(LongTensor, LongTensor, int)` if :obj:`return_vocab` is :obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)` """ import rdkit.Chem as Chem from scipy.sparse.csgraph import minimum_spanning_tree # Cliques = rings and bonds. cliques: List[List[int]] = [list(x) for x in Chem.GetSymmSSSR(mol)] xs: List[int] = [0] * len(cliques) for bond in mol.GetBonds(): if not bond.IsInRing(): cliques.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) xs.append(1) # Generate `atom2cliques` mappings. atom2cliques: List[List[int]] = [[] for _ in range(mol.GetNumAtoms())] for c in range(len(cliques)): for atom in cliques[c]: atom2cliques[atom].append(c) # Merge rings that share more than 2 atoms as they form bridged compounds. for c1 in range(len(cliques)): for atom in cliques[c1]: for c2 in atom2cliques[atom]: if c1 >= c2 or len(cliques[c1]) <= 2 or len(cliques[c2]) <= 2: continue if len(set(cliques[c1]) & set(cliques[c2])) > 2: cliques[c1] = list(set(cliques[c1]) | set(cliques[c2])) xs[c1] = 2 cliques[c2] = [] xs[c2] = -1 cliques = [c for c in cliques if len(c) > 0] xs = [x for x in xs if x >= 0] # Update `atom2cliques` mappings. atom2cliques = [[] for i in range(mol.GetNumAtoms())] for c in range(len(cliques)): for atom in cliques[c]: atom2cliques[atom].append(c) # Add singleton cliques in case there are more than 2 intersecting # cliques. We further compute the "initial" clique graph. edges = {} for atom in range(mol.GetNumAtoms()): cs = atom2cliques[atom] if len(cs) <= 1: continue # Number of bond clusters that the atom lies in. bonds = [c for c in cs if len(cliques[c]) == 2] # Number of ring clusters that the atom lies in. rings = [c for c in cs if len(cliques[c]) > 4] if len(bonds) > 2 or (len(bonds) == 2 and len(cs) > 2): cliques.append([atom]) xs.append(3) c2 = len(cliques) - 1 for c1 in cs: edges[(c1, c2)] = 1 elif len(rings) > 2: cliques.append([atom]) xs.append(3) c2 = len(cliques) - 1 for c1 in cs: edges[(c1, c2)] = 99 else: for i in range(len(cs)): for j in range(i + 1, len(cs)): c1, c2 = cs[i], cs[j] count = len(set(cliques[c1]) & set(cliques[c2])) edges[(c1, c2)] = min(count, edges.get((c1, c2), 99)) # Update `atom2cliques` mappings. atom2cliques = [[] for i in range(mol.GetNumAtoms())] for c in range(len(cliques)): for atom in cliques[c]: atom2cliques[atom].append(c) if len(edges) > 0: edge_index_T, weight = zip(*edges.items()) edge_index = torch.tensor(edge_index_T).t() inv_weight = 100 - torch.tensor(weight) graph = to_scipy_sparse_matrix(edge_index, inv_weight, len(cliques)) junc_tree = minimum_spanning_tree(graph) edge_index, _ = from_scipy_sparse_matrix(junc_tree) edge_index = to_undirected(edge_index, num_nodes=len(cliques)) else: edge_index = torch.empty((2, 0), dtype=torch.long) rows = [[i] * len(atom2cliques[i]) for i in range(mol.GetNumAtoms())] row = torch.tensor(list(chain.from_iterable(rows))) col = torch.tensor(list(chain.from_iterable(atom2cliques))) atom2clique = torch.stack([row, col], dim=0).to(torch.long) if return_vocab: vocab = torch.tensor(xs, dtype=torch.long) return edge_index, atom2clique, len(cliques), vocab else: return edge_index, atom2clique, len(cliques) ================================================ FILE: torch_geometric/utils/_trim_to_layer.py ================================================ from typing import Dict, List, Optional, Tuple, Union, overload import torch from torch import Tensor from torch_geometric import EdgeIndex from torch_geometric.typing import ( Adj, EdgeType, MaybeHeteroAdjTensor, MaybeHeteroEdgeTensor, MaybeHeteroNodeTensor, NodeType, SparseStorage, SparseTensor, ) @overload def trim_to_layer( layer: int, num_sampled_nodes_per_hop: List[int], num_sampled_edges_per_hop: List[int], x: Tensor, edge_index: Adj, edge_attr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: pass @overload def trim_to_layer( layer: int, num_sampled_nodes_per_hop: Dict[NodeType, List[int]], num_sampled_edges_per_hop: Dict[EdgeType, List[int]], x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Adj], edge_attr: Optional[Dict[EdgeType, Tensor]] = None, ) -> Tuple[Dict[NodeType, Tensor], Dict[EdgeType, Adj], Optional[Dict[ EdgeType, Tensor]]]: pass def trim_to_layer( layer: int, num_sampled_nodes_per_hop: Union[List[int], Dict[NodeType, List[int]]], num_sampled_edges_per_hop: Union[List[int], Dict[EdgeType, List[int]]], x: MaybeHeteroNodeTensor, edge_index: MaybeHeteroEdgeTensor, edge_attr: Optional[MaybeHeteroEdgeTensor] = None, ) -> Tuple[MaybeHeteroNodeTensor, MaybeHeteroAdjTensor, Optional[MaybeHeteroEdgeTensor]]: r"""Trims the :obj:`edge_index` representation, node features :obj:`x` and edge features :obj:`edge_attr` to a minimal-sized representation for the current GNN layer :obj:`layer` in directed :class:`~torch_geometric.loader.NeighborLoader` scenarios. This ensures that no computation is performed for nodes and edges that are not included in the current GNN layer, thus avoiding unnecessary computation within the GNN when performing neighborhood sampling. Args: layer (int): The current GNN layer. num_sampled_nodes_per_hop (List[int] or Dict[NodeType, List[int]]): The number of sampled nodes per hop. num_sampled_edges_per_hop (List[int] or Dict[EdgeType, List[int]]): The number of sampled edges per hop. x (torch.Tensor or Dict[NodeType, torch.Tensor]): The homogeneous or heterogeneous (hidden) node features. edge_index (torch.Tensor or Dict[EdgeType, torch.Tensor]): The homogeneous or heterogeneous edge indices. edge_attr (torch.Tensor or Dict[EdgeType, torch.Tensor], optional): The homogeneous or heterogeneous (hidden) edge features. """ if layer <= 0: return x, edge_index, edge_attr if isinstance(num_sampled_edges_per_hop, dict): assert isinstance(num_sampled_nodes_per_hop, dict) assert isinstance(x, dict) x = { k: trim_feat(v, layer, num_sampled_nodes_per_hop[k]) for k, v in x.items() } assert isinstance(edge_index, dict) edge_index = { k: trim_adj( v, layer, num_sampled_nodes_per_hop[k[0]], num_sampled_nodes_per_hop[k[-1]], num_sampled_edges_per_hop[k], ) for k, v in edge_index.items() } if edge_attr is not None: assert isinstance(edge_attr, dict) edge_attr = { k: trim_feat(v, layer, num_sampled_edges_per_hop[k]) for k, v in edge_attr.items() } return x, edge_index, edge_attr assert isinstance(num_sampled_nodes_per_hop, list) assert isinstance(x, Tensor) x = trim_feat(x, layer, num_sampled_nodes_per_hop) assert isinstance(edge_index, (Tensor, SparseTensor)) edge_index = trim_adj( edge_index, layer, num_sampled_nodes_per_hop, num_sampled_nodes_per_hop, num_sampled_edges_per_hop, ) if edge_attr is not None: assert isinstance(edge_attr, Tensor) edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop) return x, edge_index, edge_attr class TrimToLayer(torch.nn.Module): @torch.jit.unused def forward( self, layer: int, num_sampled_nodes_per_hop: Optional[List[int]], num_sampled_edges_per_hop: Optional[List[int]], x: Tensor, edge_index: Adj, edge_attr: Optional[Tensor] = None, ) -> Tuple[Tensor, Adj, Optional[Tensor]]: if (not isinstance(num_sampled_nodes_per_hop, list) and isinstance(num_sampled_edges_per_hop, list)): raise ValueError("'num_sampled_nodes_per_hop' needs to be given") if (not isinstance(num_sampled_edges_per_hop, list) and isinstance(num_sampled_nodes_per_hop, list)): raise ValueError("'num_sampled_edges_per_hop' needs to be given") if num_sampled_nodes_per_hop is None: return x, edge_index, edge_attr if num_sampled_edges_per_hop is None: return x, edge_index, edge_attr return trim_to_layer( layer, num_sampled_nodes_per_hop, num_sampled_edges_per_hop, x, edge_index, edge_attr, ) # Helper functions ############################################################ def trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor: if layer <= 0: return x return x.narrow( dim=0, start=0, length=x.size(0) - num_samples_per_hop[-layer], ) def trim_adj( edge_index: Adj, layer: int, num_sampled_src_nodes_per_hop: List[int], num_sampled_dst_nodes_per_hop: List[int], num_sampled_edges_per_hop: List[int], ) -> Adj: if layer <= 0: return edge_index if isinstance(edge_index, Tensor): edge_index = edge_index.narrow( dim=1, start=0, length=edge_index.size(1) - num_sampled_edges_per_hop[-layer], ) if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): num_rows, num_cols = edge_index.sparse_size() if num_rows is not None: num_rows -= num_sampled_src_nodes_per_hop[-layer] if num_cols is not None: num_cols -= num_sampled_dst_nodes_per_hop[-layer] edge_index.sparse_resize_(num_rows, num_cols) return edge_index elif isinstance(edge_index, SparseTensor): size = ( edge_index.size(0) - num_sampled_dst_nodes_per_hop[-layer], edge_index.size(1) - num_sampled_src_nodes_per_hop[-layer], ) num_seed_nodes = size[0] - num_sampled_dst_nodes_per_hop[-(layer + 1)] return trim_sparse_tensor(edge_index, size, num_seed_nodes) raise ValueError(f"Unsupported 'edge_index' type '{type(edge_index)}'") def trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int], num_seed_nodes: int) -> SparseTensor: r"""Trims a :class:`SparseTensor` along both dimensions to only contain the upper :obj:`num_nodes` in both dimensions. It is assumed that :class:`SparseTensor` is obtained from BFS traversing, starting from the nodes that have been initially selected. Args: src (SparseTensor): The sparse tensor. size (Tuple[int, int]): The number of source and destination nodes to keep. num_seed_nodes (int): The number of seed nodes to compute representations. """ rowptr, col, value = src.csr() rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone() rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes] col = torch.narrow(col, 0, 0, rowptr[-1]) # type: ignore if value is not None: value = torch.narrow(value, 0, 0, rowptr[-1]) # type: ignore csr2csc = src.storage._csr2csc if csr2csc is not None: csr2csc = csr2csc[csr2csc < len(col)] storage = SparseStorage( row=None, rowptr=rowptr, col=col, value=value, sparse_sizes=size, rowcount=None, colptr=None, colcount=None, csr2csc=csr2csc, csc2csr=None, is_sorted=True, trust_data=True, ) return src.from_storage(storage) ================================================ FILE: torch_geometric/utils/_unbatch.py ================================================ from typing import List, Optional import torch from torch import Tensor from torch_geometric.utils import cumsum, degree def unbatch( src: Tensor, batch: Tensor, dim: int = 0, batch_size: Optional[int] = None, ) -> List[Tensor]: r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension :obj:`dim`. Args: src (Tensor): The source tensor. batch (LongTensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each entry in :obj:`src` to a specific example. Must be ordered. dim (int, optional): The dimension along which to split the :obj:`src` tensor. (default: :obj:`0`) batch_size (int, optional): The batch size. (default: :obj:`None`) :rtype: :class:`List[Tensor]` Example: >>> src = torch.arange(7) >>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2]) >>> unbatch(src, batch) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) """ sizes = degree(batch, batch_size, dtype=torch.long).tolist() return src.split(sizes, dim) def unbatch_edge_index( edge_index: Tensor, batch: Tensor, batch_size: Optional[int] = None, ) -> List[Tensor]: r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. Args: edge_index (Tensor): The edge_index tensor. Must be ordered. batch (LongTensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. batch_size (int, optional): The batch size. (default: :obj:`None`) :rtype: :class:`List[Tensor]` Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6], ... [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]]) >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1]) >>> unbatch_edge_index(edge_index, batch) (tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) """ deg = degree(batch, batch_size, dtype=torch.long) ptr = cumsum(deg) edge_batch = batch[edge_index[0]] edge_index = edge_index - ptr[edge_batch] sizes = degree(edge_batch, batch_size, dtype=torch.long).cpu().tolist() return edge_index.split(sizes, dim=1) ================================================ FILE: torch_geometric/utils/augmentation.py ================================================ from typing import Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.utils import cumsum, negative_sampling, scatter def shuffle_node( x: Tensor, batch: Optional[Tensor] = None, training: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Randomly shuffle the feature matrix :obj:`x` along the first dimension. The method returns (1) the shuffled :obj:`x`, (2) the permutation indicating the orders of original nodes after shuffling. Args: x (FloatTensor): The feature matrix. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. (default: :obj:`None`) training (bool, optional): If set to :obj:`False`, this operation is a no-op. (default: :obj:`True`) :rtype: (:class:`FloatTensor`, :class:`LongTensor`) Example: >>> # Standard case >>> x = torch.tensor([[0, 1, 2], ... [3, 4, 5], ... [6, 7, 8], ... [9, 10, 11]], dtype=torch.float) >>> x, node_perm = shuffle_node(x) >>> x tensor([[ 3., 4., 5.], [ 9., 10., 11.], [ 0., 1., 2.], [ 6., 7., 8.]]) >>> node_perm tensor([1, 3, 0, 2]) >>> # For batched graphs as inputs >>> batch = torch.tensor([0, 0, 1, 1]) >>> x, node_perm = shuffle_node(x, batch) >>> x tensor([[ 3., 4., 5.], [ 0., 1., 2.], [ 9., 10., 11.], [ 6., 7., 8.]]) >>> node_perm tensor([1, 0, 3, 2]) """ if not training: perm = torch.arange(x.size(0), device=x.device) return x, perm if batch is None: perm = torch.randperm(x.size(0), device=x.device) return x[perm], perm num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, reduce='sum') ptr = cumsum(num_nodes) perm = torch.cat([ torch.randperm(n, device=x.device) + offset for offset, n in zip(ptr[:-1], num_nodes) ]) return x[perm], perm def mask_feature( x: Tensor, p: float = 0.5, mode: str = 'col', fill_value: float = 0., training: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Randomly masks feature from the feature matrix :obj:`x` with probability :obj:`p` using samples from a Bernoulli distribution. The method returns (1) the retained :obj:`x`, (2) the feature mask broadcastable with :obj:`x` (:obj:`mode='row'` and :obj:`mode='col'`) or with the same shape as :obj:`x` (:obj:`mode='all'`), indicating where features are retained. Args: x (FloatTensor): The feature matrix. p (float, optional): The masking ratio. (default: :obj:`0.5`) mode (str, optional): The masked scheme to use for feature masking. (:obj:`"row"`, :obj:`"col"` or :obj:`"all"`). If :obj:`mode='col'`, will mask entire features of all nodes from the feature matrix. If :obj:`mode='row'`, will mask entire nodes from the feature matrix. If :obj:`mode='all'`, will mask individual features across all nodes. (default: :obj:`'col'`) fill_value (float, optional): The value for masked features in the output tensor. (default: :obj:`0`) training (bool, optional): If set to :obj:`False`, this operation is a no-op. (default: :obj:`True`) :rtype: (:class:`FloatTensor`, :class:`BoolTensor`) Examples: >>> # Masked features are column-wise sampled >>> x = torch.tensor([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]], dtype=torch.float) >>> x, feat_mask = mask_feature(x) >>> x tensor([[1., 0., 3.], [4., 0., 6.], [7., 0., 9.]]), >>> feat_mask tensor([[True, False, True]]) >>> # Masked features are row-wise sampled >>> x, feat_mask = mask_feature(x, mode='row') >>> x tensor([[1., 2., 3.], [0., 0., 0.], [7., 8., 9.]]), >>> feat_mask tensor([[True], [False], [True]]) >>> # Masked features are uniformly sampled >>> x, feat_mask = mask_feature(x, mode='all') >>> x tensor([[0., 0., 0.], [4., 0., 6.], [0., 0., 9.]]) >>> feat_mask tensor([[False, False, False], [True, False, True], [False, False, True]]) """ if p < 0. or p > 1.: raise ValueError(f'Masking ratio has to be between 0 and 1 ' f'(got {p}') if not training or p == 0.0: return x, torch.ones_like(x, dtype=torch.bool) assert mode in ['row', 'col', 'all'] if mode == 'row': mask = torch.rand(x.size(0), device=x.device) >= p mask = mask.view(-1, 1) elif mode == 'col': mask = torch.rand(x.size(1), device=x.device) >= p mask = mask.view(1, -1) else: mask = torch.rand_like(x) >= p x = x.masked_fill(~mask, fill_value) return x, mask def add_random_edge( edge_index: Tensor, p: float = 0.5, force_undirected: bool = False, num_nodes: Optional[Union[int, Tuple[int, int]]] = None, training: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Randomly adds edges to :obj:`edge_index`. The method returns (1) the retained :obj:`edge_index`, (2) the added edge indices. Args: edge_index (LongTensor): The edge indices. p (float): Ratio of added edges to the existing edges. (default: :obj:`0.5`) force_undirected (bool, optional): If set to :obj:`True`, added edges will be undirected. (default: :obj:`False`) num_nodes (int, Tuple[int], optional): The overall number of nodes, *i.e.* :obj:`max_val + 1`, or the number of source and destination nodes, *i.e.* :obj:`(max_src_val + 1, max_dst_val + 1)` of :attr:`edge_index`. (default: :obj:`None`) training (bool, optional): If set to :obj:`False`, this operation is a no-op. (default: :obj:`True`) :rtype: (:class:`LongTensor`, :class:`LongTensor`) Examples: >>> # Standard case >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5) >>> edge_index tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3], [1, 0, 2, 1, 3, 2, 0, 2, 1]]) >>> added_edges tensor([[2, 1, 3], [0, 2, 1]]) >>> # The returned graph is kept undirected >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5, ... force_undirected=True) >>> edge_index tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3, 0, 2, 1], [1, 0, 2, 1, 3, 2, 0, 2, 1, 2, 1, 3]]) >>> added_edges tensor([[2, 1, 3, 0, 2, 1], [0, 2, 1, 2, 1, 3]]) >>> # For bipartite graphs >>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], ... [2, 3, 1, 4, 2, 1]]) >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5, ... num_nodes=(6, 5)) >>> edge_index tensor([[0, 1, 2, 3, 4, 5, 3, 4, 1], [2, 3, 1, 4, 2, 1, 1, 3, 2]]) >>> added_edges tensor([[3, 4, 1], [1, 3, 2]]) """ if p < 0. or p > 1.: raise ValueError(f"Ratio of added edges has to be between 0 and 1 " f"(got '{p}')") if force_undirected and isinstance(num_nodes, (tuple, list)): raise RuntimeError("'force_undirected' is not supported for " "bipartite graphs") device = edge_index.device if not training or p == 0.0: edge_index_to_add = torch.tensor([[], []], device=device) return edge_index, edge_index_to_add edge_index_to_add = negative_sampling( edge_index=edge_index, num_nodes=num_nodes, num_neg_samples=round(edge_index.size(1) * p), force_undirected=force_undirected, ) edge_index = torch.cat([edge_index, edge_index_to_add], dim=1) return edge_index, edge_index_to_add ================================================ FILE: torch_geometric/utils/convert.py ================================================ from collections import defaultdict from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union import torch from torch import Tensor from torch.utils.dlpack import from_dlpack, to_dlpack import torch_geometric from torch_geometric.utils.num_nodes import maybe_num_nodes def to_scipy_sparse_matrix( edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Any: r"""Converts a graph given by edge indices and edge attributes to a scipy sparse matrix. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> to_scipy_sparse_matrix(edge_index) <4x4 sparse matrix of type '' with 6 stored elements in COOrdinate format> """ import scipy.sparse as sp row, col = edge_index.cpu() if edge_attr is None: edge_attr = torch.ones(row.size(0), device="cpu") else: edge_attr = edge_attr.view(-1).cpu() assert edge_attr.size(0) == row.size(0) N = maybe_num_nodes(edge_index, num_nodes) out = sp.coo_matrix( # (edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N)) return out def from_scipy_sparse_matrix(A: Any) -> Tuple[Tensor, Tensor]: r"""Converts a scipy sparse matrix to edge indices and edge attributes. Args: A (scipy.sparse): A sparse matrix. Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> adj = to_scipy_sparse_matrix(edge_index) >>> # `edge_index` and `edge_weight` are both returned >>> from_scipy_sparse_matrix(adj) (tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), tensor([1., 1., 1., 1., 1., 1.])) """ A = A.tocoo() row = torch.from_numpy(A.row).to(torch.long) col = torch.from_numpy(A.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) edge_weight = torch.from_numpy(A.data) return edge_index, edge_weight def to_networkx( data: Union[ 'torch_geometric.data.Data', 'torch_geometric.data.HeteroData', ], node_attrs: Optional[Iterable[str]] = None, edge_attrs: Optional[Iterable[str]] = None, graph_attrs: Optional[Iterable[str]] = None, to_undirected: Optional[Union[bool, str]] = False, to_multi: bool = False, remove_self_loops: bool = False, ) -> Any: r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or a directed :obj:`networkx.DiGraph` otherwise. Args: data (torch_geometric.data.Data or torch_geometric.data.HeteroData): A homogeneous or heterogeneous data object. node_attrs (iterable of str, optional): The node attributes to be copied. (default: :obj:`None`) edge_attrs (iterable of str, optional): The edge attributes to be copied. (default: :obj:`None`) graph_attrs (iterable of str, optional): The graph attributes to be copied. (default: :obj:`None`) to_undirected (bool or str, optional): If set to :obj:`True`, will return a :class:`networkx.Graph` instead of a :class:`networkx.DiGraph`. By default, will include all edges and make them undirected. If set to :obj:`"upper"`, the undirected graph will only correspond to the upper triangle of the input adjacency matrix. If set to :obj:`"lower"`, the undirected graph will only correspond to the lower triangle of the input adjacency matrix. Only applicable in case the :obj:`data` object holds a homogeneous graph. (default: :obj:`False`) to_multi (bool, optional): if set to :obj:`True`, will return a :class:`networkx.MultiGraph` or a :class:`networkx:MultiDiGraph` (depending on the :obj:`to_undirected` option), which will not drop duplicated edges that may exist in :obj:`data`. (default: :obj:`False`) remove_self_loops (bool, optional): If set to :obj:`True`, will not include self-loops in the resulting graph. (default: :obj:`False`) Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> data = Data(edge_index=edge_index, num_nodes=4) >>> to_networkx(data) """ import networkx as nx from torch_geometric.data import HeteroData to_undirected_upper: bool = to_undirected == 'upper' to_undirected_lower: bool = to_undirected == 'lower' to_undirected = to_undirected is True to_undirected |= to_undirected_upper or to_undirected_lower assert isinstance(to_undirected, bool) if isinstance(data, HeteroData) and to_undirected: raise ValueError("'to_undirected' is not supported in " "'to_networkx' for heterogeneous graphs") if to_undirected: G = nx.MultiGraph() if to_multi else nx.Graph() else: G = nx.MultiDiGraph() if to_multi else nx.DiGraph() def to_networkx_value(value: Any) -> Any: return value.tolist() if isinstance(value, Tensor) else value for key in graph_attrs or []: G.graph[key] = to_networkx_value(data[key]) node_offsets = data.node_offsets for node_store in data.node_stores: start = node_offsets[node_store._key] assert node_store.num_nodes is not None for i in range(node_store.num_nodes): node_kwargs: Dict[str, Any] = {} if isinstance(data, HeteroData): node_kwargs['type'] = node_store._key for key in node_attrs or []: node_kwargs[key] = to_networkx_value(node_store[key][i]) G.add_node(start + i, **node_kwargs) for edge_store in data.edge_stores: for i, (v, w) in enumerate(edge_store.edge_index.t().tolist()): if to_undirected_upper and v > w: continue elif to_undirected_lower and v < w: continue elif remove_self_loops and v == w and not edge_store.is_bipartite( ): continue edge_kwargs: Dict[str, Any] = {} if isinstance(data, HeteroData): v = v + node_offsets[edge_store._key[0]] w = w + node_offsets[edge_store._key[-1]] edge_kwargs['type'] = edge_store._key for key in edge_attrs or []: edge_kwargs[key] = to_networkx_value(edge_store[key][i]) G.add_edge(v, w, **edge_kwargs) return G def from_networkx( G: Any, group_node_attrs: Optional[Union[List[str], Literal['all']]] = None, group_edge_attrs: Optional[Union[List[str], Literal['all']]] = None, ) -> 'torch_geometric.data.Data': r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a :class:`torch_geometric.data.Data` instance. Args: G (networkx.Graph or networkx.DiGraph): A networkx graph. group_node_attrs (List[str] or "all", optional): The node attributes to be concatenated and added to :obj:`data.x`. (default: :obj:`None`) group_edge_attrs (List[str] or "all", optional): The edge attributes to be concatenated and added to :obj:`data.edge_attr`. (default: :obj:`None`) .. note:: All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must be numeric. Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> data = Data(edge_index=edge_index, num_nodes=4) >>> g = to_networkx(data) >>> # A `Data` object is returned >>> from_networkx(g) Data(edge_index=[2, 6], num_nodes=4) """ import networkx as nx from torch_geometric.data import Data G = G.to_directed() if not nx.is_directed(G) else G mapping = dict(zip(G.nodes(), range(G.number_of_nodes()))) edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long) for i, (src, dst) in enumerate(G.edges()): edge_index[0, i] = mapping[src] edge_index[1, i] = mapping[dst] data_dict: Dict[str, Any] = defaultdict(list) data_dict['edge_index'] = edge_index node_attrs: List[str] = [] if G.number_of_nodes() > 0: node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys()) edge_attrs: List[str] = [] if G.number_of_edges() > 0: edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys()) if group_node_attrs is not None and not isinstance(group_node_attrs, list): group_node_attrs = node_attrs if group_edge_attrs is not None and not isinstance(group_edge_attrs, list): group_edge_attrs = edge_attrs for _, feat_dict in G.nodes(data=True): if set(feat_dict.keys()) != set(node_attrs): raise ValueError('Not all nodes contain the same attributes') for key, value in feat_dict.items(): data_dict[str(key)].append(value) for _, _, feat_dict in G.edges(data=True): if set(feat_dict.keys()) != set(edge_attrs): raise ValueError('Not all edges contain the same attributes') for key, value in feat_dict.items(): key = f'edge_{key}' if key in node_attrs else key data_dict[str(key)].append(value) for key, value in G.graph.items(): if key == 'node_default' or key == 'edge_default': continue # Do not load default attributes. key = f'graph_{key}' if key in node_attrs else key data_dict[str(key)] = value for key, value in data_dict.items(): if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor): data_dict[key] = torch.stack(value, dim=0) else: try: data_dict[key] = torch.as_tensor(value) except Exception: pass data = Data.from_dict(data_dict) if group_node_attrs is not None: xs = [] for key in group_node_attrs: x = data[key] x = x.view(-1, 1) if x.dim() <= 1 else x xs.append(x) del data[key] data.x = torch.cat(xs, dim=-1) if group_edge_attrs is not None: xs = [] for key in group_edge_attrs: key = f'edge_{key}' if key in node_attrs else key x = data[key] x = x.view(-1, 1) if x.dim() <= 1 else x xs.append(x) del data[key] data.edge_attr = torch.cat(xs, dim=-1) if data.x is None and data.pos is None: data.num_nodes = G.number_of_nodes() return data def to_networkit( edge_index: Tensor, edge_weight: Optional[Tensor] = None, num_nodes: Optional[int] = None, directed: bool = True, ) -> Any: r"""Converts a :obj:`(edge_index, edge_weight)` tuple to a :class:`networkit.Graph`. Args: edge_index (torch.Tensor): The edge indices of the graph. edge_weight (torch.Tensor, optional): The edge weights of the graph. (default: :obj:`None`) num_nodes (int, optional): The number of nodes in the graph. (default: :obj:`None`) directed (bool, optional): If set to :obj:`False`, the graph will be undirected. (default: :obj:`True`) """ import networkit as nk num_nodes = maybe_num_nodes(edge_index, num_nodes) g = nk.graph.Graph( num_nodes, weighted=edge_weight is not None, directed=directed, ) if edge_weight is None: edge_weight = torch.ones(edge_index.size(1)) if not directed: mask = edge_index[0] <= edge_index[1] edge_index = edge_index[:, mask] edge_weight = edge_weight[mask] for (u, v), w in zip(edge_index.t().tolist(), edge_weight.tolist()): g.addEdge(u, v, w) return g def from_networkit(g: Any) -> Tuple[Tensor, Optional[Tensor]]: r"""Converts a :class:`networkit.Graph` to a :obj:`(edge_index, edge_weight)` tuple. If the :class:`networkit.Graph` is not weighted, the returned :obj:`edge_weight` will be :obj:`None`. Args: g (networkkit.graph.Graph): A :obj:`networkit` graph object. """ is_directed = g.isDirected() is_weighted = g.isWeighted() edge_indices, edge_weights = [], [] for u, v, w in g.iterEdgesWeights(): edge_indices.append([u, v]) edge_weights.append(w) if not is_directed: edge_indices.append([v, u]) edge_weights.append(w) edge_index = torch.tensor(edge_indices).t().contiguous() edge_weight = torch.tensor(edge_weights) if is_weighted else None return edge_index, edge_weight def to_trimesh(data: 'torch_geometric.data.Data') -> Any: r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`trimesh.Trimesh`. Args: data (torch_geometric.data.Data): The data object. Example: >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], ... dtype=torch.float) >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t() >>> data = Data(pos=pos, face=face) >>> to_trimesh(data) """ import trimesh assert data.pos is not None assert data.face is not None return trimesh.Trimesh( vertices=data.pos.detach().cpu().numpy(), faces=data.face.detach().t().cpu().numpy(), process=False, ) def from_trimesh(mesh: Any) -> 'torch_geometric.data.Data': r"""Converts a :obj:`trimesh.Trimesh` to a :class:`torch_geometric.data.Data` instance. Args: mesh (trimesh.Trimesh): A :obj:`trimesh` mesh. Example: >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], ... dtype=torch.float) >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t() >>> data = Data(pos=pos, face=face) >>> mesh = to_trimesh(data) >>> from_trimesh(mesh) Data(pos=[4, 3], face=[3, 2]) """ from torch_geometric.data import Data pos = torch.from_numpy(mesh.vertices).to(torch.float) face = torch.from_numpy(mesh.faces).t().contiguous() return Data(pos=pos, face=face) def to_cugraph( edge_index: Tensor, edge_weight: Optional[Tensor] = None, relabel_nodes: bool = True, directed: bool = True, ) -> Any: r"""Converts a graph given by :obj:`edge_index` and optional :obj:`edge_weight` into a :obj:`cugraph` graph object. Args: edge_index (torch.Tensor): The edge indices of the graph. edge_weight (torch.Tensor, optional): The edge weights of the graph. (default: :obj:`None`) relabel_nodes (bool, optional): If set to :obj:`True`, :obj:`cugraph` will remove any isolated nodes, leading to a relabeling of nodes. (default: :obj:`True`) directed (bool, optional): If set to :obj:`False`, the graph will be undirected. (default: :obj:`True`) """ import cudf import cugraph g = cugraph.Graph(directed=directed) df = cudf.from_dlpack(to_dlpack(edge_index.t())) df = cudf.DataFrame({ 'source': cudf.from_dlpack(to_dlpack(edge_index[0])), 'destination': cudf.from_dlpack(to_dlpack(edge_index[1])), }) if edge_weight is not None: assert edge_weight.dim() == 1 df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight)) g.from_cudf_edgelist( df, source='source', destination='destination', edge_attr='weight' if edge_weight is not None else None, renumber=relabel_nodes, ) return g def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]: r"""Converts a :obj:`cugraph` graph object into :obj:`edge_index` and optional :obj:`edge_weight` tensors. Args: g (cugraph.Graph): A :obj:`cugraph` graph object. """ df = g.view_edge_list() src = from_dlpack(df[g.source_columns].to_dlpack()).long() dst = from_dlpack(df[g.destination_columns].to_dlpack()).long() edge_index = torch.stack([src, dst], dim=0) edge_weight = None if g.weight_column is not None: edge_weight = from_dlpack(df[g.weight_column].to_dlpack()) return edge_index, edge_weight def to_dgl( data: Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData'] ) -> Any: r"""Converts a :class:`torch_geometric.data.Data` or :class:`torch_geometric.data.HeteroData` instance to a :obj:`dgl` graph object. Args: data (torch_geometric.data.Data or torch_geometric.data.HeteroData): The data object. Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]]) >>> x = torch.randn(5, 3) >>> edge_attr = torch.randn(6, 2) >>> data = Data(x=x, edge_index=edge_index, edge_attr=y) >>> g = to_dgl(data) >>> g Graph(num_nodes=5, num_edges=6, ndata_schemes={'x': Scheme(shape=(3,))} edata_schemes={'edge_attr': Scheme(shape=(2, ))}) >>> data = HeteroData() >>> data['paper'].x = torch.randn(5, 3) >>> data['author'].x = torch.ones(5, 3) >>> edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) >>> data['author', 'cites', 'paper'].edge_index = edge_index >>> g = to_dgl(data) >>> g Graph(num_nodes={'author': 5, 'paper': 5}, num_edges={('author', 'cites', 'paper'): 5}, metagraph=[('author', 'paper', 'cites')]) """ import dgl from torch_geometric.data import Data, HeteroData if isinstance(data, Data): if data.edge_index is not None: row, col = data.edge_index elif 'adj' in data: row, col, _ = data.adj.coo() elif 'adj_t' in data: row, col, _ = data.adj_t.t().coo() else: row, col = [], [] g = dgl.graph((row, col), num_nodes=data.num_nodes) for attr in data.node_attrs(): g.ndata[attr] = data[attr] for attr in data.edge_attrs(): if attr in ['edge_index', 'adj_t']: continue g.edata[attr] = data[attr] return g if isinstance(data, HeteroData): data_dict = {} for edge_type, edge_store in data.edge_items(): if edge_store.get('edge_index') is not None: row, col = edge_store.edge_index else: row, col, _ = edge_store['adj_t'].t().coo() data_dict[edge_type] = (row, col) g = dgl.heterograph(data_dict) for node_type, node_store in data.node_items(): for attr, value in node_store.items(): g.nodes[node_type].data[attr] = value for edge_type, edge_store in data.edge_items(): for attr, value in edge_store.items(): if attr in ['edge_index', 'adj_t']: continue g.edges[edge_type].data[attr] = value return g raise ValueError(f"Invalid data type (got '{type(data)}')") def from_dgl( g: Any, ) -> Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData']: r"""Converts a :obj:`dgl` graph object to a :class:`torch_geometric.data.Data` or :class:`torch_geometric.data.HeteroData` instance. Args: g (dgl.DGLGraph): The :obj:`dgl` graph object. Example: >>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) >>> g.ndata['x'] = torch.randn(g.num_nodes(), 3) >>> g.edata['edge_attr'] = torch.randn(g.num_edges(), 2) >>> data = from_dgl(g) >>> data Data(x=[6, 3], edge_attr=[4, 2], edge_index=[2, 4]) >>> g = dgl.heterograph({ >>> g = dgl.heterograph({ ... ('author', 'writes', 'paper'): ([0, 1, 1, 2, 3, 3, 4], ... [0, 0, 1, 1, 1, 2, 2])}) >>> g.nodes['author'].data['x'] = torch.randn(5, 3) >>> g.nodes['paper'].data['x'] = torch.randn(5, 3) >>> data = from_dgl(g) >>> data HeteroData( author={ x=[5, 3] }, paper={ x=[3, 3] }, (author, writes, paper)={ edge_index=[2, 7] } ) """ import dgl from torch_geometric.data import Data, HeteroData if not isinstance(g, dgl.DGLGraph): raise ValueError(f"Invalid data type (got '{type(g)}')") data: Union[Data, HeteroData] if g.is_homogeneous: data = Data() data.edge_index = torch.stack(g.edges(), dim=0) for attr, value in g.ndata.items(): data[attr] = value for attr, value in g.edata.items(): data[attr] = value return data data = HeteroData() for node_type in g.ntypes: for attr, value in g.nodes[node_type].data.items(): data[node_type][attr] = value for edge_type in g.canonical_etypes: row, col = g.edges(form="uv", etype=edge_type) data[edge_type].edge_index = torch.stack([row, col], dim=0) for attr, value in g.edge_attr_schemes(edge_type).items(): data[edge_type][attr] = value return data ================================================ FILE: torch_geometric/utils/cross_entropy.py ================================================ from typing import Any, Optional, Tuple import torch from torch import Tensor from torch_geometric.utils import scatter class SparseCrossEntropy(torch.autograd.Function): # We implement our own custom autograd function for this to avoid the # double gradient computation to `inputs`. @staticmethod def forward( ctx: Any, inputs: Tensor, edge_label_index: Tensor, edge_label_weight: Optional[Tensor], ) -> Tensor: assert inputs.dim() == 2 # Support for both positive and negative weights: # Positive weights scale the logits *after* softmax. # Negative weights scale the denominator *before* softmax: pos_y = edge_label_index neg_y = pos_weight = neg_weight = None if edge_label_weight is not None: pos_mask = edge_label_weight >= 0 pos_y = edge_label_index[:, pos_mask] pos_weight = edge_label_weight[pos_mask] if pos_y.size(1) < edge_label_index.size(1): neg_mask = ~pos_mask neg_y = edge_label_index[:, neg_mask] neg_weight = edge_label_weight[neg_mask] if neg_y is not None and neg_weight is not None: inputs = inputs.clone() inputs[ neg_y[0], neg_y[1], ] += neg_weight.abs().log().clamp(min=1e-12) logsumexp = inputs.logsumexp(dim=-1) ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp) out = inputs[pos_y[0], pos_y[1]] out.neg_().add_(logsumexp[pos_y[0]]) if pos_weight is not None: out *= pos_weight return out.sum() / inputs.size(0) @staticmethod @torch.autograd.function.once_differentiable def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]: inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors grad_out = grad_out / inputs.size(0) grad_out = grad_out.expand(pos_y.size(1)) if pos_weight is not None: grad_out = grad_out * pos_weight grad_logsumexp = scatter(grad_out, pos_y[0], dim=0, dim_size=inputs.size(0), reduce='sum') # Gradient computation of `logsumexp`: `grad * (self - result).exp()` grad_input = (inputs - logsumexp.view(-1, 1)) grad_input.exp_() grad_input.mul_(grad_logsumexp.view(-1, 1)) grad_input[pos_y[0], pos_y[1]] -= grad_out return grad_input, None, None def sparse_cross_entropy( inputs: Tensor, edge_label_index: Tensor, edge_label_weight: Optional[Tensor] = None, ) -> Tensor: r"""A sparse-label variant of :func:`torch.nn.functional.cross_entropy`. In particular, the binary target matrix is solely given by sparse indices :obj:`edge_label_index`. Args: inputs (torch.Tensor): The predicted unnormalized logits of shape :obj:`[batch_size, num_classes]`. edge_label_index (torch.Tensor): The sparse ground-truth indices with shape :obj:`[2, num_labels]`. edge_label_weight (torch.Tensor, optional): The weight of ground-truth indices with shape :obj:`[num_labels]`. (default: :obj:`None`) :rtype: :class:`torch.Tensor` Example: >>> inputs = torch.randn(2, 3) >>> edge_label_index = torch.tensor([ ... [0, 0, 1], ... [0, 1, 2], ... ]) >>> loss = sparse_cross_entropy(inputs, edge_label_index) tensor(1.2919) """ if edge_label_weight is not None: assert not edge_label_weight.requires_grad return SparseCrossEntropy.apply( inputs, edge_label_index, edge_label_weight, ) ================================================ FILE: torch_geometric/utils/dropout.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor import torch_geometric.typing from torch_geometric import is_compiling from torch_geometric.deprecation import deprecated from torch_geometric.typing import OptTensor from torch_geometric.utils import cumsum, degree, sort_edge_index, subgraph from torch_geometric.utils.num_nodes import maybe_num_nodes def filter_adj(row: Tensor, col: Tensor, edge_attr: OptTensor, mask: Tensor) -> Tuple[Tensor, Tensor, OptTensor]: return row[mask], col[mask], None if edge_attr is None else edge_attr[mask] @deprecated("use 'dropout_edge' instead") def dropout_adj( edge_index: Tensor, edge_attr: OptTensor = None, p: float = 0.5, force_undirected: bool = False, num_nodes: Optional[int] = None, training: bool = True, ) -> Tuple[Tensor, OptTensor]: r"""Randomly drops edges from the adjacency matrix :obj:`(edge_index, edge_attr)` with probability :obj:`p` using samples from a Bernoulli distribution. .. warning:: :class:`~torch_geometric.utils.dropout_adj` is deprecated and will be removed in a future release. Use :class:`torch_geometric.utils.dropout_edge` instead. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) p (float, optional): Dropout probability. (default: :obj:`0.5`) force_undirected (bool, optional): If set to :obj:`True`, will either drop or keep both edges of an undirected edge. (default: :obj:`False`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) training (bool, optional): If set to :obj:`False`, this operation is a no-op. (default: :obj:`True`) Examples: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6]) >>> dropout_adj(edge_index, edge_attr) (tensor([[0, 1, 2, 3], [1, 2, 3, 2]]), tensor([1, 3, 5, 6])) >>> # The returned graph is kept undirected >>> dropout_adj(edge_index, edge_attr, force_undirected=True) (tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]]), tensor([1, 3, 5, 1, 3, 5])) """ if p < 0. or p > 1.: raise ValueError(f'Dropout probability has to be between 0 and 1 ' f'(got {p}') if not training or p == 0.0: return edge_index, edge_attr row, col = edge_index mask = torch.rand(row.size(0), device=edge_index.device) >= p if force_undirected: mask[row > col] = False row, col, edge_attr = filter_adj(row, col, edge_attr, mask) if force_undirected: edge_index = torch.stack( [torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)], dim=0) if edge_attr is not None: edge_attr = torch.cat([edge_attr, edge_attr], dim=0) else: edge_index = torch.stack([row, col], dim=0) return edge_index, edge_attr def dropout_node( edge_index: Tensor, p: float = 0.5, num_nodes: Optional[int] = None, training: bool = True, relabel_nodes: bool = False, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Randomly drops nodes from the adjacency matrix :obj:`edge_index` with probability :obj:`p` using samples from a Bernoulli distribution. The method returns (1) the retained :obj:`edge_index`, (2) the edge mask indicating which edges were retained. (3) the node mask indicating which nodes were retained. Args: edge_index (LongTensor): The edge indices. p (float, optional): Dropout probability. (default: :obj:`0.5`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) training (bool, optional): If set to :obj:`False`, this operation is a no-op. (default: :obj:`True`) relabel_nodes (bool, optional): If set to `True`, the resulting `edge_index` will be relabeled to hold consecutive indices starting from zero. :rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`) Examples: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, edge_mask, node_mask = dropout_node(edge_index) >>> edge_index tensor([[0, 1], [1, 0]]) >>> edge_mask tensor([ True, True, False, False, False, False]) >>> node_mask tensor([ True, True, False, False]) """ if p < 0. or p > 1.: raise ValueError(f'Dropout probability has to be between 0 and 1 ' f'(got {p}') num_nodes = maybe_num_nodes(edge_index, num_nodes) if not training or p == 0.0: node_mask = edge_index.new_ones(num_nodes, dtype=torch.bool) edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool) return edge_index, edge_mask, node_mask prob = torch.rand(num_nodes, device=edge_index.device) node_mask = prob > p edge_index, _, edge_mask = subgraph( node_mask, edge_index, relabel_nodes=relabel_nodes, num_nodes=num_nodes, return_edge_mask=True, ) return edge_index, edge_mask, node_mask def dropout_edge(edge_index: Tensor, p: float = 0.5, force_undirected: bool = False, training: bool = True) -> Tuple[Tensor, Tensor]: r"""Randomly drops edges from the adjacency matrix :obj:`edge_index` with probability :obj:`p` using samples from a Bernoulli distribution. The method returns (1) the retained :obj:`edge_index`, (2) the edge mask or index indicating which edges were retained, depending on the argument :obj:`force_undirected`. Args: edge_index (LongTensor): The edge indices. p (float, optional): Dropout probability. (default: :obj:`0.5`) force_undirected (bool, optional): If set to :obj:`True`, will either drop or keep both edges of an undirected edge. (default: :obj:`False`) training (bool, optional): If set to :obj:`False`, this operation is a no-op. (default: :obj:`True`) :rtype: (:class:`LongTensor`, :class:`BoolTensor` or :class:`LongTensor`) Examples: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, edge_mask = dropout_edge(edge_index) >>> edge_index tensor([[0, 1, 2, 2], [1, 2, 1, 3]]) >>> edge_mask # masks indicating which edges are retained tensor([ True, False, True, True, True, False]) >>> edge_index, edge_id = dropout_edge(edge_index, ... force_undirected=True) >>> edge_index tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]]) >>> edge_id # indices indicating which edges are retained tensor([0, 2, 4, 0, 2, 4]) """ if p < 0. or p > 1.: raise ValueError(f'Dropout probability has to be between 0 and 1 ' f'(got {p}') if not training or p == 0.0: edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool) return edge_index, edge_mask row, col = edge_index edge_mask = torch.rand(row.size(0), device=edge_index.device) >= p if force_undirected: edge_mask[row > col] = False edge_index = edge_index[:, edge_mask] if force_undirected: edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) edge_mask = edge_mask.nonzero().repeat((2, 1)).squeeze() return edge_index, edge_mask def dropout_path(edge_index: Tensor, p: float = 0.2, walks_per_node: int = 1, walk_length: int = 3, num_nodes: Optional[int] = None, is_sorted: bool = False, training: bool = True) -> Tuple[Tensor, Tensor]: r"""Drops edges from the adjacency matrix :obj:`edge_index` based on random walks. The source nodes to start random walks from are sampled from :obj:`edge_index` with probability :obj:`p`, following a Bernoulli distribution. The method returns (1) the retained :obj:`edge_index`, (2) the edge mask indicating which edges were retained. Args: edge_index (LongTensor): The edge indices. p (float, optional): Sample probability. (default: :obj:`0.2`) walks_per_node (int, optional): The number of walks per node, same as :class:`~torch_geometric.nn.models.Node2Vec`. (default: :obj:`1`) walk_length (int, optional): The walk length, same as :class:`~torch_geometric.nn.models.Node2Vec`. (default: :obj:`3`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) is_sorted (bool, optional): If set to :obj:`True`, will expect :obj:`edge_index` to be already sorted row-wise. (default: :obj:`False`) training (bool, optional): If set to :obj:`False`, this operation is a no-op. (default: :obj:`True`) :rtype: (:class:`LongTensor`, :class:`BoolTensor`) Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, edge_mask = dropout_path(edge_index) >>> edge_index tensor([[1, 2], [2, 3]]) >>> edge_mask # masks indicating which edges are retained tensor([False, False, True, False, True, False]) """ if p < 0. or p > 1.: raise ValueError(f'Sample probability has to be between 0 and 1 ' f'(got {p}') num_edges = edge_index.size(1) edge_mask = edge_index.new_ones(num_edges, dtype=torch.bool) if not training or p == 0.0: return edge_index, edge_mask if not torch_geometric.typing.WITH_TORCH_CLUSTER or is_compiling(): raise ImportError('`dropout_path` requires `torch-cluster`.') num_nodes = maybe_num_nodes(edge_index, num_nodes) edge_orders = None ori_edge_index = edge_index if not is_sorted: edge_orders = torch.arange(num_edges, device=edge_index.device) edge_index, edge_orders = sort_edge_index(edge_index, edge_orders, num_nodes=num_nodes) row, col = edge_index sample_mask = torch.rand(row.size(0), device=edge_index.device) <= p start = row[sample_mask].repeat(walks_per_node) rowptr = cumsum(degree(row, num_nodes=num_nodes, dtype=torch.long)) n_id, e_id = torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length, 1.0, 1.0) e_id = e_id[e_id != -1].view(-1) # filter illegal edges if edge_orders is not None: # Permute edge indices: e_id = edge_orders[e_id] edge_mask[e_id] = False edge_index = ori_edge_index[:, edge_mask] return edge_index, edge_mask ================================================ FILE: torch_geometric/utils/embedding.py ================================================ import warnings from typing import Any, Dict, List, Optional, Type import torch from torch import Tensor from torch_geometric.typing import NodeType def get_embeddings( model: torch.nn.Module, *args: Any, **kwargs: Any, ) -> List[Tensor]: """Returns the output embeddings of all :class:`~torch_geometric.nn.conv.MessagePassing` layers in :obj:`model`. Internally, this method registers forward hooks on all :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`, and runs the forward pass of the :obj:`model` by calling :obj:`model(*args, **kwargs)`. Args: model (torch.nn.Module): The message passing model. *args: Arguments passed to the model. **kwargs (optional): Additional keyword arguments passed to the model. """ from torch_geometric.nn import MessagePassing embeddings: List[Tensor] = [] def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None: # Clone output in case it will be later modified in-place: outputs = outputs[0] if isinstance(outputs, tuple) else outputs assert isinstance(outputs, Tensor) embeddings.append(outputs.clone()) hook_handles = [] for module in model.modules(): # Register forward hooks: if isinstance(module, MessagePassing): hook_handles.append(module.register_forward_hook(hook)) if len(hook_handles) == 0: warnings.warn("The 'model' does not have any 'MessagePassing' layers", stacklevel=2) training = model.training model.eval() with torch.no_grad(): model(*args, **kwargs) model.train(training) for handle in hook_handles: # Remove hooks: handle.remove() return embeddings def get_embeddings_hetero( model: torch.nn.Module, supported_models: Optional[List[Type[torch.nn.Module]]] = None, *args: Any, **kwargs: Any, ) -> Dict[NodeType, List[Tensor]]: """Returns the output embeddings of all :class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous :obj:`model`, organized by edge type. Internally, this method registers forward hooks on all modules that process heterogeneous graphs in the model and runs the forward pass of the model. For heterogeneous models, the output is a dictionary where each key is a node type and each value is a list of embeddings from different layers. Args: model (torch.nn.Module): The heterogeneous GNN model. supported_models (List[Type[torch.nn.Module]], optional): A list of supported model classes. If not provided, defaults to [HGTConv, HANConv, HeteroConv]. *args: Arguments passed to the model. **kwargs (optional): Additional keyword arguments passed to the model. Returns: Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to a list of embeddings from different layers. """ from torch_geometric.nn import HANConv, HeteroConv, HGTConv if not supported_models: supported_models = [HGTConv, HANConv, HeteroConv] # Dictionary to store node embeddings by type node_embeddings_dict: Dict[NodeType, List[Tensor]] = {} # Hook function to capture node embeddings def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None: # Check if the outputs is a dictionary mapping node types to embeddings if isinstance(outputs, dict) and outputs: # Store embeddings for each node type for node_type, embedding in outputs.items(): # Made sure that the outputs are a dictionary mapping node # types to embeddings and remove the false positives. if node_type not in node_embeddings_dict: node_embeddings_dict[node_type] = [] node_embeddings_dict[node_type].append(embedding.clone()) # List to store hook handles hook_handles = [] # Find ModuleDict objects in the model for _, module in model.named_modules(): # Handle the native heterogenous models, e.g. HGTConv, HANConv # and HeteroConv, etc. if isinstance(module, tuple(supported_models)): hook_handles.append(module.register_forward_hook(hook)) else: # Handle the heterogenous models that are generated by calling # to_hetero() on the homogeneous models. submodules = list(module.children()) submodules_contains_module_dict = any([ isinstance(submodule, torch.nn.ModuleDict) for submodule in submodules ]) if submodules_contains_module_dict: hook_handles.append(module.register_forward_hook(hook)) if len(hook_handles) == 0: warnings.warn( "The 'model' does not have any heterogenous " "'MessagePassing' layers", stacklevel=2) # Run the model forward pass training = model.training model.eval() with torch.no_grad(): model(*args, **kwargs) model.train(training) # Clean up hooks for handle in hook_handles: handle.remove() return node_embeddings_dict ================================================ FILE: torch_geometric/utils/functions.py ================================================ import torch from torch import Tensor def cumsum(x: Tensor, dim: int = 0) -> Tensor: r"""Returns the cumulative sum of elements of :obj:`x`. In contrast to :meth:`torch.cumsum`, prepends the output with zero. Args: x (torch.Tensor): The input tensor. dim (int, optional): The dimension to do the operation over. (default: :obj:`0`) Example: >>> x = torch.tensor([2, 4, 1]) >>> cumsum(x) tensor([0, 2, 6, 7]) """ size = x.size()[:dim] + (x.size(dim) + 1, ) + x.size()[dim + 1:] out = x.new_empty(size) out.narrow(dim, 0, 1).zero_() torch.cumsum(x, dim=dim, out=out.narrow(dim, 1, x.size(dim))) return out ================================================ FILE: torch_geometric/utils/geodesic.py ================================================ import multiprocessing as mp import warnings from typing import Optional import numpy as np import torch from torch import Tensor def geodesic_distance( # noqa: D417 pos: Tensor, face: Tensor, src: Optional[Tensor] = None, dst: Optional[Tensor] = None, norm: bool = True, max_distance: Optional[float] = None, num_workers: int = 0, # Backward compatibility for `dest`: **kwargs: Optional[Tensor], ) -> Tensor: r"""Computes (normalized) geodesic distances of a mesh given by :obj:`pos` and :obj:`face`. If :obj:`src` and :obj:`dst` are given, this method only computes the geodesic distances for the respective source and target node-pairs. .. note:: This function requires the :obj:`gdist` package. To install, run :obj:`pip install cython && pip install gdist`. Args: pos (torch.Tensor): The node positions. face (torch.Tensor): The face indices. src (torch.Tensor, optional): If given, only compute geodesic distances for the specified source indices. (default: :obj:`None`) dst (torch.Tensor, optional): If given, only compute geodesic distances for the specified target indices. (default: :obj:`None`) norm (bool, optional): Normalizes geodesic distances by :math:`\sqrt{\textrm{area}(\mathcal{M})}`. (default: :obj:`True`) max_distance (float, optional): If given, only yields results for geodesic distances less than :obj:`max_distance`. This will speed up runtime dramatically. (default: :obj:`None`) num_workers (int, optional): How many subprocesses to use for calculating geodesic distances. :obj:`0` means that computation takes place in the main process. :obj:`-1` means that the available amount of CPU cores is used. (default: :obj:`0`) :rtype: :class:`Tensor` Example: >>> pos = torch.tensor([[0.0, 0.0, 0.0], ... [2.0, 0.0, 0.0], ... [0.0, 2.0, 0.0], ... [2.0, 2.0, 0.0]]) >>> face = torch.tensor([[0, 0], ... [1, 2], ... [3, 3]]) >>> geodesic_distance(pos, face) [[0, 1, 1, 1.4142135623730951], [1, 0, 1.4142135623730951, 1], [1, 1.4142135623730951, 0, 1], [1.4142135623730951, 1, 1, 0]] """ import gdist if 'dest' in kwargs: dst = kwargs['dest'] warnings.warn( "'dest' attribute in 'geodesic_distance' is deprecated " "and will be removed in a future release. Use the 'dst' " "argument instead.", stacklevel=2) max_distance = float('inf') if max_distance is None else max_distance if norm: area = (pos[face[1]] - pos[face[0]]).cross( pos[face[2]] - pos[face[0]], dim=1, ) scale = float((area.norm(p=2, dim=1) / 2).sum().sqrt()) else: scale = 1.0 dtype = pos.dtype pos_np = pos.detach().cpu().to(torch.double).numpy() face_np = face.detach().t().cpu().to(torch.int).numpy() if src is None and dst is None: out = gdist.local_gdist_matrix( pos_np, face_np, max_distance * scale, ).toarray() / scale return torch.from_numpy(out).to(dtype) if src is None: src_np = torch.arange(pos.size(0), dtype=torch.int).numpy() else: src_np = src.detach().cpu().to(torch.int).numpy() dst_np = None if dst is None else dst.detach().cpu().to(torch.int).numpy() def _parallel_loop( pos_np: np.ndarray, face_np: np.ndarray, src_np: np.ndarray, dst_np: Optional[np.ndarray], max_distance: float, scale: float, i: int, dtype: torch.dtype, ) -> Tensor: s = src_np[i:i + 1] d = None if dst_np is None else dst_np[i:i + 1] out = gdist.compute_gdist(pos_np, face_np, s, d, max_distance * scale) out = out / scale return torch.from_numpy(out).to(dtype) num_workers = mp.cpu_count() if num_workers <= -1 else num_workers if num_workers > 0: with mp.Pool(num_workers) as pool: data = [(pos_np, face_np, src_np, dst_np, max_distance, scale, i, dtype) for i in range(len(src_np))] outs = pool.starmap(_parallel_loop, data) else: outs = [ _parallel_loop(pos_np, face_np, src_np, dst_np, max_distance, scale, i, dtype) for i in range(len(src_np)) ] out = torch.cat(outs, dim=0) if dst is None: out = out.view(-1, pos.size(0)) return out ================================================ FILE: torch_geometric/utils/hetero.py ================================================ from typing import Dict, List, Optional, Set, Tuple, Union import torch from torch import Tensor from torch.nn import ParameterDict from torch_geometric.typing import Adj, EdgeType, NodeType, SparseTensor from torch_geometric.utils import is_sparse, to_edge_index from torch_geometric.utils.num_nodes import maybe_num_nodes_dict def group_hetero_graph( edge_index_dict: Dict[EdgeType, Tensor], num_nodes_dict: Optional[Dict[NodeType, int]] = None, ) -> Tuple[ Tensor, Tensor, Tensor, Tensor, Dict[Union[str, int], Tensor], Dict[Union[NodeType, EdgeType], int], ]: num_nodes_dict = maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) tmp = list(edge_index_dict.values())[0] key2int: Dict[Union[NodeType, EdgeType], int] = {} cumsum, offset = 0, {} # Helper data. node_types, local_node_indices = [], [] local2global: Dict[Union[str, int], Tensor] = {} for i, (key, N) in enumerate(num_nodes_dict.items()): key2int[key] = i node_types.append(tmp.new_full((N, ), i)) local_node_indices.append(torch.arange(N, device=tmp.device)) offset[key] = cumsum local2global[key] = local_node_indices[-1] + cumsum local2global[i] = local2global[key] cumsum += N node_type = torch.cat(node_types, dim=0) local_node_idx = torch.cat(local_node_indices, dim=0) edge_indices, edge_types = [], [] for i, (keys, edge_index) in enumerate(edge_index_dict.items()): key2int[keys] = i inc = torch.tensor([offset[keys[0]], offset[keys[-1]]]).view(2, 1) edge_indices.append(edge_index + inc.to(tmp.device)) edge_types.append(tmp.new_full((edge_index.size(1), ), i)) edge_index = torch.cat(edge_indices, dim=-1) edge_type = torch.cat(edge_types, dim=0) return ( edge_index, edge_type, node_type, local_node_idx, local2global, key2int, ) def get_unused_node_types(node_types: List[NodeType], edge_types: List[EdgeType]) -> Set[NodeType]: dst_node_types = {edge_type[-1] for edge_type in edge_types} return set(node_types) - set(dst_node_types) def check_add_self_loops( module: torch.nn.Module, edge_types: List[EdgeType], ) -> None: is_bipartite = any([key[0] != key[-1] for key in edge_types]) if is_bipartite and getattr(module, 'add_self_loops', False): raise ValueError( f"'add_self_loops' attribute set to 'True' on module '{module}' " f"for use with edge type(s) '{edge_types}'. This will lead to " f"incorrect message passing results.") def construct_bipartite_edge_index( edge_index_dict: Dict[EdgeType, Adj], src_offset_dict: Dict[EdgeType, int], dst_offset_dict: Dict[NodeType, int], edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, num_nodes: Optional[int] = None, ) -> Tuple[Adj, Optional[Tensor]]: """Constructs a tensor of edge indices by concatenating edge indices for each edge type. The edge indices are increased by the offset of the source and destination nodes. Args: edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding graph connectivity information for each individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. src_offset_dict (Dict[Tuple[str, str, str], int]): A dictionary of offsets to apply to the source node type for each edge type. dst_offset_dict (Dict[str, int]): A dictionary of offsets to apply for destination node types. edge_attr_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding edge features for each individual edge type. (default: :obj:`None`) num_nodes (int, optional): The final number of nodes in the bipartite adjacency matrix. (default: :obj:`None`) """ is_sparse_tensor = False edge_indices: List[Tensor] = [] edge_attrs: List[Tensor] = [] for edge_type, src_offset in src_offset_dict.items(): edge_index = edge_index_dict[edge_type] dst_offset = dst_offset_dict[edge_type[-1]] # TODO Add support for SparseTensor w/o converting. is_sparse_tensor = isinstance(edge_index, SparseTensor) if is_sparse(edge_index): edge_index, _ = to_edge_index(edge_index) edge_index = edge_index.flip([0]) else: edge_index = edge_index.clone() edge_index[0] += src_offset edge_index[1] += dst_offset edge_indices.append(edge_index) if edge_attr_dict is not None: if isinstance(edge_attr_dict, ParameterDict): value = edge_attr_dict['__'.join(edge_type)] else: value = edge_attr_dict[edge_type] if value.size(0) != edge_index.size(1): value = value.expand(edge_index.size(1), -1) edge_attrs.append(value) edge_index = torch.cat(edge_indices, dim=1) edge_attr: Optional[Tensor] = None if edge_attr_dict is not None: edge_attr = torch.cat(edge_attrs, dim=0) if is_sparse_tensor: edge_index = SparseTensor( row=edge_index[1], col=edge_index[0], value=edge_attr, sparse_sizes=(num_nodes, num_nodes), ) return edge_index, edge_attr ================================================ FILE: torch_geometric/utils/influence.py ================================================ from typing import List, Tuple, Union, cast import torch from torch import Tensor from torch.autograd.functional import jacobian from tqdm.auto import tqdm from torch_geometric.data import Data from torch_geometric.utils import k_hop_subgraph def k_hop_subsets_rough( node_idx: int, num_hops: int, edge_index: Tensor, num_nodes: int, ) -> List[Tensor]: r"""Return *rough* (possibly overlapping) *k*-hop node subsets. This is a thin wrapper around :pyfunc:`torch_geometric.utils.k_hop_subgraph` that *additionally* returns **all** intermediate hop subsets rather than the full union only. Parameters ---------- node_idx: int Index or indices of the central node(s). num_hops: int Number of hops *k*. edge_index: Tensor Edge index in COO format with shape :math:`[2, \text{num_edges}]`. num_nodes: int Total number of nodes in the graph. Required to allocate the masks. Returns: ------- List[Tensor] A list ``[H₀, H₁, …, H_k]`` where ``H₀`` contains the seed node(s) and ``H_i`` (for *i*>0) contains **all** nodes that are exactly *i* hops away in the *expanded* neighbourhood (i.e. overlaps are *not* removed). """ col, row = edge_index node_mask = row.new_empty(num_nodes, dtype=torch.bool) edge_mask = row.new_empty(row.size(0), dtype=torch.bool) node_idx_ = torch.tensor([node_idx], device=row.device) subsets = [node_idx_] for _ in range(num_hops): node_mask.zero_() node_mask[subsets[-1]] = True torch.index_select(node_mask, 0, row, out=edge_mask) subsets.append(col[edge_mask]) return subsets def k_hop_subsets_exact( node_idx: int, num_hops: int, edge_index: Tensor, num_nodes: int, device: Union[torch.device, str], ) -> List[Tensor]: """Return **disjoint** *k*-hop subsets. This function refines :pyfunc:`k_hop_subsets_rough` by removing nodes that have already appeared in previous hops, ensuring that each subset contains nodes *exactly* *i* hops away from the seed. """ rough_subsets = k_hop_subsets_rough(node_idx, num_hops, edge_index, num_nodes) exact_subsets: List[List[int]] = [rough_subsets[0].tolist()] visited: set[int] = set(exact_subsets[0]) for hop_subset in rough_subsets[1:]: fresh = set(hop_subset.tolist()) - visited visited |= fresh exact_subsets.append(list(fresh)) return [ torch.tensor(s, device=device, dtype=edge_index.dtype) for s in exact_subsets ] def jacobian_l1( model: torch.nn.Module, data: Data, max_hops: int, node_idx: int, device: Union[torch.device, str], *, vectorize: bool = True, ) -> Tensor: """Compute the **L1 norm** of the Jacobian for a given node. The Jacobian is evaluated w.r.t. the node features of the *k*-hop induced sub‑graph centred at ``node_idx``. The result is *folded back* onto the **original** node index space so that the returned tensor has length ``data.num_nodes``, where the influence score will be zero for nodes outside the *k*-hop subgraph. Notes: ----- * The function assumes that the model *and* ``data.x`` share the same floating‑point precision (e.g. both ``float32`` or both ``float16``). """ # Build the induced *k*-hop sub‑graph (with node re‑labelling). edge_index = cast(Tensor, data.edge_index) x = cast(Tensor, data.x) k_hop_nodes, sub_edge_index, mapping, _ = k_hop_subgraph( node_idx, max_hops, edge_index, relabel_nodes=True) # get the location of the *center* node inside the sub‑graph root_pos = cast(int, mapping[0]) # Move tensors & model to the correct device device = torch.device(device) sub_x = x[k_hop_nodes].to(device) sub_edge_index = sub_edge_index.to(device) model = model.to(device) # Jacobian evaluation def _forward(x: Tensor) -> Tensor: return model(x, sub_edge_index)[root_pos] jac = jacobian(_forward, sub_x, vectorize=vectorize) influence_sub = jac.abs().sum(dim=(0, 2)) # Sum of L1 norm num_nodes = cast(int, data.num_nodes) # Scatter the influence scores back to the *global* node space influence_full = torch.zeros(num_nodes, dtype=influence_sub.dtype, device=device) influence_full[k_hop_nodes] = influence_sub return influence_full def jacobian_l1_agg_per_hop( model: torch.nn.Module, data: Data, max_hops: int, node_idx: int, device: Union[torch.device, str], vectorize: bool = True, ) -> Tensor: """Aggregate Jacobian L1 norms **per hop** for node_idx. Returns a vector ``[I_0, I_1, …, I_k]`` where ``I_i`` is the *total* influence exerted by nodes that are exactly *i* hops away from ``node_idx``. """ num_nodes = cast(int, data.num_nodes) edge_index = cast(Tensor, data.edge_index) influence = jacobian_l1(model, data, max_hops, node_idx, device, vectorize=vectorize) hop_subsets = k_hop_subsets_exact(node_idx, max_hops, edge_index, num_nodes, influence.device) single_node_influence_per_hop = [influence[s].sum() for s in hop_subsets] return torch.tensor(single_node_influence_per_hop, device=influence.device) def avg_total_influence( influence_all_nodes: Tensor, normalize: bool = True, ) -> Tensor: """Compute the *influence‑weighted receptive field* ``R``.""" avg_total_influences = torch.mean(influence_all_nodes, dim=0) if normalize: # normalize by hop_0 (jacobian of the center node feature) avg_total_influences = avg_total_influences / avg_total_influences[0] return avg_total_influences def influence_weighted_receptive_field(T: Tensor) -> float: """Compute the *influence‑weighted receptive field* ``R``. Given an influence matrix ``T`` of shape ``[N, k+1]`` (i‑th row contains the per‑hop influences of node *i*), the receptive field breadth *R* is defined as the expected hop distance when weighting by influence. A larger *R* indicates that, on average, influence comes from **farther** hops. """ normalised = T / torch.sum(T, dim=1, keepdim=True) hops = torch.arange(T.shape[1]).float() # 0 … k breadth = normalised @ hops # shape (N,) return breadth.mean().item() def total_influence( model: torch.nn.Module, data: Data, max_hops: int, num_samples: Union[int, None] = None, normalize: bool = True, average: bool = True, device: Union[torch.device, str] = "cpu", vectorize: bool = True, ) -> Tuple[Tensor, float]: r"""Compute Jacobian‑based influence aggregates for *multiple* seed nodes, as introduced in the `"Towards Quantifying Long-Range Interactions in Graph Machine Learning: a Large Graph Dataset and a Measurement" `_ paper. This measurement quantifies how a GNN model's output at a node is influenced by features of other nodes at increasing hop distances. Specifically, for every sampled node :math:`v`, this method 1. evaluates the **L1‑norm** of the Jacobian of the model output at :math:`v` w.r.t. the node features of its *k*-hop induced sub‑graph; 2. sums these scores **per hop** to obtain the influence vector :math:`(I_{0}, I_{1}, \dots, I_{k})`; 3. optionally averages those vectors over all sampled nodes and optionally normalises them by :math:`I_{0}`. Please refer to Section 4 of the paper for a more detailed definition. Args: model (torch.nn.Module): A PyTorch Geometric‑compatible model with forward signature ``model(x, edge_index) -> Tensor``. data (torch_geometric.data.Data): Graph data object providing at least :obj:`x` (node features) and :obj:`edge_index` (connectivity). max_hops (int): Maximum hop distance :math:`k`. num_samples (int, optional): Number of random seed nodes to evaluate. If :obj:`None`, all nodes are used. (default: :obj:`None`) normalize (bool, optional): If :obj:`True`, normalize each hop‑wise influence by the influence of hop 0. (default: :obj:`True`) average (bool, optional): If :obj:`True`, return the hop‑wise **mean** over all seed nodes (shape ``[k+1]``). If :obj:`False`, return the full influence matrix of shape ``[N, k+1]``. (default: :obj:`True`) device (torch.device or str, optional): Device on which to perform the computation. (default: :obj:`"cpu"`) vectorize (bool, optional): Forwarded to :func:`torch.autograd.functional.jacobian`. Keeping this :obj:`True` is often faster but increases memory usage. (default: :obj:`True`) Returns: Tuple[Tensor, float]: * **avg_influence** (*Tensor*): shape ``[k+1]`` if :obj:`average=True`; shape ``[N, k+1]`` otherwise. * **R** (*float*): Influence‑weighted receptive‑field breadth returned by :func:`influence_weighted_receptive_field`. Example:: >>> avg_I, R = total_influence(model, data, max_hops=3, ... num_samples=1000) >>> avg_I tensor([1.0000, 0.1273, 0.0142, 0.0019]) >>> R 0.216 """ num_samples = data.num_nodes if num_samples is None else num_samples num_nodes = cast(int, data.num_nodes) nodes = torch.randperm(num_nodes)[:num_samples].tolist() influence_all_nodes: List[Tensor] = [ jacobian_l1_agg_per_hop(model, data, max_hops, n, device, vectorize=vectorize) for n in tqdm(nodes, desc="Influence") ] allnodes = torch.vstack(influence_all_nodes).detach().cpu() # Average total influence at each hop if average: avg_influence = avg_total_influence(allnodes, normalize=normalize) else: avg_influence = allnodes # Influence‑weighted receptive field R = influence_weighted_receptive_field(allnodes) return avg_influence, R ================================================ FILE: torch_geometric/utils/isolated.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.utils import remove_self_loops, segregate_self_loops from torch_geometric.utils.num_nodes import maybe_num_nodes def contains_isolated_nodes( edge_index: Tensor, num_nodes: Optional[int] = None, ) -> bool: r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains isolated nodes. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: bool Examples: >>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> contains_isolated_nodes(edge_index) False >>> contains_isolated_nodes(edge_index, num_nodes=3) True """ num_nodes = maybe_num_nodes(edge_index, num_nodes) edge_index, _ = remove_self_loops(edge_index) return torch.unique(edge_index.view(-1)).numel() < num_nodes def remove_isolated_nodes( edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor]: r"""Removes the isolated nodes from the graph given by :attr:`edge_index` with optional edge attributes :attr:`edge_attr`. In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter out isolated node features later on. Self-loops are preserved for non-isolated nodes. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (LongTensor, Tensor, BoolTensor) Examples: >>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index) >>> mask # node mask (2 nodes) tensor([True, True]) >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index, ... num_nodes=3) >>> mask # node mask (3 nodes) tensor([True, True, False]) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) out = segregate_self_loops(edge_index, edge_attr) edge_index, edge_attr, loop_edge_index, loop_edge_attr = out mask = torch.zeros(num_nodes, dtype=torch.bool, device=edge_index.device) mask[edge_index.view(-1)] = 1 assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device) assoc[mask] = torch.arange(mask.sum(), device=assoc.device) # type: ignore edge_index = assoc[edge_index] loop_mask = torch.zeros_like(mask) loop_mask[loop_edge_index[0]] = 1 loop_mask = loop_mask & mask loop_assoc = torch.full_like(assoc, -1) loop_assoc[loop_edge_index[0]] = torch.arange(loop_edge_index.size(1), device=loop_assoc.device) loop_idx = loop_assoc[loop_mask] loop_edge_index = assoc[loop_edge_index[:, loop_idx]] edge_index = torch.cat([edge_index, loop_edge_index], dim=1) if edge_attr is not None: assert loop_edge_attr is not None loop_edge_attr = loop_edge_attr[loop_idx] edge_attr = torch.cat([edge_attr, loop_edge_attr], dim=0) return edge_index, edge_attr, mask ================================================ FILE: torch_geometric/utils/laplacian.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.typing import OptTensor from torch_geometric.utils import add_self_loops, remove_self_loops, scatter from torch_geometric.utils.num_nodes import maybe_num_nodes def get_laplacian( edge_index: Tensor, edge_weight: OptTensor = None, normalization: Optional[str] = None, dtype: Optional[torch.dtype] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: r"""Computes the graph Laplacian of the graph given by :obj:`edge_index` and optional :obj:`edge_weight`. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor, optional): One-dimensional edge weights. (default: :obj:`None`) normalization (str, optional): The normalization scheme for the graph Laplacian (default: :obj:`None`): 1. :obj:`None`: No normalization :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 2. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}` 3. :obj:`"rw"`: Random-walk normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` dtype (torch.dtype, optional): The desired data type of returned tensor in case :obj:`edge_weight=None`. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) Examples: >>> edge_index = torch.tensor([[0, 1, 1, 2], ... [1, 0, 2, 1]]) >>> edge_weight = torch.tensor([1., 2., 2., 4.]) >>> # No normalization >>> lap = get_laplacian(edge_index, edge_weight) >>> # Symmetric normalization >>> lap_sym = get_laplacian(edge_index, edge_weight, normalization='sym') >>> # Random-walk normalization >>> lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw') """ if normalization is not None: assert normalization in ['sym', 'rw'] # 'Invalid normalization' edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) if edge_weight is None: edge_weight = torch.ones(edge_index.size(1), dtype=dtype, device=edge_index.device) num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index[0], edge_index[1] deg = scatter(edge_weight, row, 0, dim_size=num_nodes, reduce='sum') if normalization is None: # L = D - A. edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) edge_weight = torch.cat([-edge_weight, deg], dim=0) elif normalization == 'sym': # Compute A_norm = -D^{-1/2} A D^{-1/2}. deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] # L = I - A_norm. assert isinstance(edge_weight, Tensor) edge_index, edge_weight = add_self_loops( # edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes) else: # Compute A_norm = -D^{-1} A. deg_inv = 1.0 / deg deg_inv.masked_fill_(deg_inv == float('inf'), 0) edge_weight = deg_inv[row] * edge_weight # L = I - A_norm. assert isinstance(edge_weight, Tensor) edge_index, edge_weight = add_self_loops( # edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes) return edge_index, edge_weight ================================================ FILE: torch_geometric/utils/loop.py ================================================ import typing from typing import Optional, Tuple, Union import torch from torch import Tensor from torch_geometric import EdgeIndex from torch_geometric.utils import scatter from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.utils.sparse import ( is_torch_sparse_tensor, to_edge_index, to_torch_coo_tensor, to_torch_csr_tensor, ) if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload as overload def contains_self_loops(edge_index: Tensor) -> bool: r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains self-loops. Args: edge_index (LongTensor): The edge indices. :rtype: bool Examples: >>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> contains_self_loops(edge_index) True >>> edge_index = torch.tensor([[0, 1, 1], ... [1, 0, 2]]) >>> contains_self_loops(edge_index) False """ mask = edge_index[0] == edge_index[1] return mask.sum().item() > 0 @overload def remove_self_loops( edge_index: Tensor, edge_attr: None = None, ) -> Tuple[Tensor, None]: pass @overload def remove_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, ) -> Tuple[Tensor, Tensor]: pass @overload def remove_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], ) -> Tuple[Tensor, Optional[Tensor]]: pass def remove_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Removes every self-loop in the graph given by :attr:`edge_index`, so that :math:`(i,i) \not\in \mathcal{E}` for every :math:`i \in \mathcal{V}`. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Example: >>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_attr = [[1, 2], [3, 4], [5, 6]] >>> edge_attr = torch.tensor(edge_attr) >>> remove_self_loops(edge_index, edge_attr) (tensor([[0, 1], [1, 0]]), tensor([[1, 2], [3, 4]])) """ size: Optional[Tuple[int, int]] = None if not typing.TYPE_CHECKING and torch.jit.is_scripting(): layout: Optional[int] = None else: layout: Optional[torch.layout] = None value: Optional[Tensor] = None if is_torch_sparse_tensor(edge_index): layout = edge_index.layout size = (edge_index.size(0), edge_index.size(1)) edge_index, value = to_edge_index(edge_index) is_undirected = False if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): is_undirected = edge_index.is_undirected mask = edge_index[0] != edge_index[1] edge_index = edge_index[:, mask] if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): edge_index._is_undirected = is_undirected if layout is not None: assert edge_attr is None assert value is not None value = value[mask] if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :( return to_torch_coo_tensor(edge_index, value, size, True), None elif str(layout) == 'torch.sparse_csr': return to_torch_csr_tensor(edge_index, value, size, True), None raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')") if edge_attr is None: return edge_index, None else: return edge_index, edge_attr[mask] @overload def segregate_self_loops( edge_index: Tensor, edge_attr: None = None, ) -> Tuple[Tensor, None, Tensor, None]: pass @overload def segregate_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: pass @overload def segregate_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], ) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: pass def segregate_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: r"""Segregates self-loops from the graph. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`, :class:`LongTensor`, :class:`Tensor`) Example: >>> edge_index = torch.tensor([[0, 0, 1], ... [0, 1, 0]]) >>> (edge_index, edge_attr, ... loop_edge_index, ... loop_edge_attr) = segregate_self_loops(edge_index) >>> loop_edge_index tensor([[0], [0]]) """ mask = edge_index[0] != edge_index[1] inv_mask = ~mask is_undirected = False if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): is_undirected = edge_index.is_undirected loop_edge_index = edge_index[:, inv_mask] loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask] edge_index = edge_index[:, mask] edge_attr = None if edge_attr is None else edge_attr[mask] if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): assert isinstance(loop_edge_index, EdgeIndex) edge_index._is_undirected = is_undirected loop_edge_index._is_undirected = is_undirected return edge_index, edge_attr, loop_edge_index, loop_edge_attr @overload def add_self_loops( edge_index: Tensor, edge_attr: None = None, fill_value: Optional[float] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, None]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: None = None, fill_value: Optional[float] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, None]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: None = None, fill_value: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, None]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: None = None, fill_value: Optional[Tensor] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, None]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: None = None, fill_value: Optional[str] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, None]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: None = None, fill_value: Optional[str] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, None]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[float] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[float] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[Tensor] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[str] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[str] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[float] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[float] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[Tensor] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[str] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[str] = None, num_nodes: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass def add_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor] = None, fill_value: Optional[Union[float, Tensor, str]] = None, num_nodes: Optional[Union[int, Tuple[int, int]]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted or has multi-dimensional edge features (:obj:`edge_attr != None`), edge features of self-loops will be added according to :obj:`fill_value`. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_attr != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) num_nodes (int or Tuple[int, int], optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. If given as a tuple, then :obj:`edge_index` is interpreted as a bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Examples: >>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_weight = torch.tensor([0.5, 0.5, 0.5]) >>> add_self_loops(edge_index) (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), None) >>> add_self_loops(edge_index, edge_weight) (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), tensor([0.5000, 0.5000, 0.5000, 1.0000, 1.0000])) >>> # edge features of self-loops are filled by constant `2.0` >>> add_self_loops(edge_index, edge_weight, ... fill_value=2.) (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), tensor([0.5000, 0.5000, 0.5000, 2.0000, 2.0000])) >>> # Use 'add' operation to merge edge features for self-loops >>> add_self_loops(edge_index, edge_weight, ... fill_value='add') (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), tensor([0.5000, 0.5000, 0.5000, 1.0000, 0.5000])) """ if not typing.TYPE_CHECKING and torch.jit.is_scripting(): layout: Optional[int] = None else: layout: Optional[torch.layout] = None is_sparse = is_torch_sparse_tensor(edge_index) value: Optional[Tensor] = None if is_sparse: assert edge_attr is None layout = edge_index.layout size = (edge_index.size(0), edge_index.size(1)) N = min(size) edge_index, value = to_edge_index(edge_index) elif isinstance(num_nodes, (tuple, list)): size = (num_nodes[0], num_nodes[1]) N = min(size) else: N = maybe_num_nodes(edge_index, num_nodes) size = (N, N) device = edge_index.device if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): loop_index: Tensor = EdgeIndex( torch.arange(0, N, device=device).view(1, -1).repeat(2, 1), sparse_size=(N, N), is_undirected=True, ) else: loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1) full_edge_index = torch.cat([edge_index, loop_index], dim=1) if is_sparse: assert edge_attr is None assert value is not None loop_attr = compute_loop_attr( # edge_index, value, N, is_sparse, fill_value) value = torch.cat([value, loop_attr], dim=0) if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :( return to_torch_coo_tensor(full_edge_index, value, size), None elif str(layout) == 'torch.sparse_csr': return to_torch_csr_tensor(full_edge_index, value, size), None raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')") if edge_attr is not None: loop_attr = compute_loop_attr( # edge_index, edge_attr, N, is_sparse, fill_value) edge_attr = torch.cat([edge_attr, loop_attr], dim=0) return full_edge_index, edge_attr @overload def add_remaining_self_loops( edge_index: Tensor, edge_attr: None = None, fill_value: Optional[float] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, None]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: None = None, fill_value: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, None]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: None = None, fill_value: Optional[str] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, None]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[float] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, fill_value: Optional[str] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[float] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], fill_value: Optional[str] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: pass def add_remaining_self_loops( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor] = None, fill_value: Optional[Union[float, Tensor, str]] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Adds remaining self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted or has multi-dimensional edge features (:obj:`edge_attr != None`), edge features of non-existing self-loops will be added according to :obj:`fill_value`. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_attr != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Example: >>> edge_index = torch.tensor([[0, 1], ... [1, 0]]) >>> edge_weight = torch.tensor([0.5, 0.5]) >>> add_remaining_self_loops(edge_index, edge_weight) (tensor([[0, 1, 0, 1], [1, 0, 0, 1]]), tensor([0.5000, 0.5000, 1.0000, 1.0000])) """ N = maybe_num_nodes(edge_index, num_nodes) mask = edge_index[0] != edge_index[1] device = edge_index.device if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): loop_index: Tensor = EdgeIndex( torch.arange(0, N, device=device).view(1, -1).repeat(2, 1), sparse_size=(N, N), is_undirected=True, ) else: loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1) if edge_attr is not None: loop_attr = compute_loop_attr( # edge_index, edge_attr, N, False, fill_value) inv_mask = ~mask loop_attr[edge_index[0][inv_mask]] = edge_attr[inv_mask] edge_attr = torch.cat([edge_attr[mask], loop_attr], dim=0) is_undirected = False if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): is_undirected = edge_index.is_undirected edge_index = edge_index[:, mask] if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): edge_index._is_undirected = is_undirected edge_index = torch.cat([edge_index, loop_index], dim=1) return edge_index, edge_attr def get_self_loop_attr( edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tensor: r"""Returns the edge features or weights of self-loops :math:`(i, i)` of every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. Edge features of missing self-loops not present in :attr:`edge_index` will be filled with zeros. If :attr:`edge_attr` is not given, it will be the vector of ones. .. note:: This operation is analogous to getting the diagonal elements of the dense adjacency matrix. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: :class:`Tensor` Examples: >>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_weight = torch.tensor([0.2, 0.3, 0.5]) >>> get_self_loop_attr(edge_index, edge_weight) tensor([0.5000, 0.0000]) >>> get_self_loop_attr(edge_index, edge_weight, num_nodes=4) tensor([0.5000, 0.0000, 0.0000, 0.0000]) """ loop_mask = edge_index[0] == edge_index[1] loop_index = edge_index[0][loop_mask] if edge_attr is not None: loop_attr = edge_attr[loop_mask] else: # A vector of ones: loop_attr = torch.ones(loop_index.numel(), device=edge_index.device) num_nodes = maybe_num_nodes(edge_index, num_nodes) full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:]) full_loop_attr[loop_index] = loop_attr return full_loop_attr @overload def compute_loop_attr( edge_index: Tensor, edge_attr: Tensor, num_nodes: int, is_sparse: bool, fill_value: Optional[float] = None, ) -> Tensor: pass @overload def compute_loop_attr( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, num_nodes: int, is_sparse: bool, fill_value: Optional[Tensor] = None, ) -> Tensor: pass @overload def compute_loop_attr( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, num_nodes: int, is_sparse: bool, fill_value: Optional[str] = None, ) -> Tensor: pass def compute_loop_attr( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, num_nodes: int, is_sparse: bool, fill_value: Optional[Union[float, Tensor, str]] = None, ) -> Tensor: if fill_value is None: size = (num_nodes, ) + edge_attr.size()[1:] return edge_attr.new_ones(size) elif isinstance(fill_value, (int, float)): size = (num_nodes, ) + edge_attr.size()[1:] return edge_attr.new_full(size, fill_value) elif isinstance(fill_value, Tensor): size = (num_nodes, ) + edge_attr.size()[1:] loop_attr = fill_value.to(edge_attr.device, edge_attr.dtype) if edge_attr.dim() != loop_attr.dim(): loop_attr = loop_attr.unsqueeze(0) return loop_attr.expand(size).contiguous() elif isinstance(fill_value, str): col = edge_index[0] if is_sparse else edge_index[1] return scatter(edge_attr, col, 0, num_nodes, fill_value) raise AttributeError("No valid 'fill_value' provided") ================================================ FILE: torch_geometric/utils/map.py ================================================ from typing import Optional, Tuple, Union import numpy as np import torch from torch import Tensor from torch.utils.dlpack import from_dlpack from torch_geometric.warnings import WarningCache _warning_cache = WarningCache() def map_index( src: Tensor, index: Tensor, max_index: Optional[Union[int, Tensor]] = None, inclusive: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Maps indices in :obj:`src` to the positional value of their corresponding occurrence in :obj:`index`. Indices must be strictly positive. Args: src (torch.Tensor): The source tensor to map. index (torch.Tensor): The index tensor that denotes the new mapping. max_index (int, optional): The maximum index value. (default :obj:`None`) inclusive (bool, optional): If set to :obj:`True`, it is assumed that every entry in :obj:`src` has a valid entry in :obj:`index`. Can speed-up computation. (default: :obj:`False`) :rtype: (:class:`torch.Tensor`, :class:`torch.BoolTensor`) Examples: >>> src = torch.tensor([2, 0, 1, 0, 3]) >>> index = torch.tensor([3, 2, 0, 1]) >>> map_index(src, index) (tensor([1, 2, 3, 2, 0]), tensor([True, True, True, True, True])) >>> src = torch.tensor([2, 0, 1, 0, 3]) >>> index = torch.tensor([3, 2, 0]) >>> map_index(src, index) (tensor([1, 2, 2, 0]), tensor([True, True, False, True, True])) .. note:: If inputs are on GPU and :obj:`cudf` is available, consider using RMM for significant speed boosts. Proceed with caution as RMM may conflict with other allocators or fragments. .. code-block:: python import rmm rmm.reinitialize(pool_allocator=True) torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator) """ if src.is_floating_point(): raise ValueError(f"Expected 'src' to be an index (got '{src.dtype}')") if index.is_floating_point(): raise ValueError(f"Expected 'index' to be an index (got " f"'{index.dtype}')") if src.device != index.device: raise ValueError(f"Both 'src' and 'index' must be on the same device " f"(got '{src.device}' and '{index.device}')") if max_index is None: max_index = torch.maximum(src.max(), index.max()) # If the `max_index` is in a reasonable range, we can accelerate this # operation by creating a helper vector to perform the mapping. # NOTE This will potentially consumes a large chunk of memory # (max_index=10 million => ~75MB), so we cap it at a reasonable size: THRESHOLD = 40_000_000 if src.is_cuda else 10_000_000 if max_index <= THRESHOLD: if inclusive: assoc = src.new_empty(max_index + 1) # type: ignore else: assoc = src.new_full((max_index + 1, ), -1) # type: ignore assoc[index] = torch.arange(index.numel(), dtype=src.dtype, device=src.device) out = assoc[src] if inclusive: return out, None else: mask = out != -1 return out[mask], mask WITH_CUDF = False if src.is_cuda: try: import cudf WITH_CUDF = True except ImportError: import pandas as pd _warning_cache.warn("Using CPU-based processing within " "'map_index' which may cause slowdowns and " "device synchronization. Consider installing " "'cudf' to accelerate computation") else: import pandas as pd if not WITH_CUDF: left_ser = pd.Series(src.cpu().numpy(), name='left_ser') right_ser = pd.Series( index=index.cpu().numpy(), data=pd.RangeIndex(0, index.size(0)), name='right_ser', ) result = pd.merge(left_ser, right_ser, how='left', left_on='left_ser', right_index=True) out_numpy = result['right_ser'].values if (index.device.type == 'mps' # MPS does not support `float64` and issubclass(out_numpy.dtype.type, np.floating)): out_numpy = out_numpy.astype(np.float32) out = torch.from_numpy(out_numpy).to(index.device) if out.is_floating_point() and inclusive: raise ValueError("Found invalid entries in 'src' that do not have " "a corresponding entry in 'index'. Set " "`inclusive=False` to ignore these entries.") if out.is_floating_point(): mask = torch.isnan(out).logical_not_() out = out[mask].to(index.dtype) return out, mask if inclusive: return out, None else: mask = out != -1 return out[mask], mask else: left_ser = cudf.Series(src, name='left_ser') right_ser = cudf.Series( index=index, data=cudf.RangeIndex(0, index.size(0)), name='right_ser', ) result = cudf.merge(left_ser, right_ser, how='left', left_on='left_ser', right_index=True, sort=True) if inclusive: try: out = from_dlpack(result['right_ser'].to_dlpack()) except ValueError as e: raise ValueError( "Found invalid entries in 'src' that do not " "have a corresponding entry in 'index'. Set " "`inclusive=False` to ignore these entries.") from e else: out = from_dlpack(result['right_ser'].fillna(-1).to_dlpack()) out = out[src.argsort().argsort()] # Restore original order. if inclusive: return out, None else: mask = out != -1 return out[mask], mask ================================================ FILE: torch_geometric/utils/mask.py ================================================ from typing import Optional import torch from torch import Tensor from torch_geometric.typing import TensorFrame def mask_select(src: Tensor, dim: int, mask: Tensor) -> Tensor: r"""Returns a new tensor which masks the :obj:`src` tensor along the dimension :obj:`dim` according to the boolean mask :obj:`mask`. Args: src (torch.Tensor): The input tensor. dim (int): The dimension in which to mask. mask (torch.BoolTensor): The 1-D tensor containing the binary mask to index with. """ assert mask.dim() == 1 if not torch.jit.is_scripting(): if isinstance(src, TensorFrame): assert dim == 0 and src.num_rows == mask.numel() return src[mask] assert src.size(dim) == mask.numel() dim = dim + src.dim() if dim < 0 else dim assert dim >= 0 and dim < src.dim() # Applying a 1-dimensional mask in the first dimension is significantly # faster than broadcasting the mask and utilizing `masked_select`. # As such, we transpose in the first dimension, perform the masking, and # then transpose back to the original shape. src = src.transpose(0, dim) if dim != 0 else src out = src[mask] out = out.transpose(0, dim) if dim != 0 else out return out def index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor: r"""Converts indices to a mask representation. Args: index (Tensor): The indices. size (int, optional): The size of the mask. If set to :obj:`None`, a minimal sized output mask is returned. Example: >>> index = torch.tensor([1, 3, 5]) >>> index_to_mask(index) tensor([False, True, False, True, False, True]) >>> index_to_mask(index, size=7) tensor([False, True, False, True, False, True, False]) """ index = index.view(-1) size = int(index.max()) + 1 if size is None else size mask = index.new_zeros(size, dtype=torch.bool) mask[index] = True return mask def mask_to_index(mask: Tensor) -> Tensor: r"""Converts a mask to an index representation. Args: mask (Tensor): The mask. Example: >>> mask = torch.tensor([False, True, False]) >>> mask_to_index(mask) tensor([1]) """ return mask.nonzero(as_tuple=False).view(-1) ================================================ FILE: torch_geometric/utils/mesh_laplacian.py ================================================ from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.utils import add_self_loops, scatter, to_undirected def get_mesh_laplacian( pos: Tensor, face: Tensor, normalization: Optional[str] = None, ) -> Tuple[Tensor, Tensor]: r"""Computes the mesh Laplacian of a mesh given by :obj:`pos` and :obj:`face`. Computation is based on the cotangent matrix defined as .. math:: \mathbf{C}_{ij} = \begin{cases} \frac{\cot \angle_{ikj}~+\cot \angle_{ilj}}{2} & \text{if } i, j \text{ is an edge} \\ -\sum_{j \in N(i)}{C_{ij}} & \text{if } i \text{ is in the diagonal} \\ 0 & \text{otherwise} \end{cases} Normalization depends on the mass matrix defined as .. math:: \mathbf{M}_{ij} = \begin{cases} a(i) & \text{if } i \text{ is in the diagonal} \\ 0 & \text{otherwise} \end{cases} where :math:`a(i)` is obtained by joining the barycenters of the triangles around vertex :math:`i`. Args: pos (Tensor): The node positions. face (LongTensor): The face indices. normalization (str, optional): The normalization scheme for the mesh Laplacian (default: :obj:`None`): 1. :obj:`None`: No normalization :math:`\mathbf{L} = \mathbf{C}` 2. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{L} = \mathbf{M}^{-1/2} \mathbf{C}\mathbf{M}^{-1/2}` 3. :obj:`"rw"`: Row-wise normalization :math:`\mathbf{L} = \mathbf{M}^{-1} \mathbf{C}` """ assert pos.size(1) == 3 and face.size(0) == 3 num_nodes = pos.shape[0] def get_cots(left: Tensor, centre: Tensor, right: Tensor) -> Tensor: left_pos, central_pos, right_pos = pos[left], pos[centre], pos[right] left_vec = left_pos - central_pos right_vec = right_pos - central_pos dot = torch.einsum('ij, ij -> i', left_vec, right_vec) cross = torch.norm(torch.cross(left_vec, right_vec, dim=1), dim=1) cot = dot / cross # cot = cos / sin return cot / 2.0 # by definition # For each triangle face, get all three cotangents: cot_021 = get_cots(face[0], face[2], face[1]) cot_102 = get_cots(face[1], face[0], face[2]) cot_012 = get_cots(face[0], face[1], face[2]) cot_weight = torch.cat([cot_021, cot_102, cot_012]) # Face to edge: cot_index = torch.cat([face[:2], face[1:], face[::2]], dim=1) cot_index, cot_weight = to_undirected(cot_index, cot_weight) # Compute the diagonal part: cot_deg = scatter(cot_weight, cot_index[0], 0, num_nodes, reduce='sum') edge_index, _ = add_self_loops(cot_index, num_nodes=num_nodes) edge_weight = torch.cat([cot_weight, -cot_deg], dim=0) if normalization is not None: def get_areas(left: Tensor, centre: Tensor, right: Tensor) -> Tensor: central_pos = pos[centre] left_vec = pos[left] - central_pos right_vec = pos[right] - central_pos cross = torch.norm(torch.cross(left_vec, right_vec, dim=1), dim=1) area = cross / 6.0 # one-third of a triangle's area is cross / 6.0 return area / 2.0 # since each corresponding area is counted twice # Like before, but here we only need the diagonal (the mass matrix): area_021 = get_areas(face[0], face[2], face[1]) area_102 = get_areas(face[1], face[0], face[2]) area_012 = get_areas(face[0], face[1], face[2]) area_weight = torch.cat([area_021, area_102, area_012]) area_index = torch.cat([face[:2], face[1:], face[::2]], dim=1) area_index, area_weight = to_undirected(area_index, area_weight) area_deg = scatter(area_weight, area_index[0], 0, num_nodes, 'sum') if normalization == 'sym': area_deg_inv_sqrt = area_deg.pow_(-0.5) area_deg_inv_sqrt[area_deg_inv_sqrt == float('inf')] = 0.0 edge_weight = (area_deg_inv_sqrt[edge_index[0]] * edge_weight * area_deg_inv_sqrt[edge_index[1]]) elif normalization == 'rw': area_deg_inv = 1.0 / area_deg area_deg_inv[area_deg_inv == float('inf')] = 0.0 edge_weight = area_deg_inv[edge_index[0]] * edge_weight return edge_index, edge_weight ================================================ FILE: torch_geometric/utils/mixin.py ================================================ from typing import Any, Iterator, TypeVar T = TypeVar('T') class CastMixin: @classmethod def cast(cls: T, *args: Any, **kwargs: Any) -> T: if len(args) == 1 and len(kwargs) == 0: elem = args[0] if elem is None: return None # type: ignore if isinstance(elem, CastMixin): return elem # type: ignore if isinstance(elem, tuple): return cls(*elem) # type: ignore if isinstance(elem, dict): return cls(**elem) # type: ignore return cls(*args, **kwargs) # type: ignore def __iter__(self) -> Iterator: return iter(self.__dict__.values()) ================================================ FILE: torch_geometric/utils/nested.py ================================================ from typing import Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.utils import scatter def to_nested_tensor( x: Tensor, batch: Optional[Tensor] = None, ptr: Optional[Tensor] = None, batch_size: Optional[int] = None, ) -> Tensor: r"""Given a contiguous batch of tensors :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}` (with :math:`N_i` indicating the number of elements in example :math:`i`), creates a `nested PyTorch tensor `__. Reverse operation of :meth:`from_nested_tensor`. Args: x (torch.Tensor): The input tensor :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. Must be ordered. (default: :obj:`None`) ptr (torch.Tensor, optional): Alternative representation of :obj:`batch` in compressed format. (default: :obj:`None`) batch_size (int, optional): The batch size :math:`B`. (default: :obj:`None`) """ if ptr is not None: offsets = ptr[1:] - ptr[:-1] sizes = offsets.tolist() xs = list(torch.split(x, sizes, dim=0)) elif batch is not None: offsets = scatter(torch.ones_like(batch), batch, dim_size=batch_size) sizes = offsets.tolist() xs = list(torch.split(x, sizes, dim=0)) else: xs = [x] # This currently copies the data, although `x` is already contiguous. # Sadly, there does not exist any (public) API to prevent this :( return torch.nested.as_nested_tensor(xs) def from_nested_tensor( x: Tensor, return_batch: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Given a `nested PyTorch tensor `__, creates a contiguous batch of tensors :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}`, and optionally a batch vector which assigns each element to a specific example. Reverse operation of :meth:`to_nested_tensor`. Args: x (torch.Tensor): The nested input tensor. The size of nested tensors need to match except for the first dimension. return_batch (bool, optional): If set to :obj:`True`, will also return the batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`. (default: :obj:`False`) """ if not x.is_nested: raise ValueError("Input tensor in 'from_nested_tensor' is not nested") sizes = x._nested_tensor_size() for dim, (a, b) in enumerate(zip(sizes[0, 1:], sizes.t()[1:])): if not torch.equal(a.expand_as(b), b): raise ValueError(f"Not all nested tensors have the same size " f"in dimension {dim + 1} " f"(expected size {a.item()} for all tensors)") out = x.contiguous().values() out = out.view(-1, *sizes[0, 1:].tolist()) if not return_batch: return out batch = torch.arange(x.size(0), device=x.device) batch = batch.repeat_interleave(sizes[:, 0].to(batch.device)) return out, batch ================================================ FILE: torch_geometric/utils/noise_scheduler.py ================================================ import math from typing import Literal, Optional import torch from torch import Tensor def get_smld_sigma_schedule( sigma_min: float, sigma_max: float, num_scales: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tensor: r"""Generates a set of noise values on a logarithmic scale for "Score Matching with Langevin Dynamics" from the `"Generative Modeling by Estimating Gradients of the Data Distribution" `_ paper. This function returns a vector of sigma values that define the schedule of noise levels used during Score Matching with Langevin Dynamics. The sigma values are determined on a logarithmic scale from :obj:`sigma_max` to :obj:`sigma_min`, inclusive. Args: sigma_min (float): The minimum value of sigma, corresponding to the lowest noise level. sigma_max (float): The maximum value of sigma, corresponding to the highest noise level. num_scales (int): The number of sigma values to generate, defining the granularity of the noise schedule. dtype (torch.dtype, optional): The output data type. (default: :obj:`None`) device (torch.device, optional): The output device. (default: :obj:`None`) """ return torch.linspace( math.log(sigma_max), math.log(sigma_min), num_scales, dtype=dtype, device=device, ).exp() def get_diffusion_beta_schedule( schedule_type: Literal['linear', 'quadratic', 'constant', 'sigmoid'], beta_start: float, beta_end: float, num_diffusion_timesteps: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tensor: r"""Generates a schedule of beta values according to the specified strategy for the diffusion process from the `"Denoising Diffusion Probabilistic Models" `_ paper. Beta values are used to scale the noise added during the diffusion process in generative models. This function creates an array of beta values according to a pre-defined schedule, which can be either :obj:`"linear"`, :obj:`"quadratic"`, :obj:`"constant"`, or :obj:`"sigmoid"`. Args: schedule_type (str): The type of schedule to use for beta values. beta_start (float): The starting value of beta. beta_end (float): The ending value of beta. num_diffusion_timesteps (int): The number of timesteps for the diffusion process. dtype (torch.dtype, optional): The output data type. (default: :obj:`None`) device (torch.device, optional): The output device. (default: :obj:`None`) """ if schedule_type == 'linear': return torch.linspace( beta_start, beta_end, num_diffusion_timesteps, dtype=dtype, device=device, ) if schedule_type == 'quadratic': return torch.linspace( beta_start**0.5, beta_end**0.5, num_diffusion_timesteps, dtype=dtype, device=device, )**2 if schedule_type == 'constant': return torch.full( (num_diffusion_timesteps, ), fill_value=beta_end, dtype=dtype, device=device, ) if schedule_type == 'sigmoid': return torch.linspace( -6, 6, num_diffusion_timesteps, dtype=dtype, device=device, ).sigmoid() * (beta_end - beta_start) + beta_start raise ValueError(f"Found invalid 'schedule_type' (got '{schedule_type}')") ================================================ FILE: torch_geometric/utils/num_nodes.py ================================================ from copy import copy from typing import Dict, Optional, Tuple, Union import torch from torch import Tensor import torch_geometric from torch_geometric import EdgeIndex from torch_geometric.typing import EdgeType, NodeType, SparseTensor def maybe_num_nodes( edge_index: Union[Tensor, Tuple[Tensor, Tensor], SparseTensor], num_nodes: Optional[int] = None, ) -> int: if num_nodes is not None: return num_nodes elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): return max(edge_index.get_sparse_size()) elif isinstance(edge_index, Tensor): if torch_geometric.utils.is_torch_sparse_tensor(edge_index): return max(edge_index.size(0), edge_index.size(1)) if torch.jit.is_tracing(): # Avoid non-traceable if-check for empty `edge_index` tensor: tmp = torch.concat([ edge_index.view(-1), edge_index.new_full((1, ), fill_value=-1) ]) return tmp.max() + 1 # type: ignore return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 elif isinstance(edge_index, tuple): return max( int(edge_index[0].max()) + 1 if edge_index[0].numel() > 0 else 0, int(edge_index[1].max()) + 1 if edge_index[1].numel() > 0 else 0, ) elif isinstance(edge_index, SparseTensor): return max(edge_index.size(0), edge_index.size(1)) raise NotImplementedError def maybe_num_nodes_dict( edge_index_dict: Dict[EdgeType, Tensor], num_nodes_dict: Optional[Dict[NodeType, int]] = None, ) -> Dict[NodeType, int]: num_nodes_dict = {} if num_nodes_dict is None else copy(num_nodes_dict) found_types = list(num_nodes_dict.keys()) for keys, edge_index in edge_index_dict.items(): key = keys[0] if key not in found_types: N = int(edge_index[0].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) key = keys[-1] if key not in found_types: N = int(edge_index[1].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) return num_nodes_dict ================================================ FILE: torch_geometric/utils/ppr.py ================================================ from itertools import chain from typing import Callable, List, Optional, Tuple import numpy as np import torch from torch import Tensor from torch_geometric import EdgeIndex from torch_geometric.utils.num_nodes import maybe_num_nodes try: import numba WITH_NUMBA = True except Exception: # pragma: no cover WITH_NUMBA = False def _get_ppr( # pragma: no cover rowptr: np.ndarray, col: np.ndarray, alpha: float, eps: float, target: Optional[np.ndarray] = None, ) -> Tuple[List[List[int]], List[List[float]]]: num_nodes = len(rowptr) - 1 if target is None else len(target) alpha_eps = alpha * eps js = [[0]] * num_nodes vals = [[0.]] * num_nodes for inode_uint in numba.prange(num_nodes): if target is None: inode = numba.int64(inode_uint) else: inode = target[inode_uint] p = {inode: 0.0} r = {} r[inode] = alpha q = [inode] while len(q) > 0: unode = q.pop() res = r[unode] if unode in r else 0 if unode in p: p[unode] += res else: p[unode] = res r[unode] = 0 start, end = rowptr[unode], rowptr[unode + 1] ucount = end - start for vnode in col[start:end]: _val = (1 - alpha) * res / ucount if vnode in r: r[vnode] += _val else: r[vnode] = _val res_vnode = r[vnode] if vnode in r else 0 vcount = rowptr[vnode + 1] - rowptr[vnode] if res_vnode >= alpha_eps * vcount: if vnode not in q: q.append(vnode) js[inode_uint] = list(p.keys()) vals[inode_uint] = list(p.values()) return js, vals _get_ppr_numba: Optional[Callable] = None def get_ppr( edge_index: Tensor, alpha: float = 0.2, eps: float = 1e-5, target: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: r"""Calculates the personalized PageRank (PPR) vector for all or a subset of nodes using a variant of the `Andersen algorithm `_. Args: edge_index (torch.Tensor): The indices of the graph. alpha (float, optional): The alpha value of the PageRank algorithm. (default: :obj:`0.2`) eps (float, optional): The threshold for stopping the PPR calculation (:obj:`edge_weight >= eps * out_degree`). (default: :obj:`1e-5`) target (torch.Tensor, optional): The target nodes to compute PPR for. If not given, calculates PPR vectors for all nodes. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) """ if not WITH_NUMBA: # pragma: no cover raise ImportError("'get_ppr' requires the 'numba' package") global _get_ppr_numba if _get_ppr_numba is None: _get_ppr_numba = numba.jit(nopython=True, parallel=True)(_get_ppr) num_nodes = maybe_num_nodes(edge_index, num_nodes) edge_index = EdgeIndex(edge_index, sparse_size=(num_nodes, num_nodes)) edge_index = edge_index.sort_by('row')[0] (rowptr, col), _ = edge_index.get_csr() cols, weights = _get_ppr_numba( rowptr.cpu().numpy(), col.cpu().numpy(), alpha, eps, None if target is None else target.cpu().numpy(), ) device = edge_index.device col = torch.tensor(list(chain.from_iterable(cols)), device=device) weight = torch.tensor(list(chain.from_iterable(weights)), device=device) deg = torch.tensor([len(value) for value in cols], device=device) row = torch.arange(num_nodes) if target is None else target row = row.repeat_interleave(deg, output_size=col.numel()) edge_index = torch.stack([row, col], dim=0) return edge_index, weight ================================================ FILE: torch_geometric/utils/random.py ================================================ import warnings from typing import List, Union import numpy as np import torch from torch_geometric.utils import remove_self_loops, to_undirected def erdos_renyi_graph( num_nodes: int, edge_prob: float, directed: bool = False, ) -> torch.Tensor: r"""Returns the :obj:`edge_index` of a random Erdos-Renyi graph. Args: num_nodes (int): The number of nodes. edge_prob (float): Probability of an edge. directed (bool, optional): If set to :obj:`True`, will return a directed graph. (default: :obj:`False`) Examples: >>> erdos_renyi_graph(5, 0.2, directed=False) tensor([[0, 1, 1, 4], [1, 0, 4, 1]]) >>> erdos_renyi_graph(5, 0.2, directed=True) tensor([[0, 1, 3, 3, 4, 4], [4, 3, 1, 2, 1, 3]]) """ if directed: idx = torch.arange((num_nodes - 1) * num_nodes) idx = idx.view(num_nodes - 1, num_nodes) idx = idx + torch.arange(1, num_nodes).view(-1, 1) idx = idx.view(-1) else: warnings.filterwarnings('ignore', '.*pass the indexing argument.*') idx = torch.combinations(torch.arange(num_nodes), r=2) # Filter edges. mask = torch.rand(idx.size(0)) < edge_prob idx = idx[mask] if directed: row = idx.div(num_nodes, rounding_mode='floor') col = idx % num_nodes edge_index = torch.stack([row, col], dim=0) else: edge_index = to_undirected(idx.t(), num_nodes=num_nodes) return edge_index def stochastic_blockmodel_graph( block_sizes: Union[List[int], torch.Tensor], edge_probs: Union[List[List[float]], torch.Tensor], directed: bool = False, ) -> torch.Tensor: r"""Returns the :obj:`edge_index` of a stochastic blockmodel graph. Args: block_sizes ([int] or LongTensor): The sizes of blocks. edge_probs ([[float]] or FloatTensor): The density of edges going from each block to each other block. Must be symmetric if the graph is undirected. directed (bool, optional): If set to :obj:`True`, will return a directed graph. (default: :obj:`False`) Examples: >>> block_sizes = [2, 2, 4] >>> edge_probs = [[0.25, 0.05, 0.02], ... [0.05, 0.35, 0.07], ... [0.02, 0.07, 0.40]] >>> stochastic_blockmodel_graph(block_sizes, edge_probs, ... directed=False) tensor([[2, 4, 4, 5, 5, 6, 7, 7], [5, 6, 7, 2, 7, 4, 4, 5]]) >>> stochastic_blockmodel_graph(block_sizes, edge_probs, ... directed=True) tensor([[0, 2, 3, 4, 4, 5, 5], [3, 4, 1, 5, 6, 6, 7]]) """ size, prob = block_sizes, edge_probs if not isinstance(size, torch.Tensor): size = torch.tensor(size, dtype=torch.long) if not isinstance(prob, torch.Tensor): prob = torch.tensor(prob, dtype=torch.float) assert size.dim() == 1 assert prob.dim() == 2 and prob.size(0) == prob.size(1) assert size.size(0) == prob.size(0) if not directed: assert torch.allclose(prob, prob.t()) node_idx = torch.cat([size.new_full((b, ), i) for i, b in enumerate(size)]) num_nodes = node_idx.size(0) if directed: idx = torch.arange((num_nodes - 1) * num_nodes) idx = idx.view(num_nodes - 1, num_nodes) idx = idx + torch.arange(1, num_nodes).view(-1, 1) idx = idx.view(-1) row = idx.div(num_nodes, rounding_mode='floor') col = idx % num_nodes else: row, col = torch.combinations(torch.arange(num_nodes), r=2).t() mask = torch.bernoulli(prob[node_idx[row], node_idx[col]]).to(torch.bool) edge_index = torch.stack([row[mask], col[mask]], dim=0) if not directed: edge_index = to_undirected(edge_index, num_nodes=num_nodes) return edge_index def barabasi_albert_graph(num_nodes: int, num_edges: int) -> torch.Tensor: r"""Returns the :obj:`edge_index` of a Barabasi-Albert preferential attachment model, where a graph of :obj:`num_nodes` nodes grows by attaching new nodes with :obj:`num_edges` edges that are preferentially attached to existing nodes with high degree. Args: num_nodes (int): The number of nodes. num_edges (int): The number of edges from a new node to existing nodes. Example: >>> barabasi_albert_graph(num_nodes=4, num_edges=3) tensor([[0, 0, 0, 1, 1, 2, 2, 3], [1, 2, 3, 0, 2, 0, 1, 0]]) """ assert num_edges > 0 and num_edges < num_nodes row, col = torch.arange(num_edges), torch.randperm(num_edges) for i in range(num_edges, num_nodes): row = torch.cat([row, torch.full((num_edges, ), i, dtype=torch.long)]) choice = np.random.choice(torch.cat([row, col]).numpy(), num_edges) col = torch.cat([col, torch.from_numpy(choice)]) edge_index = torch.stack([row, col], dim=0) edge_index, _ = remove_self_loops(edge_index) edge_index = to_undirected(edge_index, num_nodes=num_nodes) return edge_index ================================================ FILE: torch_geometric/utils/repeat.py ================================================ import itertools import numbers from typing import Any import torch from torch import Tensor def repeat(src: Any, length: int) -> Any: if src is None: return None if isinstance(src, Tensor): if src.numel() == 1: return src.repeat(length) if src.numel() > length: return src[:length] if src.numel() < length: last_elem = src[-1].unsqueeze(0) padding = last_elem.repeat(length - src.numel()) return torch.cat([src, padding]) return src if isinstance(src, numbers.Number): return list(itertools.repeat(src, length)) if (len(src) > length): return src[:length] if (len(src) < length): return src + list(itertools.repeat(src[-1], length - len(src))) return src ================================================ FILE: torch_geometric/utils/smiles.py ================================================ from typing import Any, Dict, List import torch import torch_geometric x_map: Dict[str, List[Any]] = { 'atomic_num': list(range(0, 119)), 'chirality': [ 'CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW', 'CHI_OTHER', 'CHI_TETRAHEDRAL', 'CHI_ALLENE', 'CHI_SQUAREPLANAR', 'CHI_TRIGONALBIPYRAMIDAL', 'CHI_OCTAHEDRAL', ], 'degree': list(range(0, 11)), 'formal_charge': list(range(-5, 7)), 'num_hs': list(range(0, 9)), 'num_radical_electrons': list(range(0, 5)), 'hybridization': [ 'UNSPECIFIED', 'S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'OTHER', ], 'is_aromatic': [False, True], 'is_in_ring': [False, True], } e_map: Dict[str, List[Any]] = { 'bond_type': [ 'UNSPECIFIED', 'SINGLE', 'DOUBLE', 'TRIPLE', 'QUADRUPLE', 'QUINTUPLE', 'HEXTUPLE', 'ONEANDAHALF', 'TWOANDAHALF', 'THREEANDAHALF', 'FOURANDAHALF', 'FIVEANDAHALF', 'AROMATIC', 'IONIC', 'HYDROGEN', 'THREECENTER', 'DATIVEONE', 'DATIVE', 'DATIVEL', 'DATIVER', 'OTHER', 'ZERO', ], 'stereo': [ 'STEREONONE', 'STEREOANY', 'STEREOZ', 'STEREOE', 'STEREOCIS', 'STEREOTRANS', ], 'is_conjugated': [False, True], } def from_rdmol(mol: Any) -> 'torch_geometric.data.Data': r"""Converts a :class:`rdkit.Chem.Mol` instance to a :class:`torch_geometric.data.Data` instance. Args: mol (rdkit.Chem.Mol): The :class:`rdkit` molecule. """ from rdkit import Chem from torch_geometric.data import Data assert isinstance(mol, Chem.Mol) xs: List[List[int]] = [] for atom in mol.GetAtoms(): row: List[int] = [] row.append(x_map['atomic_num'].index(atom.GetAtomicNum())) row.append(x_map['chirality'].index(str(atom.GetChiralTag()))) row.append(x_map['degree'].index(atom.GetTotalDegree())) row.append(x_map['formal_charge'].index(atom.GetFormalCharge())) row.append(x_map['num_hs'].index(atom.GetTotalNumHs())) row.append(x_map['num_radical_electrons'].index( atom.GetNumRadicalElectrons())) row.append(x_map['hybridization'].index(str(atom.GetHybridization()))) row.append(x_map['is_aromatic'].index(atom.GetIsAromatic())) row.append(x_map['is_in_ring'].index(atom.IsInRing())) xs.append(row) x = torch.tensor(xs, dtype=torch.long).view(-1, 9) edge_indices, edge_attrs = [], [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() e = [] e.append(e_map['bond_type'].index(str(bond.GetBondType()))) e.append(e_map['stereo'].index(str(bond.GetStereo()))) e.append(e_map['is_conjugated'].index(bond.GetIsConjugated())) edge_indices += [[i, j], [j, i]] edge_attrs += [e, e] edge_index = torch.tensor(edge_indices) edge_index = edge_index.t().to(torch.long).view(2, -1) edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3) if edge_index.numel() > 0: # Sort indices. perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) def from_smiles( smiles: str, with_hydrogen: bool = False, kekulize: bool = False, ) -> 'torch_geometric.data.Data': r"""Converts a SMILES string to a :class:`torch_geometric.data.Data` instance. Args: smiles (str): The SMILES string. with_hydrogen (bool, optional): If set to :obj:`True`, will store hydrogens in the molecule graph. (default: :obj:`False`) kekulize (bool, optional): If set to :obj:`True`, converts aromatic bonds to single/double bonds. (default: :obj:`False`) """ from rdkit import Chem, RDLogger RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined] mol = Chem.MolFromSmiles(smiles) if mol is None: mol = Chem.MolFromSmiles('') if with_hydrogen: mol = Chem.AddHs(mol) if kekulize: Chem.Kekulize(mol) data = from_rdmol(mol) data.smiles = smiles return data def to_rdmol( data: 'torch_geometric.data.Data', kekulize: bool = False, ) -> Any: """Converts a :class:`torch_geometric.data.Data` instance to a :class:`rdkit.Chem.Mol` instance. Args: data (torch_geometric.data.Data): The molecular graph data. kekulize (bool, optional): If set to :obj:`True`, converts aromatic bonds to single/double bonds. (default: :obj:`False`) """ from rdkit import Chem mol = Chem.RWMol() assert data.x is not None assert data.num_nodes is not None assert data.edge_index is not None assert data.edge_attr is not None for i in range(data.num_nodes): atom = Chem.Atom(int(data.x[i, 0])) atom.SetChiralTag(Chem.rdchem.ChiralType.values[int(data.x[i, 1])]) atom.SetFormalCharge(x_map['formal_charge'][int(data.x[i, 3])]) atom.SetNumExplicitHs(x_map['num_hs'][int(data.x[i, 4])]) atom.SetNumRadicalElectrons(x_map['num_radical_electrons'][int( data.x[i, 5])]) atom.SetHybridization(Chem.rdchem.HybridizationType.values[int( data.x[i, 6])]) atom.SetIsAromatic(bool(data.x[i, 7])) mol.AddAtom(atom) edges = [tuple(i) for i in data.edge_index.t().tolist()] visited = set() for i in range(len(edges)): src, dst = edges[i] if tuple(sorted(edges[i])) in visited: continue bond_type = Chem.BondType.values[int(data.edge_attr[i, 0])] mol.AddBond(src, dst, bond_type) # Set stereochemistry: stereo = Chem.rdchem.BondStereo.values[int(data.edge_attr[i, 1])] if stereo != Chem.rdchem.BondStereo.STEREONONE: db = mol.GetBondBetweenAtoms(src, dst) db.SetStereoAtoms(dst, src) db.SetStereo(stereo) # Set conjugation: is_conjugated = bool(data.edge_attr[i, 2]) mol.GetBondBetweenAtoms(src, dst).SetIsConjugated(is_conjugated) visited.add(tuple(sorted(edges[i]))) mol = mol.GetMol() if kekulize: Chem.Kekulize(mol) Chem.SanitizeMol(mol) Chem.AssignStereochemistry(mol) return mol def to_smiles( data: 'torch_geometric.data.Data', kekulize: bool = False, ) -> str: """Converts a :class:`torch_geometric.data.Data` instance to a SMILES string. Args: data (torch_geometric.data.Data): The molecular graph. kekulize (bool, optional): If set to :obj:`True`, converts aromatic bonds to single/double bonds. (default: :obj:`False`) """ from rdkit import Chem mol = to_rdmol(data, kekulize=kekulize) return Chem.MolToSmiles(mol, isomericSmiles=True) ================================================ FILE: torch_geometric/utils/sparse.py ================================================ import warnings from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor import torch_geometric.typing from torch_geometric.index import index2ptr, ptr2index from torch_geometric.typing import SparseTensor from torch_geometric.utils import coalesce, cumsum def dense_to_sparse( adj: Tensor, mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: r"""Converts a dense adjacency matrix to a sparse adjacency matrix defined by edge indices and edge attributes. Args: adj (torch.Tensor): The dense adjacency matrix of shape :obj:`[num_nodes, num_nodes]` or :obj:`[batch_size, num_nodes, num_nodes]`. mask (torch.Tensor, optional): A boolean tensor of shape :obj:`[batch_size, num_nodes]` holding information about which nodes are in each example are valid. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) Examples: >>> # For a single adjacency matrix: >>> adj = torch.tensor([[3, 1], ... [2, 0]]) >>> dense_to_sparse(adj) (tensor([[0, 0, 1], [0, 1, 0]]), tensor([3, 1, 2])) >>> # For two adjacency matrixes: >>> adj = torch.tensor([[[3, 1], ... [2, 0]], ... [[0, 1], ... [0, 2]]]) >>> dense_to_sparse(adj) (tensor([[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]), tensor([3, 1, 2, 1, 2])) >>> # First graph with two nodes, second with three: >>> adj = torch.tensor([[ ... [3, 1, 0], ... [2, 0, 0], ... [0, 0, 0] ... ], [ ... [0, 1, 0], ... [0, 2, 3], ... [0, 5, 0] ... ]]) >>> mask = torch.tensor([ ... [True, True, False], ... [True, True, True] ... ]) >>> dense_to_sparse(adj, mask) (tensor([[0, 0, 1, 2, 3, 3, 4], [0, 1, 0, 3, 3, 4, 3]]), tensor([3, 1, 2, 1, 2, 3, 5])) """ if adj.dim() < 2 or adj.dim() > 3: raise ValueError(f"Dense adjacency matrix 'adj' must be two- or " f"three-dimensional (got {adj.dim()} dimensions)") if mask is not None and adj.dim() == 2: warnings.warn( "Mask should not be provided in case the dense " "adjacency matrix is two-dimensional", stacklevel=2) mask = None if mask is not None and mask.dim() != 2: raise ValueError(f"Mask must be two-dimensional " f"(got {mask.dim()} dimensions)") if mask is not None and adj.size(-2) != adj.size(-1): raise ValueError(f"Mask is only supported on quadratic adjacency " f"matrices (got [*, {adj.size(-2)}, {adj.size(-1)}])") if adj.dim() == 2: edge_index = adj.nonzero().t() edge_attr = adj[edge_index[0], edge_index[1]] return edge_index, edge_attr else: flatten_adj = adj.view(-1, adj.size(-1)) if mask is not None: flatten_adj = flatten_adj[mask.view(-1)] edge_index = flatten_adj.nonzero().t() edge_attr = flatten_adj[edge_index[0], edge_index[1]] if mask is None: offset = torch.arange( start=0, end=adj.size(0) * adj.size(2), step=adj.size(2), device=adj.device, ) offset = offset.repeat_interleave(adj.size(1)) else: count = mask.sum(dim=-1) offset = cumsum(count)[:-1] offset = offset.repeat_interleave(count) edge_index[1] += offset[edge_index[0]] return edge_index, edge_attr def is_torch_sparse_tensor(src: Any) -> bool: r"""Returns :obj:`True` if the input :obj:`src` is a :class:`torch.sparse.Tensor` (in any sparse layout). Args: src (Any): The input object to be checked. """ if isinstance(src, Tensor): if src.layout == torch.sparse_coo: return True if src.layout == torch.sparse_csr: return True if src.layout == torch.sparse_csc: return True return False def is_sparse(src: Any) -> bool: r"""Returns :obj:`True` if the input :obj:`src` is of type :class:`torch.sparse.Tensor` (in any sparse layout) or of type :class:`torch_sparse.SparseTensor`. Args: src (Any): The input object to be checked. """ return is_torch_sparse_tensor(src) or isinstance(src, SparseTensor) def to_torch_coo_tensor( edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None, is_coalesced: bool = False, ) -> Tensor: r"""Converts a sparse adjacency matrix defined by edge indices and edge attributes to a :class:`torch.sparse.Tensor` with layout `torch.sparse_coo`. See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): The edge attributes. (default: :obj:`None`) size (int or (int, int), optional): The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to :obj:`None`, will infer a quadratic sparse matrix based on :obj:`edge_index.max() + 1`. (default: :obj:`None`) is_coalesced (bool): If set to :obj:`True`, will assume that :obj:`edge_index` is already coalesced and thus avoids expensive computation. (default: :obj:`False`) :rtype: :class:`torch.sparse.Tensor` Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> to_torch_coo_tensor(edge_index) tensor(indices=tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), values=tensor([1., 1., 1., 1., 1., 1.]), size=(4, 4), nnz=6, layout=torch.sparse_coo) """ if size is None: size = int(edge_index.max()) + 1 if isinstance(size, (tuple, list)): num_src_nodes, num_dst_nodes = size if num_src_nodes is None: num_src_nodes = int(edge_index[0].max()) + 1 if num_dst_nodes is None: num_dst_nodes = int(edge_index[1].max()) + 1 size = (num_src_nodes, num_dst_nodes) else: size = (size, size) if not is_coalesced: edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size)) if edge_attr is None: # Expanded tensors are not yet supported in all PyTorch code paths :( # edge_attr = torch.ones(1, device=edge_index.device) # edge_attr = edge_attr.expand(edge_index.size(1)) edge_attr = torch.ones(edge_index.size(1), device=edge_index.device) if not torch_geometric.typing.WITH_PT21: adj = torch.sparse_coo_tensor( indices=edge_index, values=edge_attr, size=tuple(size) + edge_attr.size()[1:], device=edge_index.device, ) adj = adj._coalesced_(True) return adj return torch.sparse_coo_tensor( indices=edge_index, values=edge_attr, size=tuple(size) + edge_attr.size()[1:], device=edge_index.device, is_coalesced=True, ) def to_torch_csr_tensor( edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None, is_coalesced: bool = False, ) -> Tensor: r"""Converts a sparse adjacency matrix defined by edge indices and edge attributes to a :class:`torch.sparse.Tensor` with layout `torch.sparse_csr`. See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): The edge attributes. (default: :obj:`None`) size (int or (int, int), optional): The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to :obj:`None`, will infer a quadratic sparse matrix based on :obj:`edge_index.max() + 1`. (default: :obj:`None`) is_coalesced (bool): If set to :obj:`True`, will assume that :obj:`edge_index` is already coalesced and thus avoids expensive computation. (default: :obj:`False`) :rtype: :class:`torch.sparse.Tensor` Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> to_torch_csr_tensor(edge_index) tensor(crow_indices=tensor([0, 1, 3, 5, 6]), col_indices=tensor([1, 0, 2, 1, 3, 2]), values=tensor([1., 1., 1., 1., 1., 1.]), size=(4, 4), nnz=6, layout=torch.sparse_csr) """ if size is None: size = int(edge_index.max()) + 1 if isinstance(size, (tuple, list)): num_src_nodes, num_dst_nodes = size if num_src_nodes is None: num_src_nodes = int(edge_index[0].max()) + 1 if num_dst_nodes is None: num_dst_nodes = int(edge_index[1].max()) + 1 size = (num_src_nodes, num_dst_nodes) else: size = (size, size) if not is_coalesced: edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size)) if edge_attr is None: # Expanded tensors are not yet supported in all PyTorch code paths :( # edge_attr = torch.ones(1, device=edge_index.device) # edge_attr = edge_attr.expand(edge_index.size(1)) edge_attr = torch.ones(edge_index.size(1), device=edge_index.device) adj = torch.sparse_csr_tensor( crow_indices=index2ptr(edge_index[0], size[0]), col_indices=edge_index[1], values=edge_attr, size=tuple(size) + edge_attr.size()[1:], device=edge_index.device, ) return adj def to_torch_csc_tensor( edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None, is_coalesced: bool = False, ) -> Tensor: r"""Converts a sparse adjacency matrix defined by edge indices and edge attributes to a :class:`torch.sparse.Tensor` with layout `torch.sparse_csc`. See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): The edge attributes. (default: :obj:`None`) size (int or (int, int), optional): The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to :obj:`None`, will infer a quadratic sparse matrix based on :obj:`edge_index.max() + 1`. (default: :obj:`None`) is_coalesced (bool): If set to :obj:`True`, will assume that :obj:`edge_index` is already coalesced and thus avoids expensive computation. (default: :obj:`False`) :rtype: :class:`torch.sparse.Tensor` Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> to_torch_csc_tensor(edge_index) tensor(ccol_indices=tensor([0, 1, 3, 5, 6]), row_indices=tensor([1, 0, 2, 1, 3, 2]), values=tensor([1., 1., 1., 1., 1., 1.]), size=(4, 4), nnz=6, layout=torch.sparse_csc) """ if size is None: size = int(edge_index.max()) + 1 if isinstance(size, (tuple, list)): num_src_nodes, num_dst_nodes = size if num_src_nodes is None: num_src_nodes = int(edge_index[0].max()) + 1 if num_dst_nodes is None: num_dst_nodes = int(edge_index[1].max()) + 1 size = (num_src_nodes, num_dst_nodes) else: size = (size, size) if not is_coalesced: edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size), sort_by_row=False) if edge_attr is None: # Expanded tensors are not yet supported in all PyTorch code paths :( # edge_attr = torch.ones(1, device=edge_index.device) # edge_attr = edge_attr.expand(edge_index.size(1)) edge_attr = torch.ones(edge_index.size(1), device=edge_index.device) adj = torch.sparse_csc_tensor( ccol_indices=index2ptr(edge_index[1], size[1]), row_indices=edge_index[0], values=edge_attr, size=tuple(size) + edge_attr.size()[1:], device=edge_index.device, ) return adj def to_torch_sparse_tensor( edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None, is_coalesced: bool = False, layout: torch.layout = torch.sparse_coo, ) -> Tensor: r"""Converts a sparse adjacency matrix defined by edge indices and edge attributes to a :class:`torch.sparse.Tensor` with custom :obj:`layout`. See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): The edge attributes. (default: :obj:`None`) size (int or (int, int), optional): The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to :obj:`None`, will infer a quadratic sparse matrix based on :obj:`edge_index.max() + 1`. (default: :obj:`None`) is_coalesced (bool): If set to :obj:`True`, will assume that :obj:`edge_index` is already coalesced and thus avoids expensive computation. (default: :obj:`False`) layout (torch.layout, optional): The layout of the output sparse tensor (:obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`, :obj:`torch.sparse_csc`). (default: :obj:`torch.sparse_coo`) :rtype: :class:`torch.sparse.Tensor` """ if layout == torch.sparse_coo: return to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced) if layout == torch.sparse_csr: return to_torch_csr_tensor(edge_index, edge_attr, size, is_coalesced) if layout == torch.sparse_csc: return to_torch_csc_tensor(edge_index, edge_attr, size, is_coalesced) raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')") def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]: r"""Converts a :class:`torch.sparse.Tensor` or a :class:`torch_sparse.SparseTensor` to edge indices and edge attributes. Args: adj (torch.sparse.Tensor or SparseTensor): The adjacency matrix. :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> adj = to_torch_coo_tensor(edge_index) >>> to_edge_index(adj) (tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), tensor([1., 1., 1., 1., 1., 1.])) """ if isinstance(adj, SparseTensor): row, col, value = adj.coo() if value is None: value = torch.ones(row.size(0), device=row.device) return torch.stack([row, col], dim=0).long(), value if adj.layout == torch.sparse_coo: adj = adj._coalesced_(True) return adj.indices().detach().long(), adj.values() if adj.layout == torch.sparse_csr: row = ptr2index(adj.crow_indices().detach()) col = adj.col_indices().detach() return torch.stack([row, col], dim=0).long(), adj.values() if adj.layout == torch.sparse_csc: col = ptr2index(adj.ccol_indices().detach()) row = adj.row_indices().detach() return torch.stack([row, col], dim=0).long(), adj.values() raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')") # Helper functions ############################################################ def get_sparse_diag( size: int, fill_value: float = 1.0, layout: Optional[int] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tensor: return torch.sparse.spdiags( torch.full((1, size), fill_value, dtype=dtype, device=device), offsets=torch.zeros(1, dtype=torch.long, device=device), shape=(size, size), layout=layout, ) def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor: if value.dim() > 1: size = adj.size() + value.size()[1:] else: size = adj.size() if adj.layout == torch.sparse_coo: return torch.sparse_coo_tensor( indices=adj.indices(), values=value, size=size, device=value.device, ).coalesce() if adj.layout == torch.sparse_csr: return torch.sparse_csr_tensor( crow_indices=adj.crow_indices(), col_indices=adj.col_indices(), values=value, size=size, device=value.device, ) if adj.layout == torch.sparse_csc: return torch.sparse_csc_tensor( ccol_indices=adj.ccol_indices(), row_indices=adj.row_indices(), values=value, size=size, device=value.device, ) raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')") def cat_coo(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor: assert dim in {0, 1, (0, 1)} assert tensors[0].layout == torch.sparse_coo indices, values = [], [] num_rows = num_cols = 0 is_coalesced = True if dim == 0: for i, tensor in enumerate(tensors): if i == 0: indices.append(tensor._indices()) else: offset = torch.tensor([[num_rows], [0]], device=tensor.device) indices.append(tensor._indices() + offset) values.append(tensor._values()) num_rows += tensor.size(0) num_cols = max(num_cols, tensor.size(1)) if not tensor.is_coalesced(): is_coalesced = False elif dim == 1: for i, tensor in enumerate(tensors): if i == 0: indices.append(tensor._indices()) else: offset = torch.tensor([[0], [num_cols]], device=tensor.device) indices.append(tensor.indices() + offset) values.append(tensor._values()) num_rows = max(num_rows, tensor.size(0)) num_cols += tensor.size(1) is_coalesced = False else: for i, tensor in enumerate(tensors): if i == 0: indices.append(tensor._indices()) else: offset = torch.tensor([[num_rows], [num_cols]], device=tensor.device) indices.append(tensor._indices() + offset) values.append(tensor._values()) num_rows += tensor.size(0) num_cols += tensor.size(1) if not tensor.is_coalesced(): is_coalesced = False if not torch_geometric.typing.WITH_PT21: out = torch.sparse_coo_tensor( indices=torch.cat(indices, dim=-1), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, ) if is_coalesced: out = out._coalesced_(True) return out return torch.sparse_coo_tensor( indices=torch.cat(indices, dim=-1), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, is_coalesced=True if is_coalesced else None, ) def cat_csr(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor: assert dim in {0, 1, (0, 1)} assert tensors[0].layout == torch.sparse_csr rows, cols, values = [], [], [] num_rows = num_cols = nnz = 0 if dim == 0: for i, tensor in enumerate(tensors): if i == 0: rows.append(tensor.crow_indices()) else: rows.append(tensor.crow_indices()[1:] + nnz) cols.append(tensor.col_indices()) values.append(tensor.values()) num_rows += tensor.size(0) num_cols = max(num_cols, tensor.size(1)) nnz += cols[-1].numel() return torch.sparse_csr_tensor( crow_indices=torch.cat(rows), col_indices=torch.cat(cols), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, ) elif dim == 1: for i, tensor in enumerate(tensors): rows.append(ptr2index(tensor.crow_indices())) if i == 0: cols.append(tensor.col_indices()) else: cols.append(tensor.col_indices() + num_cols) values.append(tensor.values()) num_rows = max(num_rows, tensor.size(0)) num_cols += tensor.size(1) return torch.sparse_coo_tensor( indices=torch.stack((torch.cat(rows), torch.cat(cols)), 0), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, ) else: for i, tensor in enumerate(tensors): if i == 0: rows.append(tensor.crow_indices()) cols.append(tensor.col_indices()) else: rows.append(tensor.crow_indices()[1:] + nnz) cols.append(tensor.col_indices() + num_cols) values.append(tensor.values()) num_rows += tensor.size(0) num_cols += tensor.size(1) nnz += cols[-1].numel() return torch.sparse_csr_tensor( crow_indices=torch.cat(rows), col_indices=torch.cat(cols), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, ) def cat_csc(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor: assert dim in {0, 1, (0, 1)} assert tensors[0].layout == torch.sparse_csc rows, cols, values = [], [], [] num_rows = num_cols = nnz = 0 if dim == 0: for i, tensor in enumerate(tensors): cols.append(ptr2index(tensor.ccol_indices())) if i == 0: rows.append(tensor.row_indices()) else: rows.append(tensor.row_indices() + num_rows) values.append(tensor.values()) num_rows += tensor.size(0) num_cols = max(num_cols, tensor.size(1)) return torch.sparse_coo_tensor( indices=torch.stack((torch.cat(rows), torch.cat(cols)), 0), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, ) elif dim == 1: for i, tensor in enumerate(tensors): if i == 0: cols.append(tensor.ccol_indices()) else: cols.append(tensor.ccol_indices()[1:] + nnz) rows.append(tensor.row_indices()) values.append(tensor.values()) num_rows = max(num_rows, tensor.size(0)) num_cols += tensor.size(1) nnz += rows[-1].numel() return torch.sparse_csc_tensor( row_indices=torch.cat(rows), ccol_indices=torch.cat(cols), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, ) else: for i, tensor in enumerate(tensors): if i == 0: rows.append(tensor.row_indices()) cols.append(tensor.ccol_indices()) else: rows.append(tensor.row_indices() + num_rows) cols.append(tensor.ccol_indices()[1:] + nnz) values.append(tensor.values()) num_rows += tensor.size(0) num_cols += tensor.size(1) nnz += rows[-1].numel() return torch.sparse_csc_tensor( row_indices=torch.cat(rows), ccol_indices=torch.cat(cols), values=torch.cat(values), size=(num_rows, num_cols) + values[-1].size()[1:], device=tensor.device, ) def cat(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor: assert is_torch_sparse_tensor(tensors[0]) if tensors[0].layout == torch.sparse_coo: return cat_coo(tensors, dim) elif tensors[0].layout == torch.sparse_csr: return cat_csr(tensors, dim) else: return cat_csc(tensors, dim) ================================================ FILE: torch_geometric/utils/undirected.py ================================================ import typing from typing import List, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.typing import OptTensor from torch_geometric.utils import coalesce, sort_edge_index from torch_geometric.utils.num_nodes import maybe_num_nodes if typing.TYPE_CHECKING: from typing import overload else: from torch.jit import _overload as overload MISSING = '???' @overload def is_undirected( edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> bool: pass @overload def is_undirected( # noqa: F811 edge_index: Tensor, edge_attr: List[Tensor], num_nodes: Optional[int] = None, ) -> bool: pass def is_undirected( # noqa: F811 edge_index: Tensor, edge_attr: Union[Optional[Tensor], List[Tensor]] = None, num_nodes: Optional[int] = None, ) -> bool: r"""Returns :obj:`True` if the graph given by :attr:`edge_index` is undirected. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- dimensional edge features. If given as a list, will check for equivalence in all its entries. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max(edge_index) + 1`. (default: :obj:`None`) :rtype: bool Examples: >>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> weight = torch.tensor([0, 0, 1]) >>> is_undirected(edge_index, weight) True >>> weight = torch.tensor([0, 1, 1]) >>> is_undirected(edge_index, weight) False """ num_nodes = maybe_num_nodes(edge_index, num_nodes) edge_attrs: List[Tensor] = [] if isinstance(edge_attr, Tensor): edge_attrs.append(edge_attr) elif isinstance(edge_attr, (list, tuple)): edge_attrs = edge_attr edge_index1, edge_attrs1 = sort_edge_index( edge_index, edge_attrs, num_nodes=num_nodes, sort_by_row=True, ) edge_index2, edge_attrs2 = sort_edge_index( edge_index, edge_attrs, num_nodes=num_nodes, sort_by_row=False, ) if not torch.equal(edge_index1[0], edge_index2[1]): return False if not torch.equal(edge_index1[1], edge_index2[0]): return False assert isinstance(edge_attrs1, list) and isinstance(edge_attrs2, list) for edge_attr1, edge_attr2 in zip(edge_attrs1, edge_attrs2): if not torch.equal(edge_attr1, edge_attr2): return False return True @overload def to_undirected( edge_index: Tensor, edge_attr: str = MISSING, num_nodes: Optional[int] = None, reduce: str = 'add', ) -> Tensor: pass @overload def to_undirected( # noqa: F811 edge_index: Tensor, edge_attr: Tensor, num_nodes: Optional[int] = None, reduce: str = 'add', ) -> Tuple[Tensor, Tensor]: pass @overload def to_undirected( # noqa: F811 edge_index: Tensor, edge_attr: Optional[Tensor], num_nodes: Optional[int] = None, reduce: str = 'add', ) -> Tuple[Tensor, Optional[Tensor]]: pass @overload def to_undirected( # noqa: F811 edge_index: Tensor, edge_attr: List[Tensor], num_nodes: Optional[int] = None, reduce: str = 'add', ) -> Tuple[Tensor, List[Tensor]]: pass def to_undirected( # noqa: F811 edge_index: Tensor, edge_attr: Union[Optional[Tensor], List[Tensor], str] = MISSING, num_nodes: Optional[int] = None, reduce: str = 'add', ) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]: r"""Converts the graph given by :attr:`edge_index` to an undirected graph such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in \mathcal{E}`. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- dimensional edge features. If given as a list, will remove duplicates for all its entries. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max(edge_index) + 1`. (default: :obj:`None`) reduce (str, optional): The reduce operation to use for merging edge features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"add"`) :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`) .. warning:: From :pyg:`PyG >= 2.3.0` onwards, this function will always return a tuple whenever :obj:`edge_attr` is passed as an argument (even in case it is set to :obj:`None`). Examples: >>> edge_index = torch.tensor([[0, 1, 1], ... [1, 0, 2]]) >>> to_undirected(edge_index) tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) >>> edge_index = torch.tensor([[0, 1, 1], ... [1, 0, 2]]) >>> edge_weight = torch.tensor([1., 1., 1.]) >>> to_undirected(edge_index, edge_weight) (tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), tensor([2., 2., 1., 1.])) >>> # Use 'mean' operation to merge edge features >>> to_undirected(edge_index, edge_weight, reduce='mean') (tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), tensor([1., 1., 1., 1.])) """ # Maintain backward compatibility to `to_undirected(edge_index, num_nodes)` if isinstance(edge_attr, int): num_nodes = edge_attr edge_attr = MISSING row, col = edge_index[0], edge_index[1] row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) edge_index = torch.stack([row, col], dim=0) if isinstance(edge_attr, Tensor): edge_attr = torch.cat([edge_attr, edge_attr], dim=0) elif isinstance(edge_attr, (list, tuple)): edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr] return coalesce(edge_index, edge_attr, num_nodes, reduce) ================================================ FILE: torch_geometric/visualization/__init__.py ================================================ r"""Visualization package.""" from .graph import visualize_graph, visualize_hetero_graph from .influence import influence __all__ = [ 'visualize_graph', 'visualize_hetero_graph', 'influence', ] ================================================ FILE: torch_geometric/visualization/graph.py ================================================ from math import sqrt from typing import Any, Dict, List, Optional, Set, Tuple import torch from torch import Tensor BACKENDS = {'graphviz', 'networkx'} def has_graphviz() -> bool: try: import graphviz except ImportError: return False try: graphviz.Digraph().pipe() except graphviz.backend.ExecutableNotFound: return False return True def visualize_graph( edge_index: Tensor, edge_weight: Optional[Tensor] = None, path: Optional[str] = None, backend: Optional[str] = None, node_labels: Optional[List[str]] = None, ) -> Any: r"""Visualizes the graph given via :obj:`edge_index` and (optional) :obj:`edge_weight`. Args: edge_index (torch.Tensor): The edge indices. edge_weight (torch.Tensor, optional): The edge weights. path (str, optional): The path to where the plot is saved. If set to :obj:`None`, will visualize the plot on-the-fly. (default: :obj:`None`) backend (str, optional): The graph drawing backend to use for visualization (:obj:`"graphviz"`, :obj:`"networkx"`). If set to :obj:`None`, will use the most appropriate visualization backend based on available system packages. (default: :obj:`None`) node_labels (List[str], optional): The labels/IDs of nodes. (default: :obj:`None`) """ if edge_weight is not None: # Normalize edge weights. edge_weight = edge_weight - edge_weight.min() edge_weight = edge_weight / edge_weight.max() if edge_weight is not None: # Discard any edges with zero edge weight: mask = edge_weight > 1e-7 edge_index = edge_index[:, mask] edge_weight = edge_weight[mask] if edge_weight is None: edge_weight = torch.ones(edge_index.size(1)) if backend is None: backend = 'graphviz' if has_graphviz() else 'networkx' if backend.lower() == 'networkx': return _visualize_graph_via_networkx(edge_index, edge_weight, path, node_labels) elif backend.lower() == 'graphviz': return _visualize_graph_via_graphviz(edge_index, edge_weight, path, node_labels) raise ValueError(f"Expected graph drawing backend to be in " f"{BACKENDS} (got '{backend}')") def _visualize_graph_via_graphviz( edge_index: Tensor, edge_weight: Tensor, path: Optional[str] = None, node_labels: Optional[List[str]] = None, ) -> Any: import graphviz suffix = path.split('.')[-1] if path is not None else None g = graphviz.Digraph('graph', format=suffix) g.attr('node', shape='circle', fontsize='11pt') for node in edge_index.view(-1).unique().tolist(): g.node(str(node) if node_labels is None else node_labels[node]) for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()): hex_color = hex(255 - round(255 * w))[2:] hex_color = f'{hex_color}0' if len(hex_color) == 1 else hex_color if node_labels is not None: src = node_labels[src] dst = node_labels[dst] g.edge(str(src), str(dst), color=f'#{hex_color}{hex_color}{hex_color}') if path is not None: path = '.'.join(path.split('.')[:-1]) g.render(path, cleanup=True) else: g.view() return g def _visualize_graph_via_networkx( edge_index: Tensor, edge_weight: Tensor, path: Optional[str] = None, node_labels: Optional[List[str]] = None, ) -> Any: import matplotlib.pyplot as plt import networkx as nx g = nx.DiGraph() node_size = 800 for node in edge_index.view(-1).unique().tolist(): g.add_node(node if node_labels is None else node_labels[node]) for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()): if node_labels is not None: src = node_labels[src] dst = node_labels[dst] g.add_edge(src, dst, alpha=w) ax = plt.gca() pos = nx.spring_layout(g) for src, dst, data in g.edges(data=True): ax.annotate( '', xy=pos[src], xytext=pos[dst], arrowprops=dict( arrowstyle="<-", alpha=data['alpha'], shrinkA=sqrt(node_size) / 2.0, shrinkB=sqrt(node_size) / 2.0, connectionstyle="arc3,rad=0.1", ), ) nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white', margins=0.1, edgecolors='black') nx.draw_networkx_labels(g, pos, font_size=10) if path is not None: plt.savefig(path) else: plt.show() plt.close() def visualize_hetero_graph( edge_index_dict: Dict[Tuple[str, str, str], Tensor], edge_weight_dict: Dict[Tuple[str, str, str], Tensor], path: Optional[str] = None, backend: Optional[str] = None, node_labels_dict: Optional[Dict[str, List[str]]] = None, node_weight_dict: Optional[Dict[str, Tensor]] = None, node_size_range: Tuple[float, float] = (50, 500), node_opacity_range: Tuple[float, float] = (1.0, 1.0), edge_width_range: Tuple[float, float] = (0.1, 2.0), edge_opacity_range: Tuple[float, float] = (1.0, 1.0), ) -> Any: """Visualizes a heterogeneous graph using networkx.""" if backend is not None and backend != "networkx": raise ValueError("Only 'networkx' backend is supported") # Filter out edges with 0 weight filtered_edge_index_dict = {} filtered_edge_weight_dict = {} for edge_type in edge_index_dict.keys(): mask = edge_weight_dict[edge_type] > 0 if mask.sum() > 0: filtered_edge_index_dict[edge_type] = edge_index_dict[ edge_type][:, mask] filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][ mask] # Get all unique nodes that are still in the filtered edges remaining_nodes: Dict[str, Set[int]] = {} for edge_type, edge_index in filtered_edge_index_dict.items(): src_type, _, dst_type = edge_type if src_type not in remaining_nodes: remaining_nodes[src_type] = set() if dst_type not in remaining_nodes: remaining_nodes[dst_type] = set() remaining_nodes[src_type].update(edge_index[0].tolist()) remaining_nodes[dst_type].update(edge_index[1].tolist()) # Filter node weights to only include remaining nodes if node_weight_dict is not None: filtered_node_weight_dict = {} for node_type, weights in node_weight_dict.items(): if node_type in remaining_nodes: mask = torch.zeros(len(weights), dtype=torch.bool) mask[list(remaining_nodes[node_type])] = True filtered_node_weight_dict[node_type] = weights[mask] node_weight_dict = filtered_node_weight_dict # Filter node labels to only include remaining nodes if node_labels_dict is not None: filtered_node_labels_dict = {} for node_type, labels in node_labels_dict.items(): if node_type in remaining_nodes: filtered_node_labels_dict[node_type] = [ label for i, label in enumerate(labels) if i in remaining_nodes[node_type] ] node_labels_dict = filtered_node_labels_dict return _visualize_hetero_graph_via_networkx( filtered_edge_index_dict, filtered_edge_weight_dict, path, node_labels_dict, node_weight_dict, node_size_range, node_opacity_range, edge_width_range, edge_opacity_range, ) def _visualize_hetero_graph_via_networkx( edge_index_dict: Dict[Tuple[str, str, str], Tensor], edge_weight_dict: Dict[Tuple[str, str, str], Tensor], path: Optional[str] = None, node_labels_dict: Optional[Dict[str, List[str]]] = None, node_weight_dict: Optional[Dict[str, Tensor]] = None, node_size_range: Tuple[float, float] = (50, 500), node_opacity_range: Tuple[float, float] = (1.0, 1.0), edge_width_range: Tuple[float, float] = (0.1, 2.0), edge_opacity_range: Tuple[float, float] = (1.0, 1.0), ) -> Any: import matplotlib.pyplot as plt import networkx as nx g = nx.DiGraph() node_offsets: Dict[str, int] = {} current_offset = 0 # First, collect all unique node types and their counts node_types = set() node_counts: Dict[str, int] = {} remaining_nodes: Dict[str, Set[int]] = { } # Track which nodes are actually present in edges # Get all unique nodes that are in the edges for edge_type in edge_index_dict.keys(): src_type, _, dst_type = edge_type node_types.add(src_type) node_types.add(dst_type) if src_type not in remaining_nodes: remaining_nodes[src_type] = set() if dst_type not in remaining_nodes: remaining_nodes[dst_type] = set() remaining_nodes[src_type].update( edge_index_dict[edge_type][0].tolist()) remaining_nodes[dst_type].update( edge_index_dict[edge_type][1].tolist()) # Set node counts based on remaining nodes for node_type in node_types: node_counts[node_type] = len(remaining_nodes[node_type]) # Add nodes for each node type for node_type in node_types: num_nodes = node_counts[node_type] node_offsets[node_type] = current_offset # Get node weights if provided weights = None if node_weight_dict is not None and node_type in node_weight_dict: weights = node_weight_dict[node_type] if len(weights) != num_nodes: raise ValueError(f"Number of weights for node type " f"{node_type} ({len(weights)}) does not " f"match number of nodes ({num_nodes})") for i in range(num_nodes): node_id = current_offset + i label = (node_labels_dict[node_type][i] if node_labels_dict is not None and node_type in node_labels_dict else "") # Calculate node size and opacity if weights provided size = node_size_range[1] opacity = node_opacity_range[1] if weights is not None: w = weights[i].item() size = node_size_range[0] + w * \ (node_size_range[1] - node_size_range[0]) opacity = node_opacity_range[0] + w * \ (node_opacity_range[1] - node_opacity_range[0]) g.add_node(node_id, label=label, type=node_type, size=size, alpha=opacity) current_offset += num_nodes # Add edges with remapped node indices for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type edge_weight = edge_weight_dict[edge_type] src_offset = node_offsets[src_type] dst_offset = node_offsets[dst_type] # Create mappings for source and target nodes src_mapping = { old_idx: new_idx for new_idx, old_idx in enumerate(sorted( remaining_nodes[src_type])) } dst_mapping = { old_idx: new_idx for new_idx, old_idx in enumerate(sorted( remaining_nodes[dst_type])) } for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()): # Remap node indices new_src = src_mapping[src] + src_offset new_dst = dst_mapping[dst] + dst_offset # Calculate edge width and opacity based on weight width = edge_width_range[0] + w * \ (edge_width_range[1] - edge_width_range[0]) opacity = edge_opacity_range[0] + w * \ (edge_opacity_range[1] - edge_opacity_range[0]) g.add_edge(new_src, new_dst, width=width, alpha=opacity) # Draw the graph ax = plt.gca() pos = nx.arf_layout(g) # Draw edges with arrows for src, dst, data in g.edges(data=True): ax.annotate( '', xy=pos[src], xytext=pos[dst], arrowprops=dict( arrowstyle="<-", alpha=data['alpha'], linewidth=data['width'], shrinkA=sqrt(g.nodes[src]['size']) / 2.0, shrinkB=sqrt(g.nodes[dst]['size']) / 2.0, connectionstyle="arc3,rad=0.1", ), ) # Draw nodes colored by type node_colors = [] node_sizes = [] node_alphas = [] # Use matplotlib tab20 colormap for consistent coloring tab10_cmap = plt.cm.tab10 # type: ignore[attr-defined] node_type_colors: Dict[str, Any] = {} # Store color for each node type for node in g.nodes(): node_type = g.nodes[node]['type'] # Assign a consistent color for each node type if node_type not in node_type_colors: color_idx = len(node_type_colors) % 10 # Cycle through colors node_type_colors[node_type] = tab10_cmap(color_idx) node_colors.append(node_type_colors[node_type]) node_sizes.append(g.nodes[node]['size']) node_alphas.append(g.nodes[node]['alpha']) nx.draw_networkx_nodes(g, pos, node_size=node_sizes, node_color=node_colors, margins=0.1, alpha=node_alphas) # Draw labels labels = nx.get_node_attributes(g, 'label') nx.draw_networkx_labels(g, pos, labels, font_size=10) # Add legend legend_elements = [] for node_type, color in node_type_colors.items(): legend_elements.append( plt.Line2D([0], [0], marker='o', color='w', label=node_type, markerfacecolor=color, markersize=10)) ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.9, 1)) if path is not None: plt.savefig(path, bbox_inches='tight') else: plt.show() plt.close() ================================================ FILE: torch_geometric/visualization/influence.py ================================================ from typing import Any import torch from torch import Tensor from torch.autograd import grad def influence(model: torch.nn.Module, src: Tensor, *args: Any) -> Tensor: x = src.clone().requires_grad_() out = model(x, *args).sum(dim=-1) influences = [] for j in range(src.size(0)): influence = grad([out[j]], [x], retain_graph=True)[0].abs().sum(dim=-1) influences.append(influence / influence.sum()) return torch.stack(influences, dim=0) ================================================ FILE: torch_geometric/warnings.py ================================================ import warnings from typing import Literal import torch_geometric def warn(message: str, stacklevel: int = 5) -> None: if torch_geometric.is_compiling(): return warnings.warn(message, stacklevel=stacklevel) def filterwarnings( action: Literal['default', 'error', 'ignore', 'always', 'module', 'once'], message: str, ) -> None: if torch_geometric.is_compiling(): return warnings.filterwarnings(action, message) class WarningCache(set): """Cache for warnings.""" def warn(self, message: str, stacklevel: int = 5) -> None: """Trigger warning message.""" if message not in self: self.add(message) warn(message, stacklevel=stacklevel)