Full Code of RUCAIBox/RecBole-GNN for AI

main 632ef8885899 cached
74 files
361.6 KB
125.2k tokens
339 symbols
1 requests
Download .txt
Showing preview only (385K chars total). Download the full file or copy to clipboard to get everything.
Repository: RUCAIBox/RecBole-GNN
Branch: main
Commit: 632ef8885899
Files: 74
Total size: 361.6 KB

Directory structure:
gitextract_g58jlyza/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   ├── bug_report_CN.md
│   │   ├── feature_request.md
│   │   └── feature_request_CN.md
│   └── workflows/
│       └── python-package.yml
├── .gitignore
├── LICENSE
├── README.md
├── recbole_gnn/
│   ├── config.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataloader.py
│   │   ├── dataset.py
│   │   └── transform.py
│   ├── model/
│   │   ├── abstract_recommender.py
│   │   ├── general_recommender/
│   │   │   ├── __init__.py
│   │   │   ├── directau.py
│   │   │   ├── hmlet.py
│   │   │   ├── lightgcl.py
│   │   │   ├── lightgcn.py
│   │   │   ├── ncl.py
│   │   │   ├── ngcf.py
│   │   │   ├── sgl.py
│   │   │   ├── simgcl.py
│   │   │   ├── ssl4rec.py
│   │   │   └── xsimgcl.py
│   │   ├── layers.py
│   │   ├── sequential_recommender/
│   │   │   ├── __init__.py
│   │   │   ├── gcegnn.py
│   │   │   ├── gcsan.py
│   │   │   ├── lessr.py
│   │   │   ├── niser.py
│   │   │   ├── sgnnhn.py
│   │   │   ├── srgnn.py
│   │   │   └── tagnn.py
│   │   └── social_recommender/
│   │       ├── __init__.py
│   │       ├── diffnet.py
│   │       ├── mhcn.py
│   │       └── sept.py
│   ├── properties/
│   │   ├── model/
│   │   │   ├── DiffNet.yaml
│   │   │   ├── DirectAU.yaml
│   │   │   ├── GCEGNN.yaml
│   │   │   ├── GCSAN.yaml
│   │   │   ├── HMLET.yaml
│   │   │   ├── LESSR.yaml
│   │   │   ├── LightGCL.yaml
│   │   │   ├── LightGCN.yaml
│   │   │   ├── MHCN.yaml
│   │   │   ├── NCL.yaml
│   │   │   ├── NGCF.yaml
│   │   │   ├── NISER.yaml
│   │   │   ├── SEPT.yaml
│   │   │   ├── SGL.yaml
│   │   │   ├── SGNNHN.yaml
│   │   │   ├── SRGNN.yaml
│   │   │   ├── SSL4REC.yaml
│   │   │   ├── SimGCL.yaml
│   │   │   ├── TAGNN.yaml
│   │   │   └── XSimGCL.yaml
│   │   └── quick_start_config/
│   │       ├── sequential_base.yaml
│   │       └── social_base.yaml
│   ├── quick_start.py
│   ├── trainer.py
│   └── utils.py
├── results/
│   ├── README.md
│   ├── general/
│   │   └── ml-1m.md
│   ├── sequential/
│   │   └── diginetica.md
│   └── social/
│       └── lastfm.md
├── run_hyper.py
├── run_recbole_gnn.py
├── run_test.sh
└── tests/
    ├── test_data/
    │   └── test/
    │       ├── test.inter
    │       └── test.net
    ├── test_model.py
    └── test_model.yaml

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: "[\U0001F41BBUG] Describe your problem in one sentence."
labels: bug
assignees: ''

---

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Steps to reproduce the behavior:
1. extra yaml file
2. your code
3. script for running

**Expected behavior**
A clear and concise description of what you expected to happen.

**Screenshots**
If applicable, add screenshots to help explain your problem.

**Colab Links**
If applicable, add links to Colab or other Jupyter laboratory platforms that can reproduce the bug.

**Desktop (please complete the following information):**
 - OS: [e.g. Linux, macOS or Windows]
- RecBole Version [e.g. 0.1.0]
 - Python Version [e.g. 3.79]
- PyTorch Version [e.g. 1.60]
- cudatoolkit Version [e.g. 9.2, none]


================================================
FILE: .github/ISSUE_TEMPLATE/bug_report_CN.md
================================================
---
name: Bug 报告
about: 提交一份 bug 报告,帮助 RecBole-GNN 变得更好
title: "[\U0001F41BBUG] 用一句话描述您的问题。"
labels: bug
assignees: ''

---

**描述这个 bug**
对 bug 作一个清晰简明的描述。

**如何复现**
复现这个 bug 的步骤:
1. 您引入的额外 yaml 文件
2. 您的代码
3. 您的运行脚本

**预期**
对您的预期作清晰简明的描述。

**屏幕截图**
添加屏幕截图以帮助解释您的问题。(可选)

**链接**
添加能够复现 bug 的代码链接,如 Colab 或者其他在线 Jupyter 平台。(可选)

**实验环境(请补全下列信息):**
 - 操作系统: [如 Linux, macOS 或 Windows]
- RecBole 版本 [如 0.1.0]
 - Python 版本 [如 3.79]
- PyTorch 版本 [如 1.60]
- cudatoolkit 版本 [如 9.2, none]


================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: "[\U0001F4A1SUG] Description of what you want to happen in one sentence"
labels: enhancement
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.


================================================
FILE: .github/ISSUE_TEMPLATE/feature_request_CN.md
================================================
---
name: 请求添加新功能
about: 提出一个关于本项目新功能/新特性的建议
title: "[\U0001F4A1SUG] 一句话描述您希望新增的功能或特性"
labels: enhancement
assignees: ''

---

**您希望添加的功能是否与某个问题相关?**
关于这个问题的简洁清晰的描述,例如,当 [...] 时,我总是很沮丧。

**描述您希望的解决方案**
关于解决方案的简洁清晰的描述。

**描述您考虑的替代方案**
关于您考虑的,能实现这个功能的其他替代方案的简洁清晰的描述。

**其他**
您可以添加其他任何的资料、链接或者屏幕截图,以帮助我们理解这个新功能。


================================================
FILE: .github/workflows/python-package.yml
================================================
name: RecBole-GNN tests

# Controls when the action will run. 
on:
  # Triggers the workflow on push or pull request events but only for the master branch
  push:
  pull_request:

  # Allows you to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  build:

    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: [3.9]
        torch-version: [2.0.0]
    defaults:
      run:
        shell: bash -l {0}

    steps:
    - uses: actions/checkout@v2
    - name: Setup Miniconda
      uses: conda-incubator/setup-miniconda@v2
      with:
        python-version: ${{ matrix.python-version }}
        channels: conda-forge
        channel-priority: true
        auto-activate-base: true
    # install setuptools as a interim solution for bugs in PyTorch 1.10.2 (#69904)
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install pytest
        pip install dgl
        pip install torch==${{ matrix.torch-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
        pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
        pip install recbole==1.1.1
        conda install -c conda-forge faiss-cpu
    # Use "python -m pytest" instead of "pytest" to fix imports
    - name: Test model
      run: |
        python -m pytest -v tests/test_model.py


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# RecBole
log_tensorboard/
saved/
dataset/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 RUCAIBox

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# RecBole-GNN

![](asset/recbole-gnn-logo.png)

-----

*Updates*:

* [Oct 29, 2023] Add [SSL4Rec](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/ssl4rec.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/76, by [@downeykking](https://github.com/downeykking))
* [Oct 23, 2023] Add sparse tensor support, accelerating LightGCN & NGCF by ~5x, with 1/6 GPU memories. (https://github.com/RUCAIBox/RecBole-GNN/pull/75, by [@downeykking](https://github.com/downeykking))
* [Oct 20, 2023] Add [DirectAU](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/directau.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/74, by [@downeykking](https://github.com/downeykking))
* [Oct 16, 2023] Add [XSimGCL](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/xsimgcl.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/72, by [@downeykking](https://github.com/downeykking))
* [Apr 12, 2023] Add [LightGCL](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/lightgcl.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/63, by [@wending0417](https://github.com/wending0417))
* [Oct 29, 2022] Adaptation to RecBole 1.1.1. (https://github.com/RUCAIBox/RecBole-GNN/pull/53)
* [Jun 15, 2022] Add [MultiBehaviorDataset](https://github.com/RUCAIBox/RecBole-GNN/blob/8c61463451b294dce9af2d1939a5e054f7955e0f/recbole_gnn/data/dataset.py#L145). (https://github.com/RUCAIBox/RecBole-GNN/pull/43, by [@Tokkiu](https://github.com/Tokkiu))

-----

**RecBole-GNN** is a library built upon [PyTorch](https://pytorch.org) and [RecBole](https://github.com/RUCAIBox/RecBole) for reproducing and developing recommendation algorithms based on graph neural networks (GNNs). Our library includes algorithms covering three major categories:
* **General Recommendation** with user-item interaction graphs;
* **Sequential Recommendation** with session/sequence graphs;
* **Social Recommendation** with social networks.

![](asset/arch.png)

## Highlights

* **Easy-to-use and unified API**:
    Our library shares unified API and input (atomic files) as RecBole.
* **Efficient and reusable graph processing**:
    We provide highly efficient and reusable basic datasets, dataloaders and layers for graph processing and learning.
* **Extensive graph library**:
    Graph neural networks from widely-used library like [PyG](https://github.com/pyg-team/pytorch_geometric) are incorporated. Recently proposed graph algorithms can be easily equipped and compared with existing methods.

## Requirements

```
recbole==1.1.1
pyg>=2.0.4
pytorch>=1.7.0
python>=3.7.0
```

> If you are using `recbole==1.0.1`, please refer to our `recbole1.0.1` branch [[link]](https://github.com/hyp1231/RecBole-GNN/tree/recbole1.0.1).

## Quick-Start

With the source code, you can use the provided script for initial usage of our library:

```bash
python run_recbole_gnn.py
```

If you want to change the models or datasets, just run the script by setting additional command parameters:

```bash
python run_recbole_gnn.py -m [model] -d [dataset]
```

## Implemented Models

We list currently supported models according to category:

**General Recommendation**:

* **[NGCF](recbole_gnn/model/general_recommender/ngcf.py)** from Wang *et al.*: [Neural Graph Collaborative Filtering](https://arxiv.org/abs/1905.08108) (SIGIR 2019).
* **[LightGCN](recbole_gnn/model/general_recommender/lightgcn.py)** from He *et al.*: [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](https://arxiv.org/abs/2002.02126) (SIGIR 2020).
* **[SSL4Rec](recbole_gnn/model/general_recommender/ssl4rec.py)** from Yao *et al.*: [Self-supervised Learning for Large-scale Item Recommendations](https://arxiv.org/abs/2007.12865) (CIKM 2021).
* **[SGL](recbole_gnn/model/general_recommender/sgl.py)** from Wu *et al.*: [Self-supervised Graph Learning for Recommendation](https://arxiv.org/abs/2010.10783) (SIGIR 2021).
* **[HMLET](recbole_gnn/model/general_recommender/hmlet.py)** from Kong *et al.*: [Linear, or Non-Linear, That is the Question!](https://arxiv.org/abs/2111.07265) (WSDM 2022).
* **[NCL](recbole_gnn/model/general_recommender/ncl.py)** from Lin *et al.*: [Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning](https://arxiv.org/abs/2202.06200) (TheWebConf 2022).
* **[DirectAU](recbole_gnn/model/general_recommender/directau.py)** from Wang *et al.*: [Towards Representation Alignment and Uniformity in Collaborative Filtering](https://arxiv.org/abs/2206.12811) (KDD 2022).
* **[SimGCL](recbole_gnn/model/general_recommender/simgcl.py)** from Yu *et al.*: [Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2112.08679) (SIGIR 2022).
* **[XSimGCL](recbole_gnn/model/general_recommender/xsimgcl.py)** from Yu *et al.*: [XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2209.02544) (TKDE 2023).
* **[LightGCL](recbole_gnn/model/general_recommender/lightgcl.py)** from Cai *et al.*: [LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation
](https://arxiv.org/abs/2302.08191) (ICLR 2023).

**Sequential Recommendation**:

* **[SR-GNN](recbole_gnn/model/sequential_recommender/srgnn.py)** from Wu *et al.*: [Session-based Recommendation with Graph Neural Networks](https://arxiv.org/abs/1811.00855) (AAAI 2019).
* **[GC-SAN](recbole_gnn/model/sequential_recommender/gcsan.py)** from Xu *et al.*: [Graph Contextualized Self-Attention Network for Session-based Recommendation](https://www.ijcai.org/proceedings/2019/547) (IJCAI 2019).
* **[NISER+](recbole_gnn/model/sequential_recommender/niser.py)** from Gupta *et al.*: [NISER: Normalized Item and Session Representations to Handle Popularity Bias](https://arxiv.org/abs/1909.04276) (GRLA, CIKM 2019 workshop).
* **[LESSR](recbole_gnn/model/sequential_recommender/lessr.py)** from Chen *et al.*: [Handling Information Loss of Graph Neural Networks for Session-based Recommendation](https://dl.acm.org/doi/10.1145/3394486.3403170) (KDD 2020).
* **[TAGNN](recbole_gnn/model/sequential_recommender/tagnn.py)** from Yu *et al.*: [TAGNN: Target Attentive Graph Neural Networks for Session-based Recommendation](https://arxiv.org/abs/2005.02844) (SIGIR 2020 short).
* **[GCE-GNN](recbole_gnn/model/sequential_recommender/gcegnn.py)** from Wang *et al.*: [Global Context Enhanced Graph Neural Networks for Session-based Recommendation](https://arxiv.org/abs/2106.05081) (SIGIR 2020).
* **[SGNN-HN](recbole_gnn/model/sequential_recommender/sgnnhn.py)** from Pan *et al.*: [Star Graph Neural Networks for Session-based Recommendation](https://dl.acm.org/doi/10.1145/3340531.3412014) (CIKM 2020).

**Social Recommendation**:

> Note that datasets for social recommendation methods can be downloaded from [Social-Datasets](https://github.com/Sherry-XLL/Social-Datasets).

* **[DiffNet](recbole_gnn/model/social_recommender/diffnet.py)** from Wu *et al.*: [A Neural Influence Diffusion Model for Social Recommendation](https://arxiv.org/abs/1904.10322) (SIGIR 2019).
* **[MHCN](recbole_gnn/model/social_recommender/mhcn.py)** from Yu *et al.*: [Self-Supervised Multi-Channel Hypergraph Convolutional Network for Social Recommendation](https://doi.org/10.1145/3442381.3449844) (WWW 2021).
* **[SEPT](recbole_gnn/model/social_recommender/sept.py)** from Yu *et al.*: [Socially-Aware Self-Supervised Tri-Training for Recommendation](https://doi.org/10.1145/3447548.3467340) (KDD 2021).

## Result

### Leaderboard

We carefully tune the hyper-parameters of the implemented models of each research field and release the corresponding leaderboards for reference:

- **General** recommendation on `MovieLens-1M` dataset [[link]](results/general/ml-1m.md);
- **Sequential** recommendation on `Diginetica` dataset [[link]](results/sequential/diginetica.md);
- **Social** recommendation on `LastFM` dataset [[link]](results/social/lastfm.md);

### Efficiency

With the sequential/session graphs preprocessing technique, as well as efficient GNN layers, we speed up the training process of our sequential recommenders a lot.

<img src='asset/ml-1m.png' width='25%'><img src='asset/diginetica.png' width='25%'>

## The Team

RecBole-GNN is initially developed and maintained by members from [RUCAIBox](http://aibox.ruc.edu.cn/), the main developers are Yupeng Hou ([@hyp1231](https://github.com/hyp1231)), Lanling Xu ([@Sherry-XLL](https://github.com/Sherry-XLL)) and Changxin Tian ([@ChangxinTian](https://github.com/ChangxinTian)). We also thank Xinzhou ([@downeykking](https://github.com/downeykking)), Wanli ([@wending0417](https://github.com/wending0417)), and Jingqi ([@Tokkiu](https://github.com/Tokkiu)) for their great contribution! ❤️

## Acknowledgement

The implementation is based on the open-source recommendation library [RecBole](https://github.com/RUCAIBox/RecBole). RecBole-GNN is part of [RecBole 2.0](https://github.com/RUCAIBox/RecBole2.0) now!

Please cite the following paper as the reference if you use our code or processed datasets.

```bibtex
@inproceedings{zhao2022recbole2,
  author={Wayne Xin Zhao and Yupeng Hou and Xingyu Pan and Chen Yang and Zeyu Zhang and Zihan Lin and Jingsen Zhang and Shuqing Bian and Jiakai Tang and Wenqi Sun and Yushuo Chen and Lanling Xu and Gaowei Zhang and Zhen Tian and Changxin Tian and Shanlei Mu and Xinyan Fan and Xu Chen and Ji-Rong Wen},
  title={RecBole 2.0: Towards a More Up-to-Date Recommendation Library},
  booktitle = {{CIKM}},
  year={2022}
}

@inproceedings{zhao2021recbole,
  author    = {Wayne Xin Zhao and Shanlei Mu and Yupeng Hou and Zihan Lin and Yushuo Chen and Xingyu Pan and Kaiyuan Li and Yujie Lu and Hui Wang and Changxin Tian and  Yingqian Min and Zhichao Feng and Xinyan Fan and Xu Chen and Pengfei Wang and Wendi Ji and Yaliang Li and Xiaoling Wang and Ji{-}Rong Wen},
  title     = {RecBole: Towards a Unified, Comprehensive and Efficient Framework for Recommendation Algorithms},
  booktitle = {{CIKM}},
  pages     = {4653--4664},
  publisher = {{ACM}},
  year      = {2021}
}
```


================================================
FILE: recbole_gnn/config.py
================================================
import os
import recbole
from recbole.config.configurator import Config as RecBole_Config
from recbole.utils import ModelType as RecBoleModelType

from recbole_gnn.utils import get_model, ModelType


class Config(RecBole_Config):
    def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=None):
        """
        Args:
            model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config
            will search the parameter 'model' from the external input as the model name or model class.
            dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset'
            from the external input as the dataset name.
            config_file_list (list of str): the external config file, it allows multiple config files, default is None.
            config_dict (dict): the external parameter dictionaries, default is None.
        """
        if recbole.__version__ == "1.1.1":
            self.compatibility_settings()
        super(Config, self).__init__(model, dataset, config_file_list, config_dict)

    def compatibility_settings(self):
        import numpy as np
        np.bool = np.bool_
        np.int = np.int_
        np.float = np.float_
        np.complex = np.complex_
        np.object = np.object_
        np.str = np.str_
        np.long = np.int_
        np.unicode = np.unicode_

    def _get_model_and_dataset(self, model, dataset):

        if model is None:
            try:
                model = self.external_config_dict['model']
            except KeyError:
                raise KeyError(
                    'model need to be specified in at least one of the these ways: '
                    '[model variable, config file, config dict, command line] '
                )
        if not isinstance(model, str):
            final_model_class = model
            final_model = model.__name__
        else:
            final_model = model
            final_model_class = get_model(final_model)

        if dataset is None:
            try:
                final_dataset = self.external_config_dict['dataset']
            except KeyError:
                raise KeyError(
                    'dataset need to be specified in at least one of the these ways: '
                    '[dataset variable, config file, config dict, command line] '
                )
        else:
            final_dataset = dataset

        return final_model, final_model_class, final_dataset

    def _load_internal_config_dict(self, model, model_class, dataset):
        super()._load_internal_config_dict(model, model_class, dataset)
        current_path = os.path.dirname(os.path.realpath(__file__))
        model_init_file = os.path.join(current_path, './properties/model/' + model + '.yaml')
        quick_start_config_path = os.path.join(current_path, './properties/quick_start_config/')
        sequential_base_init = os.path.join(quick_start_config_path, 'sequential_base.yaml')
        social_base_init = os.path.join(quick_start_config_path, 'social_base.yaml')

        if os.path.isfile(model_init_file):
            config_dict = self._update_internal_config_dict(model_init_file)

        self.internal_config_dict['MODEL_TYPE'] = model_class.type
        if self.internal_config_dict['MODEL_TYPE'] == RecBoleModelType.SEQUENTIAL:
            self._update_internal_config_dict(sequential_base_init)
        if self.internal_config_dict['MODEL_TYPE'] == ModelType.SOCIAL:
            self._update_internal_config_dict(social_base_init)


================================================
FILE: recbole_gnn/data/__init__.py
================================================


================================================
FILE: recbole_gnn/data/dataloader.py
================================================
import numpy as np
import torch
from recbole.data.interaction import cat_interactions
from recbole.data.dataloader.general_dataloader import TrainDataLoader, NegSampleEvalDataLoader, FullSortEvalDataLoader

from recbole_gnn.data.transform import gnn_construct_transform


class CustomizedTrainDataLoader(TrainDataLoader):
    def __init__(self, config, dataset, sampler, shuffle=False):
        super().__init__(config, dataset, sampler, shuffle=shuffle)
        if config['gnn_transform'] is not None:
            self.transform = gnn_construct_transform(config)


class CustomizedNegSampleEvalDataLoader(NegSampleEvalDataLoader):
    def __init__(self, config, dataset, sampler, shuffle=False):
        super().__init__(config, dataset, sampler, shuffle=shuffle)
        if config['gnn_transform'] is not None:
            self.transform = gnn_construct_transform(config)

    def collate_fn(self, index):
        index = np.array(index)
        if (
            self.neg_sample_args["distribution"] != "none"
            and self.neg_sample_args["sample_num"] != "none"
        ):
            uid_list = self.uid_list[index]
            data_list = []
            idx_list = []
            positive_u = []
            positive_i = torch.tensor([], dtype=torch.int64)

            for idx, uid in enumerate(uid_list):
                index = self.uid2index[uid]
                data_list.append(self._neg_sampling(self._dataset[index]))
                idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)]
                positive_u += [idx for i in range(self.uid2items_num[uid])]
                positive_i = torch.cat(
                    (positive_i, self._dataset[index][self.iid_field]), 0
                )

            cur_data = cat_interactions(data_list)
            idx_list = torch.from_numpy(np.array(idx_list)).long()
            positive_u = torch.from_numpy(np.array(positive_u)).long()

            return self.transform(self._dataset, cur_data), idx_list, positive_u, positive_i
        else:
            data = self._dataset[index]
            transformed_data = self.transform(self._dataset, data)
            cur_data = self._neg_sampling(transformed_data)
            return cur_data, None, None, None


class CustomizedFullSortEvalDataLoader(FullSortEvalDataLoader):
    def __init__(self, config, dataset, sampler, shuffle=False):
        super().__init__(config, dataset, sampler, shuffle=shuffle)
        if config['gnn_transform'] is not None:
            self.transform = gnn_construct_transform(config)


================================================
FILE: recbole_gnn/data/dataset.py
================================================
import os
import torch
import numpy as np
import pandas as pd

from tqdm import tqdm
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import degree
try:
    from torch_sparse import SparseTensor
    is_sparse = True
except ImportError:
    is_sparse = False

from recbole.data.dataset import SequentialDataset
from recbole.data.dataset import Dataset as RecBoleDataset
from recbole.utils import set_color, FeatureSource

import recbole
import pickle
from recbole.utils import ensure_dir


class GeneralGraphDataset(RecBoleDataset):
    def __init__(self, config):
        super().__init__(config)

    if recbole.__version__ == "1.1.1":

        def save(self):
            """Saving this :class:`Dataset` object to :attr:`config['checkpoint_dir']`."""
            save_dir = self.config["checkpoint_dir"]
            ensure_dir(save_dir)
            file = os.path.join(save_dir, f'{self.config["dataset"]}-{self.__class__.__name__}.pth')
            self.logger.info(
                set_color("Saving filtered dataset into ", "pink") + f"[{file}]"
            )
            with open(file, "wb") as f:
                pickle.dump(self, f)

    @staticmethod
    def edge_index_to_adj_t(edge_index, edge_weight, m_num_nodes, n_num_nodes):
        adj = SparseTensor(row=edge_index[0],
                           col=edge_index[1],
                           value=edge_weight,
                           sparse_sizes=(m_num_nodes, n_num_nodes))
        return adj.t()

    def get_norm_adj_mat(self, enable_sparse=False):
        self.is_sparse = is_sparse
        r"""Get the normalized interaction matrix of users and items.
        Construct the square matrix from the training data and normalize it
        using the laplace matrix.
        .. math::
            A_{hat} = D^{-0.5} \times A \times D^{-0.5}
        Returns:
            The normalized interaction matrix in Tensor.
        """

        row = self.inter_feat[self.uid_field]
        col = self.inter_feat[self.iid_field] + self.user_num
        edge_index1 = torch.stack([row, col])
        edge_index2 = torch.stack([col, row])
        edge_index = torch.cat([edge_index1, edge_index2], dim=1)
        edge_weight = torch.ones(edge_index.size(1))
        num_nodes = self.user_num + self.item_num

        if enable_sparse:
            if not is_sparse:
                self.logger.warning(
                    "Import `torch_sparse` error, please install corrsponding version of `torch_sparse`. Now we will use dense edge_index instead of SparseTensor in dataset.")
            else:
                adj_t = self.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes)
                adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False)
                return adj_t, None

        edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False)

        return edge_index, edge_weight

    def get_bipartite_inter_mat(self, row='user', row_norm=True):
        r"""Get the row-normalized bipartite interaction matrix of users and items.
        """
        if row == 'user':
            row_field, col_field = self.uid_field, self.iid_field
        else:
            row_field, col_field = self.iid_field, self.uid_field

        row = self.inter_feat[row_field]
        col = self.inter_feat[col_field]
        edge_index = torch.stack([row, col])

        if row_norm:
            deg = degree(edge_index[0], self.num(row_field))
            norm_deg = 1. / torch.where(deg == 0, torch.ones([1]), deg)
            edge_weight = norm_deg[edge_index[0]]
        else:
            row_deg = degree(edge_index[0], self.num(row_field))
            col_deg = degree(edge_index[1], self.num(col_field))

            row_norm_deg = 1. / torch.sqrt(torch.where(row_deg == 0, torch.ones([1]), row_deg))
            col_norm_deg = 1. / torch.sqrt(torch.where(col_deg == 0, torch.ones([1]), col_deg))

            edge_weight = row_norm_deg[edge_index[0]] * col_norm_deg[edge_index[1]]

        return edge_index, edge_weight


class SessionGraphDataset(SequentialDataset):
    def __init__(self, config):
        super().__init__(config)

    def session_graph_construction(self):
        # Default session graph dataset follows the graph construction operator like SR-GNN.
        self.logger.info('Constructing session graphs.')
        item_seq = self.inter_feat[self.item_id_list_field]
        item_seq_len = self.inter_feat[self.item_list_length_field]
        x = []
        edge_index = []
        alias_inputs = []

        for i, seq in enumerate(tqdm(list(torch.chunk(item_seq, item_seq.shape[0])))):
            seq, idx = torch.unique(seq, return_inverse=True)
            x.append(seq)
            alias_seq = idx.squeeze(0)[:item_seq_len[i]]
            alias_inputs.append(alias_seq)
            # No repeat click
            edge = torch.stack([alias_seq[:-1], alias_seq[1:]]).unique(dim=-1)
            edge_index.append(edge)

        self.inter_feat.interaction['graph_idx'] = torch.arange(item_seq.shape[0])
        self.graph_objs = {
            'x': x,
            'edge_index': edge_index,
            'alias_inputs': alias_inputs
        }

    def build(self):
        datasets = super().build()
        for dataset in datasets:
            dataset.session_graph_construction()
        return datasets


class MultiBehaviorDataset(SessionGraphDataset):

    def session_graph_construction(self):
        self.logger.info('Constructing multi-behavior session graphs.')
        self.item_behavior_list_field = self.config['ITEM_BEHAVIOR_LIST_FIELD']
        self.behavior_id_field = self.config['BEHAVIOR_ID_FIELD']
        item_seq = self.inter_feat[self.item_id_list_field]
        item_seq_len = self.inter_feat[self.item_list_length_field]
        if self.item_behavior_list_field == None or self.behavior_id_field == None:
            # To be compatible with existing datasets
            item_behavior_seq = torch.tensor([0] * len(item_seq))
            self.behavior_id_field = 'behavior_id'
            self.field2id_token[self.behavior_id_field] = {0: 'interaction'}
        else:
            item_behavior_seq = self.inter_feat[self.item_list_length_field]

        edge_index = []
        alias_inputs = []
        behaviors = torch.unique(item_behavior_seq)
        x = {}
        for behavior in behaviors:
            x[behavior.item()] = []

        behavior_seqs = list(torch.chunk(item_behavior_seq, item_seq.shape[0]))
        for i, seq in enumerate(tqdm(list(torch.chunk(item_seq, item_seq.shape[0])))):
            bseq = behavior_seqs[i]
            for behavior in behaviors:
                bidx = torch.where(bseq == behavior)
                subseq = torch.index_select(seq, 0, bidx[0])
                subseq, _ = torch.unique(subseq, return_inverse=True)
                x[behavior.item()].append(subseq)

            seq, idx = torch.unique(seq, return_inverse=True)
            alias_seq = idx.squeeze(0)[:item_seq_len[i]]
            alias_inputs.append(alias_seq)
            # No repeat click
            edge = torch.stack([alias_seq[:-1], alias_seq[1:]]).unique(dim=-1)
            edge_index.append(edge)

        nx = {}
        for k, v in x.items():
            behavior_name = self.id2token(self.behavior_id_field, k)
            nx[behavior_name] = v

        self.inter_feat.interaction['graph_idx'] = torch.arange(item_seq.shape[0])
        self.graph_objs = {
            'x': nx,
            'edge_index': edge_index,
            'alias_inputs': alias_inputs
        }


class LESSRDataset(SessionGraphDataset):
    def session_graph_construction(self):
        self.logger.info('Constructing LESSR session graphs.')
        item_seq = self.inter_feat[self.item_id_list_field]
        item_seq_len = self.inter_feat[self.item_list_length_field]

        empty_edge = torch.stack([torch.LongTensor([]), torch.LongTensor([])])

        x = []
        edge_index_EOP = []
        edge_index_shortcut = []
        is_last = []

        for i, seq in enumerate(tqdm(list(torch.chunk(item_seq, item_seq.shape[0])))):
            seq, idx = torch.unique(seq, return_inverse=True)
            x.append(seq)
            alias_seq = idx.squeeze(0)[:item_seq_len[i]]
            edge = torch.stack([alias_seq[:-1], alias_seq[1:]])
            edge_index_EOP.append(edge)
            last = torch.zeros_like(seq, dtype=torch.bool)
            last[alias_seq[-1]] = True
            is_last.append(last)
            sub_edges = []
            for j in range(1, item_seq_len[i]):
                sub_edges.append(torch.stack([alias_seq[:-j], alias_seq[j:]]))
            shortcut_edge = torch.cat(sub_edges, dim=-1).unique(dim=-1) if len(sub_edges) > 0 else empty_edge
            edge_index_shortcut.append(shortcut_edge)

        self.inter_feat.interaction['graph_idx'] = torch.arange(item_seq.shape[0])
        self.graph_objs = {
            'x': x,
            'edge_index_EOP': edge_index_EOP,
            'edge_index_shortcut': edge_index_shortcut,
            'is_last': is_last
        }
        self.node_attr = ['x', 'is_last']


class GCEGNNDataset(SequentialDataset):
    def __init__(self, config):
        super().__init__(config)

    def reverse_session(self):
        self.logger.info('Reversing sessions.')
        item_seq = self.inter_feat[self.item_id_list_field]
        item_seq_len = self.inter_feat[self.item_list_length_field]
        for i in tqdm(range(item_seq.shape[0])):
            item_seq[i, :item_seq_len[i]] = item_seq[i, :item_seq_len[i]].flip(dims=[0])

    def bidirectional_edge(self, edge_index):
        seq_len = edge_index.shape[1]
        ed = edge_index.T
        ed2 = edge_index.T.flip(dims=[1])
        idc = ed.unsqueeze(1).expand(-1, seq_len, 2) == ed2.unsqueeze(0).expand(seq_len, -1, 2)
        return torch.logical_and(idc[:, :, 0], idc[:, :, 1]).any(dim=-1)

    def session_graph_construction(self):
        self.logger.info('Constructing session graphs.')
        item_seq = self.inter_feat[self.item_id_list_field]
        item_seq_len = self.inter_feat[self.item_list_length_field]
        x = []
        edge_index = []
        edge_attr = []
        alias_inputs = []

        for i, seq in enumerate(tqdm(list(torch.chunk(item_seq, item_seq.shape[0])))):
            seq, idx = torch.unique(seq, return_inverse=True)
            x.append(seq)
            alias_seq = idx.squeeze(0)[:item_seq_len[i]]
            alias_inputs.append(alias_seq)

            edge_index_backward = torch.stack([alias_seq[:-1], alias_seq[1:]])
            edge_attr_backward = torch.where(self.bidirectional_edge(edge_index_backward), 3, 1)
            edge_backward = torch.cat([edge_index_backward, edge_attr_backward.unsqueeze(0)], dim=0)

            edge_index_forward = torch.stack([alias_seq[1:], alias_seq[:-1]])
            edge_attr_forward = torch.where(self.bidirectional_edge(edge_index_forward), 3, 2)
            edge_forward = torch.cat([edge_index_forward, edge_attr_forward.unsqueeze(0)], dim=0)

            edge_index_selfloop = torch.stack([alias_seq, alias_seq])
            edge_selfloop = torch.cat([edge_index_selfloop, torch.zeros([1, edge_index_selfloop.shape[1]])], dim=0)

            edge = torch.cat([edge_backward, edge_forward, edge_selfloop], dim=-1).long()
            edge = edge.unique(dim=-1)

            cur_edge_index = edge[:2]
            cur_edge_attr = edge[2]
            edge_index.append(cur_edge_index)
            edge_attr.append(cur_edge_attr)

        self.inter_feat.interaction['graph_idx'] = torch.arange(item_seq.shape[0])
        self.graph_objs = {
            'x': x,
            'edge_index': edge_index,
            'edge_attr': edge_attr,
            'alias_inputs': alias_inputs
        }

    def build(self):
        datasets = super().build()
        for dataset in datasets:
            dataset.reverse_session()
            dataset.session_graph_construction()
        return datasets


class SocialDataset(GeneralGraphDataset):
    """:class:`SocialDataset` is based on :class:`~recbole_gnn.data.dataset.GeneralGraphDataset`,
    and load ``.net``.

    All users in ``.inter`` and ``.net`` are remapped into the same ID sections.
    Users that only exist in social network will be filtered.

    It also provides several interfaces to transfer ``.net`` features into coo sparse matrix,
    csr sparse matrix, :class:`DGL.Graph` or :class:`PyG.Data`.

    Attributes:
        net_src_field (str): The same as ``config['NET_SOURCE_ID_FIELD']``.

        net_tgt_field (str): The same as ``config['NET_TARGET_ID_FIELD']``.

        net_feat (pandas.DataFrame): Internal data structure stores the users' social network relations.
            It's loaded from file ``.net``.
    """

    def __init__(self, config):
        super().__init__(config)

    def _get_field_from_config(self):
        super()._get_field_from_config()

        self.net_src_field = self.config['NET_SOURCE_ID_FIELD']
        self.net_tgt_field = self.config['NET_TARGET_ID_FIELD']
        self.filter_net_by_inter = self.config['filter_net_by_inter']
        self.undirected_net = self.config['undirected_net']
        self._check_field('net_src_field', 'net_tgt_field')

        self.logger.debug(set_color('net_src_field', 'blue') + f': {self.net_src_field}')
        self.logger.debug(set_color('net_tgt_field', 'blue') + f': {self.net_tgt_field}')

    def _data_filtering(self):
        super()._data_filtering()
        if self.filter_net_by_inter:
            self._filter_net_by_inter()

    def _filter_net_by_inter(self):
        """Filter users in ``net_feat`` that don't occur in interactions.
        """
        inter_uids = set(self.inter_feat[self.uid_field])
        self.net_feat.drop(self.net_feat.index[~self.net_feat[self.net_src_field].isin(inter_uids)], inplace=True)
        self.net_feat.drop(self.net_feat.index[~self.net_feat[self.net_tgt_field].isin(inter_uids)], inplace=True)

    def _load_data(self, token, dataset_path):
        super()._load_data(token, dataset_path)
        self.net_feat = self._load_net(self.dataset_name, self.dataset_path)

    @property
    def net_num(self):
        """Get the number of social network records.

        Returns:
            int: Number of social network records.
        """
        return len(self.net_feat)

    def __str__(self):
        info = [
            super().__str__(),
            set_color('The number of social network relations', 'blue') + f': {self.net_num}'
        ]  # yapf: disable
        return '\n'.join(info)

    def _build_feat_name_list(self):
        feat_name_list = super()._build_feat_name_list()
        if self.net_feat is not None:
            feat_name_list.append('net_feat')
        return feat_name_list

    def _load_net(self, token, dataset_path):
        self.logger.debug(set_color(f'Loading social network from [{dataset_path}].', 'green'))
        net_path = os.path.join(dataset_path, f'{token}.net')
        if not os.path.isfile(net_path):
            raise ValueError(f'[{token}.net] not found in [{dataset_path}].')
        df = self._load_feat(net_path, FeatureSource.NET)
        if self.undirected_net:
            row = df[self.net_src_field]
            col = df[self.net_tgt_field]
            df_net_src = pd.concat([row, col], axis=0)
            df_net_tgt = pd.concat([col, row], axis=0)
            df_net_src.name = self.net_src_field
            df_net_tgt.name = self.net_tgt_field
            df = pd.concat([df_net_src, df_net_tgt], axis=1)
        self._check_net(df)
        return df

    def _check_net(self, net):
        net_warn_message = 'net data requires field [{}]'
        assert self.net_src_field in net, net_warn_message.format(self.net_src_field)
        assert self.net_tgt_field in net, net_warn_message.format(self.net_tgt_field)

    def _init_alias(self):
        """Add :attr:`alias_of_user_id`.
        """
        self._set_alias('user_id', [self.uid_field, self.net_src_field, self.net_tgt_field])
        self._set_alias('item_id', [self.iid_field])

        for alias_name_1, alias_1 in self.alias.items():
            for alias_name_2, alias_2 in self.alias.items():
                if alias_name_1 != alias_name_2:
                    intersect = np.intersect1d(alias_1, alias_2, assume_unique=True)
                    if len(intersect) > 0:
                        raise ValueError(
                            f'`alias_of_{alias_name_1}` and `alias_of_{alias_name_2}` '
                            f'should not have the same field {list(intersect)}.'
                        )

        self._rest_fields = self.token_like_fields
        for alias_name, alias in self.alias.items():
            isin = np.isin(alias, self._rest_fields, assume_unique=True)
            if isin.all() is False:
                raise ValueError(
                    f'`alias_of_{alias_name}` should not contain '
                    f'non-token-like field {list(alias[~isin])}.'
                )
            self._rest_fields = np.setdiff1d(self._rest_fields, alias, assume_unique=True)

    def get_norm_net_adj_mat(self, row_norm=False):
        r"""Get the normalized socail matrix of users and users.
        Construct the square matrix from the social network data and 
        normalize it using the laplace matrix.
        .. math::
            A_{hat} = D^{-0.5} \times A \times D^{-0.5}
        Returns:
            The normalized social network matrix in Tensor.
        """

        row = self.net_feat[self.net_src_field]
        col = self.net_feat[self.net_tgt_field]
        edge_index = torch.stack([row, col])

        deg = degree(edge_index[0], self.user_num)

        if row_norm:
            norm_deg = 1. / torch.where(deg == 0, torch.ones([1]), deg)
            edge_weight = norm_deg[edge_index[0]]
        else:
            norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg))
            edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]]

        return edge_index, edge_weight

    def net_matrix(self, form='coo', value_field=None):
        """Get sparse matrix that describe social relations between user_id and user_id.

        Sparse matrix has shape (user_num, user_num).

        Returns:
            scipy.sparse: Sparse matrix in form ``coo`` or ``csr``.
        """
        return self._create_sparse_matrix(self.net_feat, self.net_src_field, self.net_tgt_field, form, value_field)


================================================
FILE: recbole_gnn/data/transform.py
================================================
from logging import getLogger
import torch
from torch.nn.utils.rnn import pad_sequence
from recbole.data.interaction import Interaction


def gnn_construct_transform(config):
    if config['gnn_transform'] is None:
        raise ValueError('config["gnn_transform"] is None but trying to construct transform.')
    str2transform = {
        'sess_graph': SessionGraph,
    }
    return str2transform[config['gnn_transform']](config)


class SessionGraph:
    def __init__(self, config):
        self.logger = getLogger()
        self.logger.info('SessionGraph Transform in DataLoader.')

    def __call__(self, dataset, interaction):
        graph_objs = dataset.graph_objs
        index = interaction['graph_idx']
        graph_batch = {
            k: [graph_objs[k][_.item()] for _ in index]
            for k in graph_objs
        }
        graph_batch['batch'] = []

        tot_node_num = torch.ones([1], dtype=torch.long)
        for i in range(index.shape[0]):
            for k in graph_batch:
                if 'edge_index' in k:
                    graph_batch[k][i] = graph_batch[k][i] + tot_node_num
            if 'alias_inputs' in graph_batch:
                graph_batch['alias_inputs'][i] = graph_batch['alias_inputs'][i] + tot_node_num
            graph_batch['batch'].append(torch.full_like(graph_batch['x'][i], i))
            tot_node_num += graph_batch['x'][i].shape[0]

        if hasattr(dataset, 'node_attr'):
            node_attr = ['batch'] + dataset.node_attr
        else:
            node_attr = ['x', 'batch']
        for k in node_attr:
            graph_batch[k] = [torch.zeros([1], dtype=graph_batch[k][-1].dtype)] + graph_batch[k]

        for k in graph_batch:
            if k == 'alias_inputs':
                graph_batch[k] = pad_sequence(graph_batch[k], batch_first=True)
            else:
                graph_batch[k] = torch.cat(graph_batch[k], dim=-1)

        interaction.update(Interaction(graph_batch))
        return interaction


================================================
FILE: recbole_gnn/model/abstract_recommender.py
================================================
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.utils import ModelType as RecBoleModelType

from recbole_gnn.utils import ModelType


class GeneralGraphRecommender(GeneralRecommender):
    """This is an abstract general graph recommender. All the general graph models should implement in this class.
    The base general graph recommender class provide the basic U-I graph dataset and parameters information.
    """
    type = RecBoleModelType.GENERAL

    def __init__(self, config, dataset):
        super(GeneralGraphRecommender, self).__init__(config, dataset)
        self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"])
        self.use_sparse = config["enable_sparse"] and dataset.is_sparse
        if self.use_sparse:
            self.edge_index, self.edge_weight = self.edge_index.to(self.device), None
        else:
            self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)


class SocialRecommender(GeneralRecommender):
    """This is an abstract social recommender. All the social graph model should implement this class.
    The base social recommender class provide the basic social graph dataset and parameters information.
    """
    type = ModelType.SOCIAL

    def __init__(self, config, dataset):
        super(SocialRecommender, self).__init__(config, dataset)


================================================
FILE: recbole_gnn/model/general_recommender/__init__.py
================================================
from recbole_gnn.model.general_recommender.lightgcn import LightGCN
from recbole_gnn.model.general_recommender.hmlet import HMLET
from recbole_gnn.model.general_recommender.ncl import NCL
from recbole_gnn.model.general_recommender.ngcf import NGCF
from recbole_gnn.model.general_recommender.sgl import SGL
from recbole_gnn.model.general_recommender.lightgcl import LightGCL
from recbole_gnn.model.general_recommender.simgcl import SimGCL
from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL
from recbole_gnn.model.general_recommender.directau import DirectAU
from recbole_gnn.model.general_recommender.ssl4rec import SSL4REC


================================================
FILE: recbole_gnn/model/general_recommender/directau.py
================================================
# r"""
# DiretAU
# ################################################
# Reference:
#     Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022.

# Reference code:
#     https://github.com/THUwangcy/DirectAU
# """

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType
from recbole.model.general_recommender import BPR
from recbole_gnn.model.general_recommender import LightGCN

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender


class DirectAU(GeneralGraphRecommender):
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(DirectAU, self).__init__(config, dataset)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.gamma = config['gamma']
        self.encoder_name = config['encoder']

        # define encoder
        if self.encoder_name == 'MF':
            self.encoder = MFEncoder(config, dataset)
        elif self.encoder_name == 'LightGCN':
            self.encoder = LGCNEncoder(config, dataset)
        else:
            raise ValueError('Non-implemented Encoder.')

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_normal_initialization)

    def forward(self, user, item):
        user_e, item_e = self.encoder(user, item)
        return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1)

    @staticmethod
    def alignment(x, y, alpha=2):
        return (x - y).norm(p=2, dim=1).pow(alpha).mean()

    @staticmethod
    def uniformity(x, t=2):
        return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()

    def calculate_loss(self, interaction):
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_e, item_e = self.forward(user, item)
        align = self.alignment(user_e, item_e)
        uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2

        return align, uniform

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]
        user_e = self.user_embedding(user)
        item_e = self.item_embedding(item)
        return torch.mul(user_e, item_e).sum(dim=1)

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.encoder_name == 'LightGCN':
            if self.restore_user_e is None or self.restore_item_e is None:
                self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings()
            user_e = self.restore_user_e[user]
            all_item_e = self.restore_item_e
        else:
            user_e = self.encoder.user_embedding(user)
            all_item_e = self.encoder.item_embedding.weight
        score = torch.matmul(user_e, all_item_e.transpose(0, 1))
        return score.view(-1)


class MFEncoder(BPR):
    def __init__(self, config, dataset):
        super(MFEncoder, self).__init__(config, dataset)

    def forward(self, user_id, item_id):
        return super().forward(user_id, item_id)

    def get_all_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        return user_embeddings, item_embeddings


class LGCNEncoder(LightGCN):
    def __init__(self, config, dataset):
        super(LGCNEncoder, self).__init__(config, dataset)

    def forward(self, user_id, item_id):
        user_all_embeddings, item_all_embeddings = self.get_all_embeddings()
        u_embed = user_all_embeddings[user_id]
        i_embed = item_all_embeddings[item_id]
        return u_embed, i_embed

    def get_all_embeddings(self):
        return super().forward()


================================================
FILE: recbole_gnn/model/general_recommender/hmlet.py
================================================
# @Time   : 2022/3/21
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
HMLET
################################################
Reference:
    Taeyong Kong et al. "Linear, or Non-Linear, That is the Question!." in WSDM 2022.

Reference code:
    https://github.com/qbxlvnf11/HMLET
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.model.layers import activation_layer
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class Gating_Net(nn.Module):
    def __init__(self, embedding_dim, mlp_dims, dropout_p):
        super(Gating_Net, self).__init__()
        self.embedding_dim = embedding_dim

        fc_layers = []
        for i in range(len(mlp_dims)):
            if i == 0:
                fc = nn.Linear(embedding_dim*2, mlp_dims[i])
                fc_layers.append(fc)
            else:
                fc = nn.Linear(mlp_dims[i-1], mlp_dims[i])
                fc_layers.append(fc)
            if i != len(mlp_dims) - 1:
                fc_layers.append(nn.BatchNorm1d(mlp_dims[i]))
                fc_layers.append(nn.Dropout(p=dropout_p))
                fc_layers.append(nn.ReLU(inplace=True))
        self.mlp = nn.Sequential(*fc_layers)

    def gumbel_softmax(self, logits, temperature, hard):
        """Sample from the Gumbel-Softmax distribution and optionally discretize.
        Args:
          logits: [batch_size, n_class] unnormalized log-probs
          temperature: non-negative scalar
          hard: if True, take argmax, but differentiate w.r.t. soft sample y
        Returns:
          [batch_size, n_class] sample from the Gumbel-Softmax distribution.
          If hard=True, then the returned sample will be one-hot, otherwise it will
          be a probabilitiy distribution that sums to 1 across classes
        """
        y = self.gumbel_softmax_sample(logits, temperature) ## (0.6, 0.2, 0.1,..., 0.11)
        if hard:
            k = logits.size(1) # k is numb of classes
            # y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)  ## (1, 0, 0, ..., 0)
            y_hard = torch.eq(y, torch.max(y, dim=1, keepdim=True)[0]).type_as(y)
            y = (y_hard - y).detach() + y
        return y

    def gumbel_softmax_sample(self, logits, temperature):
        """ Draw a sample from the Gumbel-Softmax distribution"""
        noise = self.sample_gumbel(logits)
        y = (logits + noise) / temperature
        return F.softmax(y, dim=1)

    def sample_gumbel(self, logits):
        """Sample from Gumbel(0, 1)"""
        noise = torch.rand(logits.size())
        eps = 1e-20
        noise.add_(eps).log_().neg_()
        noise.add_(eps).log_().neg_()
        return torch.Tensor(noise.float()).to(logits.device)

    def forward(self, feature, temperature, hard):
        x = self.mlp(feature)
        out = self.gumbel_softmax(x, temperature, hard)
        out_value = out.unsqueeze(2)
        gating_out = out_value.repeat(1, 1, self.embedding_dim)
        return gating_out


class HMLET(GeneralGraphRecommender):
    r"""HMLET combines both linear and non-linear propagation layers for general recommendation and yields better performance.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(HMLET, self).__init__(config, dataset)

        # load parameters info
        self.latent_dim = config['embedding_size']  # int type:the embedding size of lightGCN
        self.n_layers = config['n_layers']  # int type:the layer num of lightGCN
        self.reg_weight = config['reg_weight']  # float32 type: the weight decay for l2 normalization
        self.require_pow = config['require_pow']  # bool type: whether to require pow when regularization
        self.gate_layer_ids = config['gate_layer_ids']  # list type: layer ids for non-linear gating
        self.gating_mlp_dims = config['gating_mlp_dims']  # list type: list of mlp dimensions in gating module
        self.dropout_ratio = config['dropout_ratio']  # dropout ratio for mlp in gating module
        self.gum_temp = config['ori_temp']
        self.logger.info(f'Model initialization, gumbel softmax temperature: {self.gum_temp}')

        # define layers and loss
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
        self.gcn_conv = LightGCNConv(dim=self.latent_dim)
        self.activation = nn.ELU() if config['activation_function'] == 'elu' else activation_layer(config['activation_function'])
        self.gating_nets = nn.ModuleList([
            Gating_Net(self.latent_dim, self.gating_mlp_dims, self.dropout_ratio) for _ in range(len(self.gate_layer_ids))
        ])

        self.mf_loss = BPRLoss()
        self.reg_loss = EmbLoss()

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e', 'gum_temp']

        for gating in self.gating_nets:
            self._gating_freeze(gating, False)

    def _gating_freeze(self, model, freeze_flag):
        for name, child in model.named_children():
            for param in child.parameters():
                param.requires_grad = freeze_flag

    def __choosing_one(self, features, gumbel_out):
        feature = torch.sum(torch.mul(features, gumbel_out), dim=1)  # batch x embedding_dim (or batch x embedding_dim x layer_num)
        return feature

    def __where(self, idx, lst):
        for i in range(len(lst)):
            if lst[i] == idx:
                return i
        raise ValueError(f'{idx} not in {lst}.')

    def get_ego_embeddings(self):
        r"""Get the embedding of users and items and combine to an embedding matrix.
        Returns:
            Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
        """
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings

    def forward(self):
        all_embeddings = self.get_ego_embeddings()
        embeddings_list = [all_embeddings]
        non_lin_emb_list = [all_embeddings]

        for layer_idx in range(self.n_layers):
            linear_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
            if layer_idx not in self.gate_layer_ids:
                all_embeddings = linear_embeddings
            else:
                non_lin_id = self.__where(layer_idx, self.gate_layer_ids)
                last_non_lin_emb = non_lin_emb_list[non_lin_id]
                non_lin_embeddings = self.activation(self.gcn_conv(last_non_lin_emb, self.edge_index, self.edge_weight))
                stack_embeddings = torch.stack([linear_embeddings, non_lin_embeddings], dim=1)
                concat_embeddings = torch.cat((linear_embeddings, non_lin_embeddings), dim=-1)
                gumbel_out = self.gating_nets[non_lin_id](concat_embeddings, self.gum_temp, not self.training)
                all_embeddings = self.__choosing_one(stack_embeddings, gumbel_out)
                non_lin_emb_list.append(all_embeddings)
            embeddings_list.append(all_embeddings)
        hmlet_all_embeddings = torch.stack(embeddings_list, dim=1)
        hmlet_all_embeddings = torch.mean(hmlet_all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(hmlet_all_embeddings, [self.n_users, self.n_items])
        return user_all_embeddings, item_all_embeddings

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()
        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        # calculate BPR Loss
        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
        mf_loss = self.mf_loss(pos_scores, neg_scores)

        # calculate regularization Loss
        u_ego_embeddings = self.user_embedding(user)
        pos_ego_embeddings = self.item_embedding(pos_item)
        neg_ego_embeddings = self.item_embedding(neg_item)

        reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
        loss = mf_loss + self.reg_weight * reg_loss

        return loss

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()

        u_embeddings = user_all_embeddings[user]
        i_embeddings = item_all_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()
        # get user embedding from storage variable
        u_embeddings = self.restore_user_e[user]

        # dot with all item embedding to accelerate
        scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

        return scores.view(-1)

================================================
FILE: recbole_gnn/model/general_recommender/lightgcl.py
================================================
# -*- coding: utf-8 -*-
# @Time   : 2023/04/12
# @Author : Wanli Yang
# @Email  : 2013774@mail.nankai.edu.cn

r"""
LightGCL
################################################
Reference:
    Xuheng Cai et al. "LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation" in ICLR 2023.

Reference code:
    https://github.com/HKUDS/LightGCL
"""

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType
import torch.nn.functional as F


class LightGCL(GeneralRecommender):
    r"""LightGCL is a GCN-based recommender model.

    LightGCL guides graph augmentation by singular value decomposition (SVD) to not only
    distill the useful information of user-item interactions but also inject the global
    collaborative context into the representation alignment of contrastive learning.

    We implement the model following the original author with a pairwise training mode.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(LightGCL, self).__init__(config, dataset)
        self._user = dataset.inter_feat[dataset.uid_field]
        self._item = dataset.inter_feat[dataset.iid_field]

        # load parameters info
        self.embed_dim = config["embedding_size"]
        self.n_layers = config["n_layers"]
        self.dropout = config["dropout"]
        self.temp = config["temp"]
        self.lambda_1 = config["lambda1"]
        self.lambda_2 = config["lambda2"]
        self.q = config["q"]
        self.act = nn.LeakyReLU(0.5)
        self.reg_loss = EmbLoss()

        # get the normalized adjust matrix
        self.adj_norm = self.coo2tensor(self.create_adjust_matrix())

        # perform svd reconstruction
        svd_u, s, svd_v = torch.svd_lowrank(self.adj_norm, q=self.q)
        self.u_mul_s = svd_u @ (torch.diag(s))
        self.v_mul_s = svd_v @ (torch.diag(s))
        del s
        self.ut = svd_u.T
        self.vt = svd_v.T

        self.E_u_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_users, self.embed_dim)))
        self.E_i_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_items, self.embed_dim)))
        self.E_u_list = [None] * (self.n_layers + 1)
        self.E_i_list = [None] * (self.n_layers + 1)
        self.E_u_list[0] = self.E_u_0
        self.E_i_list[0] = self.E_i_0
        self.Z_u_list = [None] * (self.n_layers + 1)
        self.Z_i_list = [None] * (self.n_layers + 1)
        self.G_u_list = [None] * (self.n_layers + 1)
        self.G_i_list = [None] * (self.n_layers + 1)
        self.G_u_list[0] = self.E_u_0
        self.G_i_list[0] = self.E_i_0

        self.E_u = None
        self.E_i = None
        self.restore_user_e = None
        self.restore_item_e = None

        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

    def create_adjust_matrix(self):
        r"""Get the normalized interaction matrix of users and items.

        Returns:
            coo_matrix of the normalized interaction matrix.
        """
        ratings = np.ones_like(self._user, dtype=np.float32)
        matrix = sp.csr_matrix(
            (ratings, (self._user, self._item)),
            shape=(self.n_users, self.n_items),
        ).tocoo()
        rowD = np.squeeze(np.array(matrix.sum(1)), axis=1)
        colD = np.squeeze(np.array(matrix.sum(0)), axis=0)
        for i in range(len(matrix.data)):
            matrix.data[i] = matrix.data[i] / pow(rowD[matrix.row[i]] * colD[matrix.col[i]], 0.5)
        return matrix

    def coo2tensor(self, matrix: sp.coo_matrix):
        r"""Convert coo_matrix to tensor.

        Args:
            matrix (scipy.coo_matrix): Sparse matrix to be converted.

        Returns:
            torch.sparse.FloatTensor: Transformed sparse matrix.
        """
        indices = torch.from_numpy(
            np.vstack((matrix.row, matrix.col)).astype(np.int64))
        values = torch.from_numpy(matrix.data)
        shape = torch.Size(matrix.shape)
        x = torch.sparse.FloatTensor(indices, values, shape).coalesce().to(self.device)
        return x

    def sparse_dropout(self, matrix, dropout):
        if dropout == 0.0:
            return matrix
        indices = matrix.indices()
        values = F.dropout(matrix.values(), p=dropout)
        size = matrix.size()
        return torch.sparse.FloatTensor(indices, values, size)

    def forward(self):
        for layer in range(1, self.n_layers + 1):
            # GNN propagation
            self.Z_u_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout),
                                              self.E_i_list[layer - 1])
            self.Z_i_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout).transpose(0, 1),
                                              self.E_u_list[layer - 1])
            # aggregate
            self.E_u_list[layer] = self.Z_u_list[layer]
            self.E_i_list[layer] = self.Z_i_list[layer]

        # aggregate across layer
        self.E_u = sum(self.E_u_list)
        self.E_i = sum(self.E_i_list)

        return self.E_u, self.E_i

    def calculate_loss(self, interaction):
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user_list = interaction[self.USER_ID]
        pos_item_list = interaction[self.ITEM_ID]
        neg_item_list = interaction[self.NEG_ITEM_ID]
        E_u_norm, E_i_norm = self.forward()
        bpr_loss = self.calc_bpr_loss(E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list)
        ssl_loss = self.calc_ssl_loss(E_u_norm, E_i_norm, user_list, pos_item_list)
        total_loss = bpr_loss + ssl_loss
        return total_loss

    def calc_bpr_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list):
        r"""Calculate the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss.

        Args:
            E_u_norm (torch.Tensor): Ego embedding of all users after forwarding.
            E_i_norm (torch.Tensor): Ego embedding of all items after forwarding.
            user_list (torch.Tensor): List of the user.
            pos_item_list (torch.Tensor): List of positive examples.
            neg_item_list (torch.Tensor): List of negative examples.

        Returns:
            torch.Tensor: Loss of BPR tasks and parameter regularization.
        """
        u_e = E_u_norm[user_list]
        pi_e = E_i_norm[pos_item_list]
        ni_e = E_i_norm[neg_item_list]
        pos_scores = torch.mul(u_e, pi_e).sum(dim=1)
        neg_scores = torch.mul(u_e, ni_e).sum(dim=1)
        loss1 = -(pos_scores - neg_scores).sigmoid().log().mean()

        # reg loss
        loss_reg = 0
        for param in self.parameters():
            loss_reg += param.norm(2).square()
        loss_reg *= self.lambda_2
        return loss1 + loss_reg

    def calc_ssl_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list):
        r"""Calculate the loss of self-supervised tasks.

        Args:
            E_u_norm (torch.Tensor): Ego embedding of all users in the original graph after forwarding.
            E_i_norm (torch.Tensor): Ego embedding of all items in the original graph after forwarding.
            user_list (torch.Tensor): List of the user.
            pos_item_list (torch.Tensor): List of positive examples.

        Returns:
            torch.Tensor: Loss of self-supervised tasks.
        """
        # calculate G_u_norm&G_i_norm
        for layer in range(1, self.n_layers + 1):
            # svd_adj propagation
            vt_ei = self.vt @ self.E_i_list[layer - 1]
            self.G_u_list[layer] = self.u_mul_s @ vt_ei
            ut_eu = self.ut @ self.E_u_list[layer - 1]
            self.G_i_list[layer] = self.v_mul_s @ ut_eu

        # aggregate across layer
        G_u_norm = sum(self.G_u_list)
        G_i_norm = sum(self.G_i_list)

        neg_score = torch.log(torch.exp(G_u_norm[user_list] @ E_u_norm.T / self.temp).sum(1) + 1e-8).mean()
        neg_score += torch.log(torch.exp(G_i_norm[pos_item_list] @ E_i_norm.T / self.temp).sum(1) + 1e-8).mean()
        pos_score = (torch.clamp((G_u_norm[user_list] * E_u_norm[user_list]).sum(1) / self.temp, -5.0, 5.0)).mean() + (
            torch.clamp((G_i_norm[pos_item_list] * E_i_norm[pos_item_list]).sum(1) / self.temp, -5.0, 5.0)).mean()
        ssl_loss = -pos_score + neg_score
        return self.lambda_1 * ssl_loss

    def predict(self, interaction):
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()
        user = self.restore_user_e[interaction[self.USER_ID]]
        item = self.restore_item_e[interaction[self.ITEM_ID]]
        return torch.sum(user * item, dim=1)

    def full_sort_predict(self, interaction):
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()
        user = self.restore_user_e[interaction[self.USER_ID]]
        return user.matmul(self.restore_item_e.T)


================================================
FILE: recbole_gnn/model/general_recommender/lightgcn.py
================================================
# @Time   : 2022/3/8
# @Author : Lanling Xu
# @Email  : xulanling_sherry@163.com

r"""
LightGCN
################################################
Reference:
    Xiangnan He et al. "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation." in SIGIR 2020.

Reference code:
    https://github.com/kuandeng/LightGCN
"""

import numpy as np
import torch

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class LightGCN(GeneralGraphRecommender):
    r"""LightGCN is a GCN-based recommender model, implemented via PyG.
    LightGCN includes only the most essential component in GCN — neighborhood aggregation — for
    collaborative filtering. Specifically, LightGCN learns user and item embeddings by linearly 
    propagating them on the user-item interaction graph, and uses the weighted sum of the embeddings
    learned at all layers as the final embedding.
    We implement the model following the original author with a pairwise training mode.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(LightGCN, self).__init__(config, dataset)

        # load parameters info
        self.latent_dim = config['embedding_size']  # int type:the embedding size of lightGCN
        self.n_layers = config['n_layers']  # int type:the layer num of lightGCN
        self.reg_weight = config['reg_weight']  # float32 type: the weight decay for l2 normalization
        self.require_pow = config['require_pow']  # bool type: whether to require pow when regularization

        # define layers and loss
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
        self.gcn_conv = LightGCNConv(dim=self.latent_dim)
        self.mf_loss = BPRLoss()
        self.reg_loss = EmbLoss()

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

    def get_ego_embeddings(self):
        r"""Get the embedding of users and items and combine to an embedding matrix.
        Returns:
            Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
        """
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings

    def forward(self):
        all_embeddings = self.get_ego_embeddings()
        embeddings_list = [all_embeddings]

        for layer_idx in range(self.n_layers):
            all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
            embeddings_list.append(all_embeddings)
        lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
        return user_all_embeddings, item_all_embeddings

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()
        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        # calculate BPR Loss
        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
        mf_loss = self.mf_loss(pos_scores, neg_scores)

        # calculate regularization Loss
        u_ego_embeddings = self.user_embedding(user)
        pos_ego_embeddings = self.item_embedding(pos_item)
        neg_ego_embeddings = self.item_embedding(neg_item)

        reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
        loss = mf_loss + self.reg_weight * reg_loss

        return loss

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()

        u_embeddings = user_all_embeddings[user]
        i_embeddings = item_all_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()
        # get user embedding from storage variable
        u_embeddings = self.restore_user_e[user]

        # dot with all item embedding to accelerate
        scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

        return scores.view(-1)


================================================
FILE: recbole_gnn/model/general_recommender/ncl.py
================================================
# -*- coding: utf-8 -*-
r"""
NCL
################################################
Reference:
    Zihan Lin*, Changxin Tian*, Yupeng Hou*, Wayne Xin Zhao. "Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning." in WWW 2022.
"""

import torch
import torch.nn.functional as F

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class NCL(GeneralGraphRecommender):
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(NCL, self).__init__(config, dataset)

        # load parameters info
        self.latent_dim = config['embedding_size']  # int type: the embedding size of the base model
        self.n_layers = config['n_layers']          # int type: the layer num of the base model
        self.reg_weight = config['reg_weight']      # float32 type: the weight decay for l2 normalization

        self.ssl_temp = config['ssl_temp']
        self.ssl_reg = config['ssl_reg']
        self.hyper_layers = config['hyper_layers']

        self.alpha = config['alpha']

        self.proto_reg = config['proto_reg']
        self.k = config['num_clusters']

        # define layers and loss
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
        self.gcn_conv = LightGCNConv(dim=self.latent_dim)
        self.mf_loss = BPRLoss()
        self.reg_loss = EmbLoss()

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

        self.user_centroids = None
        self.user_2cluster = None
        self.item_centroids = None
        self.item_2cluster = None

    def e_step(self):
        user_embeddings = self.user_embedding.weight.detach().cpu().numpy()
        item_embeddings = self.item_embedding.weight.detach().cpu().numpy()
        self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings)
        self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings)

    def run_kmeans(self, x):
        """Run K-means algorithm to get k clusters of the input tensor x
        """
        import faiss
        kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True)
        kmeans.train(x)
        cluster_cents = kmeans.centroids

        _, I = kmeans.index.search(x, 1)

        # convert to cuda Tensors for broadcast
        centroids = torch.Tensor(cluster_cents).to(self.device)
        centroids = F.normalize(centroids, p=2, dim=1)

        node2cluster = torch.LongTensor(I).squeeze().to(self.device)
        return centroids, node2cluster

    def get_ego_embeddings(self):
        r"""Get the embedding of users and items and combine to an embedding matrix.
        Returns:
            Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
        """
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings

    def forward(self):
        all_embeddings = self.get_ego_embeddings()
        embeddings_list = [all_embeddings]
        for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)):
            all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
            embeddings_list.append(all_embeddings)

        lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers + 1], dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
        return user_all_embeddings, item_all_embeddings, embeddings_list

    def ProtoNCE_loss(self, node_embedding, user, item):
        user_embeddings_all, item_embeddings_all = torch.split(node_embedding, [self.n_users, self.n_items])

        user_embeddings = user_embeddings_all[user]     # [B, e]
        norm_user_embeddings = F.normalize(user_embeddings)

        user2cluster = self.user_2cluster[user]     # [B,]
        user2centroids = self.user_centroids[user2cluster]   # [B, e]
        pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1)
        pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
        ttl_score_user = torch.matmul(norm_user_embeddings, self.user_centroids.transpose(0, 1))
        ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

        proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

        item_embeddings = item_embeddings_all[item]
        norm_item_embeddings = F.normalize(item_embeddings)

        item2cluster = self.item_2cluster[item]  # [B, ]
        item2centroids = self.item_centroids[item2cluster]  # [B, e]
        pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1)
        pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
        ttl_score_item = torch.matmul(norm_item_embeddings, self.item_centroids.transpose(0, 1))
        ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
        proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

        proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item)
        return proto_nce_loss

    def ssl_layer_loss(self, current_embedding, previous_embedding, user, item):
        current_user_embeddings, current_item_embeddings = torch.split(current_embedding, [self.n_users, self.n_items])
        previous_user_embeddings_all, previous_item_embeddings_all = torch.split(previous_embedding, [self.n_users, self.n_items])

        current_user_embeddings = current_user_embeddings[user]
        previous_user_embeddings = previous_user_embeddings_all[user]
        norm_user_emb1 = F.normalize(current_user_embeddings)
        norm_user_emb2 = F.normalize(previous_user_embeddings)
        norm_all_user_emb = F.normalize(previous_user_embeddings_all)
        pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1)
        ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1))
        pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
        ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

        ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

        current_item_embeddings = current_item_embeddings[item]
        previous_item_embeddings = previous_item_embeddings_all[item]
        norm_item_emb1 = F.normalize(current_item_embeddings)
        norm_item_emb2 = F.normalize(previous_item_embeddings)
        norm_all_item_emb = F.normalize(previous_item_embeddings_all)
        pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1)
        ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1))
        pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
        ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)

        ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

        ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item)
        return ssl_loss

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()

        center_embedding = embeddings_list[0]
        context_embedding = embeddings_list[self.hyper_layers * 2]

        ssl_loss = self.ssl_layer_loss(context_embedding, center_embedding, user, pos_item)
        proto_loss = self.ProtoNCE_loss(center_embedding, user, pos_item)

        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        # calculate BPR Loss
        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)

        mf_loss = self.mf_loss(pos_scores, neg_scores)

        u_ego_embeddings = self.user_embedding(user)
        pos_ego_embeddings = self.item_embedding(pos_item)
        neg_ego_embeddings = self.item_embedding(neg_item)

        reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)

        return mf_loss + self.reg_weight * reg_loss, ssl_loss, proto_loss

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()

        u_embeddings = user_all_embeddings[user]
        i_embeddings = item_all_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e, embedding_list = self.forward()
        # get user embedding from storage variable
        u_embeddings = self.restore_user_e[user]

        # dot with all item embedding to accelerate
        scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

        return scores.view(-1)


================================================
FILE: recbole_gnn/model/general_recommender/ngcf.py
================================================
# @Time   : 2022/3/8
# @Author : Changxin Tian
# @Email  : cx.tian@outlook.com
r"""
NGCF
################################################
Reference:
    Xiang Wang et al. "Neural Graph Collaborative Filtering." in SIGIR 2019.

Reference code:
    https://github.com/xiangwang1223/neural_graph_collaborative_filtering

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dropout_adj

from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import BiGNNConv


class NGCF(GeneralGraphRecommender):
    r"""NGCF is a model that incorporate GNN for recommendation.
    We implement the model following the original author with a pairwise training mode.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(NGCF, self).__init__(config, dataset)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.hidden_size_list = config['hidden_size_list']
        self.hidden_size_list = [self.embedding_size] + self.hidden_size_list
        self.node_dropout = config['node_dropout']
        self.message_dropout = config['message_dropout']
        self.reg_weight = config['reg_weight']

        # define layers and loss
        self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
        self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
        self.GNNlayers = torch.nn.ModuleList()
        for input_size, output_size in zip(self.hidden_size_list[:-1], self.hidden_size_list[1:]):
            self.GNNlayers.append(BiGNNConv(input_size, output_size))
        self.mf_loss = BPRLoss()
        self.reg_loss = EmbLoss()

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_normal_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

    def get_ego_embeddings(self):
        r"""Get the embedding of users and items and combine to an embedding matrix.

        Returns:
            Tensor of the embedding matrix. Shape of (n_items+n_users, embedding_dim)
        """
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings

    def forward(self):
        if self.node_dropout == 0:
            edge_index, edge_weight = self.edge_index, self.edge_weight
        else:
            edge_index, edge_weight = self.edge_index, self.edge_weight
            if self.use_sparse:
                row, col, edge_weight = edge_index.t().coo()
                edge_index = torch.stack([row, col], 0)
                edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight,
                                                      p=self.node_dropout, training=self.training)
                from torch_sparse import SparseTensor
                edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight,
                                          sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items))
                edge_index = edge_index.t()
                edge_weight = None
            else:
                edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight,
                                                      p=self.node_dropout, training=self.training)

        all_embeddings = self.get_ego_embeddings()
        embeddings_list = [all_embeddings]
        for gnn in self.GNNlayers:
            all_embeddings = gnn(all_embeddings, edge_index, edge_weight)
            all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings)
            all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings)
            all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
            embeddings_list += [all_embeddings]  # storage output embedding of each layer
        ngcf_all_embeddings = torch.cat(embeddings_list, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(ngcf_all_embeddings, [self.n_users, self.n_items])

        return user_all_embeddings, item_all_embeddings

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()
        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
        mf_loss = self.mf_loss(pos_scores, neg_scores)  # calculate BPR Loss

        reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings)  # L2 regularization of embeddings

        return mf_loss + self.reg_weight * reg_loss

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()

        u_embeddings = user_all_embeddings[user]
        i_embeddings = item_all_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()
        # get user embedding from storage variable
        u_embeddings = self.restore_user_e[user]

        # dot with all item embedding to accelerate
        scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

        return scores.view(-1)


================================================
FILE: recbole_gnn/model/general_recommender/sgl.py
================================================
# -*- coding: utf-8 -*-
# @Time   : 2022/3/8
# @Author : Changxin Tian
# @Email  : cx.tian@outlook.com
r"""
SGL
################################################
Reference:
    Jiancan Wu et al. "SGL: Self-supervised Graph Learning for Recommendation" in SIGIR 2021.

Reference code:
    https://github.com/wujcan/SGL
"""

import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.utils import degree
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class SGL(GeneralGraphRecommender):
    r"""SGL is a GCN-based recommender model.

    SGL supplements the classical supervised task of recommendation with an auxiliary
    self supervised task, which reinforces node representation learning via self-
    discrimination.Specifically,SGL generates multiple views of a node, maximizing the
    agreement between different views of the same node compared to that of other nodes.
    SGL devises three operators to generate the views — node dropout, edge dropout, and
    random walk — that change the graph structure in different manners.

    We implement the model following the original author with a pairwise training mode.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(SGL, self).__init__(config, dataset)

        # load parameters info
        self.latent_dim = config["embedding_size"]
        self.n_layers = int(config["n_layers"])
        self.aug_type = config["type"]
        self.drop_ratio = config["drop_ratio"]
        self.ssl_tau = config["ssl_tau"]
        self.reg_weight = config["reg_weight"]
        self.ssl_weight = config["ssl_weight"]

        self._user = dataset.inter_feat[dataset.uid_field]
        self._item = dataset.inter_feat[dataset.iid_field]
        self.dataset = dataset

        # define layers and loss
        self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim)
        self.item_embedding = torch.nn.Embedding(self.n_items, self.latent_dim)
        self.gcn_conv = LightGCNConv(dim=self.latent_dim)
        self.reg_loss = EmbLoss()

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

    def train(self, mode: bool = True):
        r"""Override train method of base class. The subgraph is reconstructed each time it is called.

        """
        T = super().train(mode=mode)
        if mode:
            self.graph_construction()
        return T

    def graph_construction(self):
        r"""Devise three operators to generate the views — node dropout, edge dropout, and random walk of a node.

        """
        if self.aug_type == "ND" or self.aug_type == "ED":
            self.sub_graph1 = [self.random_graph_augment()] * self.n_layers
            self.sub_graph2 = [self.random_graph_augment()] * self.n_layers
        elif self.aug_type == "RW":
            self.sub_graph1 = [self.random_graph_augment() for _ in range(self.n_layers)]
            self.sub_graph2 = [self.random_graph_augment() for _ in range(self.n_layers)]

    def random_graph_augment(self):
        def rand_sample(high, size=None, replace=True):
            return np.random.choice(np.arange(high), size=size, replace=replace)

        if self.aug_type == "ND":
            drop_user = rand_sample(self.n_users, size=int(self.n_users * self.drop_ratio), replace=False)
            drop_item = rand_sample(self.n_items, size=int(self.n_items * self.drop_ratio), replace=False)

            mask = np.isin(self._user.numpy(), drop_user)
            mask |= np.isin(self._item.numpy(), drop_item)
            keep = np.where(~mask)

            row = self._user[keep]
            col = self._item[keep] + self.n_users

        elif self.aug_type == "ED" or self.aug_type == "RW":
            keep = rand_sample(len(self._user), size=int(len(self._user) * (1 - self.drop_ratio)), replace=False)
            row = self._user[keep]
            col = self._item[keep] + self.n_users

        edge_index1 = torch.stack([row, col])
        edge_index2 = torch.stack([col, row])
        edge_index = torch.cat([edge_index1, edge_index2], dim=1)
        edge_weight = torch.ones(edge_index.size(1))
        num_nodes = self.n_users + self.n_items

        if self.use_sparse:
            adj_t = self.dataset.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes)
            adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False)
            return adj_t.to(self.device), None

        edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False)

        return edge_index.to(self.device), edge_weight.to(self.device)

    def forward(self, graph=None):
        all_embeddings = torch.cat([self.user_embedding.weight, self.item_embedding.weight])
        embeddings_list = [all_embeddings]

        if graph is None:  # for the original graph
            for _ in range(self.n_layers):
                all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
                embeddings_list.append(all_embeddings)
        else:  # for the augmented graph
            for graph_edge_index, graph_edge_weight in graph:
                all_embeddings = self.gcn_conv(all_embeddings, graph_edge_index, graph_edge_weight)
                embeddings_list.append(all_embeddings)

        embeddings_list = torch.stack(embeddings_list, dim=1)
        embeddings_list = torch.mean(embeddings_list, dim=1, keepdim=False)
        user_all_embeddings, item_all_embeddings = torch.split(embeddings_list, [self.n_users, self.n_items], dim=0)

        return user_all_embeddings, item_all_embeddings

    def calc_bpr_loss(self, user_emd, item_emd, user_list, pos_item_list, neg_item_list):
        r"""Calculate the the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss.

        Args:
            user_emd (torch.Tensor): Ego embedding of all users after forwarding.
            item_emd (torch.Tensor): Ego embedding of all items after forwarding.
            user_list (torch.Tensor): List of the user.
            pos_item_list (torch.Tensor): List of positive examples.
            neg_item_list (torch.Tensor): List of negative examples.

        Returns:
            torch.Tensor: Loss of BPR tasks and parameter regularization.
        """
        u_e = user_emd[user_list]
        pi_e = item_emd[pos_item_list]
        ni_e = item_emd[neg_item_list]
        p_scores = torch.mul(u_e, pi_e).sum(dim=1)
        n_scores = torch.mul(u_e, ni_e).sum(dim=1)

        l1 = torch.sum(-F.logsigmoid(p_scores - n_scores))

        u_e_p = self.user_embedding(user_list)
        pi_e_p = self.item_embedding(pos_item_list)
        ni_e_p = self.item_embedding(neg_item_list)

        l2 = self.reg_loss(u_e_p, pi_e_p, ni_e_p)

        return l1 + l2 * self.reg_weight

    def calc_ssl_loss(self, user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2):
        r"""Calculate the loss of self-supervised tasks.

        Args:
            user_list (torch.Tensor): List of the user.
            pos_item_list (torch.Tensor): List of positive examples.
            user_sub1 (torch.Tensor): Ego embedding of all users in the first subgraph after forwarding.
            user_sub2 (torch.Tensor): Ego embedding of all users in the second subgraph after forwarding.
            item_sub1 (torch.Tensor): Ego embedding of all items in the first subgraph after forwarding.
            item_sub2 (torch.Tensor): Ego embedding of all items in the second subgraph after forwarding.

        Returns:
            torch.Tensor: Loss of self-supervised tasks.
        """

        u_emd1 = F.normalize(user_sub1[user_list], dim=1)
        u_emd2 = F.normalize(user_sub2[user_list], dim=1)
        all_user2 = F.normalize(user_sub2, dim=1)
        v1 = torch.sum(u_emd1 * u_emd2, dim=1)
        v2 = u_emd1.matmul(all_user2.T)
        v1 = torch.exp(v1 / self.ssl_tau)
        v2 = torch.sum(torch.exp(v2 / self.ssl_tau), dim=1)
        ssl_user = -torch.sum(torch.log(v1 / v2))

        i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1)
        i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1)
        all_item2 = F.normalize(item_sub2, dim=1)
        v3 = torch.sum(i_emd1 * i_emd2, dim=1)
        v4 = i_emd1.matmul(all_item2.T)
        v3 = torch.exp(v3 / self.ssl_tau)
        v4 = torch.sum(torch.exp(v4 / self.ssl_tau), dim=1)
        ssl_item = -torch.sum(torch.log(v3 / v4))

        return (ssl_item + ssl_user) * self.ssl_weight

    def calculate_loss(self, interaction):
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user_list = interaction[self.USER_ID]
        pos_item_list = interaction[self.ITEM_ID]
        neg_item_list = interaction[self.NEG_ITEM_ID]

        user_emd, item_emd = self.forward()
        user_sub1, item_sub1 = self.forward(self.sub_graph1)
        user_sub2, item_sub2 = self.forward(self.sub_graph2)

        total_loss = self.calc_bpr_loss(user_emd, item_emd, user_list, pos_item_list, neg_item_list) + \
            self.calc_ssl_loss(user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2)
        return total_loss

    def predict(self, interaction):
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()

        user = self.restore_user_e[interaction[self.USER_ID]]
        item = self.restore_item_e[interaction[self.ITEM_ID]]
        return torch.sum(user * item, dim=1)

    def full_sort_predict(self, interaction):
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()

        user = self.restore_user_e[interaction[self.USER_ID]]
        return user.matmul(self.restore_item_e.T)


================================================
FILE: recbole_gnn/model/general_recommender/simgcl.py
================================================
# -*- coding: utf-8 -*-
r"""
SimGCL
################################################
Reference:
    Junliang Yu, Hongzhi Yin, Xin Xia, Tong Chen, Lizhen Cui, Quoc Viet Hung Nguyen. "Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation." in SIGIR 2022.
"""


import torch
import torch.nn.functional as F

from recbole_gnn.model.general_recommender import LightGCN


class SimGCL(LightGCN):
    def __init__(self, config, dataset):
        super(SimGCL, self).__init__(config, dataset)

        self.cl_rate = config['lambda']
        self.eps = config['eps']
        self.temperature = config['temperature']

    def forward(self, perturbed=False):
        all_embs = self.get_ego_embeddings()
        embeddings_list = []

        for layer_idx in range(self.n_layers):
            all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
            if perturbed:
                random_noise = torch.rand_like(all_embs, device=all_embs.device)
                all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
            embeddings_list.append(all_embs)
        lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
        return user_all_embeddings, item_all_embeddings

    def calculate_cl_loss(self, x1, x2):
        x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
        pos_score = (x1 * x2).sum(dim=-1)
        pos_score = torch.exp(pos_score / self.temperature)
        ttl_score = torch.matmul(x1, x2.transpose(0, 1))
        ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
        return -torch.log(pos_score / ttl_score).sum()

    def calculate_loss(self, interaction):
        loss = super().calculate_loss(interaction)

        user = torch.unique(interaction[self.USER_ID])
        pos_item = torch.unique(interaction[self.ITEM_ID])

        perturbed_user_embs_1, perturbed_item_embs_1 = self.forward(perturbed=True)
        perturbed_user_embs_2, perturbed_item_embs_2 = self.forward(perturbed=True)

        user_cl_loss = self.calculate_cl_loss(perturbed_user_embs_1[user], perturbed_user_embs_2[user])
        item_cl_loss = self.calculate_cl_loss(perturbed_item_embs_1[pos_item], perturbed_item_embs_2[pos_item])

        return loss + self.cl_rate * (user_cl_loss + item_cl_loss)


================================================
FILE: recbole_gnn/model/general_recommender/ssl4rec.py
================================================
r"""
SSL4REC
################################################
Reference:
    Tiansheng Yao et al. "Self-supervised Learning for Large-scale Item Recommendations." in CIKM 2021.

Reference code:
    https://github.com/Coder-Yu/SELFRec/model/graph/SSL4Rec.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.loss import EmbLoss
from recbole.utils import InputType

from recbole.model.init import xavier_uniform_initialization
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender


class SSL4REC(GeneralGraphRecommender):
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(SSL4REC, self).__init__(config, dataset)

        # load parameters info
        self.tau = config["tau"]
        self.reg_weight = config["reg_weight"]
        self.cl_rate = config["ssl_weight"]
        self.require_pow = config["require_pow"]

        self.reg_loss = EmbLoss()

        self.encoder = DNN_Encoder(config, dataset)

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

    def forward(self, user, item):
        user_e, item_e = self.encoder(user, item)
        return user_e, item_e

    def calculate_batch_softmax_loss(self, user_emb, item_emb, temperature):
        user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1)
        pos_score = (user_emb * item_emb).sum(dim=-1)
        pos_score = torch.exp(pos_score / temperature)
        ttl_score = torch.matmul(user_emb, item_emb.transpose(0, 1))
        ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
        loss = -torch.log(pos_score / ttl_score + 10e-6)
        return torch.mean(loss)

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]

        user_embeddings, item_embeddings = self.forward(user, pos_item)

        rec_loss = self.calculate_batch_softmax_loss(user_embeddings, item_embeddings, self.tau)
        cl_loss = self.encoder.calculate_cl_loss(pos_item)
        reg_loss = self.reg_loss(user_embeddings, item_embeddings, require_pow=self.require_pow)

        loss = rec_loss + self.cl_rate * cl_loss + self.reg_weight * reg_loss

        return loss

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_embeddings, item_embeddings = self.forward(user, item)

        u_embeddings = user_embeddings[user]
        i_embeddings = item_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward(torch.arange(
                self.n_users, device=self.device), torch.arange(self.n_items, device=self.device))
        # get user embedding from storage variable
        u_embeddings = self.restore_user_e[user]

        # dot with all item embedding to accelerate
        scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

        return scores.view(-1)


class DNN_Encoder(nn.Module):
    def __init__(self, config, dataset):
        super(DNN_Encoder, self).__init__()

        self.emb_size = config["embedding_size"]
        self.drop_ratio = config["drop_ratio"]
        self.tau = config["tau"]

        self.USER_ID = config["USER_ID_FIELD"]
        self.ITEM_ID = config["ITEM_ID_FIELD"]
        self.n_users = dataset.num(self.USER_ID)
        self.n_items = dataset.num(self.ITEM_ID)

        self.user_tower = nn.Sequential(
            nn.Linear(self.emb_size, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 128),
            nn.Tanh()
        )
        self.item_tower = nn.Sequential(
            nn.Linear(self.emb_size, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 128),
            nn.Tanh()
        )
        self.dropout = nn.Dropout(self.drop_ratio)

        self.initial_user_emb = nn.Embedding(self.n_users, self.emb_size)
        self.initial_item_emb = nn.Embedding(self.n_items, self.emb_size)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.initial_user_emb.weight)
        nn.init.xavier_uniform_(self.initial_item_emb.weight)

    def forward(self, q, x):
        q_emb = self.initial_user_emb(q)
        i_emb = self.initial_item_emb(x)

        q_emb = self.user_tower(q_emb)
        i_emb = self.item_tower(i_emb)

        return q_emb, i_emb

    def item_encoding(self, x):
        i_emb = self.initial_item_emb(x)
        i1_emb = self.dropout(i_emb)
        i2_emb = self.dropout(i_emb)

        i1_emb = self.item_tower(i1_emb)
        i2_emb = self.item_tower(i2_emb)

        return i1_emb, i2_emb

    def calculate_cl_loss(self, idx):
        x1, x2 = self.item_encoding(idx)
        x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
        pos_score = (x1 * x2).sum(dim=-1)
        pos_score = torch.exp(pos_score / self.tau)
        ttl_score = torch.matmul(x1, x2.transpose(0, 1))
        ttl_score = torch.exp(ttl_score / self.tau).sum(dim=1)
        return -torch.log(pos_score / ttl_score).mean()


================================================
FILE: recbole_gnn/model/general_recommender/xsimgcl.py
================================================
# -*- coding: utf-8 -*-
r"""
XSimGCL
################################################
Reference:
    Junliang Yu, Xin Xia, Tong Chen, Lizhen Cui, Nguyen Quoc Viet Hung, Hongzhi Yin. "XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation" in TKDE 2023.

Reference code:
    https://github.com/Coder-Yu/SELFRec/blob/main/model/graph/XSimGCL.py
"""


import torch
import torch.nn.functional as F

from recbole_gnn.model.general_recommender import LightGCN


class XSimGCL(LightGCN):
    def __init__(self, config, dataset):
        super(XSimGCL, self).__init__(config, dataset)

        self.cl_rate = config['lambda']
        self.eps = config['eps']
        self.temperature = config['temperature']
        self.layer_cl = config['layer_cl']

    def forward(self, perturbed=False):
        all_embs = self.get_ego_embeddings()
        all_embs_cl = all_embs
        embeddings_list = []

        for layer_idx in range(self.n_layers):
            all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
            if perturbed:
                random_noise = torch.rand_like(all_embs, device=all_embs.device)
                all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
            embeddings_list.append(all_embs)
            if layer_idx == self.layer_cl - 1:
                all_embs_cl = all_embs
        lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
        user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embs_cl, [self.n_users, self.n_items])
        if perturbed:
            return user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl
        return user_all_embeddings, item_all_embeddings

    def calculate_cl_loss(self, x1, x2):
        x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
        pos_score = (x1 * x2).sum(dim=-1)
        pos_score = torch.exp(pos_score / self.temperature)
        ttl_score = torch.matmul(x1, x2.transpose(0, 1))
        ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
        return -torch.log(pos_score / ttl_score).mean()

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl = self.forward(perturbed=True)
        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        # calculate BPR Loss
        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
        mf_loss = self.mf_loss(pos_scores, neg_scores)

        # calculate regularization Loss
        u_ego_embeddings = self.user_embedding(user)
        pos_ego_embeddings = self.item_embedding(pos_item)
        neg_ego_embeddings = self.item_embedding(neg_item)
        reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)

        user = torch.unique(interaction[self.USER_ID])
        pos_item = torch.unique(interaction[self.ITEM_ID])

        # calculate CL Loss
        user_cl_loss = self.calculate_cl_loss(user_all_embeddings[user], user_all_embeddings_cl[user])
        item_cl_loss = self.calculate_cl_loss(item_all_embeddings[pos_item], item_all_embeddings_cl[pos_item])

        return mf_loss, self.reg_weight * reg_loss, self.cl_rate * (user_cl_loss + item_cl_loss)


================================================
FILE: recbole_gnn/model/layers.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_sparse import matmul


class LightGCNConv(MessagePassing):
    def __init__(self, dim):
        super(LightGCNConv, self).__init__(aggr='add')
        self.dim = dim

    def forward(self, x, edge_index, edge_weight):
        return self.propagate(edge_index, x=x, edge_weight=edge_weight)

    def message(self, x_j, edge_weight):
        return edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, self.dim)


class BipartiteGCNConv(MessagePassing):
    def __init__(self, dim):
        super(BipartiteGCNConv, self).__init__(aggr='add')
        self.dim = dim

    def forward(self, x, edge_index, edge_weight, size):
        return self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)

    def message(self, x_j, edge_weight):
        return edge_weight.view(-1, 1) * x_j

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, self.dim)


class BiGNNConv(MessagePassing):
    r"""Propagate a layer of Bi-interaction GNN

    .. math::
        output = (L+I)EW_1 + LE \otimes EW_2
    """

    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.in_channels, self.out_channels = in_channels, out_channels
        self.lin1 = torch.nn.Linear(in_features=in_channels, out_features=out_channels)
        self.lin2 = torch.nn.Linear(in_features=in_channels, out_features=out_channels)

    def forward(self, x, edge_index, edge_weight):
        x_prop = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        x_trans = self.lin1(x_prop + x)
        x_inter = self.lin2(torch.mul(x_prop, x))
        return x_trans + x_inter

    def message(self, x_j, edge_weight):
        return edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self):
        return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


class SRGNNConv(MessagePassing):
    def __init__(self, dim):
        # mean aggregation to incorporate weight naturally
        super(SRGNNConv, self).__init__(aggr='mean')

        self.lin = torch.nn.Linear(dim, dim)

    def forward(self, x, edge_index):
        x = self.lin(x)
        return self.propagate(edge_index, x=x)


class SRGNNCell(nn.Module):
    def __init__(self, dim):
        super(SRGNNCell, self).__init__()

        self.dim = dim
        self.incomming_conv = SRGNNConv(dim)
        self.outcomming_conv = SRGNNConv(dim)

        self.lin_ih = nn.Linear(2 * dim, 3 * dim)
        self.lin_hh = nn.Linear(dim, 3 * dim)

        self._reset_parameters()

    def forward(self, hidden, edge_index):
        input_in = self.incomming_conv(hidden, edge_index)
        reversed_edge_index = torch.flip(edge_index, dims=[0])
        input_out = self.outcomming_conv(hidden, reversed_edge_index)
        inputs = torch.cat([input_in, input_out], dim=-1)

        gi = self.lin_ih(inputs)
        gh = self.lin_hh(hidden)
        i_r, i_i, i_n = gi.chunk(3, -1)
        h_r, h_i, h_n = gh.chunk(3, -1)
        reset_gate = torch.sigmoid(i_r + h_r)
        input_gate = torch.sigmoid(i_i + h_i)
        new_gate = torch.tanh(i_n + reset_gate * h_n)
        hy = (1 - input_gate) * hidden + input_gate * new_gate
        return hy

    def _reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.dim)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)


================================================
FILE: recbole_gnn/model/sequential_recommender/__init__.py
================================================
from recbole_gnn.model.sequential_recommender.gcegnn import GCEGNN
from recbole_gnn.model.sequential_recommender.gcsan import GCSAN
from recbole_gnn.model.sequential_recommender.lessr import LESSR
from recbole_gnn.model.sequential_recommender.niser import NISER
from recbole_gnn.model.sequential_recommender.sgnnhn import SGNNHN
from recbole_gnn.model.sequential_recommender.srgnn import SRGNN
from recbole_gnn.model.sequential_recommender.tagnn import TAGNN


================================================
FILE: recbole_gnn/model/sequential_recommender/gcegnn.py
================================================
# @Time   : 2022/3/22
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
GCE-GNN
################################################

Reference:
    Ziyang Wang et al. "Global Context Enhanced Graph Neural Networks for Session-based Recommendation." in SIGIR 2020.

Reference code:
    https://github.com/CCIIPLab/GCE-GNN

"""

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from recbole.model.loss import BPRLoss
from recbole.model.abstract_recommender import SequentialRecommender


class LocalAggregator(MessagePassing):
    def __init__(self, dim, alpha):
        super().__init__(aggr='add')
        self.edge_emb = nn.Embedding(4, dim)
        self.leakyrelu = nn.LeakyReLU(alpha)

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, x_i, edge_attr, index, ptr, size_i):
        x = x_j * x_i
        a = self.edge_emb(edge_attr)
        e = (x * a).sum(dim=-1)
        e = self.leakyrelu(e)
        e = softmax(e, index, ptr, size_i)
        return e.unsqueeze(-1) * x_j


class GlobalAggregator(nn.Module):
    def __init__(self, dim, dropout, act=torch.relu):
        super(GlobalAggregator, self).__init__()
        self.dropout = dropout
        self.act = act
        self.dim = dim

        self.w_1 = nn.Parameter(torch.Tensor(self.dim + 1, self.dim))
        self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1))
        self.w_3 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim))
        self.bias = nn.Parameter(torch.Tensor(self.dim))

    def forward(self, self_vectors, neighbor_vector, batch_size, masks, neighbor_weight, extra_vector=None):
        if extra_vector is not None:
            alpha = torch.matmul(torch.cat([extra_vector.unsqueeze(2).repeat(1, 1, neighbor_vector.shape[2], 1)*neighbor_vector, neighbor_weight.unsqueeze(-1)], -1), self.w_1).squeeze(-1)
            alpha = F.leaky_relu(alpha, negative_slope=0.2)
            alpha = torch.matmul(alpha, self.w_2).squeeze(-1)
            alpha = torch.softmax(alpha, -1).unsqueeze(-1)
            neighbor_vector = torch.sum(alpha * neighbor_vector, dim=-2)
        else:
            neighbor_vector = torch.mean(neighbor_vector, dim=2)
        # self_vectors = F.dropout(self_vectors, 0.5, training=self.training)
        output = torch.cat([self_vectors, neighbor_vector], -1)
        output = F.dropout(output, self.dropout, training=self.training)
        output = torch.matmul(output, self.w_3)
        output = output.view(batch_size, -1, self.dim)
        output = self.act(output)
        return output


class GCEGNN(SequentialRecommender):
    def __init__(self, config, dataset):
        super(GCEGNN, self).__init__(config, dataset)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.leakyrelu_alpha = config['leakyrelu_alpha']
        self.dropout_local = config['dropout_local']
        self.dropout_global = config['dropout_global']
        self.dropout_gcn = config['dropout_gcn']
        self.device = config['device']
        self.loss_type = config['loss_type']
        self.build_global_graph = config['build_global_graph']
        self.sample_num = config['sample_num']
        self.hop = config['hop']
        self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ]

        # global graph construction
        self.global_graph = None
        if self.build_global_graph:
            self.global_adj, self.global_weight = self.construct_global_graph(dataset)

        # item embedding
        self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
        self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size)

        # define layers and loss
        # Aggregator
        self.local_agg = LocalAggregator(self.embedding_size, self.leakyrelu_alpha)
        global_agg_list = []
        for i in range(self.hop):
            global_agg_list.append(GlobalAggregator(self.embedding_size, self.dropout_gcn))
        self.global_agg = nn.ModuleList(global_agg_list)

        self.w_1 = nn.Linear(2 * self.embedding_size, self.embedding_size, bias=False)
        self.w_2 = nn.Linear(self.embedding_size, 1, bias=False)
        self.glu1 = nn.Linear(self.embedding_size, self.embedding_size)
        self.glu2 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
        if self.loss_type == 'BPR':
            self.loss_fct = BPRLoss()
        elif self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

        self.reset_parameters()
        self.other_parameter_name = ['global_adj', 'global_weight']

    def reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.embedding_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def _add_edge(self, graph, sid, tid):
        if tid not in graph[sid]:
            graph[sid][tid] = 0
        graph[sid][tid] += 1

    def construct_global_graph(self, dataset):
        self.logger.info('Constructing global graphs.')
        item_id_list = dataset.inter_feat['item_id_list']
        src_item_ids = item_id_list[:,:4].tolist()
        tgt_itme_id = dataset.inter_feat['item_id'].tolist()
        global_graph = [{} for _ in range(self.n_items)]
        for i in tqdm(range(len(tgt_itme_id)), desc='Converting: '):
            tid = tgt_itme_id[i]
            for sid in src_item_ids[i]:
                if sid > 0:
                    self._add_edge(global_graph, tid, sid)
                    self._add_edge(global_graph, sid, tid)
        global_adj = [[] for _ in range(self.n_items)]
        global_weight = [[] for _ in range(self.n_items)]
        for i in tqdm(range(self.n_items), desc='Sorting: '):
            sorted_out_edges = [v for v in sorted(global_graph[i].items(), reverse=True, key=lambda x: x[1])]
            global_adj[i] = [v[0] for v in sorted_out_edges[:self.sample_num]]
            global_weight[i] = [v[1] for v in sorted_out_edges[:self.sample_num]]
            if len(global_adj[i]) < self.sample_num:
                for j in range(self.sample_num - len(global_adj[i])):
                    global_adj[i].append(0)
                    global_weight[i].append(0)
        return torch.LongTensor(global_adj).to(self.device), torch.FloatTensor(global_weight).to(self.device)

    def fusion(self, hidden, mask):
        batch_size = hidden.shape[0]
        length = hidden.shape[1]
        pos_emb = self.pos_embedding.weight[:length]
        pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1)

        hs = torch.sum(hidden * mask, -2) / torch.sum(mask, 1)
        hs = hs.unsqueeze(-2).expand(-1, length, -1)
        nh = self.w_1(torch.cat([pos_emb, hidden], -1))
        nh = torch.tanh(nh)
        nh = torch.sigmoid(self.glu1(nh) + self.glu2(hs))
        beta = self.w_2(nh)
        beta = beta * mask
        final_h = torch.sum(beta * hidden, 1)
        return final_h

    def forward(self, x, edge_index, edge_attr, alias_inputs, item_seq_len):
        batch_size = alias_inputs.shape[0]
        mask = alias_inputs.gt(0).unsqueeze(-1)
        h = self.item_embedding(x)

        # local
        h_local = self.local_agg(h, edge_index, edge_attr)

        # global
        item_neighbors = [F.pad(x[alias_inputs], (0, self.max_seq_length - x[alias_inputs].shape[1]), "constant", 0)]
        weight_neighbors = []
        support_size = self.max_seq_length

        for i in range(self.hop):
            item_sample_i, weight_sample_i = self.global_adj[item_neighbors[-1].view(-1)], self.global_weight[item_neighbors[-1].view(-1)]
            support_size *= self.sample_num
            item_neighbors.append(item_sample_i.view(batch_size, support_size))
            weight_neighbors.append(weight_sample_i.view(batch_size, support_size))

        entity_vectors = [self.item_embedding(i) for i in item_neighbors]
        weight_vectors = weight_neighbors

        session_info = []
        item_emb = h[alias_inputs] * mask

        # mean 
        sum_item_emb = torch.sum(item_emb, 1) / torch.sum(mask.float(), 1)

        # sum
        # sum_item_emb = torch.sum(item_emb, 1)

        sum_item_emb = sum_item_emb.unsqueeze(-2)
        for i in range(self.hop):
            session_info.append(sum_item_emb.repeat(1, entity_vectors[i].shape[1], 1))

        for n_hop in range(self.hop):
            entity_vectors_next_iter = []
            shape = [batch_size, -1, self.sample_num, self.embedding_size]
            for hop in range(self.hop - n_hop):
                aggregator = self.global_agg[n_hop]
                vector = aggregator(self_vectors=entity_vectors[hop],
                                    neighbor_vector=entity_vectors[hop + 1].view(shape),
                                    masks=None,
                                    batch_size=batch_size,
                                    neighbor_weight=weight_vectors[hop].view(batch_size, -1, self.sample_num),
                                    extra_vector=session_info[hop])
                entity_vectors_next_iter.append(vector)
            entity_vectors = entity_vectors_next_iter

        h_global = entity_vectors[0].view(batch_size, self.max_seq_length, self.embedding_size)
        h_global = h_global[:,:alias_inputs.shape[1],:]

        h_local = F.dropout(h_local, self.dropout_local, training=self.training)
        h_global = F.dropout(h_global, self.dropout_global, training=self.training)
        h_local = h_local[alias_inputs]

        h_session = h_local + h_global
        h_session = self.fusion(h_session, mask)
        return h_session

    def calculate_loss(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        edge_attr = interaction['edge_attr']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len)
        pos_items = interaction[self.POS_ITEM_ID]
        if self.loss_type == 'BPR':
            neg_items = interaction[self.NEG_ITEM_ID]
            pos_items_emb = self.item_embedding(pos_items)
            neg_items_emb = self.item_embedding(neg_items)
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [B]
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  # [B]
            loss = self.loss_fct(pos_score, neg_score)
            return loss
        else:  # self.loss_type = 'CE'
            test_item_emb = self.item_embedding.weight
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            loss = self.loss_fct(logits, pos_items)
            return loss

    def predict(self, interaction):
        test_item = interaction[self.ITEM_ID]
        x = interaction['x']
        edge_index = interaction['edge_index']
        edge_attr = interaction['edge_attr']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len)
        test_item_emb = self.item_embedding(test_item)
        scores = torch.mul(seq_output, test_item_emb).sum(dim=1)  # [B]
        return scores

    def full_sort_predict(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        edge_attr = interaction['edge_attr']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len)
        test_items_emb = self.item_embedding.weight
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  # [B, n_items]
        return scores


================================================
FILE: recbole_gnn/model/sequential_recommender/gcsan.py
================================================
# @Time   : 2022/3/7
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
GCSAN
################################################

Reference:
    Chengfeng Xu et al. "Graph Contextualized Self-Attention Network for Session-based Recommendation." in IJCAI 2019.

"""

import torch
from torch import nn
from recbole.model.layers import TransformerEncoder
from recbole.model.loss import EmbLoss, BPRLoss
from recbole.model.abstract_recommender import SequentialRecommender

from recbole_gnn.model.layers import SRGNNCell


class GCSAN(SequentialRecommender):
    r"""GCSAN captures rich local dependencies via graph neural network,
     and learns long-range dependencies by applying the self-attention mechanism.
     
    Note:

        In the original paper, the attention mechanism in the self-attention layer is a single head,
        for the reusability of the project code, we use a unified transformer component.
        According to the experimental results, we only applied regularization to embedding.
    """

    def __init__(self, config, dataset):
        super(GCSAN, self).__init__(config, dataset)

        # load parameters info
        self.n_layers = config['n_layers']
        self.n_heads = config['n_heads']
        self.hidden_size = config['hidden_size']  # same as embedding_size
        self.inner_size = config['inner_size']  # the dimensionality in feed-forward layer
        self.hidden_dropout_prob = config['hidden_dropout_prob']
        self.attn_dropout_prob = config['attn_dropout_prob']
        self.hidden_act = config['hidden_act']
        self.layer_norm_eps = config['layer_norm_eps']

        self.step = config['step']
        self.device = config['device']
        self.weight = config['weight']
        self.reg_weight = config['reg_weight']
        self.loss_type = config['loss_type']
        self.initializer_range = config['initializer_range']

        # item embedding
        self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0)

        # define layers and loss
        self.gnncell = SRGNNCell(self.hidden_size)
        self.self_attention = TransformerEncoder(
            n_layers=self.n_layers,
            n_heads=self.n_heads,
            hidden_size=self.hidden_size,
            inner_size=self.inner_size,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attn_dropout_prob=self.attn_dropout_prob,
            hidden_act=self.hidden_act,
            layer_norm_eps=self.layer_norm_eps
        )
        self.reg_loss = EmbLoss()
        if self.loss_type == 'BPR':
            self.loss_fct = BPRLoss()
        elif self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

        # parameters initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def get_attention_mask(self, item_seq):
        """Generate left-to-right uni-directional attention mask for multi-head attention."""
        attention_mask = (item_seq > 0).long()
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # torch.int64
        # mask for left-to-right unidirectional
        max_len = attention_mask.size(-1)
        attn_shape = (1, max_len, max_len)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1)  # torch.uint8
        subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
        subsequent_mask = subsequent_mask.long().to(item_seq.device)

        extended_attention_mask = extended_attention_mask * subsequent_mask
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    def forward(self, x, edge_index, alias_inputs, item_seq_len):
        hidden = self.item_embedding(x)
        for i in range(self.step):
            hidden = self.gnncell(hidden, edge_index)

        seq_hidden = hidden[alias_inputs]
        # fetch the last hidden state of last timestamp
        ht = self.gather_indexes(seq_hidden, item_seq_len - 1)

        attention_mask = self.get_attention_mask(alias_inputs)
        outputs = self.self_attention(seq_hidden, attention_mask, output_all_encoded_layers=True)
        output = outputs[-1]
        at = self.gather_indexes(output, item_seq_len - 1)
        seq_output = self.weight * at + (1 - self.weight) * ht
        return seq_output

    def calculate_loss(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        pos_items = interaction[self.POS_ITEM_ID]
        if self.loss_type == 'BPR':
            neg_items = interaction[self.NEG_ITEM_ID]
            pos_items_emb = self.item_embedding(pos_items)
            neg_items_emb = self.item_embedding(neg_items)
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [B]
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  # [B]
            loss = self.loss_fct(pos_score, neg_score)
        else:  # self.loss_type = 'CE'
            test_item_emb = self.item_embedding.weight
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            loss = self.loss_fct(logits, pos_items)
        reg_loss = self.reg_loss(self.item_embedding.weight)
        total_loss = loss + self.reg_weight * reg_loss
        return total_loss

    def predict(self, interaction):
        test_item = interaction[self.ITEM_ID]
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        test_item_emb = self.item_embedding(test_item)
        scores = torch.mul(seq_output, test_item_emb).sum(dim=1)  # [B]
        return scores

    def full_sort_predict(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        test_items_emb = self.item_embedding.weight
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  # [B, n_items]
        return scores


================================================
FILE: recbole_gnn/model/sequential_recommender/lessr.py
================================================
# @Time   : 2022/3/11
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
LESSR
################################################

Reference:
    Tianwen Chen and Raymond Chi-Wing Wong. "Handling Information Loss of Graph Neural Networks for Session-based Recommendation." in KDD 2020.

Reference code:
    https://github.com/twchen/lessr

"""

import torch
from torch import nn
from torch_geometric.utils import softmax
from torch_geometric.nn import global_add_pool
from recbole.model.abstract_recommender import SequentialRecommender


class EOPA(nn.Module):
    def __init__(
        self, input_dim, output_dim, batch_norm=True, feat_drop=0.0, activation=None
    ):
        super().__init__()
        self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
        self.feat_drop = nn.Dropout(feat_drop)
        self.gru = nn.GRU(input_dim, input_dim, batch_first=True)
        self.fc_self = nn.Linear(input_dim, output_dim, bias=False)
        self.fc_neigh = nn.Linear(input_dim, output_dim, bias=False)
        self.activation = activation

    def reducer(self, nodes):
        m = nodes.mailbox['m']  # (num_nodes, deg, d)
        # m[i]: the messages passed to the i-th node with in-degree equal to 'deg'
        # the order of messages follows the order of incoming edges
        # since the edges are sorted by occurrence time when the EOP multigraph is built
        # the messages are in the order required by EOPA
        _, hn = self.gru(m)  # hn: (1, num_nodes, d)
        return {'neigh': hn.squeeze(0)}

    def forward(self, mg, feat):
        import dgl.function as fn

        with mg.local_scope():
            if self.batch_norm is not None:
                feat = self.batch_norm(feat)
            mg.ndata['ft'] = self.feat_drop(feat)
            if mg.number_of_edges() > 0:
                mg.update_all(fn.copy_u('ft', 'm'), self.reducer)
                neigh = mg.ndata['neigh']
                rst = self.fc_self(feat) + self.fc_neigh(neigh)
            else:
                rst = self.fc_self(feat)
            if self.activation is not None:
                rst = self.activation(rst)
            return rst


class SGAT(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        batch_norm=True,
        feat_drop=0.0,
        activation=None,
    ):
        super().__init__()
        self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
        self.feat_drop = nn.Dropout(feat_drop)
        self.fc_q = nn.Linear(input_dim, hidden_dim, bias=True)
        self.fc_k = nn.Linear(input_dim, hidden_dim, bias=False)
        self.fc_v = nn.Linear(input_dim, output_dim, bias=False)
        self.fc_e = nn.Linear(hidden_dim, 1, bias=False)
        self.activation = activation

    def forward(self, sg, feat):
        import dgl.ops as F

        if self.batch_norm is not None:
            feat = self.batch_norm(feat)
        feat = self.feat_drop(feat)
        q = self.fc_q(feat)
        k = self.fc_k(feat)
        v = self.fc_v(feat)
        e = F.u_add_v(sg, q, k)
        e = self.fc_e(torch.sigmoid(e))
        a = F.edge_softmax(sg, e)
        rst = F.u_mul_e_sum(sg, v, a)
        if self.activation is not None:
            rst = self.activation(rst)
        return rst


class AttnReadout(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        batch_norm=True,
        feat_drop=0.0,
        activation=None,
    ):
        super().__init__()
        self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
        self.feat_drop = nn.Dropout(feat_drop)
        self.fc_u = nn.Linear(input_dim, hidden_dim, bias=False)
        self.fc_v = nn.Linear(input_dim, hidden_dim, bias=True)
        self.fc_e = nn.Linear(hidden_dim, 1, bias=False)
        self.fc_out = (
            nn.Linear(input_dim, output_dim, bias=False)
            if output_dim != input_dim else None
        )
        self.activation = activation

    def forward(self, g, feat, last_nodes, batch):
        if self.batch_norm is not None:
            feat = self.batch_norm(feat)
        feat = self.feat_drop(feat)
        feat_u = self.fc_u(feat)
        feat_v = self.fc_v(feat[last_nodes])
        feat_v = torch.index_select(feat_v, dim=0, index=batch)
        e = self.fc_e(torch.sigmoid(feat_u + feat_v))
        alpha = softmax(e, batch)
        feat_norm = feat * alpha
        rst = global_add_pool(feat_norm, batch)
        if self.fc_out is not None:
            rst = self.fc_out(rst)
        if self.activation is not None:
            rst = self.activation(rst)
        return rst


class LESSR(SequentialRecommender):
    r"""LESSR analyzes the information losses when constructing session graphs,
    and emphasises lossy session encoding problem and the ineffective long-range dependency capturing problem.
    To solve the first problem, authors propose a lossless encoding scheme and an edge-order preserving aggregation layer.
    To solve the second problem, authors propose a shortcut graph attention layer that effectively captures long-range dependencies.

    Note:
        We follow the original implementation, which requires DGL package.
        We find it difficult to implement these functions via PyG, so we remain them.
        If you would like to test this model, please install DGL.
    """

    def __init__(self, config, dataset):
        super().__init__(config, dataset)

        embedding_dim = config['embedding_size']
        self.num_layers = config['n_layers']
        batch_norm = config['batch_norm']
        feat_drop = config['feat_drop']
        self.loss_type = config['loss_type']

        self.item_embedding = nn.Embedding(self.n_items, embedding_dim, max_norm=1)
        self.layers = nn.ModuleList()
        input_dim = embedding_dim
        for i in range(self.num_layers):
            if i % 2 == 0:
                layer = EOPA(
                    input_dim,
                    embedding_dim,
                    batch_norm=batch_norm,
                    feat_drop=feat_drop,
                    activation=nn.PReLU(embedding_dim),
                )
            else:
                layer = SGAT(
                    input_dim,
                    embedding_dim,
                    embedding_dim,
                    batch_norm=batch_norm,
                    feat_drop=feat_drop,
                    activation=nn.PReLU(embedding_dim),
                )
            input_dim += embedding_dim
            self.layers.append(layer)
        self.readout = AttnReadout(
            input_dim,
            embedding_dim,
            embedding_dim,
            batch_norm=batch_norm,
            feat_drop=feat_drop,
            activation=nn.PReLU(embedding_dim),
        )
        input_dim += embedding_dim
        self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
        self.feat_drop = nn.Dropout(feat_drop)
        self.fc_sr = nn.Linear(input_dim, embedding_dim, bias=False)

        if self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['CE']!")

    def forward(self, x, edge_index_EOP, edge_index_shortcut, batch, is_last):
        import dgl

        mg = dgl.graph((edge_index_EOP[0], edge_index_EOP[1]), num_nodes=batch.shape[0])
        sg = dgl.graph((edge_index_shortcut[0], edge_index_shortcut[1]), num_nodes=batch.shape[0])

        feat = self.item_embedding(x)
        for i, layer in enumerate(self.layers):
            if i % 2 == 0:
                out = layer(mg, feat)
            else:
                out = layer(sg, feat)
            feat = torch.cat([out, feat], dim=1)
        sr_g = self.readout(mg, feat, is_last, batch)
        sr_l = feat[is_last]
        sr = torch.cat([sr_l, sr_g], dim=1)
        if self.batch_norm is not None:
            sr = self.batch_norm(sr)
        sr = self.fc_sr(self.feat_drop(sr))
        return sr

    def calculate_loss(self, interaction):
        x = interaction['x']
        edge_index_EOP = interaction['edge_index_EOP']
        edge_index_shortcut = interaction['edge_index_shortcut']
        batch = interaction['batch']
        is_last = interaction['is_last']
        seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
        pos_items = interaction[self.POS_ITEM_ID]
        test_item_emb = self.item_embedding.weight
        logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
        loss = self.loss_fct(logits, pos_items)
        return loss

    def predict(self, interaction):
        test_item = interaction[self.ITEM_ID]
        x = interaction['x']
        edge_index_EOP = interaction['edge_index_EOP']
        edge_index_shortcut = interaction['edge_index_shortcut']
        batch = interaction['batch']
        is_last = interaction['is_last']
        seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
        test_item_emb = self.item_embedding(test_item)
        scores = torch.mul(seq_output, test_item_emb).sum(dim=1)  # [B]
        return scores

    def full_sort_predict(self, interaction):
        x = interaction['x']
        edge_index_EOP = interaction['edge_index_EOP']
        edge_index_shortcut = interaction['edge_index_shortcut']
        batch = interaction['batch']
        is_last = interaction['is_last']
        seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
        test_items_emb = self.item_embedding.weight
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  # [B, n_items]
        return scores


================================================
FILE: recbole_gnn/model/sequential_recommender/niser.py
================================================
# @Time   : 2022/3/7
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
NISER
################################################

Reference:
    Priyanka Gupta et al. "NISER: Normalized Item and Session Representations to Handle Popularity Bias." in CIKM 2019 GRLA workshop.

"""
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from recbole.model.loss import BPRLoss
from recbole.model.abstract_recommender import SequentialRecommender

from recbole_gnn.model.layers import SRGNNCell


class NISER(SequentialRecommender):
    r"""NISER+ is a GNN-based model that normalizes session and item embeddings to handle popularity bias.
    """

    def __init__(self, config, dataset):
        super(NISER, self).__init__(config, dataset)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.step = config['step']
        self.device = config['device']
        self.loss_type = config['loss_type']
        self.sigma = config['sigma']
        self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ]

        # item embedding
        self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
        self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size)
        self.item_dropout = nn.Dropout(config['item_dropout'])

        # define layers and loss
        self.gnncell = SRGNNCell(self.embedding_size)
        self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
        self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
        if self.loss_type == 'BPR':
            self.loss_fct = BPRLoss()
        elif self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

        # parameters initialization
        self._reset_parameters()

    def _reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.embedding_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, edge_index, alias_inputs, item_seq_len):
        mask = alias_inputs.gt(0)
        hidden = self.item_embedding(x)
        # Dropout in NISER+
        hidden = self.item_dropout(hidden)
        # Normalize item embeddings
        hidden = F.normalize(hidden, dim=-1)
        for i in range(self.step):
            hidden = self.gnncell(hidden, edge_index)

        seq_hidden = hidden[alias_inputs]
        batch_size = seq_hidden.shape[0]
        pos_emb = self.pos_embedding.weight[:seq_hidden.shape[1]]
        pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1)
        seq_hidden = seq_hidden + pos_emb
        # fetch the last hidden state of last timestamp
        ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
        q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
        q2 = self.linear_two(seq_hidden)

        alpha = self.linear_three(torch.sigmoid(q1 + q2))
        a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
        seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
        # Normalize session embeddings
        seq_output = F.normalize(seq_output, dim=-1)
        return seq_output

    def calculate_loss(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        pos_items = interaction[self.POS_ITEM_ID]
        if self.loss_type == 'BPR':
            neg_items = interaction[self.NEG_ITEM_ID]
            pos_items_emb = F.normalize(self.item_embedding(pos_items), dim=-1)
            neg_items_emb = F.normalize(self.item_embedding(neg_items), dim=-1)
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [B]
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  # [B]
            loss = self.loss_fct(self.sigma * pos_score, self.sigma * neg_score)
            return loss
        else:  # self.loss_type = 'CE'
            test_item_emb = F.normalize(self.item_embedding.weight, dim=-1)
            logits = self.sigma * torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            loss = self.loss_fct(logits, pos_items)
            return loss

    def predict(self, interaction):
        test_item = interaction[self.ITEM_ID]
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        test_item_emb = F.normalize(self.item_embedding(test_item), dim=-1)
        scores = torch.mul(seq_output, test_item_emb).sum(dim=1)  # [B]
        return scores

    def full_sort_predict(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        test_items_emb = F.normalize(self.item_embedding.weight, dim=-1)
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  # [B, n_items]
        return scores


================================================
FILE: recbole_gnn/model/sequential_recommender/sgnnhn.py
================================================
# @Time   : 2022/3/28
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
SRGNN
################################################

Reference:
    Zhiqiang Pan et al. "Star Graph Neural Networks for Session-based Recommendation." in CIKM 2020.

Reference code:
    https://bitbucket.org/nudtpanzq/sgnn-hn

"""

import math
import numpy as np
import torch
from torch import nn
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.utils import softmax
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss

from recbole_gnn.model.layers import SRGNNCell


def layer_norm(x):
    ave_x = torch.mean(x, -1).unsqueeze(-1)
    x = x - ave_x
    norm_x = torch.sqrt(torch.sum(x**2, -1)).unsqueeze(-1)
    y = x / norm_x
    return y


class SGNNHN(SequentialRecommender):
    r"""SGNN-HN applies a star graph neural network to model the complex transition relationship between items in an ongoing session.
        To avoid overfitting, it applies highway networks to adaptively select embeddings from item representations.
    """

    def __init__(self, config, dataset):
        super(SGNNHN, self).__init__(config, dataset)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.step = config['step']
        self.device = config['device']
        self.loss_type = config['loss_type']
        self.scale = config['scale']

        # item embedding
        self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
        self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ]
        self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size)

        # define layers and loss
        self.gnncell = SRGNNCell(self.embedding_size)
        self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_three = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_four = nn.Linear(self.embedding_size, 1, bias=False)
        self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
        if self.loss_type == 'BPR':
            self.loss_fct = BPRLoss()
        elif self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

        # parameters initialization
        self._reset_parameters()

    def _reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.embedding_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def att_out(self, hidden, star_node, batch):
        star_node_repeat = torch.index_select(star_node, 0, batch)
        sim = (hidden * star_node_repeat).sum(dim=-1)
        sim = softmax(sim, batch)
        att_hidden = sim.unsqueeze(-1) * hidden
        output = global_add_pool(att_hidden, batch)

        return output

    def forward(self, x, edge_index, batch, alias_inputs, item_seq_len):
        mask = alias_inputs.gt(0)
        hidden = self.item_embedding(x)

        star_node = global_mean_pool(hidden, batch)
        for i in range(self.step):
            hidden = self.gnncell(hidden, edge_index)
            star_node_repeat = torch.index_select(star_node, 0, batch)
            sim = (hidden * star_node_repeat).sum(dim=-1, keepdim=True) / math.sqrt(self.embedding_size)
            alpha = torch.sigmoid(sim)
            hidden = (1 - alpha) * hidden + alpha * star_node_repeat
            star_node = self.att_out(hidden, star_node, batch)

        seq_hidden = hidden[alias_inputs]
        bs, item_num, _ = seq_hidden.shape
        pos_emb = self.pos_embedding.weight[:item_num]
        pos_emb = pos_emb.unsqueeze(0).expand(bs, -1, -1)
        seq_hidden = seq_hidden + pos_emb

        # fetch the last hidden state of last timestamp
        ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
        q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
        q2 = self.linear_two(seq_hidden)
        q3 = self.linear_three(star_node).view(star_node.shape[0], 1, star_node.shape[1])

        alpha = self.linear_four(torch.sigmoid(q1 + q2 + q3))
        a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
        seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
        return layer_norm(seq_output)

    def calculate_loss(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        batch = interaction['batch']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len)
        pos_items = interaction[self.POS_ITEM_ID]
        if self.loss_type == 'BPR':
            neg_items = interaction[self.NEG_ITEM_ID]
            pos_items_emb = layer_norm(self.item_embedding(pos_items))
            neg_items_emb = layer_norm(self.item_embedding(neg_items))
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) * self.scale  # [B]
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) * self.scale  # [B]
            loss = self.loss_fct(pos_score, neg_score)
            return loss
        else:  # self.loss_type = 'CE'
            test_item_emb = layer_norm(self.item_embedding.weight)
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) * self.scale
            loss = self.loss_fct(logits, pos_items)
            return loss

    def predict(self, interaction):
        test_item = interaction[self.ITEM_ID]
        x = interaction['x']
        edge_index = interaction['edge_index']
        batch = interaction['batch']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len)
        test_item_emb = layer_norm(self.item_embedding(test_item))
        scores = torch.mul(seq_output, test_item_emb).sum(dim=1) * self.scale  # [B]
        return scores

    def full_sort_predict(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        batch = interaction['batch']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len)
        test_items_emb = layer_norm(self.item_embedding.weight)
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) * self.scale  # [B, n_items]
        return scores


================================================
FILE: recbole_gnn/model/sequential_recommender/srgnn.py
================================================
# @Time   : 2022/3/7
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
SRGNN
################################################

Reference:
    Shu Wu et al. "Session-based Recommendation with Graph Neural Networks." in AAAI 2019.

Reference code:
    https://github.com/CRIPAC-DIG/SR-GNN

"""
import numpy as np
import torch
from torch import nn
from recbole.model.loss import BPRLoss
from recbole.model.abstract_recommender import SequentialRecommender

from recbole_gnn.model.layers import SRGNNCell


class SRGNN(SequentialRecommender):
    r"""SRGNN regards the conversation history as a directed graph.
    In addition to considering the connection between the item and the adjacent item,
    it also considers the connection with other interactive items.

    Such as: A example of a session sequence(eg:item1, item2, item3, item2, item4) and the connection matrix A

    Outgoing edges:
        === ===== ===== ===== =====
         \    1     2     3     4
        === ===== ===== ===== =====
         1    0     1     0     0
         2    0     0    1/2   1/2
         3    0     1     0     0
         4    0     0     0     0
        === ===== ===== ===== =====

    Incoming edges:
        === ===== ===== ===== =====
         \    1     2     3     4
        === ===== ===== ===== =====
         1    0     0     0     0
         2   1/2    0    1/2    0
         3    0     1     0     0
         4    0     1     0     0
        === ===== ===== ===== =====
    """

    def __init__(self, config, dataset):
        super(SRGNN, self).__init__(config, dataset)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.step = config['step']
        self.device = config['device']
        self.loss_type = config['loss_type']

        # item embedding
        self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)

        # define layers and loss
        self.gnncell = SRGNNCell(self.embedding_size)
        self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
        self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
        if self.loss_type == 'BPR':
            self.loss_fct = BPRLoss()
        elif self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

        # parameters initialization
        self._reset_parameters()

    def _reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.embedding_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, edge_index, alias_inputs, item_seq_len):
        mask = alias_inputs.gt(0)
        hidden = self.item_embedding(x)
        for i in range(self.step):
            hidden = self.gnncell(hidden, edge_index)

        seq_hidden = hidden[alias_inputs]
        # fetch the last hidden state of last timestamp
        ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
        q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
        q2 = self.linear_two(seq_hidden)

        alpha = self.linear_three(torch.sigmoid(q1 + q2))
        a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
        seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
        return seq_output

    def calculate_loss(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        pos_items = interaction[self.POS_ITEM_ID]
        if self.loss_type == 'BPR':
            neg_items = interaction[self.NEG_ITEM_ID]
            pos_items_emb = self.item_embedding(pos_items)
            neg_items_emb = self.item_embedding(neg_items)
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [B]
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  # [B]
            loss = self.loss_fct(pos_score, neg_score)
            return loss
        else:  # self.loss_type = 'CE'
            test_item_emb = self.item_embedding.weight
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            loss = self.loss_fct(logits, pos_items)
            return loss

    def predict(self, interaction):
        test_item = interaction[self.ITEM_ID]
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        test_item_emb = self.item_embedding(test_item)
        scores = torch.mul(seq_output, test_item_emb).sum(dim=1)  # [B]
        return scores

    def full_sort_predict(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
        test_items_emb = self.item_embedding.weight
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  # [B, n_items]
        return scores


================================================
FILE: recbole_gnn/model/sequential_recommender/tagnn.py
================================================
# @Time   : 2022/3/17
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

r"""
TAGNN
################################################

Reference:
    Feng Yu et al. "TAGNN: Target Attentive Graph Neural Networks for Session-based Recommendation." in SIGIR 2020 short.
    Implemented using PyTorch Geometric.

Reference code:
    https://github.com/CRIPAC-DIG/TAGNN

"""
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from recbole.model.abstract_recommender import SequentialRecommender

from recbole_gnn.model.layers import SRGNNCell


class TAGNN(SequentialRecommender):
    r"""TAGNN introduces target-aware attention and adaptively activates different user interests with respect to varied target items.
    """

    def __init__(self, config, dataset):
        super(TAGNN, self).__init__(config, dataset)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.step = config['step']
        self.device = config['device']
        self.loss_type = config['loss_type']

        # item embedding
        self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)

        # define layers and loss
        self.gnncell = SRGNNCell(self.embedding_size)
        self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
        self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
        self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
        self.linear_t = nn.Linear(self.embedding_size, self.embedding_size, bias=False)  #target attention
        if self.loss_type == 'CE':
            self.loss_fct = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

        # parameters initialization
        self._reset_parameters()

    def _reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.embedding_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, edge_index, alias_inputs, item_seq_len):
        mask = alias_inputs.gt(0)
        hidden = self.item_embedding(x)
        for i in range(self.step):
            hidden = self.gnncell(hidden, edge_index)

        seq_hidden = hidden[alias_inputs]
        # fetch the last hidden state of last timestamp
        ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
        q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
        q2 = self.linear_two(seq_hidden)

        alpha = self.linear_three(torch.sigmoid(q1 + q2))
        alpha = F.softmax(alpha, 1)
        a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
        seq_output = self.linear_transform(torch.cat([a, ht], dim=1))

        seq_hidden = seq_hidden * mask.view(mask.shape[0], -1, 1).float()
        qt = self.linear_t(seq_hidden)
        b = self.item_embedding.weight
        beta = F.softmax(b @ qt.transpose(1,2), -1)
        target = beta @ seq_hidden
        a = seq_output.view(ht.shape[0], 1, ht.shape[1])  # b,1,d
        a = a + target  # b,n,d
        scores = torch.sum(a * b, -1)  # b,n
        return scores

    def calculate_loss(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        logits = self.forward(x, edge_index, alias_inputs, item_seq_len)
        pos_items = interaction[self.POS_ITEM_ID]
        loss = self.loss_fct(logits, pos_items)
        return loss

    def predict(self, interaction):
        pass

    def full_sort_predict(self, interaction):
        x = interaction['x']
        edge_index = interaction['edge_index']
        alias_inputs = interaction['alias_inputs']
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        scores = self.forward(x, edge_index, alias_inputs, item_seq_len)
        return scores


================================================
FILE: recbole_gnn/model/social_recommender/__init__.py
================================================
from recbole_gnn.model.social_recommender.diffnet import DiffNet
from recbole_gnn.model.social_recommender.mhcn import MHCN
from recbole_gnn.model.social_recommender.sept import SEPT

================================================
FILE: recbole_gnn/model/social_recommender/diffnet.py
================================================
# @Time   : 2022/3/15
# @Author : Lanling Xu
# @Email  : xulanling_sherry@163.com

r"""
DiffNet
################################################
Reference:
    Le Wu et al. "A Neural Influence Diffusion Model for Social Recommendation." in SIGIR 2019.

Reference code:
    https://github.com/PeiJieSun/diffnet
"""

import numpy as np
import torch
import torch.nn as nn

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import SocialRecommender
from recbole_gnn.model.layers import BipartiteGCNConv


class DiffNet(SocialRecommender):
    r"""DiffNet is a deep influence propagation model to stimulate how users are influenced by the recursive social diffusion process for social recommendation.
    We implement the model following the original author with a pairwise training mode.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(DiffNet, self).__init__(config, dataset)

        # load dataset info
        self.edge_index, self.edge_weight = dataset.get_bipartite_inter_mat(row='user')
        self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)

        self.net_edge_index, self.net_edge_weight = dataset.get_norm_net_adj_mat(row_norm=True)
        self.net_edge_index, self.net_edge_weight = self.net_edge_index.to(self.device), self.net_edge_weight.to(self.device)

        # load parameters info
        self.embedding_size = config['embedding_size']  # int type:the embedding size of DiffNet
        self.n_layers = config['n_layers']  # int type:the GCN layer num of DiffNet for social net
        self.reg_weight = config['reg_weight']  # float32 type: the weight decay for l2 normalization
        self.pretrained_review = config['pretrained_review']  # bool type:whether to load pre-trained review vectors of users and items

        # define layers and loss
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.embedding_size)
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.embedding_size)
        self.bipartite_gcn_conv = BipartiteGCNConv(dim=self.embedding_size)
        self.mf_loss = BPRLoss()
        self.reg_loss = EmbLoss()

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

        if self.pretrained_review:
            # handle review information, map the origin review into the new space
            self.user_review_embedding = nn.Embedding(self.n_users, self.embedding_size, padding_idx=0)
            self.user_review_embedding.weight.requires_grad = False
            self.user_review_embedding.weight.data.copy_(self.convertDistribution(dataset.user_feat['user_review_emb']))

            self.item_review_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
            self.item_review_embedding.weight.requires_grad = False
            self.item_review_embedding.weight.data.copy_(self.convertDistribution(dataset.item_feat['item_review_emb']))

            self.user_fusion_layer = nn.Linear(self.embedding_size, self.embedding_size)
            self.item_fusion_layer = nn.Linear(self.embedding_size, self.embedding_size)
            self.activation = nn.Sigmoid()

    def convertDistribution(self, x):
        mean, std = torch.mean(x), torch.std(x)
        y = (x - mean) * 0.2 / std
        return y

    def forward(self):
        user_embedding = self.user_embedding.weight
        final_item_embedding = self.item_embedding.weight

        if self.pretrained_review:
            user_reduce_dim_vector_matrix = self.activation(self.user_fusion_layer(self.user_review_embedding.weight))
            item_reduce_dim_vector_matrix = self.activation(self.item_fusion_layer(self.item_review_embedding.weight))

            user_review_vector_matrix = self.convertDistribution(user_reduce_dim_vector_matrix)
            item_review_vector_matrix = self.convertDistribution(item_reduce_dim_vector_matrix)

            user_embedding = user_embedding + user_review_vector_matrix
            final_item_embedding = final_item_embedding + item_review_vector_matrix

        user_embedding_from_consumed_items = self.bipartite_gcn_conv(x=(final_item_embedding, user_embedding), edge_index=self.edge_index.flip([0]), edge_weight=self.edge_weight, size=(self.n_items, self.n_users))

        embeddings_list = [user_embedding]
        for layer_idx in range(self.n_layers):
            user_embedding = self.bipartite_gcn_conv((user_embedding, user_embedding), self.net_edge_index.flip([0]), self.net_edge_weight, size=(self.n_users, self.n_users))
            embeddings_list.append(user_embedding)
        final_user_embedding = torch.stack(embeddings_list, dim=1)
        final_user_embedding = torch.sum(final_user_embedding, dim=1) + user_embedding_from_consumed_items

        return final_user_embedding, final_item_embedding

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()
        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        # calculate BPR Loss
        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
        mf_loss = self.mf_loss(pos_scores, neg_scores)

        # calculate regularization Loss
        u_ego_embeddings = self.user_embedding(user)
        pos_ego_embeddings = self.item_embedding(pos_item)
        neg_ego_embeddings = self.item_embedding(neg_item)

        reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)
        loss = mf_loss + self.reg_weight * reg_loss

        return loss

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()

        u_embeddings = user_all_embeddings[user]
        i_embeddings = item_all_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()
        # get user embedding from storage variable
        u_embeddings = self.restore_user_e[user]

        # dot with all item embedding to accelerate
        scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

        return scores.view(-1)

================================================
FILE: recbole_gnn/model/social_recommender/mhcn.py
================================================
# @Time   : 2022/4/5
# @Author : Lanling Xu
# @Email  : xulanling_sherry@163.com

r"""
MHCN
################################################
Reference:
    Junliang Yu et al. "Self-Supervised Multi-Channel Hypergraph Convolutional Network for Social Recommendation." in WWW 2021.

Reference code:
    https://github.com/Coder-Yu/QRec
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy.sparse import coo_matrix

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import SocialRecommender
from recbole_gnn.model.layers import BipartiteGCNConv


class GatingLayer(nn.Module):
    def __init__(self, dim):
        super(GatingLayer, self).__init__()
        self.dim = dim
        self.linear = nn.Linear(self.dim, self.dim)
        self.activation = nn.Sigmoid()

    def forward(self, emb):
        embedding = self.linear(emb)
        embedding = self.activation(embedding)
        embedding = torch.mul(emb, embedding)
        return embedding


class AttLayer(nn.Module):
    def __init__(self, dim):
        super(AttLayer, self).__init__()
        self.dim = dim
        self.attention_mat = nn.Parameter(torch.randn([self.dim, self.dim]))
        self.attention = nn.Parameter(torch.randn([1, self.dim]))

    def forward(self, *embs):
        weights = []
        emb_list = []
        for embedding in embs:
            weights.append(torch.sum(torch.mul(self.attention, torch.matmul(embedding, self.attention_mat)), dim=1))
            emb_list.append(embedding)
        score = torch.nn.Softmax(dim=0)(torch.stack(weights, dim=0))
        embeddings = torch.stack(emb_list, dim=0)
        mixed_embeddings = torch.mul(embeddings, score.unsqueeze(dim=2).repeat(1, 1, self.dim)).sum(dim=0)
        return mixed_embeddings


class MHCN(SocialRecommender):
    r"""MHCN fuses hypergraph modeling and graph neural networks in social recommendation by 
    exploiting multiple types of high-order user relations under a multi-channel setting.
    
    We implement the model following the original author with a pairwise training mode.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(MHCN, self).__init__(config, dataset)

        # load dataset info
        self.R_user_edge_index, self.R_user_edge_weight, self.R_item_edge_index, self.R_item_edge_weight = self.get_bipartite_inter_mat(dataset)
        H_s, H_j, H_p = self.get_motif_adj_matrix(dataset)

        # transform matrix to edge index and edge weight for convolution
        self.H_s_edge_index, self.H_s_edge_weight = self.get_edge_index_weight(H_s)
        self.H_j_edge_index, self.H_j_edge_weight = self.get_edge_index_weight(H_j)
        self.H_p_edge_index, self.H_p_edge_weight = self.get_edge_index_weight(H_p)

        # load parameters info
        self.embedding_size = config['embedding_size']
        self.n_layers = config['n_layers']
        self.ssl_reg = config['ssl_reg']
        self.reg_weight = config['reg_weight']

        # define embedding and loss
        self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
        self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
        self.bipartite_gcn_conv = BipartiteGCNConv(dim=self.embedding_size)
        self.mf_loss = BPRLoss()
        self.reg_loss = EmbLoss()

        # define gating layers
        self.gating_c1 = GatingLayer(self.embedding_size)
        self.gating_c2 = GatingLayer(self.embedding_size)
        self.gating_c3 = GatingLayer(self.embedding_size)
        self.gating_simple = GatingLayer(self.embedding_size)

        # define self supervised gating layers
        self.ss_gating_c1 = GatingLayer(self.embedding_size)
        self.ss_gating_c2 = GatingLayer(self.embedding_size)
        self.ss_gating_c3 = GatingLayer(self.embedding_size)

        # define attention layers
        self.attention_layer = AttLayer(self.embedding_size)

        # storage variables for full sort evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

    def get_bipartite_inter_mat(self, dataset):
        R_user_edge_index, R_user_edge_weight = dataset.get_bipartite_inter_mat(row='user', row_norm=False)
        R_item_edge_index, R_item_edge_weight = dataset.get_bipartite_inter_mat(row='item', row_norm=False)
        return R_user_edge_index.to(self.device), R_user_edge_weight.to(self.device), R_item_edge_index.to(self.device), R_item_edge_weight.to(self.device)

    def get_edge_index_weight(self, matrix):
        matrix = coo_matrix(matrix)
        edge_index = torch.stack([torch.LongTensor(matrix.row), torch.LongTensor(matrix.col)])
        edge_weight = torch.FloatTensor(matrix.data)
        return edge_index.to(self.device), edge_weight.to(self.device)

    def get_motif_adj_matrix(self, dataset):
        S = dataset.net_matrix()
        Y = dataset.inter_matrix()
        B = S.multiply(S.T)
        U = S - B
        C1 = (U.dot(U)).multiply(U.T)
        A1 = C1 + C1.T
        C2 = (B.dot(U)).multiply(U.T) + (U.dot(B)).multiply(U.T) + (U.dot(U)).multiply(B)
        A2 = C2 + C2.T
        C3 = (B.dot(B)).multiply(U) + (B.dot(U)).multiply(B) + (U.dot(B)).multiply(B)
        A3 = C3 + C3.T
        A4 = (B.dot(B)).multiply(B)
        C5 = (U.dot(U)).multiply(U) + (U.dot(U.T)).multiply(U) + (U.T.dot(U)).multiply(U)
        A5 = C5 + C5.T
        A6 = (U.dot(B)).multiply(U) + (B.dot(U.T)).multiply(U.T) + (U.T.dot(U)).multiply(B)
        A7 = (U.T.dot(B)).multiply(U.T) + (B.dot(U)).multiply(U) + (U.dot(U.T)).multiply(B)
        A8 = (Y.dot(Y.T)).multiply(B)
        A9 = (Y.dot(Y.T)).multiply(U)
        A9 = A9 + A9.T
        A10  = Y.dot(Y.T) - A8 - A9
        # addition and row-normalization
        H_s = sum([A1, A2, A3, A4, A5, A6, A7])
        # add epsilon to avoid divide by zero Warning
        H_s = H_s.multiply(1.0 / (H_s.sum(axis=1) + 1e-7).reshape(-1, 1))
        H_j = sum([A8, A9])
        H_j = H_j.multiply(1.0 / (H_j.sum(axis=1) + 1e-7).reshape(-1, 1))
        H_p = A10
        H_p = H_p.multiply(H_p > 1)
        H_p = H_p.multiply(1.0 / (H_p.sum(axis=1) + 1e-7).reshape(-1, 1))
        return H_s, H_j, H_p

    def forward(self):
        # get ego embeddings
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight

        # self-gating
        user_embeddings_c1 = self.gating_c1(user_embeddings)
        user_embeddings_c2 = self.gating_c2(user_embeddings)
        user_embeddings_c3 = self.gating_c3(user_embeddings)
        simple_user_embeddings = self.gating_simple(user_embeddings)

        all_embeddings_c1 = [user_embeddings_c1]
        all_embeddings_c2 = [user_embeddings_c2]
        all_embeddings_c3 = [user_embeddings_c3]
        all_embeddings_simple = [simple_user_embeddings]
        all_embeddings_i = [item_embeddings]

        for layer_idx in range(self.n_layers):
            mixed_embedding = self.attention_layer(user_embeddings_c1, user_embeddings_c2, user_embeddings_c3) + simple_user_embeddings / 2
            
            # Channel S
            user_embeddings_c1 = self.bipartite_gcn_conv((user_embeddings_c1, user_embeddings_c1), self.H_s_edge_index.flip([0]), self.H_s_edge_weight, size=(self.n_users, self.n_users))
            norm_embeddings = F.normalize(user_embeddings_c1, p=2, dim=1)
            all_embeddings_c1 += [norm_embeddings]

            # Channel J
            user_embeddings_c2 = self.bipartite_gcn_conv((user_embeddings_c2, user_embeddings_c2), self.H_j_edge_index.flip([0]), self.H_j_edge_weight, size=(self.n_users, self.n_users))
            norm_embeddings = F.normalize(user_embeddings_c2, p=2, dim=1)
            all_embeddings_c2 += [norm_embeddings]

            # Channel P
            user_embeddings_c3 = self.bipartite_gcn_conv((user_embeddings_c3, user_embeddings_c3), self.H_p_edge_index.flip([0]), self.H_p_edge_weight, size=(self.n_users, self.n_users))
            norm_embeddings = F.normalize(user_embeddings_c3, p=2, dim=1)
            all_embeddings_c3 += [norm_embeddings]

            # item convolution
            new_item_embeddings = self.bipartite_gcn_conv((mixed_embedding, item_embeddings), self.R_item_edge_index.flip([0]), self.R_item_edge_weight, size=(self.n_users, self.n_items))
            norm_embeddings = F.normalize(new_item_embeddings, p=2, dim=1)
            all_embeddings_i += [norm_embeddings]
            simple_user_embeddings = self.bipartite_gcn_conv((item_embeddings, simple_user_embeddings), self.R_user_edge_index.flip([0]), self.R_user_edge_weight, size=(self.n_items, self.n_users))
            norm_embeddings = F.normalize(simple_user_embeddings, p=2, dim=1)
            all_embeddings_simple += [norm_embeddings]
            item_embeddings = new_item_embeddings

        # averaging the channel-specific embeddings
        user_embeddings_c1 = torch.stack(all_embeddings_c1, dim=0).sum(dim=0)
        user_embeddings_c2 = torch.stack(all_embeddings_c2, dim=0).sum(dim=0)
        user_embeddings_c3 = torch.stack(all_embeddings_c3, dim=0).sum(dim=0)
        simple_user_embeddings = torch.stack(all_embeddings_simple, dim=0).sum(dim=0)
        item_all_embeddings = torch.stack(all_embeddings_i, dim=0).sum(dim=0)

        # aggregating channel-specific embeddings
        user_all_embeddings = self.attention_layer(user_embeddings_c1, user_embeddings_c2, user_embeddings_c3)
        user_all_embeddings += simple_user_embeddings / 2

        return user_all_embeddings, item_all_embeddings

    def hierarchical_self_supervision(self, user_embeddings, edge_index, edge_weight):
        def row_shuffle(embedding):
            shuffled_embeddings = embedding[torch.randperm(embedding.size(0))]
            return shuffled_embeddings
        def row_column_shuffle(embedding):
            shuffled_embeddings = embedding[:, torch.randperm(embedding.size(1))]
            shuffled_embeddings = shuffled_embeddings[torch.randperm(embedding.size(0))]
            return shuffled_embeddings
        def score(x1, x2):
            return torch.sum(torch.mul(x1, x2), dim=1)

        # For Douban, normalization is needed.
        # user_embeddings = F.normalize(user_embeddings, p=2, dim=1) 
        edge_embeddings = self.bipartite_gcn_conv((user_embeddings, user_embeddings), edge_index.flip([0]), edge_weight, size=(self.n_users, self.n_users))
        # Local MIM
        pos = score(user_embeddings, edge_embeddings)
        neg1 = score(row_shuffle(user_embeddings), edge_embeddings)
        neg2 = score(row_column_shuffle(edge_embeddings), user_embeddings)
        local_loss = torch.sum(-torch.log(torch.sigmoid(pos - neg1)) - torch.log(torch.sigmoid(neg1 - neg2)))
        # Global MIM
        graph = torch.mean(edge_embeddings, dim=0, keepdim=True)
        pos = score(edge_embeddings, graph)
        neg1 = score(row_column_shuffle(edge_embeddings), graph)
        global_loss = torch.sum(-torch.log(torch.sigmoid(pos - neg1)))
        return global_loss + local_loss

    def calculate_loss(self, interaction):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()
        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        # calculate BPR Loss
        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
        mf_loss = self.mf_loss(pos_scores, neg_scores)

        # calculate self-supervised loss
        ss_loss = self.hierarchical_self_supervision(self.ss_gating_c1(user_all_embeddings), self.H_s_edge_index, self.H_s_edge_weight)
        ss_loss += self.hierarchical_self_supervision(self.ss_gating_c2(user_all_embeddings), self.H_j_edge_index, self.H_j_edge_weight)
        ss_loss += self.hierarchical_self_supervision(self.ss_gating_c3(user_all_embeddings), self.H_p_edge_index, self.H_p_edge_weight)

        # calculate regularization Loss
        u_ego_embeddings = self.user_embedding(user)
        pos_ego_embeddings = self.item_embedding(pos_item)
        neg_ego_embeddings = self.item_embedding(neg_item)

        reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)
        loss = mf_loss + self.ssl_reg * ss_loss + self.reg_weight * reg_loss

        return loss

    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()

        u_embeddings = user_all_embeddings[user]
        i_embeddings = item_all_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()
        # get user embedding from storage variable
        u_embeddings = self.restore_user_e[user]

        # dot with all item embedding to accelerate
        scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

        return scores.view(-1)

================================================
FILE: recbole_gnn/model/social_recommender/sept.py
================================================
# @Time   : 2022/3/29
# @Author : Lanling Xu
# @Email  : xulanling_sherry@163.com

r"""
SEPT
################################################
Reference:
    Junliang Yu et al. "Socially-Aware Self-Supervised Tri-Training for Recommendation." in KDD 2021.

Reference code:
    https://github.com/Coder-Yu/QRec
"""

import numpy as np
import torch
import torch.nn.functional as F

from scipy.sparse import coo_matrix, eye
from torch_geometric.utils import degree

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import SocialRecommender
from recbole_gnn.model.layers import LightGCNConv


class SEPT(SocialRecommender):
    r"""SEPT is a socially-aware GCN-based SSL framework that integrates tri-training.

    Under the regime of tri-training for multi-view encoding, the framework builds three graph 
    encoders (one for recommendation) upon the augmented views and iteratively improves each 
    encoder with self-supervision signals from other users, generated by the other two encoders.

    We implement the model following the original author with a pairwise training mode.
    """
    input_type = InputType.PAIRWISE

    def __init__(self, config, dataset):
        super(SEPT, self).__init__(config, dataset)

        # load dataset info
        self.edge_index, self.edge_weight = dataset.get_norm_adj_mat()
        self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)

        # generate intermediate data
        self.social_edge_index, self.social_edge_weight, self.sharing_edge_index, \
        self.sharing_edge_weight = self.get_user_view_matrix(dataset)

        self._user = dataset.inter_feat[dataset.uid_field]
        self._item = dataset.inter_feat[dataset.iid_field]

        self._src_user = dataset.net_feat[dataset.net_src_field]
        self._tgt_user = dataset.net_feat[dataset.net_tgt_field]

        # load parameters info
        self.latent_dim = config["embedding_size"]
        self.n_layers = int(config["n_layers"])
        self.drop_ratio = config["drop_ratio"]
        self.instance_cnt = config["instance_cnt"]
        self.reg_weight = config["reg_weight"]
        self.ssl_weight = config["ssl_weight"]
        self.ssl_tau = config["ssl_tau"]

        # define layers and loss
        self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim)
        self.item_embedding = torch.nn.Embedding(self.n_items, self.latent_dim)
        self.gcn_conv = LightGCNConv(dim=self.latent_dim)
        self.mf_loss = BPRLoss()
        self.reg_loss = EmbLoss()

        # storage variables for full sort evaluation acceleration
        self.user_all_embeddings = None
        self.restore_user_e = None
        self.restore_item_e = None

        # parameters initialization
        self.apply(xavier_uniform_initialization)
        self.other_parameter_name = ['restore_user_e', 'restore_item_e']

    def get_norm_edge_weight(self, edge_index, node_num):
        r"""Get normalized edge weight using the laplace matrix.
        """
        deg = degree(edge_index[0], node_num)
        norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg))
        edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]]
        return edge_weight

    def get_user_view_matrix(self, dataset):
        # Friend View: A_f = (SS) ⊙ S
        social_mat = dataset.net_matrix()
        social_matrix = social_mat.dot(social_mat)
        social_matrix =  social_matrix.toarray() * social_mat.toarray() + eye(self.n_users)
        social_matrix = coo_matrix(social_matrix)
        social_edge_index = torch.stack([torch.LongTensor(social_matrix.row), torch.LongTensor(social_matrix.col)])
        social_edge_weight = self.get_norm_edge_weight(social_edge_index, self.n_users)

        # Sharing View: A_s = (RR^T) ⊙ S
        rating_mat = dataset.inter_matrix()
        sharing_matrix = rating_mat
Download .txt
gitextract_g58jlyza/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   ├── bug_report_CN.md
│   │   ├── feature_request.md
│   │   └── feature_request_CN.md
│   └── workflows/
│       └── python-package.yml
├── .gitignore
├── LICENSE
├── README.md
├── recbole_gnn/
│   ├── config.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataloader.py
│   │   ├── dataset.py
│   │   └── transform.py
│   ├── model/
│   │   ├── abstract_recommender.py
│   │   ├── general_recommender/
│   │   │   ├── __init__.py
│   │   │   ├── directau.py
│   │   │   ├── hmlet.py
│   │   │   ├── lightgcl.py
│   │   │   ├── lightgcn.py
│   │   │   ├── ncl.py
│   │   │   ├── ngcf.py
│   │   │   ├── sgl.py
│   │   │   ├── simgcl.py
│   │   │   ├── ssl4rec.py
│   │   │   └── xsimgcl.py
│   │   ├── layers.py
│   │   ├── sequential_recommender/
│   │   │   ├── __init__.py
│   │   │   ├── gcegnn.py
│   │   │   ├── gcsan.py
│   │   │   ├── lessr.py
│   │   │   ├── niser.py
│   │   │   ├── sgnnhn.py
│   │   │   ├── srgnn.py
│   │   │   └── tagnn.py
│   │   └── social_recommender/
│   │       ├── __init__.py
│   │       ├── diffnet.py
│   │       ├── mhcn.py
│   │       └── sept.py
│   ├── properties/
│   │   ├── model/
│   │   │   ├── DiffNet.yaml
│   │   │   ├── DirectAU.yaml
│   │   │   ├── GCEGNN.yaml
│   │   │   ├── GCSAN.yaml
│   │   │   ├── HMLET.yaml
│   │   │   ├── LESSR.yaml
│   │   │   ├── LightGCL.yaml
│   │   │   ├── LightGCN.yaml
│   │   │   ├── MHCN.yaml
│   │   │   ├── NCL.yaml
│   │   │   ├── NGCF.yaml
│   │   │   ├── NISER.yaml
│   │   │   ├── SEPT.yaml
│   │   │   ├── SGL.yaml
│   │   │   ├── SGNNHN.yaml
│   │   │   ├── SRGNN.yaml
│   │   │   ├── SSL4REC.yaml
│   │   │   ├── SimGCL.yaml
│   │   │   ├── TAGNN.yaml
│   │   │   └── XSimGCL.yaml
│   │   └── quick_start_config/
│   │       ├── sequential_base.yaml
│   │       └── social_base.yaml
│   ├── quick_start.py
│   ├── trainer.py
│   └── utils.py
├── results/
│   ├── README.md
│   ├── general/
│   │   └── ml-1m.md
│   ├── sequential/
│   │   └── diginetica.md
│   └── social/
│       └── lastfm.md
├── run_hyper.py
├── run_recbole_gnn.py
├── run_test.sh
└── tests/
    ├── test_data/
    │   └── test/
    │       ├── test.inter
    │       └── test.net
    ├── test_model.py
    └── test_model.yaml
Download .txt
SYMBOL INDEX (339 symbols across 31 files)

FILE: recbole_gnn/config.py
  class Config (line 9) | class Config(RecBole_Config):
    method __init__ (line 10) | def __init__(self, model=None, dataset=None, config_file_list=None, co...
    method compatibility_settings (line 24) | def compatibility_settings(self):
    method _get_model_and_dataset (line 35) | def _get_model_and_dataset(self, model, dataset):
    method _load_internal_config_dict (line 65) | def _load_internal_config_dict(self, model, model_class, dataset):

FILE: recbole_gnn/data/dataloader.py
  class CustomizedTrainDataLoader (line 9) | class CustomizedTrainDataLoader(TrainDataLoader):
    method __init__ (line 10) | def __init__(self, config, dataset, sampler, shuffle=False):
  class CustomizedNegSampleEvalDataLoader (line 16) | class CustomizedNegSampleEvalDataLoader(NegSampleEvalDataLoader):
    method __init__ (line 17) | def __init__(self, config, dataset, sampler, shuffle=False):
    method collate_fn (line 22) | def collate_fn(self, index):
  class CustomizedFullSortEvalDataLoader (line 55) | class CustomizedFullSortEvalDataLoader(FullSortEvalDataLoader):
    method __init__ (line 56) | def __init__(self, config, dataset, sampler, shuffle=False):

FILE: recbole_gnn/data/dataset.py
  class GeneralGraphDataset (line 24) | class GeneralGraphDataset(RecBoleDataset):
    method __init__ (line 25) | def __init__(self, config):
    method save (line 30) | def save(self):
    method edge_index_to_adj_t (line 42) | def edge_index_to_adj_t(edge_index, edge_weight, m_num_nodes, n_num_no...
    method get_norm_adj_mat (line 49) | def get_norm_adj_mat(self, enable_sparse=False):
    method get_bipartite_inter_mat (line 81) | def get_bipartite_inter_mat(self, row='user', row_norm=True):
  class SessionGraphDataset (line 109) | class SessionGraphDataset(SequentialDataset):
    method __init__ (line 110) | def __init__(self, config):
    method session_graph_construction (line 113) | def session_graph_construction(self):
    method build (line 138) | def build(self):
  class MultiBehaviorDataset (line 145) | class MultiBehaviorDataset(SessionGraphDataset):
    method session_graph_construction (line 147) | def session_graph_construction(self):
  class LESSRDataset (line 197) | class LESSRDataset(SessionGraphDataset):
    method session_graph_construction (line 198) | def session_graph_construction(self):
  class GCEGNNDataset (line 235) | class GCEGNNDataset(SequentialDataset):
    method __init__ (line 236) | def __init__(self, config):
    method reverse_session (line 239) | def reverse_session(self):
    method bidirectional_edge (line 246) | def bidirectional_edge(self, edge_index):
    method session_graph_construction (line 253) | def session_graph_construction(self):
    method build (line 295) | def build(self):
  class SocialDataset (line 303) | class SocialDataset(GeneralGraphDataset):
    method __init__ (line 322) | def __init__(self, config):
    method _get_field_from_config (line 325) | def _get_field_from_config(self):
    method _data_filtering (line 337) | def _data_filtering(self):
    method _filter_net_by_inter (line 342) | def _filter_net_by_inter(self):
    method _load_data (line 349) | def _load_data(self, token, dataset_path):
    method net_num (line 354) | def net_num(self):
    method __str__ (line 362) | def __str__(self):
    method _build_feat_name_list (line 369) | def _build_feat_name_list(self):
    method _load_net (line 375) | def _load_net(self, token, dataset_path):
    method _check_net (line 392) | def _check_net(self, net):
    method _init_alias (line 397) | def _init_alias(self):
    method get_norm_net_adj_mat (line 423) | def get_norm_net_adj_mat(self, row_norm=False):
    method net_matrix (line 448) | def net_matrix(self, form='coo', value_field=None):

FILE: recbole_gnn/data/transform.py
  function gnn_construct_transform (line 7) | def gnn_construct_transform(config):
  class SessionGraph (line 16) | class SessionGraph:
    method __init__ (line 17) | def __init__(self, config):
    method __call__ (line 21) | def __call__(self, dataset, interaction):

FILE: recbole_gnn/model/abstract_recommender.py
  class GeneralGraphRecommender (line 7) | class GeneralGraphRecommender(GeneralRecommender):
    method __init__ (line 13) | def __init__(self, config, dataset):
  class SocialRecommender (line 23) | class SocialRecommender(GeneralRecommender):
    method __init__ (line 29) | def __init__(self, config, dataset):

FILE: recbole_gnn/model/general_recommender/directau.py
  class DirectAU (line 24) | class DirectAU(GeneralGraphRecommender):
    method __init__ (line 27) | def __init__(self, config, dataset):
    method forward (line 50) | def forward(self, user, item):
    method alignment (line 55) | def alignment(x, y, alpha=2):
    method uniformity (line 59) | def uniformity(x, t=2):
    method calculate_loss (line 62) | def calculate_loss(self, interaction):
    method predict (line 75) | def predict(self, interaction):
    method full_sort_predict (line 82) | def full_sort_predict(self, interaction):
  class MFEncoder (line 96) | class MFEncoder(BPR):
    method __init__ (line 97) | def __init__(self, config, dataset):
    method forward (line 100) | def forward(self, user_id, item_id):
    method get_all_embeddings (line 103) | def get_all_embeddings(self):
  class LGCNEncoder (line 109) | class LGCNEncoder(LightGCN):
    method __init__ (line 110) | def __init__(self, config, dataset):
    method forward (line 113) | def forward(self, user_id, item_id):
    method get_all_embeddings (line 119) | def get_all_embeddings(self):

FILE: recbole_gnn/model/general_recommender/hmlet.py
  class Gating_Net (line 27) | class Gating_Net(nn.Module):
    method __init__ (line 28) | def __init__(self, embedding_dim, mlp_dims, dropout_p):
    method gumbel_softmax (line 46) | def gumbel_softmax(self, logits, temperature, hard):
    method gumbel_softmax_sample (line 65) | def gumbel_softmax_sample(self, logits, temperature):
    method sample_gumbel (line 71) | def sample_gumbel(self, logits):
    method forward (line 79) | def forward(self, feature, temperature, hard):
  class HMLET (line 87) | class HMLET(GeneralGraphRecommender):
    method __init__ (line 92) | def __init__(self, config, dataset):
    method _gating_freeze (line 129) | def _gating_freeze(self, model, freeze_flag):
    method __choosing_one (line 134) | def __choosing_one(self, features, gumbel_out):
    method __where (line 138) | def __where(self, idx, lst):
    method get_ego_embeddings (line 144) | def get_ego_embeddings(self):
    method forward (line 154) | def forward(self):
    method calculate_loss (line 179) | def calculate_loss(self, interaction):
    method predict (line 208) | def predict(self, interaction):
    method full_sort_predict (line 219) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/general_recommender/lightgcl.py
  class LightGCL (line 27) | class LightGCL(GeneralRecommender):
    method __init__ (line 38) | def __init__(self, config, dataset):
    method create_adjust_matrix (line 86) | def create_adjust_matrix(self):
    method coo2tensor (line 103) | def coo2tensor(self, matrix: sp.coo_matrix):
    method sparse_dropout (line 119) | def sparse_dropout(self, matrix, dropout):
    method forward (line 127) | def forward(self):
    method calculate_loss (line 144) | def calculate_loss(self, interaction):
    method calc_bpr_loss (line 157) | def calc_bpr_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list, ...
    method calc_ssl_loss (line 184) | def calc_ssl_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list):
    method predict (line 215) | def predict(self, interaction):
    method full_sort_predict (line 222) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/general_recommender/lightgcn.py
  class LightGCN (line 26) | class LightGCN(GeneralGraphRecommender):
    method __init__ (line 36) | def __init__(self, config, dataset):
    method get_ego_embeddings (line 60) | def get_ego_embeddings(self):
    method forward (line 70) | def forward(self):
    method calculate_loss (line 83) | def calculate_loss(self, interaction):
    method predict (line 112) | def predict(self, interaction):
    method full_sort_predict (line 123) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/general_recommender/ncl.py
  class NCL (line 20) | class NCL(GeneralGraphRecommender):
    method __init__ (line 23) | def __init__(self, config, dataset):
    method e_step (line 60) | def e_step(self):
    method run_kmeans (line 66) | def run_kmeans(self, x):
    method get_ego_embeddings (line 83) | def get_ego_embeddings(self):
    method forward (line 93) | def forward(self):
    method ProtoNCE_loss (line 106) | def ProtoNCE_loss(self, node_embedding, user, item):
    method ssl_layer_loss (line 135) | def ssl_layer_loss(self, current_embedding, previous_embedding, user, ...
    method calculate_loss (line 166) | def calculate_loss(self, interaction):
    method predict (line 201) | def predict(self, interaction):
    method full_sort_predict (line 212) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/general_recommender/ngcf.py
  class NGCF (line 28) | class NGCF(GeneralGraphRecommender):
    method __init__ (line 34) | def __init__(self, config, dataset):
    method get_ego_embeddings (line 62) | def get_ego_embeddings(self):
    method forward (line 73) | def forward(self):
    method calculate_loss (line 106) | def calculate_loss(self, interaction):
    method predict (line 128) | def predict(self, interaction):
    method full_sort_predict (line 139) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/general_recommender/sgl.py
  class SGL (line 29) | class SGL(GeneralGraphRecommender):
    method __init__ (line 43) | def __init__(self, config, dataset):
    method train (line 73) | def train(self, mode: bool = True):
    method graph_construction (line 82) | def graph_construction(self):
    method random_graph_augment (line 93) | def random_graph_augment(self):
    method forward (line 128) | def forward(self, graph=None):
    method calc_bpr_loss (line 147) | def calc_bpr_loss(self, user_emd, item_emd, user_list, pos_item_list, ...
    method calc_ssl_loss (line 176) | def calc_ssl_loss(self, user_list, pos_item_list, user_sub1, user_sub2...
    method calculate_loss (line 211) | def calculate_loss(self, interaction):
    method predict (line 227) | def predict(self, interaction):
    method full_sort_predict (line 235) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/general_recommender/simgcl.py
  class SimGCL (line 16) | class SimGCL(LightGCN):
    method __init__ (line 17) | def __init__(self, config, dataset):
    method forward (line 24) | def forward(self, perturbed=False):
    method calculate_cl_loss (line 40) | def calculate_cl_loss(self, x1, x2):
    method calculate_loss (line 48) | def calculate_loss(self, interaction):

FILE: recbole_gnn/model/general_recommender/ssl4rec.py
  class SSL4REC (line 22) | class SSL4REC(GeneralGraphRecommender):
    method __init__ (line 25) | def __init__(self, config, dataset):
    method forward (line 46) | def forward(self, user, item):
    method calculate_batch_softmax_loss (line 50) | def calculate_batch_softmax_loss(self, user_emb, item_emb, temperature):
    method calculate_loss (line 59) | def calculate_loss(self, interaction):
    method predict (line 77) | def predict(self, interaction):
    method full_sort_predict (line 88) | def full_sort_predict(self, interaction):
  class DNN_Encoder (line 102) | class DNN_Encoder(nn.Module):
    method __init__ (line 103) | def __init__(self, config, dataset):
    method reset_parameters (line 133) | def reset_parameters(self):
    method forward (line 137) | def forward(self, q, x):
    method item_encoding (line 146) | def item_encoding(self, x):
    method calculate_cl_loss (line 156) | def calculate_cl_loss(self, idx):

FILE: recbole_gnn/model/general_recommender/xsimgcl.py
  class XSimGCL (line 19) | class XSimGCL(LightGCN):
    method __init__ (line 20) | def __init__(self, config, dataset):
    method forward (line 28) | def forward(self, perturbed=False):
    method calculate_cl_loss (line 50) | def calculate_cl_loss(self, x1, x2):
    method calculate_loss (line 58) | def calculate_loss(self, interaction):

FILE: recbole_gnn/model/layers.py
  class LightGCNConv (line 8) | class LightGCNConv(MessagePassing):
    method __init__ (line 9) | def __init__(self, dim):
    method forward (line 13) | def forward(self, x, edge_index, edge_weight):
    method message (line 16) | def message(self, x_j, edge_weight):
    method message_and_aggregate (line 19) | def message_and_aggregate(self, adj_t, x):
    method __repr__ (line 22) | def __repr__(self):
  class BipartiteGCNConv (line 26) | class BipartiteGCNConv(MessagePassing):
    method __init__ (line 27) | def __init__(self, dim):
    method forward (line 31) | def forward(self, x, edge_index, edge_weight, size):
    method message (line 34) | def message(self, x_j, edge_weight):
    method __repr__ (line 37) | def __repr__(self):
  class BiGNNConv (line 41) | class BiGNNConv(MessagePassing):
    method __init__ (line 48) | def __init__(self, in_channels, out_channels):
    method forward (line 54) | def forward(self, x, edge_index, edge_weight):
    method message (line 60) | def message(self, x_j, edge_weight):
    method message_and_aggregate (line 63) | def message_and_aggregate(self, adj_t, x):
    method __repr__ (line 66) | def __repr__(self):
  class SRGNNConv (line 70) | class SRGNNConv(MessagePassing):
    method __init__ (line 71) | def __init__(self, dim):
    method forward (line 77) | def forward(self, x, edge_index):
  class SRGNNCell (line 82) | class SRGNNCell(nn.Module):
    method __init__ (line 83) | def __init__(self, dim):
    method forward (line 95) | def forward(self, hidden, edge_index):
    method _reset_parameters (line 111) | def _reset_parameters(self):

FILE: recbole_gnn/model/sequential_recommender/gcegnn.py
  class LocalAggregator (line 28) | class LocalAggregator(MessagePassing):
    method __init__ (line 29) | def __init__(self, dim, alpha):
    method forward (line 34) | def forward(self, x, edge_index, edge_attr):
    method message (line 37) | def message(self, x_j, x_i, edge_attr, index, ptr, size_i):
  class GlobalAggregator (line 46) | class GlobalAggregator(nn.Module):
    method __init__ (line 47) | def __init__(self, dim, dropout, act=torch.relu):
    method forward (line 58) | def forward(self, self_vectors, neighbor_vector, batch_size, masks, ne...
  class GCEGNN (line 76) | class GCEGNN(SequentialRecommender):
    method __init__ (line 77) | def __init__(self, config, dataset):
    method reset_parameters (line 124) | def reset_parameters(self):
    method _add_edge (line 129) | def _add_edge(self, graph, sid, tid):
    method construct_global_graph (line 134) | def construct_global_graph(self, dataset):
    method fusion (line 158) | def fusion(self, hidden, mask):
    method forward (line 174) | def forward(self, x, edge_index, edge_attr, alias_inputs, item_seq_len):
    method calculate_loss (line 234) | def calculate_loss(self, interaction):
    method predict (line 256) | def predict(self, interaction):
    method full_sort_predict (line 268) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/sequential_recommender/gcsan.py
  class GCSAN (line 23) | class GCSAN(SequentialRecommender):
    method __init__ (line 34) | def __init__(self, config, dataset):
    method _init_weights (line 80) | def _init_weights(self, module):
    method get_attention_mask (line 92) | def get_attention_mask(self, item_seq):
    method forward (line 108) | def forward(self, x, edge_index, alias_inputs, item_seq_len):
    method calculate_loss (line 124) | def calculate_loss(self, interaction):
    method predict (line 146) | def predict(self, interaction):
    method full_sort_predict (line 157) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/sequential_recommender/lessr.py
  class EOPA (line 24) | class EOPA(nn.Module):
    method __init__ (line 25) | def __init__(
    method reducer (line 36) | def reducer(self, nodes):
    method forward (line 45) | def forward(self, mg, feat):
  class SGAT (line 63) | class SGAT(nn.Module):
    method __init__ (line 64) | def __init__(
    method forward (line 82) | def forward(self, sg, feat):
  class AttnReadout (line 100) | class AttnReadout(nn.Module):
    method __init__ (line 101) | def __init__(
    method forward (line 122) | def forward(self, g, feat, last_nodes, batch):
  class LESSR (line 140) | class LESSR(SequentialRecommender):
    method __init__ (line 152) | def __init__(self, config, dataset):
    method forward (line 202) | def forward(self, x, edge_index_EOP, edge_index_shortcut, batch, is_la...
    method calculate_loss (line 223) | def calculate_loss(self, interaction):
    method predict (line 236) | def predict(self, interaction):
    method full_sort_predict (line 248) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/sequential_recommender/niser.py
  class NISER (line 23) | class NISER(SequentialRecommender):
    method __init__ (line 27) | def __init__(self, config, dataset):
    method _reset_parameters (line 59) | def _reset_parameters(self):
    method forward (line 64) | def forward(self, x, edge_index, alias_inputs, item_seq_len):
    method calculate_loss (line 91) | def calculate_loss(self, interaction):
    method predict (line 112) | def predict(self, interaction):
    method full_sort_predict (line 123) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/sequential_recommender/sgnnhn.py
  function layer_norm (line 29) | def layer_norm(x):
  class SGNNHN (line 37) | class SGNNHN(SequentialRecommender):
    method __init__ (line 42) | def __init__(self, config, dataset):
    method _reset_parameters (line 74) | def _reset_parameters(self):
    method att_out (line 79) | def att_out(self, hidden, star_node, batch):
    method forward (line 88) | def forward(self, x, edge_index, batch, alias_inputs, item_seq_len):
    method calculate_loss (line 118) | def calculate_loss(self, interaction):
    method predict (line 140) | def predict(self, interaction):
    method full_sort_predict (line 152) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/sequential_recommender/srgnn.py
  class SRGNN (line 25) | class SRGNN(SequentialRecommender):
    method __init__ (line 53) | def __init__(self, config, dataset):
    method _reset_parameters (line 81) | def _reset_parameters(self):
    method forward (line 86) | def forward(self, x, edge_index, alias_inputs, item_seq_len):
    method calculate_loss (line 103) | def calculate_loss(self, interaction):
    method predict (line 124) | def predict(self, interaction):
    method full_sort_predict (line 135) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/sequential_recommender/tagnn.py
  class TAGNN (line 26) | class TAGNN(SequentialRecommender):
    method __init__ (line 30) | def __init__(self, config, dataset):
    method _reset_parameters (line 57) | def _reset_parameters(self):
    method forward (line 62) | def forward(self, x, edge_index, alias_inputs, item_seq_len):
    method calculate_loss (line 89) | def calculate_loss(self, interaction):
    method predict (line 99) | def predict(self, interaction):
    method full_sort_predict (line 102) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/social_recommender/diffnet.py
  class DiffNet (line 27) | class DiffNet(SocialRecommender):
    method __init__ (line 33) | def __init__(self, config, dataset):
    method convertDistribution (line 78) | def convertDistribution(self, x):
    method forward (line 83) | def forward(self):
    method calculate_loss (line 108) | def calculate_loss(self, interaction):
    method predict (line 137) | def predict(self, interaction):
    method full_sort_predict (line 148) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/social_recommender/mhcn.py
  class GatingLayer (line 30) | class GatingLayer(nn.Module):
    method __init__ (line 31) | def __init__(self, dim):
    method forward (line 37) | def forward(self, emb):
  class AttLayer (line 44) | class AttLayer(nn.Module):
    method __init__ (line 45) | def __init__(self, dim):
    method forward (line 51) | def forward(self, *embs):
  class MHCN (line 63) | class MHCN(SocialRecommender):
    method __init__ (line 71) | def __init__(self, config, dataset):
    method get_bipartite_inter_mat (line 118) | def get_bipartite_inter_mat(self, dataset):
    method get_edge_index_weight (line 123) | def get_edge_index_weight(self, matrix):
    method get_motif_adj_matrix (line 129) | def get_motif_adj_matrix(self, dataset):
    method forward (line 160) | def forward(self):
    method hierarchical_self_supervision (line 217) | def hierarchical_self_supervision(self, user_embeddings, edge_index, e...
    method calculate_loss (line 243) | def calculate_loss(self, interaction):
    method predict (line 277) | def predict(self, interaction):
    method full_sort_predict (line 288) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/model/social_recommender/sept.py
  class SEPT (line 30) | class SEPT(SocialRecommender):
    method __init__ (line 41) | def __init__(self, config, dataset):
    method get_norm_edge_weight (line 83) | def get_norm_edge_weight(self, edge_index, node_num):
    method get_user_view_matrix (line 91) | def get_user_view_matrix(self, dataset):
    method subgraph_construction (line 111) | def subgraph_construction(self):
    method get_ego_embeddings (line 135) | def get_ego_embeddings(self):
    method forward (line 145) | def forward(self, graph=None):
    method user_view_forward (line 165) | def user_view_forward(self):
    method label_prediction (line 189) | def label_prediction(self, emb, aug_emb):
    method sampling (line 194) | def sampling(self, logits):
    method generate_pesudo_labels (line 197) | def generate_pesudo_labels(self, prob1, prob2):
    method calculate_ssl_loss (line 202) | def calculate_ssl_loss(self, aug_emb, positive, emb):
    method calculate_rec_loss (line 211) | def calculate_rec_loss(self, interaction):
    method calculate_loss (line 240) | def calculate_loss(self, interaction):
    method predict (line 281) | def predict(self, interaction):
    method full_sort_predict (line 292) | def full_sort_predict(self, interaction):

FILE: recbole_gnn/quick_start.py
  function run_recbole_gnn (line 9) | def run_recbole_gnn(model=None, dataset=None, config_file_list=None, con...
  function objective_function (line 66) | def objective_function(config_dict=None, config_file_list=None, saved=Tr...

FILE: recbole_gnn/trainer.py
  class NCLTrainer (line 9) | class NCLTrainer(Trainer):
    method __init__ (line 10) | def __init__(self, config, model):
    method fit (line 16) | def fit(self, train_data, valid_data=None, verbose=True, saved=True, s...
    method _train_epoch (line 100) | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_pro...
  class HMLETTrainer (line 147) | class HMLETTrainer(Trainer):
    method __init__ (line 148) | def __init__(self, config, model):
    method _train_epoch (line 157) | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_pro...
  class SEPTTrainer (line 169) | class SEPTTrainer(Trainer):
    method __init__ (line 170) | def __init__(self, config, model):
    method _train_epoch (line 174) | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_pro...

FILE: recbole_gnn/utils.py
  function create_dataset (line 16) | def create_dataset(config):
  function get_model (line 62) | def get_model(model_name):
  function _get_customized_dataloader (line 88) | def _get_customized_dataloader(config, phase):
  function data_preparation (line 99) | def data_preparation(config, dataset):
  function get_trainer (line 145) | def get_trainer(model_type, model_name):
  class ModelType (line 159) | class ModelType(Enum):

FILE: run_hyper.py
  function main (line 7) | def main():

FILE: tests/test_model.py
  function quick_test (line 10) | def quick_test(config_dict):
  class TestGeneralRecommender (line 14) | class TestGeneralRecommender(unittest.TestCase):
    method test_bpr (line 15) | def test_bpr(self):
    method test_neumf (line 21) | def test_neumf(self):
    method test_ngcf (line 27) | def test_ngcf(self):
    method test_lightgcn (line 33) | def test_lightgcn(self):
    method test_sgl (line 39) | def test_sgl(self):
    method test_hmlet (line 45) | def test_hmlet(self):
    method test_ncl (line 51) | def test_ncl(self):
    method test_simgcl (line 58) | def test_simgcl(self):
    method test_xsimgcl (line 64) | def test_xsimgcl(self):
    method test_lightgcl (line 70) | def test_lightgcl(self):
    method test_directau (line 76) | def test_directau(self):
    method test_ssl4rec (line 82) | def test_ssl4rec(self):
  class TestSequentialRecommender (line 89) | class TestSequentialRecommender(unittest.TestCase):
    method test_gru4rec (line 90) | def test_gru4rec(self):
    method test_narm (line 96) | def test_narm(self):
    method test_sasrec (line 102) | def test_sasrec(self):
    method test_srgnn (line 108) | def test_srgnn(self):
    method test_srgnn_uni100 (line 114) | def test_srgnn_uni100(self):
    method test_gcsan (line 125) | def test_gcsan(self):
    method test_niser (line 131) | def test_niser(self):
    method test_lessr (line 137) | def test_lessr(self):
    method test_tagnn (line 143) | def test_tagnn(self):
    method test_gcegnn (line 149) | def test_gcegnn(self):
    method test_sgnnhn (line 155) | def test_sgnnhn(self):
  class TestSocialRecommender (line 162) | class TestSocialRecommender(unittest.TestCase):
    method test_diffnet (line 163) | def test_diffnet(self):
    method test_mhcn (line 169) | def test_mhcn(self):
    method test_sept (line 175) | def test_sept(self):
Condensed preview — 74 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (408K chars).
[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "chars": 845,
    "preview": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: \"[\\U0001F41BBUG] Describe your problem in one sent"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report_CN.md",
    "chars": 480,
    "preview": "---\nname: Bug 报告\nabout: 提交一份 bug 报告,帮助 RecBole-GNN 变得更好\ntitle: \"[\\U0001F41BBUG] 用一句话描述您的问题。\"\nlabels: bug\nassignees: ''\n\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "chars": 674,
    "preview": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: \"[\\U0001F4A1SUG] Description of what you want t"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request_CN.md",
    "chars": 309,
    "preview": "---\nname: 请求添加新功能\nabout: 提出一个关于本项目新功能/新特性的建议\ntitle: \"[\\U0001F4A1SUG] 一句话描述您希望新增的功能或特性\"\nlabels: enhancement\nassignees: ''"
  },
  {
    "path": ".github/workflows/python-package.yml",
    "chars": 1462,
    "preview": "name: RecBole-GNN tests\n\n# Controls when the action will run. \non:\n  # Triggers the workflow on push or pull request eve"
  },
  {
    "path": ".gitignore",
    "chars": 1843,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2021 RUCAIBox\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README.md",
    "chars": 10168,
    "preview": "# RecBole-GNN\n\n![](asset/recbole-gnn-logo.png)\n\n-----\n\n*Updates*:\n\n* [Oct 29, 2023] Add [SSL4Rec](https://github.com/RUC"
  },
  {
    "path": "recbole_gnn/config.py",
    "chars": 3582,
    "preview": "import os\nimport recbole\nfrom recbole.config.configurator import Config as RecBole_Config\nfrom recbole.utils import Mode"
  },
  {
    "path": "recbole_gnn/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "recbole_gnn/data/dataloader.py",
    "chars": 2553,
    "preview": "import numpy as np\nimport torch\nfrom recbole.data.interaction import cat_interactions\nfrom recbole.data.dataloader.gener"
  },
  {
    "path": "recbole_gnn/data/dataset.py",
    "chars": 18495,
    "preview": "import os\nimport torch\nimport numpy as np\nimport pandas as pd\n\nfrom tqdm import tqdm\nfrom torch_geometric.nn.conv.gcn_co"
  },
  {
    "path": "recbole_gnn/data/transform.py",
    "chars": 1980,
    "preview": "from logging import getLogger\nimport torch\nfrom torch.nn.utils.rnn import pad_sequence\nfrom recbole.data.interaction imp"
  },
  {
    "path": "recbole_gnn/model/abstract_recommender.py",
    "chars": 1417,
    "preview": "from recbole.model.abstract_recommender import GeneralRecommender\nfrom recbole.utils import ModelType as RecBoleModelTyp"
  },
  {
    "path": "recbole_gnn/model/general_recommender/__init__.py",
    "chars": 638,
    "preview": "from recbole_gnn.model.general_recommender.lightgcn import LightGCN\nfrom recbole_gnn.model.general_recommender.hmlet imp"
  },
  {
    "path": "recbole_gnn/model/general_recommender/directau.py",
    "chars": 4090,
    "preview": "# r\"\"\"\n# DiretAU\n# ################################################\n# Reference:\n#     Chenyang Wang et al. \"Towards Rep"
  },
  {
    "path": "recbole_gnn/model/general_recommender/hmlet.py",
    "chars": 10066,
    "preview": "# @Time   : 2022/3/21\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nHMLET\n##############################"
  },
  {
    "path": "recbole_gnn/model/general_recommender/lightgcl.py",
    "chars": 9555,
    "preview": "# -*- coding: utf-8 -*-\r\n# @Time   : 2023/04/12\r\n# @Author : Wanli Yang\r\n# @Email  : 2013774@mail.nankai.edu.cn\r\n\r\nr\"\"\"\r"
  },
  {
    "path": "recbole_gnn/model/general_recommender/lightgcn.py",
    "chars": 5610,
    "preview": "# @Time   : 2022/3/8\n# @Author : Lanling Xu\n# @Email  : xulanling_sherry@163.com\n\nr\"\"\"\nLightGCN\n########################"
  },
  {
    "path": "recbole_gnn/model/general_recommender/ncl.py",
    "chars": 10067,
    "preview": "# -*- coding: utf-8 -*-\nr\"\"\"\nNCL\n################################################\nReference:\n    Zihan Lin*, Changxin Ti"
  },
  {
    "path": "recbole_gnn/model/general_recommender/ngcf.py",
    "chars": 6376,
    "preview": "# @Time   : 2022/3/8\n# @Author : Changxin Tian\n# @Email  : cx.tian@outlook.com\nr\"\"\"\nNGCF\n###############################"
  },
  {
    "path": "recbole_gnn/model/general_recommender/sgl.py",
    "chars": 10371,
    "preview": "# -*- coding: utf-8 -*-\n# @Time   : 2022/3/8\n# @Author : Changxin Tian\n# @Email  : cx.tian@outlook.com\nr\"\"\"\nSGL\n########"
  },
  {
    "path": "recbole_gnn/model/general_recommender/simgcl.py",
    "chars": 2524,
    "preview": "# -*- coding: utf-8 -*-\nr\"\"\"\nSimGCL\n################################################\nReference:\n    Junliang Yu, Hongzhi"
  },
  {
    "path": "recbole_gnn/model/general_recommender/ssl4rec.py",
    "chars": 5733,
    "preview": "r\"\"\"\nSSL4REC\n################################################\nReference:\n    Tiansheng Yao et al. \"Self-supervised Learn"
  },
  {
    "path": "recbole_gnn/model/general_recommender/xsimgcl.py",
    "chars": 4028,
    "preview": "# -*- coding: utf-8 -*-\nr\"\"\"\nXSimGCL\n################################################\nReference:\n    Junliang Yu, Xin Xi"
  },
  {
    "path": "recbole_gnn/model/layers.py",
    "chars": 3699,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.nn import MessagePassing\nfrom torch_sparse im"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/__init__.py",
    "chars": 459,
    "preview": "from recbole_gnn.model.sequential_recommender.gcegnn import GCEGNN\nfrom recbole_gnn.model.sequential_recommender.gcsan i"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/gcegnn.py",
    "chars": 12026,
    "preview": "# @Time   : 2022/3/22\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nGCE-GNN\n############################"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/gcsan.py",
    "chars": 7214,
    "preview": "# @Time   : 2022/3/7\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nGCSAN\n###############################"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/lessr.py",
    "chars": 9703,
    "preview": "# @Time   : 2022/3/11\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nLESSR\n##############################"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/niser.py",
    "chars": 5597,
    "preview": "# @Time   : 2022/3/7\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nNISER\n###############################"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/sgnnhn.py",
    "chars": 6744,
    "preview": "# @Time   : 2022/3/28\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nSRGNN\n##############################"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/srgnn.py",
    "chars": 5552,
    "preview": "# @Time   : 2022/3/7\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nSRGNN\n###############################"
  },
  {
    "path": "recbole_gnn/model/sequential_recommender/tagnn.py",
    "chars": 4056,
    "preview": "# @Time   : 2022/3/17\n# @Author : Yupeng Hou\n# @Email  : houyupeng@ruc.edu.cn\n\nr\"\"\"\nTAGNN\n##############################"
  },
  {
    "path": "recbole_gnn/model/social_recommender/__init__.py",
    "chars": 182,
    "preview": "from recbole_gnn.model.social_recommender.diffnet import DiffNet\nfrom recbole_gnn.model.social_recommender.mhcn import M"
  },
  {
    "path": "recbole_gnn/model/social_recommender/diffnet.py",
    "chars": 7345,
    "preview": "# @Time   : 2022/3/15\n# @Author : Lanling Xu\n# @Email  : xulanling_sherry@163.com\n\nr\"\"\"\nDiffNet\n########################"
  },
  {
    "path": "recbole_gnn/model/social_recommender/mhcn.py",
    "chars": 13866,
    "preview": "# @Time   : 2022/4/5\n# @Author : Lanling Xu\n# @Email  : xulanling_sherry@163.com\n\nr\"\"\"\nMHCN\n############################"
  },
  {
    "path": "recbole_gnn/model/social_recommender/sept.py",
    "chars": 13443,
    "preview": "# @Time   : 2022/3/29\n# @Author : Lanling Xu\n# @Email  : xulanling_sherry@163.com\n\nr\"\"\"\nSEPT\n###########################"
  },
  {
    "path": "recbole_gnn/properties/model/DiffNet.yaml",
    "chars": 73,
    "preview": "embedding_size: 64\nn_layers: 2\nreg_weight: 1e-05\npretrained_review: False"
  },
  {
    "path": "recbole_gnn/properties/model/DirectAU.yaml",
    "chars": 144,
    "preview": "embedding_size: 64\nencoder: \"MF\"   # \"MF\" or \"lightGCN\"\ngamma: 0.5\nweight_decay: 1e-6\ntrain_batch_size: 256\n\n# n_layers:"
  },
  {
    "path": "recbole_gnn/properties/model/GCEGNN.yaml",
    "chars": 191,
    "preview": "embedding_size: 64\nleakyrelu_alpha: 0.2\ndropout_local: 0.\ndropout_global: 0.5\ndropout_gcn: 0.\nloss_type: CE\ngnn_transfor"
  },
  {
    "path": "recbole_gnn/properties/model/GCSAN.yaml",
    "chars": 247,
    "preview": "n_layers: 1\nn_heads: 1\nhidden_size: 64\ninner_size: 256\nhidden_dropout_prob: 0.2\nattn_dropout_prob: 0.2\nhidden_act: 'gelu"
  },
  {
    "path": "recbole_gnn/properties/model/HMLET.yaml",
    "chars": 251,
    "preview": "embedding_size: 64\nn_layers: 4\nreg_weight: 1e-05\nrequire_pow: True\ngate_layer_ids: [2,3]\ngating_mlp_dims: [64,16,2]\ndrop"
  },
  {
    "path": "recbole_gnn/properties/model/LESSR.yaml",
    "chars": 103,
    "preview": "embedding_size: 64\nn_layers: 4\nbatch_norm: True\nfeat_drop: 0.2\nloss_type: CE\ngnn_transform: sess_graph\n"
  },
  {
    "path": "recbole_gnn/properties/model/LightGCL.yaml",
    "chars": 550,
    "preview": "embedding_size: 64              # (int) The embedding size of users and items.\r\nn_layers: 2                     # (int) "
  },
  {
    "path": "recbole_gnn/properties/model/LightGCN.yaml",
    "chars": 66,
    "preview": "embedding_size: 64\nn_layers: 2\nreg_weight: 1e-05\nrequire_pow: True"
  },
  {
    "path": "recbole_gnn/properties/model/MHCN.yaml",
    "chars": 63,
    "preview": "embedding_size: 64\nn_layers: 2\nssl_reg: 1e-05\nreg_weight: 1e-05"
  },
  {
    "path": "recbole_gnn/properties/model/NCL.yaml",
    "chars": 167,
    "preview": "embedding_size: 64\nn_layers: 3\nreg_weight: 1e-4\n\nssl_temp: 0.1\nssl_reg: 1e-7\nhyper_layers: 1\n\nalpha: 1\n\nproto_reg: 8e-8\n"
  },
  {
    "path": "recbole_gnn/properties/model/NGCF.yaml",
    "chars": 103,
    "preview": "embedding_size: 64\nhidden_size_list: [64,64,64]\nnode_dropout: 0.0\nmessage_dropout: 0.1\nreg_weight: 1e-5"
  },
  {
    "path": "recbole_gnn/properties/model/NISER.yaml",
    "chars": 96,
    "preview": "embedding_size: 64\nstep: 1\nsigma: 16\nitem_dropout: 0.1\nloss_type: 'CE'\ngnn_transform: sess_graph"
  },
  {
    "path": "recbole_gnn/properties/model/SEPT.yaml",
    "chars": 132,
    "preview": "warm_up_epochs: 100\nembedding_size: 64\nn_layers: 2\ndrop_ratio: 0.3\ninstance_cnt: 10\nreg_weight: 1e-05\nssl_weight: 1e-07\n"
  },
  {
    "path": "recbole_gnn/properties/model/SGL.yaml",
    "chars": 104,
    "preview": "type: \"ED\"\nn_layers: 3\nssl_tau: 0.5\nreg_weight: 1e-5\nssl_weight: 0.05\ndrop_ratio: 0.1\nembedding_size: 64"
  },
  {
    "path": "recbole_gnn/properties/model/SGNNHN.yaml",
    "chars": 79,
    "preview": "embedding_size: 64\nstep: 6\nscale: 12\nloss_type: 'CE'\ngnn_transform: sess_graph\n"
  },
  {
    "path": "recbole_gnn/properties/model/SRGNN.yaml",
    "chars": 68,
    "preview": "embedding_size: 64\nstep: 1\nloss_type: 'CE'\ngnn_transform: sess_graph"
  },
  {
    "path": "recbole_gnn/properties/model/SSL4REC.yaml",
    "chars": 97,
    "preview": "embedding_size: 64\ndrop_ratio: 0.1\ntau: 0.1\nreg_weight: 1e-04\nssl_weight: 1e-05\nrequire_pow: True"
  },
  {
    "path": "recbole_gnn/properties/model/SimGCL.yaml",
    "chars": 87,
    "preview": "embedding_size: 64\nn_layers: 2\nreg_weight: 1e-4\n\nlambda: 0.5\neps: 0.1\ntemperature: 0.2\n"
  },
  {
    "path": "recbole_gnn/properties/model/TAGNN.yaml",
    "chars": 69,
    "preview": "embedding_size: 64\nstep: 1\nloss_type: 'CE'\ngnn_transform: sess_graph\n"
  },
  {
    "path": "recbole_gnn/properties/model/XSimGCL.yaml",
    "chars": 119,
    "preview": "embedding_size: 64\nn_layers: 2\nreg_weight: 0.0001\n\nlambda: 0.1\neps: 0.2\ntemperature: 0.2\nlayer_cl: 1\nrequire_pow: True\n"
  },
  {
    "path": "recbole_gnn/properties/quick_start_config/sequential_base.yaml",
    "chars": 25,
    "preview": "train_neg_sample_args: ~\n"
  },
  {
    "path": "recbole_gnn/properties/quick_start_config/social_base.yaml",
    "chars": 210,
    "preview": "NET_SOURCE_ID_FIELD: source_id\nNET_TARGET_ID_FIELD: target_id\n\nload_col: \n    inter: ['user_id', 'item_id', 'rating', 't"
  },
  {
    "path": "recbole_gnn/quick_start.py",
    "chars": 4201,
    "preview": "import logging\nfrom logging import getLogger\nfrom recbole.utils import init_logger, init_seed, set_color\n\nfrom recbole_g"
  },
  {
    "path": "recbole_gnn/trainer.py",
    "chars": 8817,
    "preview": "from time import time\nimport math\nfrom torch.nn.utils.clip_grad import clip_grad_norm_\nfrom tqdm import tqdm\nfrom recbol"
  },
  {
    "path": "recbole_gnn/utils.py",
    "chars": 7177,
    "preview": "import os\nimport pickle\nimport importlib\nfrom logging import getLogger\nfrom recbole.data.utils import load_split_dataloa"
  },
  {
    "path": "results/README.md",
    "chars": 186,
    "preview": "## General Model Results\n\n* [ml-1m](general/ml-1m.md)\n\n## Sequential Model Results\n\n* [diginetica](sequential/diginetica"
  },
  {
    "path": "results/general/ml-1m.md",
    "chars": 6490,
    "preview": "# Experimental Setting\n\n**Dataset:** [MovieLens-1M](https://grouplens.org/datasets/movielens/)\n\n**Filtering:** Remove in"
  },
  {
    "path": "results/sequential/diginetica.md",
    "chars": 4436,
    "preview": "# Experimental Setting\n\n**Dataset:** diginetica-not-merged\n\n**Filtering:** Remove users and items with less than 5 inter"
  },
  {
    "path": "results/social/lastfm.md",
    "chars": 3738,
    "preview": "# Experimental Setting\n\n**Dataset:** [LastFM](http://files.grouplens.org/datasets/hetrec2011/)\n\n> Note that datasets for"
  },
  {
    "path": "run_hyper.py",
    "chars": 1061,
    "preview": "import argparse\n\nfrom recbole.trainer import HyperTuning\nfrom recbole_gnn.quick_start import objective_function\n\n\ndef ma"
  },
  {
    "path": "run_recbole_gnn.py",
    "chars": 638,
    "preview": "import argparse\n\nfrom recbole_gnn.quick_start import run_recbole_gnn\n\n\nif __name__ == '__main__':\n    parser = argparse."
  },
  {
    "path": "run_test.sh",
    "chars": 82,
    "preview": "#!/bin/bash\n\n\npython -m pytest -v tests/test_model.py\necho \"model tests finished\"\n"
  },
  {
    "path": "tests/test_data/test/test.inter",
    "chars": 117305,
    "preview": "user_id:token\titem_id:token\trating:float\ttimestamp:float\n196\t242\t3\t881250949\n186\t302\t3\t891717742\n22\t377\t1\t878887116\n244\t"
  },
  {
    "path": "tests/test_data/test/test.net",
    "chars": 4419,
    "preview": "source_id:token\ttarget_id:token\n187\t100\n119\t40\n96\t119\n12\t52\n153\t131\n259\t232\n191\t307\n83\t150\n86\t255\n177\t4\n210\t192\n25\t323\n9"
  },
  {
    "path": "tests/test_model.py",
    "chars": 3942,
    "preview": "import os\nimport unittest\n\nfrom recbole_gnn.quick_start import objective_function\n\ncurrent_path = os.path.dirname(os.pat"
  },
  {
    "path": "tests/test_model.yaml",
    "chars": 917,
    "preview": "dataset: test\nepochs: 1\nstate: ERROR\ndata_path: tests/test_data/\n\n# Atomic File Format\nfield_separator: \"\\t\"\nseq_separat"
  }
]

About this extraction

This page contains the full source code of the RUCAIBox/RecBole-GNN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 74 files (361.6 KB), approximately 125.2k tokens, and a symbol index with 339 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!