Full Code of RUCAIBox/CRSLab for AI

main 649793891999 cached
253 files
590.2 KB
160.3k tokens
674 symbols
1 requests
Download .txt
Showing preview only (657K chars total). Download the full file or copy to clipboard to get everything.
Repository: RUCAIBox/CRSLab
Branch: main
Commit: 649793891999
Files: 253
Total size: 590.2 KB

Directory structure:
gitextract_o03n7yor/

├── .gitattributes
├── .gitignore
├── .readthedocs.yml
├── LICENSE
├── README.md
├── README_CN.md
├── config/
│   ├── conversation/
│   │   ├── gpt2/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   └── transformer/
│   │       ├── durecdial.yaml
│   │       ├── gorecdial.yaml
│   │       ├── inspired.yaml
│   │       ├── opendialkg.yaml
│   │       ├── redial.yaml
│   │       └── tgredial.yaml
│   ├── crs/
│   │   ├── inspired/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   ├── kbrd/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   ├── kgsf/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   ├── ntrd/
│   │   │   └── tgredial.yaml
│   │   ├── redial/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   └── tgredial/
│   │       ├── durecdial.yaml
│   │       ├── gorecdial.yaml
│   │       ├── inspired.yaml
│   │       ├── opendialkg.yaml
│   │       ├── redial.yaml
│   │       └── tgredial.yaml
│   ├── policy/
│   │   ├── conv_bert/
│   │   │   └── tgredial.yaml
│   │   ├── mgcg/
│   │   │   └── tgredial.yaml
│   │   ├── pmi/
│   │   │   └── tgredial.yaml
│   │   ├── profile_bert/
│   │   │   └── tgredial.yaml
│   │   └── topic_bert/
│   │       └── tgredial.yaml
│   └── recommendation/
│       ├── bert/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       ├── gru4rec/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       ├── popularity/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       ├── sasrec/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       └── textcnn/
│           ├── durecdial.yaml
│           ├── gorecdial.yaml
│           ├── inspired.yaml
│           ├── opendialkg.yaml
│           ├── redial.yaml
│           └── tgredial.yaml
├── crslab/
│   ├── __init__.py
│   ├── config/
│   │   ├── __init__.py
│   │   └── config.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataloader/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── inspired.py
│   │   │   ├── kbrd.py
│   │   │   ├── kgsf.py
│   │   │   ├── ntrd.py
│   │   │   ├── redial.py
│   │   │   ├── tgredial.py
│   │   │   └── utils.py
│   │   └── dataset/
│   │       ├── __init__.py
│   │       ├── base.py
│   │       ├── durecdial/
│   │       │   ├── __init__.py
│   │       │   ├── durecdial.py
│   │       │   └── resources.py
│   │       ├── gorecdial/
│   │       │   ├── __init__.py
│   │       │   ├── gorecdial.py
│   │       │   └── resources.py
│   │       ├── inspired/
│   │       │   ├── __init__.py
│   │       │   ├── inspired.py
│   │       │   └── resources.py
│   │       ├── opendialkg/
│   │       │   ├── __init__.py
│   │       │   ├── opendialkg.py
│   │       │   └── resources.py
│   │       ├── redial/
│   │       │   ├── __init__.py
│   │       │   ├── redial.py
│   │       │   └── resources.py
│   │       └── tgredial/
│   │           ├── __init__.py
│   │           ├── resources.py
│   │           └── tgredial.py
│   ├── download.py
│   ├── evaluator/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── conv.py
│   │   ├── embeddings.py
│   │   ├── end2end.py
│   │   ├── metrics/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── gen.py
│   │   │   └── rec.py
│   │   ├── rec.py
│   │   ├── standard.py
│   │   └── utils.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── conversation/
│   │   │   ├── __init__.py
│   │   │   ├── gpt2/
│   │   │   │   ├── __init__.py
│   │   │   │   └── gpt2.py
│   │   │   └── transformer/
│   │   │       ├── __init__.py
│   │   │       └── transformer.py
│   │   ├── crs/
│   │   │   ├── __init__.py
│   │   │   ├── inspired/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── inspired_conv.py
│   │   │   │   ├── inspired_rec.py
│   │   │   │   └── modules.py
│   │   │   ├── kbrd/
│   │   │   │   ├── __init__.py
│   │   │   │   └── kbrd.py
│   │   │   ├── kgsf/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── kgsf.py
│   │   │   │   ├── modules.py
│   │   │   │   └── resources.py
│   │   │   ├── ntrd/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── modules.py
│   │   │   │   ├── ntrd.py
│   │   │   │   └── resources.py
│   │   │   ├── redial/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── modules.py
│   │   │   │   ├── redial_conv.py
│   │   │   │   └── redial_rec.py
│   │   │   └── tgredial/
│   │   │       ├── __init__.py
│   │   │       ├── tg_conv.py
│   │   │       ├── tg_policy.py
│   │   │       └── tg_rec.py
│   │   ├── policy/
│   │   │   ├── __init__.py
│   │   │   ├── conv_bert/
│   │   │   │   ├── __init__.py
│   │   │   │   └── conv_bert.py
│   │   │   ├── mgcg/
│   │   │   │   ├── __init__.py
│   │   │   │   └── mgcg.py
│   │   │   ├── pmi/
│   │   │   │   ├── __init__.py
│   │   │   │   └── pmi.py
│   │   │   ├── profile_bert/
│   │   │   │   ├── __init__.py
│   │   │   │   └── profile_bert.py
│   │   │   └── topic_bert/
│   │   │       ├── __init__.py
│   │   │       └── topic_bert.py
│   │   ├── pretrained_models.py
│   │   ├── recommendation/
│   │   │   ├── __init__.py
│   │   │   ├── bert/
│   │   │   │   ├── __init__.py
│   │   │   │   └── bert.py
│   │   │   ├── gru4rec/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── gru4rec.py
│   │   │   │   └── modules.py
│   │   │   ├── popularity/
│   │   │   │   ├── __init__.py
│   │   │   │   └── popularity.py
│   │   │   ├── sasrec/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── modules.py
│   │   │   │   └── sasrec.py
│   │   │   └── textcnn/
│   │   │       ├── __init__.py
│   │   │       └── textcnn.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── functions.py
│   │       └── modules/
│   │           ├── __init__.py
│   │           ├── attention.py
│   │           └── transformer.py
│   ├── quick_start/
│   │   ├── __init__.py
│   │   └── quick_start.py
│   └── system/
│       ├── __init__.py
│       ├── base.py
│       ├── inspired.py
│       ├── kbrd.py
│       ├── kgsf.py
│       ├── ntrd.py
│       ├── redial.py
│       ├── tgredial.py
│       └── utils/
│           ├── __init__.py
│           ├── functions.py
│           └── lr_scheduler.py
├── docs/
│   ├── Makefile
│   ├── make.bat
│   ├── requirements.txt
│   ├── requirements_geometric.txt
│   ├── requirements_sphinx.txt
│   ├── requirements_torch.txt
│   └── source/
│       ├── api/
│       │   ├── crslab.config.rst
│       │   ├── crslab.data.dataloader.rst
│       │   ├── crslab.data.dataset.durecdial.rst
│       │   ├── crslab.data.dataset.gorecdial.rst
│       │   ├── crslab.data.dataset.inspired.rst
│       │   ├── crslab.data.dataset.opendialkg.rst
│       │   ├── crslab.data.dataset.redial.rst
│       │   ├── crslab.data.dataset.rst
│       │   ├── crslab.data.dataset.tgredial.rst
│       │   ├── crslab.data.rst
│       │   ├── crslab.evaluator.metrics.rst
│       │   ├── crslab.evaluator.rst
│       │   ├── crslab.model.conversation.gpt2.rst
│       │   ├── crslab.model.conversation.rst
│       │   ├── crslab.model.conversation.transformer.rst
│       │   ├── crslab.model.crs.kbrd.rst
│       │   ├── crslab.model.crs.kgsf.rst
│       │   ├── crslab.model.crs.redial.rst
│       │   ├── crslab.model.crs.rst
│       │   ├── crslab.model.crs.tgredial.rst
│       │   ├── crslab.model.policy.conv_bert.rst
│       │   ├── crslab.model.policy.mgcg.rst
│       │   ├── crslab.model.policy.pmi.rst
│       │   ├── crslab.model.policy.profile_bert.rst
│       │   ├── crslab.model.policy.rst
│       │   ├── crslab.model.policy.topic_bert.rst
│       │   ├── crslab.model.recommendation.bert.rst
│       │   ├── crslab.model.recommendation.gru4rec.rst
│       │   ├── crslab.model.recommendation.popularity.rst
│       │   ├── crslab.model.recommendation.rst
│       │   ├── crslab.model.recommendation.sasrec.rst
│       │   ├── crslab.model.recommendation.textcnn.rst
│       │   ├── crslab.model.rst
│       │   ├── crslab.model.utils.modules.rst
│       │   ├── crslab.model.utils.rst
│       │   ├── crslab.quick_start.rst
│       │   ├── crslab.rst
│       │   ├── crslab.system.rst
│       │   └── modules.rst
│       ├── conf.py
│       └── index.md
├── requirements.txt
├── run_crslab.py
└── setup.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
* text=auto eol=lf
*.{cmd,[cC][mM][dD]} text eol=crlf
*.{bat,[bB][aA][tT]} text eol=crlf

================================================
FILE: .gitignore
================================================
# Created by .ignore support plugin (hsz.mobi)
### Project

data
log
save
!crslab/data
runs

### VisualStudioCode template
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace

# Local History for Visual Studio Code
.history/

### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf

# Generated files
.idea/**/contentModel.xml

# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml

# Gradle
.idea/**/gradle.xml
.idea/**/libraries

# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn.  Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr

# CMake
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
.idea
*.iml
out
gen

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

### JupyterNotebooks template
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/

*/.ipynb_checkpoints/*

# Remove previous ipynb_checkpoints
#   git rm -r .ipynb_checkpoints/

### macOS template
# General
.DS_Store
.AppleDouble
.LSOverride

# Icon must end with two \r
Icon

# Thumbnails
._*

# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk


================================================
FILE: .readthedocs.yml
================================================
# Required
version: 2

# Build documentation in the docs/ directory with Sphinx
sphinx:
  configuration: docs/source/conf.py

# Build documentation with MkDocs
#mkdocs:
#  configuration: mkdocs.yml

# Optionally build your docs in additional formats such as PDF
formats: all

# Optionally set the version of Python and requirements required to build your docs
python:
  version: 3.6
  install:
    - requirements: docs/requirements_torch.txt
    - requirements: docs/requirements_geometric.txt
    - requirements: docs/requirements.txt
    - requirements: docs/requirements_sphinx.txt

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 RUCAIBox

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# CRSLab

[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab)
[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE)
[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939)
[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest)

[Paper](https://arxiv.org/pdf/2101.00939.pdf) | [Docs](https://crslab.readthedocs.io/en/latest/?badge=latest)
| [中文版](./README_CN.md)

**CRSLab** is an open-source toolkit for building Conversational Recommender System (CRS). It is developed based on
Python and PyTorch. CRSLab has the following highlights:

- **Comprehensive benchmark models and datasets**: We have integrated commonly-used 6 datasets and 18 models, including graph neural network and pre-training models such as R-GCN, BERT and GPT-2. We have preprocessed these datasets to support these models, and release for downloading.
- **Extensive and standard evaluation protocols**: We support a series of widely-adopted evaluation protocols for testing and comparing different CRS.
- **General and extensible structure**: We design a general and extensible structure to unify various conversational recommendation datasets and models, in which we integrate various built-in interfaces and functions for quickly development.
- **Easy to get started**: We provide simple yet flexible configuration for new researchers to quickly start in our library. 
- **Human-machine interaction interfaces**: We provide flexible human-machine interaction interfaces for researchers to conduct qualitative analysis.

<p align="center">
  <img src="https://i.loli.net/2020/12/30/6TPVG4pBg2rcDf9.png" alt="RecBole v0.1 architecture" width="400">
  <br>
  <b>Figure 1</b>: The overall framework of CRSLab
</p>




- [Installation](#Installation)
- [Quick-Start](#Quick-Start)
- [Models](#Models)
- [Datasets](#Datasets)
- [Performance](#Performance)
- [Releases](#Releases)
- [Contributions](#Contributions)
- [Citing](#Citing)
- [Team](#Team)
- [License](#License)



## Installation

CRSLab works with the following operating systems:

- Linux
- Windows 10
- macOS X

CRSLab requires Python version 3.7 or later.

CRSLab requires torch version 1.8. If you want to use CRSLab with GPU, please ensure that CUDA or CUDAToolkit version is 10.2 or later. Please use the combinations shown in this [Link](https://pytorch-geometric.com/whl/) to ensure the normal operation of PyTorch Geometric.



### Install PyTorch

Use PyTorch [Locally Installation](https://pytorch.org/get-started/locally/) or [Previous Versions Installation](https://pytorch.org/get-started/previous-versions/) commands to install PyTorch. For example, on Linux and Windows 10:

```bash
# CUDA 10.2
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch

# CUDA 11.1
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge

# CPU Only
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cpuonly -c pytorch
```

If you want to use CRSLab with GPU, make sure the following command prints `True` after installation:

```bash
$ python -c "import torch; print(torch.cuda.is_available())"
>>> True
```



### Install PyTorch Geometric

Ensure that at least PyTorch 1.8.0 is installed:

```bash
$ python -c "import torch; print(torch.__version__)"
>>> 1.8.0
```

Find the CUDA version PyTorch was installed with:

```bash
$ python -c "import torch; print(torch.version.cuda)"
>>> 11.1
```

For Linux:

Install the relevant packages:

```
conda install pyg -c pyg
```

For others:

Check PyG [installation documents](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) to install the relevant packages.



### Install CRSLab

You can install from pip:

```bash
pip install crslab
```

OR install from source:

```bash
git clone https://github.com/RUCAIBox/CRSLab && cd CRSLab
pip install -e .
```



## Quick-Start

With the source code, you can use the provided script for initial usage of our library with cpu by default:

```bash
python run_crslab.py --config config/crs/kgsf/redial.yaml
```

The system will complete the data preprocessing, and training, validation, testing of each model in turn. Finally it will get the evaluation results of specified models.

If you want to save pre-processed datasets and training results of models, you can use the following command:

```bash
python run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system
```

In summary, there are following arguments in `run_crslab.py`:

- `--config` or `-c`: relative path for configuration file(yaml).
- `--gpu` or `-g`: specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1).
- `--save_data` or `-sd`: save pre-processed dataset.
- `--restore_data` or `-rd`: restore pre-processed dataset from file.
- `--save_system` or `-ss`: save trained system.
- `--restore_system` or `-rs`: restore trained system from file.
- `--debug` or `-d`: use validation dataset to debug your system.
- `--interact` or `-i`: interact with your system instead of training.
- `--tensorboard` or `-tb`: enable tensorboard to monitor train performance.



## Models

In CRSLab, we unify the task description of conversational recommendation into three sub-tasks, namely recommendation (recommend user-preferred items), conversation (generate proper responses) and policy (select proper interactive action). The recommendation and conversation sub-tasks are the core of a CRS and have been studied in most of works. The policy sub-task is needed by recent works, by which the CRS can interact with users through purposeful strategy.
As the first release version, we have implemented 18 models in the four categories of CRS model, Recommendation model, Conversation model and Policy model.

|       Category       |                            Model                             |      Graph Neural Network?      |       Pre-training Model?       |
| :------------------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: |
|      CRS Model       | [ReDial](https://arxiv.org/abs/1812.07617)<br/>[KBRD](https://arxiv.org/abs/1908.05391)<br/>[KGSF](https://arxiv.org/abs/2007.04032)<br/>[TG-ReDial](https://arxiv.org/abs/2010.04125)<br/>[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) |       ×<br/>√<br/>√<br/>×<br/>×       |       ×<br/>×<br/>×<br/>√<br/>√       |
| Recommendation model | Popularity<br/>[GRU4Rec](https://arxiv.org/abs/1511.06939)<br/>[SASRec](https://arxiv.org/abs/1808.09781)<br/>[TextCNN](https://arxiv.org/abs/1408.5882)<br/>[R-GCN](https://arxiv.org/abs/1703.06103)<br/>[BERT](https://arxiv.org/abs/1810.04805) | ×<br/>×<br/>×<br/>×<br/>√<br/>× | ×<br/>×<br/>×<br/>×<br/>×<br/>√ |
|  Conversation model  | [HERD](https://arxiv.org/abs/1507.04808)<br/>[Transformer](https://arxiv.org/abs/1706.03762)<br/>[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) |          ×<br/>×<br/>×          |          ×<br/>×<br/>√          |
|     Policy model     | PMI<br/>[MGCG](https://arxiv.org/abs/2005.03954)<br/>[Conv-BERT](https://arxiv.org/abs/2010.04125)<br/>[Topic-BERT](https://arxiv.org/abs/2010.04125)<br/>[Profile-BERT](https://arxiv.org/abs/2010.04125) |    ×<br/>×<br/>×<br/>×<br/>×    |    ×<br/>×<br/>√<br/>√<br/>√    |

Among them, the four CRS models integrate the recommendation model and the conversation model to improve each other, while others only specify an individual task.

For Recommendation model and Conversation model, we have respectively implemented the following commonly-used automatic evaluation metrics:

|        Category        |                           Metrics                            |
| :--------------------: | :----------------------------------------------------------: |
| Recommendation Metrics |      Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50}      |
|  Conversation Metrics  | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} |
|     Policy Metrics     |        Accuracy, Hit@{1,3,5}           |



## Datasets

We have collected and preprocessed 6 commonly-used human-annotated datasets, and each dataset was matched with proper KGs as shown below:

|                           Dataset                            | Dialogs | Utterances |   Domains    | Task Definition | Entity KG  |  Word KG   |
| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: |
|       [ReDial](https://redialdata.github.io/website/)        | 10,006  |  182,150   |    Movie     |       --        |  DBpedia   | ConceptNet |
|      [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial)      | 10,000  |  129,392   |    Movie     |   Topic Guide   | CN-DBpedia |   HowNet   |
|        [GoRecDial](https://arxiv.org/abs/1909.03922)         |  9,125  |  170,904   |    Movie     |  Action Choice  |  DBpedia   | ConceptNet |
|        [DuRecDial](https://arxiv.org/abs/2005.03954)         | 10,200  |  156,000   | Movie, Music |    Goal Plan    | CN-DBpedia |   HowNet   |
|      [INSPIRED](https://github.com/sweetpeach/Inspired)      |  1,001  |   35,811   |    Movie     | Social Strategy |  DBpedia   | ConceptNet |
| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802  |   91,209   | Movie, Book  |  Path Generate  |  DBpedia   | ConceptNet |



## Performance

We have trained and test the integrated models on the TG-Redial dataset, which is split into training, validation and test sets using a ratio of 8:1:1. For each conversation, we start from the first utterance, and generate reply utterances or recommendations in turn by our model. We perform the evaluation on the three sub-tasks.

### Recommendation Task

|   Model   |    Hit@1    |   Hit@10   |   Hit@50   |    MRR@1    |   MRR@10   |   MRR@50   |   NDCG@1    |  NDCG@10   |  NDCG@50   |
| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: |
|  SASRec   |  0.000446   |  0.00134   |   0.0160   |   0.000446  |  0.000576  |  0.00114   |  0.000445   |  0.00075   |  0.00380   |
|  TextCNN  |   0.00267   |   0.0103   |   0.0236   |   0.00267   |  0.00434   |  0.00493   |   0.00267   |  0.00570   |  0.00860   |
|   BERT    |   0.00722   |  0.00490   |   0.0281   |   0.00722   |   0.0106   |   0.0124   |   0.00490   |   0.0147   |   0.0239   |
|   KBRD    |   0.00401   |   0.0254   |   0.0588   |   0.00401   |  0.00891   |   0.0103   |   0.00401   |   0.0127   |   0.0198   |
|   KGSF    |   0.00535   | **0.0285** | **0.0771** |   0.00535   |   0.0114   | **0.0135** |   0.00535   | **0.0154** | **0.0259** |
| TG-ReDial | **0.00793** |   0.0251   |   0.0524   | **0.00793** | **0.0122** |   0.0134   | **0.00793** |   0.0152   |   0.0211   |


### Conversation Task

|    Model    |  BLEU@1   |  BLEU@2   |   BLEU@3   |   BLEU@4   |  Dist@1  |  Dist@2  |  Dist@3  |  Dist@4  |  Average  |  Extreme  |  Greedy   |   PPL    |
| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: |
|    HERD     |   0.120   |  0.0141   |  0.00136   |  0.000350  |  0.181   |  0.369   |  0.847   |   1.30   |   0.697   |   0.382   |   0.639   |   472    |
| Transformer |   0.266   |  0.0440   |   0.0145   |  0.00651   |  0.324   |  0.837   |   2.02   |   3.06   |   0.879   |   0.438   |   0.680   |   30.9   |
|    GPT2     |  0.0858   |  0.0119   |  0.00377   |   0.0110   | **2.35** | **4.62** | **8.84** | **12.5** |   0.763   |   0.297   |   0.583   |   9.26   |
|    KBRD     |   0.267   |  0.0458   |   0.0134   |  0.00579   |  0.469   |   1.50   |   3.40   |   4.90   |   0.863   |   0.398   |   0.710   |   52.5   |
|    KGSF     | **0.383** | **0.115** | **0.0444** | **0.0200** |  0.340   |  0.910   |   3.50   |   6.20   | **0.888** | **0.477** | **0.767** |   50.1   |
|  TG-ReDial  |   0.125   |  0.0204   |  0.00354   |  0.000803  |  0.881   |   1.75   |   7.00   |   12.0   |   0.810   |   0.332   |   0.598   | **7.41** |


### Policy Task

|   Model    |   Hit@1   |  Hit@10   |  Hit@50   |   MRR@1   |  MRR@10   |  MRR@50   |  NDCG@1   |  NDCG@10  |  NDCG@50  |
| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
|    MGCG    |   0.591   |   0.818   |   0.883   |   0.591   |   0.680   |   0.683   |   0.591   |   0.712   |   0.729   |
| Conv-BERT  |   0.597   |   0.814   |   0.881   |   0.597   |   0.684   |   0.687   |   0.597   |   0.716   |   0.731   |
| Topic-BERT |   0.598   |   0.828   |   0.885   |   0.598   |   0.690   |   0.693   |   0.598   |   0.724   |   0.737   |
| TG-ReDial  | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** |

The above results were obtained from our CRSLab in preliminary experiments. However, these algorithms were implemented and tuned based on our understanding and experiences, which may not achieve their optimal performance. If you could yield a better result for some specific algorithm, please kindly let us know. We will update this table after the results are verified.

## Releases

| Releases |     Date      |   Features   |
| :------: | :-----------: | :----------: |
|  v0.1.1  | 1 / 4 / 2021  | Basic CRSLab |
|  v0.1.2  | 3 / 28 / 2021 |    CRSLab    |



## Contributions

Please let us know if you encounter a bug or have any suggestions by [filing an issue](https://github.com/RUCAIBox/CRSLab/issues).

We welcome all contributions from bug fixes to new features and extensions.

We expect all contributions discussed in the issue tracker and going through PRs.

We thank the nice contributions through PRs from [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang).



## Citing

If you find CRSLab useful for your research or development, please cite our [Paper](https://arxiv.org/pdf/2101.00939.pdf):

```
@article{crslab,
    title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System},
    author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen},
    year={2021},
    journal={arXiv preprint arXiv:2101.00939}
}
```



## Team

**CRSLab** was developed and maintained by [AI Box](http://aibox.ruc.edu.cn/) group in RUC.



## License

**CRSLab** uses [MIT License](./LICENSE).



================================================
FILE: README_CN.md
================================================
# CRSLab

[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab)
[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE)
[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939)
[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest)

[论文](https://arxiv.org/pdf/2101.00939.pdf) | [文档](https://crslab.readthedocs.io/en/latest/?badge=latest)
| [English Version](./README.md)

**CRSLab** 是一个用于构建对话推荐系统(CRS)的开源工具包,其基于 PyTorch 实现、主要面向研究者使用,并具有如下特色:

- **全面的基准模型和数据集**:我们集成了常用的 6 个数据集和 18 个模型,包括基于图神经网络和预训练模型,比如  GCN,BERT 和 GPT-2;我们还对数据集进行相关处理以支持这些模型,并提供预处理后的版本供大家下载。
- **大规模的标准评测**:我们支持一系列被广泛认可的评估方式来测试和比较不同的 CRS。
- **通用和可扩展的结构**:我们设计了通用和可扩展的结构来统一各种对话推荐数据集和模型,并集成了多种内置接口和函数以便于快速开发。
- **便捷的使用方法**:我们为新手提供了简单而灵活的配置,方便其快速启动集成在 CRSLab 中的模型。
- **人性化的人机交互接口**:我们提供了人性化的人机交互界面,以供研究者对比和测试不同的模型系统。

<p align="center">
  <img src="https://i.loli.net/2020/12/30/6TPVG4pBg2rcDf9.png" alt="RecBole v0.1 architecture" width="400">
  <br>
  <b>图片</b>: CRSLab 的总体架构
</p>




- [安装](#安装)
- [快速上手](#快速上手)
- [模型](#模型)
- [数据集](#数据集)
- [评测结果](#评测结果)
- [发行版本](#发行版本)
- [贡献](#贡献)
- [引用](#引用)
- [项目团队](#项目团队)
- [免责声明](#免责声明)



## 安装

CRSLab 可以在以下几种系统上运行:

- Linux
- Windows 10
- macOS X

CRSLab 需要在 Python 3.7 或更高的环境下运行。

CRSLab 要求 torch 版本为1.8,如果你想在 GPU 上运行 CRSLab,请确保你的 CUDA 版本或者 CUDAToolkit 版本在 10.2 及以上。为保证 PyTorch Geometric 库的正常运行,请使用[链接](https://pytorch-geometric.com/whl/)所示的安装方式。



### 安装 PyTorch

使用 PyTorch [本地安装](https://pytorch.org/get-started/locally/)命令或者[先前版本安装](https://pytorch.org/get-started/previous-versions/)命令安装 PyTorch,比如在 Linux 和 Windows 下:

```bash
# CUDA 10.2
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch

# CUDA 11.1
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge

# CPU Only
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cpuonly -c pytorch
```

安装完成后,如果你想在 GPU 上运行 CRSLab,请确保如下命令输出`True`:

```bash
$ python -c "import torch; print(torch.cuda.is_available())"
>>> True
```



### 安装 PyTorch Geometric

确保安装的 PyTorch 版本至少为 1.8.0:

```bash
$ python -c "import torch; print(torch.__version__)"
>>> 1.8.0
```

找到安装好的 PyTorch 对应的 CUDA 版本:

```bash
$ python -c "import torch; print(torch.version.cuda)"
>>> 11.1
```

在Linux下:

安装相关的包:

```bash
conda install pyg -c pyg
```

在其他系统下:

查看PyG[官方下载文档](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)安装相关的包。

### 安装 CRSLab

你可以通过 pip 来安装:

```bash
pip install crslab
```

也可以通过源文件进行进行安装:

```bash
git clone https://github.com/RUCAIBox/CRSLab && cd CRSLab
pip install -e .
```



## 快速上手

从 GitHub 下载 CRSLab 后,可以使用提供的脚本快速运行和测试,默认使用CPU:

```bash
python run_crslab.py --config config/crs/kgsf/redial.yaml
```

系统将依次完成数据的预处理,以及各模块的训练、验证和测试,并得到指定的模型评测结果。

如果你希望保存数据预处理结果与模型训练结果,可以使用如下命令:

```bash
python run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system
```

总的来说,`run_crslab.py`有如下参数可供调用:

- `--config` 或 `-c`:配置文件的相对路径,以指定运行的模型与数据集。
- `--gpu` or `-g`:指定 GPU id,支持多 GPU,默认使用 CPU(-1)。
- `--save_data` 或 `-sd`:保存预处理的数据。
- `--restore_data` 或 `-rd`:从文件读取预处理的数据。
- `--save_system` 或 `-ss`:保存训练好的 CRS 系统。
- `--restore_system` 或 `-rs`:从文件载入提前训练好的系统。
- `--debug` 或 `-d`:用验证集代替训练集以方便调试。
- `--interact` 或 `-i`:与你的系统进行对话交互,而非进行训练。
- `--tensorboard` or `-tb`:使用 tensorboardX 组件来监测训练表现。



## 模型

在第一个发行版中,我们实现了 4 类共 18 个模型。这里我们将对话推荐任务主要拆分成三个任务:推荐任务(生成推荐的商品),对话任务(生成对话的回复)和策略任务(规划对话推荐的策略)。其中所有的对话推荐系统都具有对话和推荐任务,他们是对话推荐系统的核心功能。而策略任务是一个辅助任务,其致力于更好的控制对话推荐系统,在不同的模型中的实现也可能不同(如 TG-ReDial 采用一个主题预测模型,DuRecDial 中采用一个对话规划模型等):



|   类别   |                             模型                             |      Graph Neural Network?      |       Pre-training Model?       |
| :------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: |
| CRS 模型 | [ReDial](https://arxiv.org/abs/1812.07617)<br/>[KBRD](https://arxiv.org/abs/1908.05391)<br/>[KGSF](https://arxiv.org/abs/2007.04032)<br/>[TG-ReDial](https://arxiv.org/abs/2010.04125)<br/>[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) |    ×<br/>√<br/>√<br/>×<br/>×    |    ×<br/>×<br/>×<br/>√<br/>√    |
| 推荐模型 | Popularity<br/>[GRU4Rec](https://arxiv.org/abs/1511.06939)<br/>[SASRec](https://arxiv.org/abs/1808.09781)<br/>[TextCNN](https://arxiv.org/abs/1408.5882)<br/>[R-GCN](https://arxiv.org/abs/1703.06103)<br/>[BERT](https://arxiv.org/abs/1810.04805) | ×<br/>×<br/>×<br/>×<br/>√<br/>× | ×<br/>×<br/>×<br/>×<br/>×<br/>√ |
| 对话模型 | [HERD](https://arxiv.org/abs/1507.04808)<br/>[Transformer](https://arxiv.org/abs/1706.03762)<br/>[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) |          ×<br/>×<br/>×          |          ×<br/>×<br/>√          |
| 策略模型 | PMI<br/>[MGCG](https://arxiv.org/abs/2005.03954)<br/>[Conv-BERT](https://arxiv.org/abs/2010.04125)<br/>[Topic-BERT](https://arxiv.org/abs/2010.04125)<br/>[Profile-BERT](https://arxiv.org/abs/2010.04125) |    ×<br/>×<br/>×<br/>×<br/>×    |    ×<br/>×<br/>√<br/>√<br/>√    |


其中,CRS 模型是指直接融合推荐模型和对话模型,以相互增强彼此的效果,故其内部往往已经包含了推荐、对话和策略模型。其他如推荐模型、对话模型、策略模型往往只关注以上任务中的某一个。

我们对于这几类模型,我们还分别实现了如下的自动评测指标模块:

|   类别   |                             指标                             |
| :------: | :----------------------------------------------------------: |
| 推荐指标 |      Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50}      |
| 对话指标 | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} |
| 策略指标 | Accuracy, Hit@{1,3,5} |





## 数据集

我们收集了 6 个常用的人工标注数据集,并对它们进行了预处理(包括引入外部知识图谱),以融入统一的 CRS 任务中。如下为相关数据集的统计数据:

|                           Dataset                            | Dialogs | Utterances |   Domains    | Task Definition | Entity KG  |  Word KG   |
| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: |
|       [ReDial](https://redialdata.github.io/website/)        | 10,006  |  182,150   |    Movie     |       --        |  DBpedia   | ConceptNet |
|      [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial)      | 10,000  |  129,392   |    Movie     |   Topic Guide   | CN-DBpedia |   HowNet   |
|        [GoRecDial](https://arxiv.org/abs/1909.03922)         |  9,125  |  170,904   |    Movie     |  Action Choice  |  DBpedia   | ConceptNet |
|        [DuRecDial](https://arxiv.org/abs/2005.03954)         | 10,200  |  156,000   | Movie, Music |    Goal Plan    | CN-DBpedia |   HowNet   |
|      [INSPIRED](https://github.com/sweetpeach/Inspired)      |  1,001  |   35,811   |    Movie     | Social Strategy |  DBpedia   | ConceptNet |
| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802  |   91,209   | Movie, Book  |  Path Generate  |  DBpedia   | ConceptNet |



## 评测结果

我们在 TG-ReDial 数据集上对模型进行了训练和测试,这里我们将数据集按照 8:1:1 切分。其中对于每条数据,我们从对话的第一轮开始,一轮一轮的进行推荐、策略生成、回复生成任务。下表记录了相关的评测结果。

### 推荐任务

|   模型    |    Hit@1    |   Hit@10   |   Hit@50   |    MRR@1    |   MRR@10   |   MRR@50   |   NDCG@1    |  NDCG@10   |  NDCG@50   |
| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: |
|  SASRec   |  0.000446   |  0.00134   |   0.0160   |  0.000446   |  0.000576  |  0.00114   |  0.000445   |  0.00075   |  0.00380   |
|  TextCNN  |   0.00267   |   0.0103   |   0.0236   |   0.00267   |  0.00434   |  0.00493   |   0.00267   |  0.00570   |  0.00860   |
|   BERT    |   0.00722   |  0.00490   |   0.0281   |   0.00722   |   0.0106   |   0.0124   |   0.00490   |   0.0147   |   0.0239   |
|   KBRD    |   0.00401   |   0.0254   |   0.0588   |   0.00401   |  0.00891   |   0.0103   |   0.00401   |   0.0127   |   0.0198   |
|   KGSF    |   0.00535   | **0.0285** | **0.0771** |   0.00535   |   0.0114   | **0.0135** |   0.00535   | **0.0154** | **0.0259** |
| TG-ReDial | **0.00793** |   0.0251   |   0.0524   | **0.00793** | **0.0122** |   0.0134   | **0.00793** |   0.0152   |   0.0211   |



### 对话任务

|    模型     |  BLEU@1   |  BLEU@2   |   BLEU@3   |   BLEU@4   |  Dist@1  |  Dist@2  |  Dist@3  |  Dist@4  |  Average  |  Extreme  |  Greedy   |   PPL    |
| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: |
|    HERD     |   0.120   |  0.0141   |  0.00136   |  0.000350  |  0.181   |  0.369   |  0.847   |   1.30   |   0.697   |   0.382   |   0.639   |   472    |
| Transformer |   0.266   |  0.0440   |   0.0145   |  0.00651   |  0.324   |  0.837   |   2.02   |   3.06   |   0.879   |   0.438   |   0.680   |   30.9   |
|    GPT2     |  0.0858   |  0.0119   |  0.00377   |   0.0110   | **2.35** | **4.62** | **8.84** | **12.5** |   0.763   |   0.297   |   0.583   |   9.26   |
|    KBRD     |   0.267   |  0.0458   |   0.0134   |  0.00579   |  0.469   |   1.50   |   3.40   |   4.90   |   0.863   |   0.398   |   0.710   |   52.5   |
|    KGSF     | **0.383** | **0.115** | **0.0444** | **0.0200** |  0.340   |  0.910   |   3.50   |   6.20   | **0.888** | **0.477** | **0.767** |   50.1   |
|  TG-ReDial  |   0.125   |  0.0204   |  0.00354   |  0.000803  |  0.881   |   1.75   |   7.00   |   12.0   |   0.810   |   0.332   |   0.598   | **7.41** |



### 策略任务

|    模型    |   Hit@1   |  Hit@10   |  Hit@50   |   MRR@1   |  MRR@10   |  MRR@50   |  NDCG@1   |  NDCG@10  |  NDCG@50  |
| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
|    MGCG    |   0.591   |   0.818   |   0.883   |   0.591   |   0.680   |   0.683   |   0.591   |   0.712   |   0.729   |
| Conv-BERT  |   0.597   |   0.814   |   0.881   |   0.597   |   0.684   |   0.687   |   0.597   |   0.716   |   0.731   |
| Topic-BERT |   0.598   |   0.828   |   0.885   |   0.598   |   0.690   |   0.693   |   0.598   |   0.724   |   0.737   |
| TG-ReDial  | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** |

上述结果是我们使用 CRSLab 进行实验得到的。然而,这些算法是根据我们的经验和理解来实现和调参的,可能还没有达到它们的最佳性能。如果您能在某个具体算法上得到更好的结果,请告知我们。验证结果后,我们会更新该表。

## 发行版本

| 版本号 |   发行日期    |     特性     |
| :----: | :-----------: | :----------: |
| v0.1.1 | 1 / 4 / 2021  | Basic CRSLab |
| v0.1.2 | 3 / 28 / 2021 |    CRSLab    |



## 贡献

如果您遇到错误或有任何建议,请通过 [Issue](https://github.com/RUCAIBox/CRSLab/issues) 进行反馈

我们欢迎关于修复错误、添加新特性的任何贡献。

如果想贡献代码,请先在 Issue 中提出问题,然后再提 PR。

我们感谢 [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang) 通过 PR 为项目贡献的新特性。



## 引用

如果你觉得 CRSLab 对你的科研工作有帮助,请引用我们的[论文](https://arxiv.org/pdf/2101.00939.pdf):

```
@article{crslab,
    title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System},
    author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen},
    year={2021},
    journal={arXiv preprint arXiv:2101.00939}
}
```



## 项目团队

**CRSLab** 由中国人民大学 [AI Box](http://aibox.ruc.edu.cn/) 小组开发和维护。



## 免责声明

**CRSLab** 基于 [MIT License](./LICENSE) 进行开发,本项目的所有数据和代码只能被用于学术目的。


================================================
FILE: config/conversation/gpt2/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
conv_model: GPT2
# optim
conv:
  epoch: 1
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000


================================================
FILE: config/conversation/gpt2/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
conv_model: GPT2
# optim
conv:
  epoch: 1
  batch_size: 4
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000


================================================
FILE: config/conversation/gpt2/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
conv_model: GPT2
# optim
conv:
  epoch: 1
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000


================================================
FILE: config/conversation/gpt2/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
conv_model: GPT2
# optim
conv:
  epoch: 1
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000


================================================
FILE: config/conversation/gpt2/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
conv_model: GPT2
# optim
conv:
  epoch: 1
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000


================================================
FILE: config/conversation/gpt2/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
conv_model: GPT2
# optim
conv:
  epoch: 50
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  early_stop: true
  stop_mode: min
  impatience: 3
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000

================================================
FILE: config/conversation/transformer/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  conv: jieba
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
conv_model: Transformer
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
conv:
  epoch: 1
  batch_size: 64
  early_stop: True
  stop_mode: min
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5

================================================
FILE: config/conversation/transformer/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  conv: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
conv_model: Transformer
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
conv:
  epoch: 1
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1
  early_stop: true
  stop_mode: min
  impatience: 3

================================================
FILE: config/conversation/transformer/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  conv: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
conv_model: Transformer
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
conv:
  epoch: 1
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1
  early_stop: true
  stop_mode: min
  impatience: 3

================================================
FILE: config/conversation/transformer/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  conv: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
conv_model: Transformer
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
conv:
  epoch: 1
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1
  early_stop: true
  stop_mode: min
  impatience: 3

================================================
FILE: config/conversation/transformer/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  conv: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
conv_model: Transformer
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
conv:
  epoch: 1
  batch_size: 64
  early_stop: True
  stop_mode: min
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5

================================================
FILE: config/conversation/transformer/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  conv: pkuseg
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
conv_model: Transformer
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
conv:
  epoch: 50
  batch_size: 64
  early_stop: True
  stop_mode: min
  patience: 3
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    factor: 0.5

================================================
FILE: config/crs/inspired/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
# rec
rec_model: InspiredRec
# conv
conv_model: InspiredConv
# embedding: word2vec
embedding_dim: 300
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  early_stop: true
  stop_mode: max
  impatience: 3
  lr_bert: !!float 1e-5
conv:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 3e-5
    eps: !!float 1e-06
    weight_decay: !!float 0.01
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 100
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/inspired/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
# rec
rec_model: InspiredRec
# conv
conv_model: InspiredConv
# embedding: word2vec
embedding_dim: 300
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  early_stop: true
  stop_mode: max
  impatience: 3
  lr_bert: !!float 1e-5
conv:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 3e-5
    eps: !!float 1e-06
    weight_decay: !!float 0.01
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 100
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/inspired/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
# rec
rec_model: InspiredRec
# conv
conv_model: InspiredConv
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  early_stop: true
  stop_mode: max
  impatience: 3
  lr_bert: !!float 1e-5
conv:
  epoch: 50
  batch_size: 1
  optimizer:
    name: AdamW
    lr: !!float 3e-5
    eps: !!float 1e-06
    weight_decay: !!float 0.01
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 100
  early_stop: true
  impatience: 3
  stop_mode: min
  label_smoothing: -1


================================================
FILE: config/crs/inspired/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
# rec
rec_model: InspiredRec
# conv
conv_model: InspiredConv
# embedding: word2vec
embedding_dim: 300
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  early_stop: true
  stop_mode: max
  impatience: 3
  lr_bert: !!float 1e-5
conv:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 3e-5
    eps: !!float 1e-06
    weight_decay: !!float 0.01
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 100
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/inspired/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
# rec
rec_model: InspiredRec
# conv
conv_model: InspiredConv
# embedding: word2vec
embedding_dim: 300
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  early_stop: true
  stop_mode: max
  impatience: 3
  lr_bert: !!float 1e-5
conv:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 3e-5
    eps: !!float 1e-06
    weight_decay: !!float 0.01
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 100
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/inspired/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
# rec
rec_model: InspiredRec
# conv
conv_model: InspiredConv
# embedding: word2vec
embedding_dim: 300
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  early_stop: true
  stop_mode: max
  impatience: 3
  lr_bert: !!float 1e-5
conv:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 3e-5
    eps: !!float 1e-06
    weight_decay: !!float 0.01
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 100
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/kbrd/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize: jieba
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
model: KBRD
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# optim
rec:
  epoch: 1
  batch_size: 4096
  optimizer:
    name: Adam
    lr: !!float 3e-3
conv:
  epoch: 1
  batch_size: 64
  early_stop: True
  stop_mode: min
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5

================================================
FILE: config/crs/kbrd/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
model: KBRD
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# optim
rec:
  epoch: 1
  batch_size: 4096
  optimizer:
    name: Adam
    lr: !!float 3e-3
conv:
  epoch: 1
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1
  early_stop: true
  stop_mode: min
  impatience: 3

================================================
FILE: config/crs/kbrd/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
model: KBRD
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 3e-3
conv:
  epoch: 1
  batch_size: 64
  early_stop: True
  stop_mode: min
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5

================================================
FILE: config/crs/kbrd/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
model: KBRD
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 3e-3
conv:
  epoch: 1
  batch_size: 64
  early_stop: True
  stop_mode: min
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5

================================================
FILE: config/crs/kbrd/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize: nltk
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
model: KBRD
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# optim
rec:
  epoch: 10
  batch_size: 4096
  optimizer:
    name: Adam
    lr: !!float 3e-3
conv:
  epoch: 10
  batch_size: 32
  early_stop: True
  stop_mode: min
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5

================================================
FILE: config/crs/kbrd/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize: pkuseg
# dataloader
context_truncate: 1024
response_truncate: 1024
scale: 1
# model
model: KBRD
token_emb_dim: 300
n_relation: 56
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# optim
rec:
  epoch: 100
  batch_size: 64
  early_stop: True
  stop_mode: max
  patience: 3
  optimizer:
    name: Adam
    lr: !!float 3e-3
conv:
  epoch: 100
  batch_size: 16
  early_stop: True
  stop_mode: min
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5

================================================
FILE: config/crs/kgsf/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize: jieba
embedding: word2vec.npy
# dataloader
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: KGSF
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
pretrain:
  epoch: 1
  batch_size: 4096
  optimizer:
    name: Adam
    lr: !!float 3e-3
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 3e-3
  early_stop: true
  stop_mode: max
  impatience: 3
conv:
  epoch: 1
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1


================================================
FILE: config/crs/kgsf/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize: nltk
embedding: word2vec.npy
# dataloader
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: KGSF
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
pretrain:
  epoch: 1
  batch_size: 64
  optimizer:
    name: Adam
    lr: !!float 3e-3
rec:
  epoch: 1
  batch_size: 64
  optimizer:
    name: Adam
    lr: !!float 3e-3
  early_stop: true
  stop_mode: max
  impatience: 3
conv:
  epoch: 1
  batch_size: 64
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1


================================================
FILE: config/crs/kgsf/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize: nltk
embedding: word2vec.npy
# dataloader
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: KGSF
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
pretrain:
  epoch: 1
  batch_size: 4096
  optimizer:
    name: Adam
    lr: !!float 3e-3
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 3e-3
  early_stop: true
  stop_mode: max
  impatience: 3
conv:
  epoch: 1
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1


================================================
FILE: config/crs/kgsf/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize: nltk
embedding: word2vec.npy
# dataloader
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: KGSF
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
pretrain:
  epoch: 1
  batch_size: 4096
  optimizer:
    name: Adam
    lr: !!float 3e-3
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 3e-3
  early_stop: true
  stop_mode: max
  impatience: 3
conv:
  epoch: 1
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 3e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1


================================================
FILE: config/crs/kgsf/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize: nltk
embedding: word2vec.npy
# dataloader
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: KGSF
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
pretrain:
  epoch: 3
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
rec:
  epoch: 9
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
conv:
  epoch: 90
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1


================================================
FILE: config/crs/kgsf/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize: pkuseg
embedding: word2vec.npy
# dataloader
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: KGSF
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
# optim
pretrain:
  epoch: 50
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
rec:
  epoch: 20
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  stop_mode: max
  impatience: 3
conv:
  epoch: 10
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1


================================================
FILE: config/crs/ntrd/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize: pkuseg
embedding: word2vec.npy
# dataloader
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: NTRD
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
gen_loss_weight: 5
n_movies: 62287
replace_token: '[ITEM]'
# optim
pretrain:
  epoch: 50
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
rec:
  epoch: 20
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  stop_mode: max
  impatience: 3
conv:
  epoch: 10
  batch_size: 64
  optimizer:
    name: Adam
    lr: !!float 1e-3
  lr_scheduler:
    name: ReduceLROnPlateau
    patience: 3
    factor: 0.5
  gradient_clip: 0.1


================================================
FILE: config/crs/redial/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: jieba
  conv: jieba
# dataloader
utterance_truncate: 80
conversation_truncate: 40
scale: 1
# model
# rec
rec_model: ReDialRec
autorec_layer_sizes: [ 1000 ]
autorec_f: sigmoid
autorec_g: sigmoid
# conv
conv_model: ReDialConv
# embedding: word2vec
embedding_dim: 300
utterance_encoder_hidden_size: 256
dialog_encoder_hidden_size: 256
dialog_encoder_num_layers: 1
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min
conv:
  epoch: 1
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/redial/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: nltk
  conv: nltk
# dataloader
utterance_truncate: 80
conversation_truncate: 40
scale: 1
# model
# rec
rec_model: ReDialRec
autorec_layer_sizes: [ 1000 ]
autorec_f: sigmoid
autorec_g: sigmoid
# conv
conv_model: ReDialConv
#embedding: word2vec
embedding_dim: 300
utterance_encoder_hidden_size: 256
dialog_encoder_hidden_size: 256
dialog_encoder_num_layers: 1
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min
conv:
  epoch: 1
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/redial/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: nltk
  conv: nltk
# dataloader
utterance_truncate: 80
conversation_truncate: 40
scale: 1
# model
# rec
rec_model: ReDialRec
autorec_layer_sizes: [ 1000 ]
autorec_f: sigmoid
autorec_g: sigmoid
# conv
conv_model: ReDialConv
# embedding: word2vec
embedding_dim: 300
utterance_encoder_hidden_size: 256
dialog_encoder_hidden_size: 256
dialog_encoder_num_layers: 1
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min
conv:
  epoch: 1
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/redial/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: nltk
  conv: nltk
# dataloader
utterance_truncate: 80
conversation_truncate: 40
scale: 1
# model
# rec
rec_model: ReDialRec
autorec_layer_sizes: [ 1000 ]
autorec_f: sigmoid
autorec_g: sigmoid
# conv
conv_model: ReDialConv
# embedding: word2vec
embedding_dim: 300
utterance_encoder_hidden_size: 256
dialog_encoder_hidden_size: 256
dialog_encoder_num_layers: 1
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min
conv:
  epoch: 1
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/redial/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: nltk
  conv: nltk
# dataloader
utterance_truncate: 80
conversation_truncate: 40
scale: 1
# model
# rec
rec_model: ReDialRec
autorec_layer_sizes: [ 1000 ]
autorec_f: sigmoid
autorec_g: sigmoid
# conv
conv_model: ReDialConv
# embedding: word2vec
embedding_dim: 300
utterance_encoder_hidden_size: 256
dialog_encoder_hidden_size: 256
dialog_encoder_num_layers: 1
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 50
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min
conv:
  epoch: 50
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/redial/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: pkuseg
  conv: pkuseg
# dataloader
utterance_truncate: 80
conversation_truncate: 40
scale: 1
# model
# rec
rec_model: ReDialRec
autorec_layer_sizes: [ 1000 ]
autorec_f: sigmoid
autorec_g: sigmoid
# conv
conv_model: ReDialConv
#embedding: word2vec
embedding_dim: 300
utterance_encoder_hidden_size: 256
dialog_encoder_hidden_size: 256
dialog_encoder_num_layers: 1
use_dropout: False
dropout: 0.3
decoder_hidden_size: 256
decoder_num_layers: 1
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min
conv:
  epoch: 1
  batch_size: 128
  optimizer:
    name: Adam
    lr: !!float 1e-3
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/tgredial/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TGRec
conv_model: TGConv
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  impatience: 3
  stop_mode: max
conv:
  epoch: 1
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/tgredial/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TGRec
conv_model: TGConv
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  impatience: 3
  stop_mode: max
conv:
  epoch: 1
  batch_size: 4
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/tgredial/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TGRec
conv_model: TGConv
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  impatience: 3
  stop_mode: max
conv:
  epoch: 1
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/tgredial/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TGRec
conv_model: TGConv
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  impatience: 3
  stop_mode: max
conv:
  epoch: 1
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/tgredial/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: bert
  conv: gpt2
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TGRec
conv_model: TGConv
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 10
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-4
    weight_decay: 0
  lr_bert: !!float 1e-5
  early_stop: true
  impatience: 3
  stop_mode: max
conv:
  epoch: 10
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000
  early_stop: true
  impatience: 3
  stop_mode: min


================================================
FILE: config/crs/tgredial/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: bert
  conv: gpt2
  policy: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TGRec
conv_model: TGConv
policy_model: TGPolicy
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 50
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  impatience: 3
  stop_mode: max
conv:
  epoch: 50
  batch_size: 8
  gradient_clip: 1.0
  update_freq: 1
  optimizer:
    name: AdamW
    lr: !!float 1.5e-4
  lr_scheduler:
    name: TransformersLinearLR
    warmup_steps: 2000
  early_stop: true
  impatience: 3
  stop_mode: min
policy:
  epoch: 50
  batch_size: 8
  weight_decay: 0.01
  optimizer:
    name: AdamW
    lr: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/policy/conv_bert/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  policy: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
policy_model: ConvBERT
# optim
policy:
  epoch: 50
  batch_size: 8
  weight_decay: 0.01
  optimizer:
    name: AdamW
    lr: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/policy/mgcg/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  policy: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
policy_model: MGCG
dropout_hidden: 0
num_layers: 1
hidden_size: 300
embedding_dim: 300
n_sent: 10
# optim
policy:
  epoch: 100
  batch_size: 1024
  weight_decay: 0.01
  optimizer:
    name: AdamW
    lr: !!float 1e-4
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/policy/pmi/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  policy: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
policy_model: PMI
# optim
policy:
  epoch: 1
  batch_size: 1024
  weight_decay: 0.01
  optimizer:
    name: AdamW
    lr: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/policy/profile_bert/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  policy: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
policy_model: ProfileBERT
n_sent: 10
# optim
policy:
  epoch: 50
  batch_size: 8
  weight_decay: 0.01
  optimizer:
    name: AdamW
    lr: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/policy/topic_bert/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  policy: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
policy_model: TopicBERT
# optim
policy:
  epoch: 50
  batch_size: 8
  weight_decay: 0.01
  optimizer:
    name: AdamW
    lr: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/recommendation/bert/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: BERT
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/bert/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: BERT
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/bert/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: BERT
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/bert/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: BERT
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/bert/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: BERT
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/bert/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: BERT
# optim
rec:
  epoch: 20
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  early_stop: true
  stop_mode: max
  impatience: 3
  lr_bert: !!float 1e-5

================================================
FILE: config/recommendation/gru4rec/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: GRU4REC
gru_hidden_size: 50
num_layers: 3
embedding_dim: 50
dropout_input: 0
dropout_hidden: 0.0
hidden_size: 50
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/gru4rec/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: GRU4REC
gru_hidden_size: 50
num_layers: 3
embedding_dim: 50
dropout_input: 0
dropout_hidden: 0.0
hidden_size: 50
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-2
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/gru4rec/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: GRU4REC
gru_hidden_size: 50
num_layers: 3
embedding_dim: 50
dropout_input: 0
dropout_hidden: 0.0
hidden_size: 50
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/gru4rec/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: GRU4REC
gru_hidden_size: 50
num_layers: 3
embedding_dim: 50
dropout_input: 0
dropout_hidden: 0.0
hidden_size: 50
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/gru4rec/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: GRU4REC
gru_hidden_size: 50
num_layers: 3
embedding_dim: 50
dropout_input: 0
dropout_hidden: 0.0
hidden_size: 50
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-2
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/gru4rec/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: GRU4REC
gru_hidden_size: 50
num_layers: 3
embedding_dim: 50
dropout_input: 0
dropout_hidden: 0.0
hidden_size: 50
# optim
rec:
  epoch: 50
  batch_size: 64
  optimizer:
    name: Adam
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/recommendation/popularity/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: Popularity
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/popularity/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: Popularity
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/popularity/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: Popularity
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/popularity/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: Popularity
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/popularity/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: Popularity
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/popularity/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: Popularity
# optim
rec:
  epoch: 1
  batch_size: 1024
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5

================================================
FILE: config/recommendation/sasrec/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: SASREC
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/sasrec/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: SASREC
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-2
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/sasrec/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: SASREC
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/sasrec/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: SASREC
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/sasrec/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: SASREC
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/sasrec/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: bert
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: SASREC
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
# optim
rec:
  epoch: 50
  batch_size: 256
  optimizer:
    name: Adam
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: config/recommendation/textcnn/durecdial.yaml
================================================
# dataset
dataset: DuRecDial
tokenize:
  rec: jieba
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TextCNN
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
num_filters: 256
embed: 300
filter_sizes: (2, 3, 4)
dropout: 0.5
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/textcnn/gorecdial.yaml
================================================
# dataset
dataset: GoRecDial
tokenize:
  rec: nltk
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TextCNN
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
num_filters: 256
embed: 300
filter_sizes: (2, 3, 4)
dropout: 0.5
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/textcnn/inspired.yaml
================================================
# dataset
dataset: Inspired
tokenize:
  rec: nltk
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TextCNN
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
num_filters: 256
embed: 300
filter_sizes: (2, 3, 4)
dropout: 0.5
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/textcnn/opendialkg.yaml
================================================
# dataset
dataset: OpenDialKG
tokenize:
  rec: nltk
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TextCNN
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
num_filters: 256
embed: 300
filter_sizes: (2, 3, 4)
dropout: 0.5
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/textcnn/redial.yaml
================================================
# dataset
dataset: ReDial
tokenize:
  rec: nltk
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TextCNN
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
num_filters: 256
embed: 300
filter_sizes: (2, 3, 4)
dropout: 0.5
# optim
rec:
  epoch: 1
  batch_size: 8
  optimizer:
    name: AdamW
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5


================================================
FILE: config/recommendation/textcnn/tgredial.yaml
================================================
# dataset
dataset: TGReDial
tokenize:
  rec: sougou
# dataloader
context_truncate: 256
response_truncate: 30
item_truncate: 100
scale: 1
# model
rec_model: TextCNN
hidden_dropout_prob: 0.2
initializer_range: 0.02
hidden_size: 50
max_history_items: 100
num_attention_heads: 1
attention_probs_dropout_prob: 0.2
hidden_act: gelu
num_hidden_layers: 2
num_filters: 256
embed: 300
filter_sizes: (2, 3, 4)
dropout: 0.5
# optim
rec:
  epoch: 50
  batch_size: 64
  optimizer:
    name: Adam
    lr: !!float 1e-3
    weight_decay: !!float 0.0000
  lr_bert: !!float 1e-5
  early_stop: true
  stop_mode: max
  impatience: 3

================================================
FILE: crslab/__init__.py
================================================
__version__ = '0.0.1'


================================================
FILE: crslab/config/__init__.py
================================================
# -*- encoding: utf-8 -*-
# @Time    :   2020/12/22
# @Author  :   Xiaolei Wang
# @email   :   wxl1999@foxmail.com

# UPDATE
# @Time    :   2020/12/29
# @Author  :   Xiaolei Wang
# @email   :   wxl1999@foxmail.com

"""Config module which loads parameters for the whole system.

Attributes:
    SAVE_PATH (str): where system to save.
    DATASET_PATH (str): where dataset to save.
    MODEL_PATH (str): where model related data to save.
    PRETRAIN_PATH (str): where pretrained model to save.
    EMBEDDING_PATH (str): where pretrained embedding to save, used for evaluate embedding related metrics.
"""

import os
from os.path import dirname, realpath

from .config import Config

ROOT_PATH = dirname(dirname(dirname(realpath(__file__))))
SAVE_PATH = os.path.join(ROOT_PATH, 'save')
DATA_PATH = os.path.join(ROOT_PATH, 'data')
DATASET_PATH = os.path.join(DATA_PATH, 'dataset')
MODEL_PATH = os.path.join(DATA_PATH, 'model')
PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain')
EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding')


================================================
FILE: crslab/config/config.py
================================================
# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/23, 2021/1/9
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

import json
import os
import time
from pprint import pprint

import yaml
import torch
from loguru import logger
from tqdm import tqdm


class Config:
    """Configurator module that load the defined parameters."""

    def __init__(self, config_file, gpu='-1', debug=False):
        """Load parameters and set log level.

        Args:
            config_file (str): path to the config file, which should be in ``yaml`` format.
                You can use default config provided in the `Github repo`_, or write it by yourself.
            debug (bool, optional): whether to enable debug function during running. Defaults to False.

        .. _Github repo:
            https://github.com/RUCAIBox/CRSLab

        """

        self.opt = self.load_yaml_configs(config_file)
        # gpu
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu
        if gpu != '-1':
            self.opt['gpu'] = [i for i in range(len(gpu.split(',')))]
        else:
            self.opt['gpu'] = [-1]
        # dataset
        dataset = self.opt['dataset']
        tokenize = self.opt['tokenize']
        if isinstance(tokenize, dict):
            tokenize = ', '.join(tokenize.values())
        # model
        model = self.opt.get('model', None)
        rec_model = self.opt.get('rec_model', None)
        conv_model = self.opt.get('conv_model', None)
        policy_model = self.opt.get('policy_model', None)
        if model:
            model_name = model
        else:
            models = []
            if rec_model:
                models.append(rec_model)
            if conv_model:
                models.append(conv_model)
            if policy_model:
                models.append(policy_model)
            model_name = '_'.join(models)
        self.opt['model_name'] = model_name
        # log
        log_name = self.opt.get("log_name", dataset + '_' + model_name + '_' + time.strftime("%Y-%m-%d-%H-%M-%S",
                                                                                             time.localtime())) + ".log"
        if not os.path.exists("log"):
            os.makedirs("log")
        logger.remove()
        if debug:
            level = 'DEBUG'
        else:
            level = 'INFO'
        logger.add(os.path.join("log", log_name), level=level)
        logger.add(lambda msg: tqdm.write(msg, end=''), colorize=True, level=level)

        logger.info(f"[Dataset: {dataset} tokenized in {tokenize}]")
        if model:
            logger.info(f'[Model: {model}]')
        if rec_model:
            logger.info(f'[Recommendation Model: {rec_model}]')
        if conv_model:
            logger.info(f'[Conversation Model: {conv_model}]')
        if policy_model:
            logger.info(f'[Policy Model: {policy_model}]')
        logger.info("[Config]" + '\n' + json.dumps(self.opt, indent=4))

    @staticmethod
    def load_yaml_configs(filename):
        """This function reads ``yaml`` file to build config dictionary

        Args:
            filename (str): path to ``yaml`` config

        Returns:
            dict: config

        """
        config_dict = dict()
        with open(filename, 'r', encoding='utf-8') as f:
            config_dict.update(yaml.safe_load(f.read()))
        return config_dict

    def __setitem__(self, key, value):
        if not isinstance(key, str):
            raise TypeError("index must be a str.")
        self.opt[key] = value

    def __getitem__(self, item):
        if item in self.opt:
            return self.opt[item]
        else:
            return None

    def get(self, item, default=None):
        """Get value of corrsponding item in config

        Args:
            item (str): key to query in config
            default (optional): default value for item if not found in config. Defaults to None.

        Returns:
            value of corrsponding item in config

        """
        if item in self.opt:
            return self.opt[item]
        else:
            return default

    def __contains__(self, key):
        if not isinstance(key, str):
            raise TypeError("index must be a str.")
        return key in self.opt

    def __str__(self):
        return str(self.opt)

    def __repr__(self):
        return self.__str__()


if __name__ == '__main__':
    opt_dict = Config('../../config/crs/kbrd/redial.yaml')
    pprint(opt_dict)


================================================
FILE: crslab/data/__init__.py
================================================
# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/24, 2020/12/29, 2020/12/17
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com

# @Time   : 2021/10/06
# @Author : Zhipeng Zhao
# @Email  : oran_official@outlook.com

"""Data module which reads, processes and batches data for the whole system

Attributes:
    dataset_register_table (dict): record all supported dataset
    dataset_language_map (dict): record all dataset corresponding language
    dataloader_register_table (dict): record all model corresponding dataloader

"""

from crslab.data.dataloader import *
from crslab.data.dataset import *

dataset_register_table = {
    'ReDial': ReDialDataset,
    'TGReDial': TGReDialDataset,
    'GoRecDial': GoRecDialDataset,
    'OpenDialKG': OpenDialKGDataset,
    'Inspired': InspiredDataset,
    'DuRecDial': DuRecDialDataset
}

dataset_language_map = {
    'ReDial': 'en',
    'TGReDial': 'zh',
    'GoRecDial': 'en',
    'OpenDialKG': 'en',
    'Inspired': 'en',
    'DuRecDial': 'zh'
}

dataloader_register_table = {
    'KGSF': KGSFDataLoader,
    'KBRD': KBRDDataLoader,
    'TGReDial': TGReDialDataLoader,
    'TGRec': TGReDialDataLoader,
    'TGConv': TGReDialDataLoader,
    'TGPolicy': TGReDialDataLoader,
    'TGRec_TGConv': TGReDialDataLoader,
    'TGRec_TGConv_TGPolicy': TGReDialDataLoader,
    'ReDialRec': ReDialDataLoader,
    'ReDialConv': ReDialDataLoader,
    'ReDialRec_ReDialConv': ReDialDataLoader,
    'InspiredRec_InspiredConv': InspiredDataLoader,
    'BERT': TGReDialDataLoader,
    'SASREC': TGReDialDataLoader,
    'TextCNN': TGReDialDataLoader,
    'GRU4REC': TGReDialDataLoader,
    'Popularity': TGReDialDataLoader,
    'Transformer': KGSFDataLoader,
    'GPT2': TGReDialDataLoader,
    'ConvBERT': TGReDialDataLoader,
    'TopicBERT': TGReDialDataLoader,
    'ProfileBERT': TGReDialDataLoader,
    'MGCG': TGReDialDataLoader,
    'PMI': TGReDialDataLoader,
    'NTRD': NTRDDataLoader
}


def get_dataset(opt, tokenize, restore, save) -> BaseDataset:
    """get and process dataset

    Args:
        opt (Config or dict): config for dataset or the whole system.
        tokenize (str): how to tokenize the dataset.
        restore (bool): whether to restore saved dataset which has been processed.
        save (bool): whether to save dataset after processing.

    Returns:
        processed dataset

    """
    dataset = opt['dataset']
    if dataset in dataset_register_table:
        return dataset_register_table[dataset](opt, tokenize, restore, save)
    else:
        raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented')


def get_dataloader(opt, dataset, vocab) -> BaseDataLoader:
    """get dataloader to batchify dataset

    Args:
        opt (Config or dict): config for dataloader or the whole system.
        dataset: processed raw data, no side data.
        vocab (dict): all kinds of useful size, idx and map between token and idx.

    Returns:
        dataloader

    """
    model_name = opt['model_name']
    if model_name in dataloader_register_table:
        return dataloader_register_table[model_name](opt, dataset, vocab)
    else:
        raise NotImplementedError(f'The dataloader [{model_name}] has not been implemented')


================================================
FILE: crslab/data/dataloader/__init__.py
================================================
from .base import BaseDataLoader
from .inspired import InspiredDataLoader
from .kbrd import KBRDDataLoader
from .kgsf import KGSFDataLoader
from .redial import ReDialDataLoader
from .tgredial import TGReDialDataLoader
from .ntrd import NTRDDataLoader


================================================
FILE: crslab/data/dataloader/base.py
================================================
# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/23, 2020/12/29
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

import random
from abc import ABC

from loguru import logger
from math import ceil
from tqdm import tqdm


class BaseDataLoader(ABC):
    """Abstract class of dataloader

    Notes:
        ``'scale'`` can be set in config to limit the size of dataset.

    """

    def __init__(self, opt, dataset):
        """
        Args:
            opt (Config or dict): config for dataloader or the whole system.
            dataset: dataset

        """
        self.opt = opt
        self.dataset = dataset
        self.scale = opt.get('scale', 1)
        assert 0 < self.scale <= 1

    def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None):
        """Collate batch data for system to fit

        Args:
            batch_fn (func): function to collate data
            batch_size (int):
            shuffle (bool, optional): Defaults to True.
            process_fn (func, optional): function to process dataset before batchify. Defaults to None.

        Yields:
            tuple or dict of torch.Tensor: batch data for system to fit

        """
        dataset = self.dataset
        if process_fn is not None:
            dataset = process_fn()
            logger.info('[Finish dataset process before batchify]')
        dataset = dataset[:ceil(len(dataset) * self.scale)]
        logger.debug(f'[Dataset size: {len(dataset)}]')

        batch_num = ceil(len(dataset) / batch_size)
        idx_list = list(range(len(dataset)))
        if shuffle:
            random.shuffle(idx_list)

        for start_idx in tqdm(range(batch_num)):
            batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size]
            batch = [dataset[idx] for idx in batch_idx]
            batch = batch_fn(batch)
            if batch == False:
                continue
            else:
                yield(batch) 

    def get_conv_data(self, batch_size, shuffle=True):
        """get_data wrapper for conversation.

        You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``.

        Args:
            batch_size (int):
            shuffle (bool, optional): Defaults to True.

        Yields:
            tuple or dict of torch.Tensor: batch data for conversation.

        """
        return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn)

    def get_rec_data(self, batch_size, shuffle=True):
        """get_data wrapper for recommendation.

        You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``.

        Args:
            batch_size (int):
            shuffle (bool, optional): Defaults to True.

        Yields:
            tuple or dict of torch.Tensor: batch data for recommendation.

        """
        return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn)

    def get_policy_data(self, batch_size, shuffle=True):
        """get_data wrapper for policy.

        You can implement your own process_fn in ``self.policy_process_fn``, batch_fn in ``policy_batchify``.

        Args:
            batch_size (int):
            shuffle (bool, optional): Defaults to True.

        Yields:
            tuple or dict of torch.Tensor: batch data for policy.

        """
        return self.get_data(self.policy_batchify, batch_size, shuffle, self.policy_process_fn)

    def conv_process_fn(self):
        """Process whole data for conversation before batch_fn.

        Returns:
            processed dataset. Defaults to return the same as `self.dataset`.

        """
        return self.dataset

    def conv_batchify(self, batch):
        """batchify data for conversation after process.

        Args:
            batch (list): processed batch dataset.

        Returns:
            batch data for the system to train conversation part.
        """
        raise NotImplementedError('dataloader must implement conv_batchify() method')

    def rec_process_fn(self):
        """Process whole data for recommendation before batch_fn.

        Returns:
            processed dataset. Defaults to return the same as `self.dataset`.

        """
        return self.dataset

    def rec_batchify(self, batch):
        """batchify data for recommendation after process.

        Args:
            batch (list): processed batch dataset.

        Returns:
            batch data for the system to train recommendation part.
        """
        raise NotImplementedError('dataloader must implement rec_batchify() method')

    def policy_process_fn(self):
        """Process whole data for policy before batch_fn.

        Returns:
            processed dataset. Defaults to return the same as `self.dataset`.

        """
        return self.dataset

    def policy_batchify(self, batch):
        """batchify data for policy after process.

        Args:
            batch (list): processed batch dataset.

        Returns:
            batch data for the system to train policy part.
        """
        raise NotImplementedError('dataloader must implement policy_batchify() method')

    def retain_recommender_target(self):
        """keep data whose role is recommender.

        Returns:
            Recommender part of ``self.dataset``.

        """
        dataset = []
        for conv_dict in tqdm(self.dataset):
            if conv_dict['role'] == 'Recommender':
                dataset.append(conv_dict)
        return dataset

    def rec_interact(self, data):
        """process user input data for system to recommend.

        Args:
            data: user input data.

        Returns:
            data for system to recommend.
        """
        pass

    def conv_interact(self, data):
        """Process user input data for system to converse.

        Args:
            data: user input data.

        Returns:
            data for system in converse.
        """
        pass


================================================
FILE: crslab/data/dataloader/inspired.py
================================================
# @Time   : 2021/3/11
# @Author : Beichen Zhang
# @Email  : zhangbeichen724@gmail.com

from copy import deepcopy

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt


class InspiredDataLoader(BaseDataLoader):
    """Dataloader for model Inspired.

    Notes:
        You can set the following parameters in config:

        - ``'context_truncate'``: the maximum length of context.
        - ``'response_truncate'``: the maximum length of response.
        - ``'entity_truncate'``: the maximum length of mentioned entities in context.
        - ``'word_truncate'``: the maximum length of mentioned words in context.
        - ``'item_truncate'``: the maximum length of mentioned items in context.

        The following values must be specified in ``vocab``:

        - ``'pad'``
        - ``'start'``
        - ``'end'``
        - ``'unk'``
        - ``'pad_entity'``
        - ``'pad_word'``

        the above values specify the id of needed special token.

        - ``'ind2tok'``: map from index to token.
        - ``'tok2ind'``: map from token to index.
        - ``'vocab_size'``: size of vocab.
        - ``'id2entity'``: map from index to entity.
        - ``'n_entity'``: number of entities in the entity KG of dataset.
        - ``'sent_split'`` (optional): token used to split sentence. Defaults to ``'end'``.
        - ``'word_split'`` (optional): token used to split word. Defaults to ``'end'``.

    """

    def __init__(self, opt, dataset, vocab):
        """

        Args:
            opt (Config or dict): config for dataloader or the whole system.
            dataset: data for model.
            vocab (dict): all kinds of useful size, idx and map between token and idx.

        """
        super().__init__(opt, dataset)

        self.n_entity = vocab['n_entity']
        self.pad_token_idx = vocab['pad']
        self.start_token_idx = vocab['start']
        self.end_token_idx = vocab['end']
        self.unk_token_idx = vocab['unk']
        self.conv_bos_id = vocab['start']
        self.cls_id = vocab['start']
        self.sep_id = vocab['end']
        if 'sent_split' in vocab:
            self.sent_split_idx = vocab['sent_split']
        else:
            self.sent_split_idx = vocab['end']

        self.pad_entity_idx = vocab['pad_entity']
        self.pad_word_idx = vocab['pad_word']

        self.tok2ind = vocab['tok2ind']
        self.ind2tok = vocab['ind2tok']
        self.id2entity = vocab['id2entity']

        self.context_truncate = opt.get('context_truncate', None)
        self.response_truncate = opt.get('response_truncate', None)

    def rec_process_fn(self, *args, **kwargs):
        augment_dataset = []
        for conv_dict in tqdm(self.dataset):
            if conv_dict['role'] == 'Recommender':
                for movie in conv_dict['items']:
                    augment_conv_dict = deepcopy(conv_dict)
                    augment_conv_dict['item'] = movie
                    augment_dataset.append(augment_conv_dict)
        return augment_dataset

    def _process_rec_context(self, context_tokens):
        compact_context = []
        for i, utterance in enumerate(context_tokens):
            if i != 0:
                utterance.insert(0, self.sent_split_idx)
            compact_context.append(utterance)
        compat_context = truncate(merge_utt(compact_context),
                                  self.context_truncate - 2,
                                  truncate_tail=False)
        compat_context = add_start_end_token_idx(compat_context,
                                                 self.start_token_idx,
                                                 self.end_token_idx)
        return compat_context

    def rec_batchify(self, batch):
        batch_context = []
        batch_movie_id = []

        for conv_dict in batch:
            context = self._process_rec_context(conv_dict['context_tokens'])
            batch_context.append(context)

            item_id = conv_dict['item']
            batch_movie_id.append(item_id)

        batch_context = padded_tensor(batch_context,
                                      self.pad_token_idx,
                                      max_len=self.context_truncate)
        batch_mask = (batch_context != self.pad_token_idx).long()

        return (batch_context, batch_mask, torch.tensor(batch_movie_id))

    def conv_batchify(self, batch):
        """get batch and corresponding roles
        """
        batch_roles = []
        batch_context_tokens = []
        batch_response = []

        for conv_dict in batch:
            batch_roles.append(0 if conv_dict['role'] == 'Seeker' else 1)
            context_tokens = [utter + [self.conv_bos_id] for utter in conv_dict['context_tokens']]
            context_tokens[-1] = context_tokens[-1][:-1]
            batch_context_tokens.append(
                truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False),
            )
            batch_response.append(
                add_start_end_token_idx(
                    truncate(conv_dict['response'], max_length=self.response_truncate - 2),
                    start_token_idx=self.start_token_idx,
                    end_token_idx=self.end_token_idx
                )
            )

        batch_context_tokens = padded_tensor(items=batch_context_tokens,
                                             pad_idx=self.pad_token_idx,
                                             max_len=self.context_truncate,
                                             pad_tail=False)
        batch_response = padded_tensor(batch_response,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.response_truncate,
                                       pad_tail=True)
        batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1)
        batch_roles = torch.tensor(batch_roles)

        return (batch_roles,
                batch_input_ids,
                batch_context_tokens,
                batch_response)

    def policy_batchify(self, batch):
        pass


================================================
FILE: crslab/data/dataloader/kbrd.py
================================================
# @Time   : 2020/11/27
# @Author : Xiaolei Wang
# @Email  : wxl1999@foxmail.com

# UPDATE:
# @Time   : 2020/12/2
# @Author : Xiaolei Wang
# @Email  : wxl1999@foxmail.com

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt


class KBRDDataLoader(BaseDataLoader):
    """Dataloader for model KBRD.

    Notes:
        You can set the following parameters in config:

        - ``'context_truncate'``: the maximum length of context.
        - ``'response_truncate'``: the maximum length of response.
        - ``'entity_truncate'``: the maximum length of mentioned entities in context.

        The following values must be specified in ``vocab``:

        - ``'pad'``
        - ``'start'``
        - ``'end'``
        - ``'pad_entity'``

        the above values specify the id of needed special token.

    """

    def __init__(self, opt, dataset, vocab):
        """

        Args:
            opt (Config or dict): config for dataloader or the whole system.
            dataset: data for model.
            vocab (dict): all kinds of useful size, idx and map between token and idx.

        """
        super().__init__(opt, dataset)
        self.pad_token_idx = vocab['pad']
        self.start_token_idx = vocab['start']
        self.end_token_idx = vocab['end']
        self.pad_entity_idx = vocab['pad_entity']
        self.context_truncate = opt.get('context_truncate', None)
        self.response_truncate = opt.get('response_truncate', None)
        self.entity_truncate = opt.get('entity_truncate', None)

    def rec_process_fn(self):
        augment_dataset = []
        for conv_dict in tqdm(self.dataset):
            if conv_dict['role'] == 'Recommender':
                for movie in conv_dict['items']:
                    augment_conv_dict = {'context_entities': conv_dict['context_entities'], 'item': movie}
                    augment_dataset.append(augment_conv_dict)
        return augment_dataset

    def rec_batchify(self, batch):
        batch_context_entities = []
        batch_movies = []
        for conv_dict in batch:
            batch_context_entities.append(conv_dict['context_entities'])
            batch_movies.append(conv_dict['item'])

        return {
            "context_entities": batch_context_entities,
            "item": torch.tensor(batch_movies, dtype=torch.long)
        }

    def conv_process_fn(self, *args, **kwargs):
        return self.retain_recommender_target()

    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_context_entities = []
        batch_response = []
        for conv_dict in batch:
            batch_context_tokens.append(
                truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False))
            batch_context_entities.append(conv_dict['context_entities'])
            batch_response.append(
                add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

        return {
            "context_tokens": padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),
            "context_entities": batch_context_entities,
            "response": padded_tensor(batch_response, self.pad_token_idx)
        }

    def policy_batchify(self, *args, **kwargs):
        pass


================================================
FILE: crslab/data/dataloader/kgsf.py
================================================
# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/23, 2020/12/2
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

from copy import deepcopy

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, get_onehot, truncate, merge_utt


class KGSFDataLoader(BaseDataLoader):
    """Dataloader for model KGSF.

    Notes:
        You can set the following parameters in config:

        - ``'context_truncate'``: the maximum length of context.
        - ``'response_truncate'``: the maximum length of response.
        - ``'entity_truncate'``: the maximum length of mentioned entities in context.
        - ``'word_truncate'``: the maximum length of mentioned words in context.

        The following values must be specified in ``vocab``:

        - ``'pad'``
        - ``'start'``
        - ``'end'``
        - ``'pad_entity'``
        - ``'pad_word'``

        the above values specify the id of needed special token.

        - ``'n_entity'``: the number of entities in the entity KG of dataset.

    """

    def __init__(self, opt, dataset, vocab):
        """

        Args:
            opt (Config or dict): config for dataloader or the whole system.
            dataset: data for model.
            vocab (dict): all kinds of useful size, idx and map between token and idx.

        """
        super().__init__(opt, dataset)
        self.n_entity = vocab['n_entity']
        self.pad_token_idx = vocab['pad']
        self.start_token_idx = vocab['start']
        self.end_token_idx = vocab['end']
        self.pad_entity_idx = vocab['pad_entity']
        self.pad_word_idx = vocab['pad_word']
        self.context_truncate = opt.get('context_truncate', None)
        self.response_truncate = opt.get('response_truncate', None)
        self.entity_truncate = opt.get('entity_truncate', None)
        self.word_truncate = opt.get('word_truncate', None)

    def get_pretrain_data(self, batch_size, shuffle=True):
        return self.get_data(self.pretrain_batchify, batch_size, shuffle, self.retain_recommender_target)

    def pretrain_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))
            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))

        return (padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity))

    def rec_process_fn(self):
        augment_dataset = []
        for conv_dict in tqdm(self.dataset):
            if conv_dict['role'] == 'Recommender':
                for movie in conv_dict['items']:
                    augment_conv_dict = deepcopy(conv_dict)
                    augment_conv_dict['item'] = movie
                    augment_dataset.append(augment_conv_dict)
        return augment_dataset

    def rec_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        batch_item = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))
            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))
            batch_item.append(conv_dict['item'])

        return (padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),
                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity),
                torch.tensor(batch_item, dtype=torch.long))

    def conv_process_fn(self, *args, **kwargs):
        return self.retain_recommender_target()

    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_context_entities = []
        batch_context_words = []
        batch_response = []
        for conv_dict in batch:
            batch_context_tokens.append(
                truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))
            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))
            batch_response.append(
                add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

        return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),
                padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),
                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),
                padded_tensor(batch_response, self.pad_token_idx))

    def policy_batchify(self, *args, **kwargs):
        pass


================================================
FILE: crslab/data/dataloader/ntrd.py
================================================
# @Time   : 2021/10/06
# @Author : Zhipeng Zhao
# @Email  : oran_official@outlook.com

from copy import deepcopy

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, merge_utt_replace, padded_tensor, get_onehot, truncate, merge_utt


class NTRDDataLoader(BaseDataLoader):
    def __init__(self, opt, dataset, vocab):
        """

        Args:
            opt (Config or dict): config for dataloader or the whole system.
            dataset: data for model.
            vocab (dict): all kinds of useful size, idx and map between token and idx.

        """
        super().__init__(opt, dataset)
        self.n_entity = vocab['n_entity']
        self.pad_token_idx = vocab['pad']
        self.start_token_idx = vocab['start']
        self.end_token_idx = vocab['end']
        self.pad_entity_idx = vocab['pad_entity']
        self.pad_word_idx = vocab['pad_word']
        self.context_truncate = opt.get('context_truncate', None)
        self.response_truncate = opt.get('response_truncate', None)
        self.entity_truncate = opt.get('entity_truncate', None)
        self.word_truncate = opt.get('word_truncate', None)
        self.replace_token = opt.get('replace_token',None)
        self.replace_token_idx = vocab[self.replace_token]

    def get_pretrain_data(self, batch_size, shuffle=True):
        return self.get_data(self.pretrain_batchify, batch_size, shuffle, self.retain_recommender_target)

    def pretrain_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))
            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))

        return (padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity))

    def rec_process_fn(self):
        augment_dataset = []
        for conv_dict in tqdm(self.dataset):
            if conv_dict['role'] == 'Recommender':
                for movie in conv_dict['items']:
                    augment_conv_dict = deepcopy(conv_dict)
                    augment_conv_dict['item'] = movie
                    augment_dataset.append(augment_conv_dict)
        return augment_dataset

    def rec_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        batch_item = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))
            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))
            batch_item.append(conv_dict['item'])

        return (padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),
                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity),
                torch.tensor(batch_item, dtype=torch.long))

    def conv_process_fn(self, *args, **kwargs):
        return self.retain_recommender_target()

    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_context_entities = []
        batch_context_words = []
        batch_response = []
        flag = False
        batch_all_movies = [] 
        for conv_dict in batch:
            temp = add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx)

            if temp.count(self.replace_token_idx) != 0:
                flag = True
            batch_context_tokens.append(
                truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))
            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))
            batch_response.append(
                add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))
            
            batch_all_movies.append(
                truncate(conv_dict['items'], temp.count(self.replace_token_idx), truncate_tail=False)) #only use movies, not all entities.
        if flag == False:# zero slot in a batch
            return False

        return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),
                padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),
                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),
                padded_tensor(batch_response, self.pad_token_idx),
                padded_tensor(batch_all_movies, self.pad_entity_idx, pad_tail=False)) 

    def policy_batchify(self, *args, **kwargs):
        pass

================================================
FILE: crslab/data/dataloader/redial.py
================================================
# @Time   : 2020/11/22
# @Author : Chenzhan Shang
# @Email  : czshang@outlook.com

# UPDATE:
# @Time   : 2020/12/16
# @Author : Xiaolei Wang
# @Email  : wxl1999@foxmail.com

import re
from copy import copy

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import padded_tensor, get_onehot, truncate

movie_pattern = re.compile(r'^@\d{5,6}$')


class ReDialDataLoader(BaseDataLoader):
    """Dataloader for model ReDial.

    Notes:
        You can set the following parameters in config:

        - ``'utterance_truncate'``: the maximum length of a single utterance.
        - ``'conversation_truncate'``: the maximum length of the whole conversation.

        The following values must be specified in ``vocab``:

        - ``'pad'``
        - ``'start'``
        - ``'end'``
        - ``'unk'``

        the above values specify the id of needed special token.

        - ``'ind2tok'``: map from index to token.
        - ``'n_entity'``: number of entities in the entity KG of dataset.
        - ``'vocab_size'``: size of vocab.

    """

    def __init__(self, opt, dataset, vocab):
        """

        Args:
            opt (Config or dict): config for dataloader or the whole system.
            dataset: data for model.
            vocab (dict): all kinds of useful size, idx and map between token and idx.

        """
        super().__init__(opt, dataset)
        self.ind2tok = vocab['ind2tok']
        self.n_entity = vocab['n_entity']
        self.pad_token_idx = vocab['pad']
        self.start_token_idx = vocab['start']
        self.end_token_idx = vocab['end']
        self.unk_token_idx = vocab['unk']
        self.item_token_idx = vocab['vocab_size']
        self.conversation_truncate = self.opt.get('conversation_truncate', None)
        self.utterance_truncate = self.opt.get('utterance_truncate', None)

    def rec_process_fn(self, *args, **kwargs):
        dataset = []
        for conversation in self.dataset:
            if conversation['role'] == 'Recommender':
                for item in conversation['items']:
                    context_entities = conversation['context_entities']
                    dataset.append({'context_entities': context_entities, 'item': item})
        return dataset

    def rec_batchify(self, batch):
        batch_context_entities = []
        batch_item = []
        for conversation in batch:
            batch_context_entities.append(conversation['context_entities'])
            batch_item.append(conversation['item'])
        context_entities = get_onehot(batch_context_entities, self.n_entity)
        return {'context_entities': context_entities, 'item': torch.tensor(batch_item, dtype=torch.long)}

    def conv_process_fn(self):
        dataset = []
        for conversation in tqdm(self.dataset):
            if conversation['role'] != 'Recommender':
                continue
            context_tokens = [truncate(utterance, self.utterance_truncate, truncate_tail=True) for utterance in
                              conversation['context_tokens']]
            context_tokens = truncate(context_tokens, self.conversation_truncate, truncate_tail=True)
            context_length = len(context_tokens)
            utterance_lengths = [len(utterance) for utterance in context_tokens]
            request = context_tokens[-1]
            response = truncate(conversation['response'], self.utterance_truncate, truncate_tail=True)
            dataset.append({'context_tokens': context_tokens, 'context_length': context_length,
                            'utterance_lengths': utterance_lengths, 'request': request, 'response': response})
        return dataset

    def conv_batchify(self, batch):
        max_utterance_length = max([max(conversation['utterance_lengths']) for conversation in batch])
        max_response_length = max([len(conversation['response']) for conversation in batch])
        max_utterance_length = max(max_utterance_length, max_response_length)
        max_context_length = max([conversation['context_length'] for conversation in batch])
        batch_context = []
        batch_context_length = []
        batch_utterance_lengths = []
        batch_request = []  # tensor
        batch_request_length = []
        batch_response = []

        for conversation in batch:
            padded_context = padded_tensor(conversation['context_tokens'], pad_idx=self.pad_token_idx,
                                           pad_tail=True, max_len=max_utterance_length)
            if len(conversation['context_tokens']) < max_context_length:
                pad_tensor = padded_context.new_full(
                    (max_context_length - len(conversation['context_tokens']), max_utterance_length), self.pad_token_idx
                )
                padded_context = torch.cat((padded_context, pad_tensor), 0)
            batch_context.append(padded_context)
            batch_context_length.append(conversation['context_length'])
            batch_utterance_lengths.append(conversation['utterance_lengths'] +
                                           [0] * (max_context_length - len(conversation['context_tokens'])))

            request = conversation['request']
            batch_request_length.append(len(request))
            batch_request.append(request)

            response = copy(conversation['response'])
            # replace '^\d{5,6}$' by '__item__'
            for i in range(len(response)):
                if movie_pattern.match(self.ind2tok[response[i]]):
                    response[i] = self.item_token_idx
            batch_response.append(response)

        context = torch.stack(batch_context, dim=0)
        request = padded_tensor(batch_request, self.pad_token_idx, pad_tail=True, max_len=max_utterance_length)
        response = padded_tensor(batch_response, self.pad_token_idx, pad_tail=True,
                                 max_len=max_utterance_length)  # (bs, utt_len)

        return {'context': context, 'context_lengths': torch.tensor(batch_context_length),
                'utterance_lengths': torch.tensor(batch_utterance_lengths), 'request': request,
                'request_lengths': torch.tensor(batch_request_length), 'response': response}

    def policy_batchify(self, batch):
        pass


================================================
FILE: crslab/data/dataloader/tgredial.py
================================================
# @Time   : 2020/12/9
# @Author : Yuanhang Zhou
# @Email  : sdzyh002@gmail.com

# UPDATE:
# @Time   : 2020/12/29, 2020/12/15
# @Author : Xiaolei Wang, Yuanhang Zhou
# @Email  : wxl1999@foxmail.com, sdzyh002@gmail

import random
from copy import deepcopy

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt


class TGReDialDataLoader(BaseDataLoader):
    """Dataloader for model TGReDial.

    Notes:
        You can set the following parameters in config:

        - ``'context_truncate'``: the maximum length of context.
        - ``'response_truncate'``: the maximum length of response.
        - ``'entity_truncate'``: the maximum length of mentioned entities in context.
        - ``'word_truncate'``: the maximum length of mentioned words in context.
        - ``'item_truncate'``: the maximum length of mentioned items in context.

        The following values must be specified in ``vocab``:

        - ``'pad'``
        - ``'start'``
        - ``'end'``
        - ``'unk'``
        - ``'pad_entity'``
        - ``'pad_word'``

        the above values specify the id of needed special token.

        - ``'ind2tok'``: map from index to token.
        - ``'tok2ind'``: map from token to index.
        - ``'vocab_size'``: size of vocab.
        - ``'id2entity'``: map from index to entity.
        - ``'n_entity'``: number of entities in the entity KG of dataset.
        - ``'sent_split'`` (optional): token used to split sentence. Defaults to ``'end'``.
        - ``'word_split'`` (optional): token used to split word. Defaults to ``'end'``.
        - ``'pad_topic'`` (optional): token used to pad topic.
        - ``'ind2topic'`` (optional): map from index to topic.

    """

    def __init__(self, opt, dataset, vocab):
        """

        Args:
            opt (Config or dict): config for dataloader or the whole system.
            dataset: data for model.
            vocab (dict): all kinds of useful size, idx and map between token and idx.

        """
        super().__init__(opt, dataset)

        self.n_entity = vocab['n_entity']
        self.item_size = self.n_entity
        self.pad_token_idx = vocab['pad']
        self.start_token_idx = vocab['start']
        self.end_token_idx = vocab['end']
        self.unk_token_idx = vocab['unk']
        self.conv_bos_id = vocab['start']
        self.cls_id = vocab['start']
        self.sep_id = vocab['end']
        if 'sent_split' in vocab:
            self.sent_split_idx = vocab['sent_split']
        else:
            self.sent_split_idx = vocab['end']
        if 'word_split' in vocab:
            self.word_split_idx = vocab['word_split']
        else:
            self.word_split_idx = vocab['end']

        self.pad_entity_idx = vocab['pad_entity']
        self.pad_word_idx = vocab['pad_word']
        if 'pad_topic' in vocab:
            self.pad_topic_idx = vocab['pad_topic']

        self.tok2ind = vocab['tok2ind']
        self.ind2tok = vocab['ind2tok']
        self.id2entity = vocab['id2entity']
        if 'ind2topic' in vocab:
            self.ind2topic = vocab['ind2topic']

        self.context_truncate = opt.get('context_truncate', None)
        self.response_truncate = opt.get('response_truncate', None)
        self.entity_truncate = opt.get('entity_truncate', None)
        self.word_truncate = opt.get('word_truncate', None)
        self.item_truncate = opt.get('item_truncate', None)

    def rec_process_fn(self, *args, **kwargs):
        augment_dataset = []
        for conv_dict in tqdm(self.dataset):
            for movie in conv_dict['items']:
                augment_conv_dict = deepcopy(conv_dict)
                augment_conv_dict['item'] = movie
                augment_dataset.append(augment_conv_dict)
        return augment_dataset

    def _process_rec_context(self, context_tokens):
        compact_context = []
        for i, utterance in enumerate(context_tokens):
            if i != 0:
                utterance.insert(0, self.sent_split_idx)
            compact_context.append(utterance)
        compat_context = truncate(merge_utt(compact_context),
                                  self.context_truncate - 2,
                                  truncate_tail=False)
        compat_context = add_start_end_token_idx(compat_context,
                                                 self.start_token_idx,
                                                 self.end_token_idx)
        return compat_context

    def _neg_sample(self, item_set):
        item = random.randint(1, self.item_size)
        while item in item_set:
            item = random.randint(1, self.item_size)
        return item

    def _process_history(self, context_items, item_id=None):
        input_ids = truncate(context_items,
                             max_length=self.item_truncate,
                             truncate_tail=False)
        input_mask = [1] * len(input_ids)
        sample_negs = []
        seq_set = set(input_ids)
        for _ in input_ids:
            sample_negs.append(self._neg_sample(seq_set))

        if item_id is not None:
            target_pos = input_ids[1:] + [item_id]
            return input_ids, target_pos, input_mask, sample_negs
        else:
            return input_ids, input_mask, sample_negs

    def rec_batchify(self, batch):
        batch_context = []
        batch_movie_id = []
        batch_input_ids = []
        batch_target_pos = []
        batch_input_mask = []
        batch_sample_negs = []

        for conv_dict in batch:
            context = self._process_rec_context(conv_dict['context_tokens'])
            batch_context.append(context)

            item_id = conv_dict['item']
            batch_movie_id.append(item_id)

            if 'interaction_history' in conv_dict:
                context_items = conv_dict['interaction_history'] + conv_dict[
                    'context_items']
            else:
                context_items = conv_dict['context_items']

            input_ids, target_pos, input_mask, sample_negs = self._process_history(
                context_items, item_id)
            batch_input_ids.append(input_ids)
            batch_target_pos.append(target_pos)
            batch_input_mask.append(input_mask)
            batch_sample_negs.append(sample_negs)

        batch_context = padded_tensor(batch_context,
                                      self.pad_token_idx,
                                      max_len=self.context_truncate)
        batch_mask = (batch_context != self.pad_token_idx).long()

        return (batch_context, batch_mask,
                padded_tensor(batch_input_ids,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(batch_target_pos,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(batch_input_mask,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(batch_sample_negs,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                torch.tensor(batch_movie_id))

    def rec_interact(self, data):
        context = [self._process_rec_context(data['context_tokens'])]
        if 'interaction_history' in data:
            context_items = data['interaction_history'] + data['context_items']
        else:
            context_items = data['context_items']
        input_ids, input_mask, sample_negs = self._process_history(context_items)
        input_ids, input_mask, sample_negs = [input_ids], [input_mask], [sample_negs]

        context = padded_tensor(context,
                                self.pad_token_idx,
                                max_len=self.context_truncate)
        mask = (context != self.pad_token_idx).long()

        return (context, mask,
                padded_tensor(input_ids,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                None,
                padded_tensor(input_mask,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(sample_negs,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                None)

    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_enhanced_context_tokens = []
        batch_response = []
        batch_context_entities = []
        batch_context_words = []
        for conv_dict in batch:
            context_tokens = [utter + [self.conv_bos_id] for utter in conv_dict['context_tokens']]
            context_tokens[-1] = context_tokens[-1][:-1]
            batch_context_tokens.append(
                truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False),
            )
            batch_response.append(
                add_start_end_token_idx(
                    truncate(conv_dict['response'], max_length=self.response_truncate - 2),
                    start_token_idx=self.start_token_idx,
                    end_token_idx=self.end_token_idx
                )
            )
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))

            enhanced_topic = []
            if 'target' in conv_dict:
                for target_policy in conv_dict['target']:
                    topic_variable = target_policy[1]
                    if isinstance(topic_variable, list):
                        for topic in topic_variable:
                            enhanced_topic.append(topic)
                enhanced_topic = [[
                    self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id]
                ] for topic_id in enhanced_topic]
                enhanced_topic = merge_utt(enhanced_topic, self.word_split_idx, False, self.sent_split_idx)

            enhanced_movie = []
            if 'items' in conv_dict:
                for movie_id in conv_dict['items']:
                    enhanced_movie.append(movie_id)
                enhanced_movie = [
                    [self.tok2ind.get(token, self.unk_token_idx) for token in self.id2entity[movie_id].split('(')[0]]
                    for movie_id in enhanced_movie]
                enhanced_movie = truncate(merge_utt(enhanced_movie, self.word_split_idx, self.sent_split_idx),
                                          self.item_truncate, truncate_tail=False)

            if len(enhanced_movie) != 0:
                enhanced_context_tokens = enhanced_movie + truncate(batch_context_tokens[-1],
                                                                    max_length=self.context_truncate - len(
                                                                        enhanced_movie), truncate_tail=False)
            elif len(enhanced_topic) != 0:
                enhanced_context_tokens = enhanced_topic + truncate(batch_context_tokens[-1],
                                                                    max_length=self.context_truncate - len(
                                                                        enhanced_topic), truncate_tail=False)
            else:
                enhanced_context_tokens = batch_context_tokens[-1]
            batch_enhanced_context_tokens.append(
                enhanced_context_tokens
            )

        batch_context_tokens = padded_tensor(items=batch_context_tokens,
                                             pad_idx=self.pad_token_idx,
                                             max_len=self.context_truncate,
                                             pad_tail=False)
        batch_response = padded_tensor(batch_response,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.response_truncate,
                                       pad_tail=True)
        batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1)
        batch_enhanced_context_tokens = padded_tensor(items=batch_enhanced_context_tokens,
                                                      pad_idx=self.pad_token_idx,
                                                      max_len=self.context_truncate,
                                                      pad_tail=False)
        batch_enhanced_input_ids = torch.cat((batch_enhanced_context_tokens, batch_response), dim=1)

        return (batch_enhanced_input_ids, batch_enhanced_context_tokens,
                batch_input_ids, batch_context_tokens,
                padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False), batch_response)

    def conv_interact(self, data):
        context_tokens = [utter + [self.conv_bos_id] for utter in data['context_tokens']]
        context_tokens[-1] = context_tokens[-1][:-1]
        context_tokens = [truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False)]
        context_tokens = padded_tensor(items=context_tokens,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.context_truncate,
                                       pad_tail=False)
        context_entities = [truncate(data['context_entities'], self.entity_truncate, truncate_tail=False)]
        context_words = [truncate(data['context_words'], self.word_truncate, truncate_tail=False)]

        return (context_tokens, context_tokens,
                context_tokens, context_tokens,
                padded_tensor(context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(context_words,
                              self.pad_word_idx,
                              pad_tail=False), None)

    def policy_process_fn(self, *args, **kwargs):
        augment_dataset = []
        for conv_dict in tqdm(self.dataset):
            for target_policy in conv_dict['target']:
                topic_variable = target_policy[1]
                for topic in topic_variable:
                    augment_conv_dict = deepcopy(conv_dict)
                    augment_conv_dict['target_topic'] = topic
                    augment_dataset.append(augment_conv_dict)
        return augment_dataset

    def policy_batchify(self, batch):
        batch_context = []
        batch_context_policy = []
        batch_user_profile = []
        batch_target = []

        for conv_dict in batch:
            final_topic = conv_dict['final']
            final_topic = [[
                self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id]
            ] for topic_id in final_topic[1]]
            final_topic = merge_utt(final_topic, self.word_split_idx, False, self.sep_id)

            context = conv_dict['context_tokens']
            context = merge_utt(context,
                                self.sent_split_idx,
                                False,
                                self.sep_id)
            context += final_topic
            context = add_start_end_token_idx(
                truncate(context, max_length=self.context_truncate - 1, truncate_tail=False),
                start_token_idx=self.cls_id)
            batch_context.append(context)

            # [topic, topic, ..., topic]
            context_policy = []
            for policies_one_turn in conv_dict['context_policy']:
                if len(policies_one_turn) != 0:
                    for policy in policies_one_turn:
                        for topic_id in policy[1]:
                            if topic_id != self.pad_topic_idx:
                                policy = []
                                for token in self.ind2topic[topic_id]:
                                    policy.append(self.tok2ind.get(token, self.unk_token_idx))
                                context_policy.append(policy)
            context_policy = merge_utt(context_policy, self.word_split_idx, False)
            context_policy = add_start_end_token_idx(
                context_policy,
                start_token_idx=self.cls_id,
                end_token_idx=self.sep_id)
            context_policy += final_topic
            batch_context_policy.append(context_policy)

            batch_user_profile.extend(conv_dict['user_profile'])

            batch_target.append(conv_dict['target_topic'])

        batch_context = padded_tensor(batch_context,
                                      pad_idx=self.pad_token_idx,
                                      pad_tail=True,
                                      max_len=self.context_truncate)
        batch_cotnext_mask = (batch_context != self.pad_token_idx).long()
        batch_context_policy = padded_tensor(batch_context_policy,
                                             pad_idx=self.pad_token_idx,
                                             pad_tail=True)
        batch_context_policy_mask = (batch_context_policy != 0).long()
        batch_user_profile = padded_tensor(batch_user_profile,
                                           pad_idx=self.pad_token_idx,
                                           pad_tail=True)
        batch_user_profile_mask = (batch_user_profile != 0).long()
        batch_target = torch.tensor(batch_target, dtype=torch.long)

        return (batch_context, batch_cotnext_mask, batch_context_policy,
                batch_context_policy_mask, batch_user_profile,
                batch_user_profile_mask, batch_target)


================================================
FILE: crslab/data/dataloader/utils.py
================================================
# -*- encoding: utf-8 -*-
# @Time    :   2020/12/10
# @Author  :   Xiaolei Wang
# @email   :   wxl1999@foxmail.com

# UPDATE
# @Time    :   2020/12/20, 2020/12/15
# @Author  :   Xiaolei Wang, Yuanhang Zhou
# @email   :   wxl1999@foxmail.com, sdzyh002@gmail

# UPDATE
# @Time   : 2021/10/06
# @Author : Zhipeng Zhao
# @Email  : oran_official@outlook.com


from copy import copy

import torch
from typing import List, Union, Optional


def padded_tensor(
        items: List[Union[List[int], torch.LongTensor]],
        pad_idx: int = 0,
        pad_tail: bool = True,
        max_len: Optional[int] = None,
) -> torch.LongTensor:
    """Create a padded matrix from an uneven list of lists.

    Returns padded matrix.

    Matrix is right-padded (filled to the right) by default, but can be
    left padded if the flag is set to True.

    Matrix can also be placed on cuda automatically.

    :param list[iter[int]] items: List of items
    :param int pad_idx: the value to use for padding
    :param bool pad_tail:
    :param int max_len: if None, the max length is the maximum item length

    :returns: padded tensor.
    :rtype: Tensor[int64]

    """
    # number of items
    n = len(items)
    # length of each item
    lens: List[int] = [len(item) for item in items]  # type: ignore
    # max in time dimension
    t = max(lens) if max_len is None else max_len
    # if input tensors are empty, we should expand to nulls
    t = max(t, 1)

    if isinstance(items[0], torch.Tensor):
        # keep type of input tensors, they may already be cuda ones
        output = items[0].new(n, t)  # type: ignore
    else:
        output = torch.LongTensor(n, t)  # type: ignore
    output.fill_(pad_idx)

    for i, (item, length) in enumerate(zip(items, lens)):
        if length == 0:
            # skip empty items
            continue
        if not isinstance(item, torch.Tensor):
            # put non-tensors into a tensor
            item = torch.tensor(item, dtype=torch.long)  # type: ignore
        if pad_tail:
            # place at beginning
            output[i, :length] = item
        else:
            # place at end
            output[i, t - length:] = item

    return output


def get_onehot(data_list, categories) -> torch.Tensor:
    """Transform lists of label into one-hot.

    Args:
        data_list (list of list of int): source data.
        categories (int): #label class.

    Returns:
        torch.Tensor: one-hot labels.

    """
    onehot_labels = []
    for label_list in data_list:
        onehot_label = torch.zeros(categories)
        for label in label_list:
            onehot_label[label] = 1.0 / len(label_list)
        onehot_labels.append(onehot_label)
    return torch.stack(onehot_labels, dim=0)


def add_start_end_token_idx(vec: list, start_token_idx: int = None, end_token_idx: int = None):
    """Can choose to add start token in the beginning and end token in the end.

    Args:
        vec: source list composed of indexes.
        start_token_idx: index of start token.
        end_token_idx: index of end token.

    Returns:
        list: list added start or end token index.

    """
    res = copy(vec)
    if start_token_idx:
        res.insert(0, start_token_idx)
    if end_token_idx:
        res.append(end_token_idx)
    return res


def truncate(vec, max_length, truncate_tail=True):
    """truncate vec to make its length no more than max length.

    Args:
        vec (list): source list.
        max_length (int)
        truncate_tail (bool, optional): Defaults to True.

    Returns:
        list: truncated vec.

    """
    if max_length is None:
        return vec
    if len(vec) <= max_length:
        return vec
    if max_length == 0:
        return []
    if truncate_tail:
        return vec[:max_length]
    else:
        return vec[-max_length:]


def merge_utt(conversation, split_token_idx=None, keep_split_in_tail=False, final_token_idx=None):
    """merge utterances in one conversation.

    Args:
        conversation (list of list of int): conversation consist of utterances consist of tokens.
        split_token_idx (int): index of split token. Defaults to None.
        keep_split_in_tail (bool): split in tail or head. Defaults to False.
        final_token_idx (int): index of final token. Defaults to None.

    Returns:
        list: tokens of all utterances in one list.

    """
    merged_conv = []
    for utt in conversation:
        for token in utt:
            merged_conv.append(token)
        if split_token_idx:
            merged_conv.append(split_token_idx)
    if split_token_idx and not keep_split_in_tail:
        merged_conv = merged_conv[:-1]
    if final_token_idx:
        merged_conv.append(final_token_idx)
    return merged_conv

def merge_utt_replace(conversation,detect_token=None,replace_token=None,method="in"):
    if method == 'in': 
        replaced_conv = []
        for utt in conversation:
            for token in utt:
                if detect_token in token:
                    replaced_conv.append(replace_token)
                else:
                    replaced_conv.append(token)
        return replaced_conv
    else:
        return [token.replace(detect_token,replace_token) for utt in conversation for token in utt]


================================================
FILE: crslab/data/dataset/__init__.py
================================================
from .base import BaseDataset
from .durecdial import DuRecDialDataset
from .gorecdial import GoRecDialDataset
from .inspired import InspiredDataset
from .opendialkg import OpenDialKGDataset
from .redial import ReDialDataset
from .tgredial import TGReDialDataset


================================================
FILE: crslab/data/dataset/base.py
================================================
# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/23, 2020/12/13
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

import os
import pickle as pkl
from abc import ABC, abstractmethod

import numpy as np
from loguru import logger

from crslab.download import build


class BaseDataset(ABC):
    """Abstract class of dataset

    Notes:
        ``'embedding'`` can be specified in config to use pretrained word embedding.

    """

    def __init__(self, opt, dpath, resource, restore=False, save=False):
        """Download resource, load, process data. Support restore and save processed dataset.

        Args:
            opt (Config or dict): config for dataset or the whole system.
            dpath (str): where to store dataset.
            resource (dict): version, download file and special token idx of tokenized dataset.
            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
            save (bool): whether to save dataset after processing. Defaults to False.

        """
        self.opt = opt
        self.dpath = dpath

        # download
        dfile = resource['file']
        build(dpath, dfile, version=resource['version'])

        if not restore:
            # load and process
            train_data, valid_data, test_data, self.vocab = self._load_data()
            logger.info('[Finish data load]')
            self.train_data, self.valid_data, self.test_data, self.side_data = self._data_preprocess(train_data,
                                                                                                     valid_data,
                                                                                                     test_data)
            embedding = opt.get('embedding', None)
            if embedding:
                self.side_data["embedding"] = np.load(os.path.join(self.dpath, embedding))
                logger.debug(f'[Load pretrained embedding {embedding}]')
            logger.info('[Finish data preprocess]')
        else:
            self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab = self._load_from_restore()

        if save:
            data = (self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab)
            self._save_to_one(data)

    @abstractmethod
    def _load_data(self):
        """Load dataset.

        Returns:
            (any, any, any, dict):

            raw train, valid and test data.

            vocab: all kinds of useful size, idx and map between token and idx.

        """
        pass

    @abstractmethod
    def _data_preprocess(self, train_data, valid_data, test_data):
        """Process raw train, valid, test data.

        Args:
            train_data: train dataset.
            valid_data: valid dataset.
            test_data: test dataset.

        Returns:
            (list of dict, dict):

            train/valid/test_data, each dict is in the following format::

                 {
                    'role' (str):
                        'Seeker' or 'Recommender',
                    'user_profile' (list of list of int):
                        id of tokens of sentences of user profile,
                    'context_tokens' (list of list int):
                        token ids of preprocessed contextual dialogs,
                    'response' (list of int):
                        token ids of the ground-truth response,
                    'interaction_history' (list of int):
                        id of items which have interaction of the user in current turn,
                    'context_items' (list of int):
                        item ids mentioned in context,
                    'items' (list of int):
                        item ids mentioned in current turn, we only keep
                        those in entity kg for comparison,
                    'context_entities' (list of int):
                        if necessary, id of entities in context,
                    'context_words' (list of int):
                        if necessary, id of words in context,
                    'context_policy' (list of list of list):
                        policy of each context turn, one turn may have several policies,
                        where first is action and second is keyword,
                    'target' (list): policy of current turn,
                    'final' (list): final goal for current turn
                }

            side_data, which is in the following format::

                {
                    'entity_kg': {
                        'edge' (list of tuple): (head_entity_id, tail_entity_id, relation_id),
                        'n_relation' (int): number of distinct relations,
                        'entity' (list of str): str of entities, used for entity linking
                    }
                    'word_kg': {
                        'edge' (list of tuple): (head_entity_id, tail_entity_id),
                        'entity' (list of str): str of entities, used for entity linking
                    }
                    'item_entity_ids' (list of int): entity id of each item;
                }

        """
        pass

    def _load_from_restore(self, file_name="all_data.pkl"):
        """Restore saved dataset.

        Args:
            file_name (str): file of saved dataset. Defaults to "all_data.pkl".

        """
        if not os.path.exists(os.path.join(self.dpath, file_name)):
            raise ValueError(f'Saved dataset [{file_name}] does not exist')
        with open(os.path.join(self.dpath, file_name), 'rb') as f:
            dataset = pkl.load(f)
        logger.info(f'Restore dataset from [{file_name}]')
        return dataset

    def _save_to_one(self, data, file_name="all_data.pkl"):
        """Save all processed dataset and vocab into one file.

        Args:
            data (tuple): all dataset and vocab.
            file_name (str, optional): file to save dataset. Defaults to "all_data.pkl".

        """
        if not os.path.exists(self.dpath):
            os.makedirs(self.dpath)
        save_path = os.path.join(self.dpath, file_name)
        with open(save_path, 'wb') as f:
            pkl.dump(data, f)
        logger.info(f'[Save dataset to {file_name}]')


================================================
FILE: crslab/data/dataset/durecdial/__init__.py
================================================
from .durecdial import DuRecDialDataset


================================================
FILE: crslab/data/dataset/durecdial/durecdial.py
================================================
# @Time   : 2020/12/21
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/12/21, 2021/1/2
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

r"""
DuRecDial
=========
References:
    Liu, Zeming, et al. `"Towards Conversational Recommendation over Multi-Type Dialogs."`_ in ACL 2020.

.. _"Towards Conversational Recommendation over Multi-Type Dialogs.":
   https://www.aclweb.org/anthology/2020.acl-main.98/

"""

import json
import os
from copy import copy

from loguru import logger
from tqdm import tqdm

from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources


class DuRecDialDataset(BaseDataset):
    """

    Attributes:
        train_data: train dataset.
        valid_data: valid dataset.
        test_data: test dataset.
        vocab (dict): ::

            {
                'tok2ind': map from token to index,
                'ind2tok': map from index to token,
                'entity2id': map from entity to index,
                'id2entity': map from index to entity,
                'word2id': map from word to index,
                'vocab_size': len(self.tok2ind),
                'n_entity': max(self.entity2id.values()) + 1,
                'n_word': max(self.word2id.values()) + 1,
            }

    Notes:
        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.

    """

    def __init__(self, opt, tokenize, restore=False, save=False):
        """

        Args:
            opt (Config or dict): config for dataset or the whole system.
            tokenize (str): how to tokenize dataset.
            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
            save (bool): whether to save dataset after processing. Defaults to False.

        """
        resource = resources[tokenize]
        self.special_token_idx = resource['special_token_idx']
        self.unk_token_idx = self.special_token_idx['unk']
        dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize)
        super().__init__(opt, dpath, resource, restore, save)

    def _load_data(self):
        train_data, valid_data, test_data = self._load_raw_data()
        self._load_vocab()
        self._load_other_data()

        vocab = {
            'tok2ind': self.tok2ind,
            'ind2tok': self.ind2tok,
            'entity2id': self.entity2id,
            'id2entity': self.id2entity,
            'word2id': self.word2id,
            'vocab_size': len(self.tok2ind),
            'n_entity': self.n_entity,
            'n_word': self.n_word,
        }
        vocab.update(self.special_token_idx)

        return train_data, valid_data, test_data, vocab

    def _load_raw_data(self):
        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
            train_data = json.load(f)
            logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
            valid_data = json.load(f)
            logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
            test_data = json.load(f)
            logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")

        return train_data, valid_data, test_data

    def _load_vocab(self):
        self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}

        logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
        logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
        logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")

    def _load_other_data(self):
        # entity kg
        with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f:
            self.entity2id = json.load(f)  # {entity: entity_id}
        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
        self.n_entity = max(self.entity2id.values()) + 1
        # {head_entity_id: [(relation_id, tail_entity_id)]}
        self.entity_kg = open(os.path.join(self.dpath, 'entity_subkg.txt'), encoding='utf-8')
        logger.debug(
            f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]")

        # hownet
        # {concept: concept_id}
        with open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8') as f:
            self.word2id = json.load(f)
        self.n_word = max(self.word2id.values()) + 1
        # {concept \t relation\t concept}
        self.word_kg = open(os.path.join(self.dpath, 'hownet_subkg.txt'), encoding='utf-8')
        logger.debug(
            f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet_subkg.txt')}]")

    def _data_preprocess(self, train_data, valid_data, test_data):
        processed_train_data = self._raw_data_process(train_data)
        logger.debug("[Finish train data process]")
        processed_valid_data = self._raw_data_process(valid_data)
        logger.debug("[Finish valid data process]")
        processed_test_data = self._raw_data_process(test_data)
        logger.debug("[Finish test data process]")
        processed_side_data = self._side_data_process()
        logger.debug("[Finish side data process]")
        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data

    def _raw_data_process(self, raw_data):
        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]
        augmented_conv_dicts = []
        for conv in tqdm(augmented_convs):
            augmented_conv_dicts.extend(self._augment_and_add(conv))
        return augmented_conv_dicts

    def _convert_to_id(self, conversation):
        augmented_convs = []
        last_role = None
        for utt in conversation['dialog']:
            assert utt['role'] != last_role, print(utt)

            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
            item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id]
            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]

            augmented_convs.append({
                "role": utt["role"],
                "text": text_token_ids,
                "entity": entity_ids,
                "movie": item_ids,
                "word": word_ids
            })
            last_role = utt["role"]

        return augmented_convs

    def _augment_and_add(self, raw_conv_dict):
        augmented_conv_dicts = []
        context_tokens, context_entities, context_words, context_items = [], [], [], []
        entity_set, word_set = set(), set()
        for i, conv in enumerate(raw_conv_dict):
            text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"]
            if len(context_tokens) > 0:
                conv_dict = {
                    'role': conv['role'],
                    "context_tokens": copy(context_tokens),
                    "response": text_tokens,
                    "context_entities": copy(context_entities),
                    "context_words": copy(context_words),
                    'context_items': copy(context_items),
                    "items": movies
                }
                augmented_conv_dicts.append(conv_dict)

            context_tokens.append(text_tokens)
            context_items += movies
            for entity in entities + movies:
                if entity not in entity_set:
                    entity_set.add(entity)
                    context_entities.append(entity)
            for word in words:
                if word not in word_set:
                    word_set.add(word)
                    context_words.append(word)

        return augmented_conv_dicts

    def _side_data_process(self):
        processed_entity_kg = self._entity_kg_process()
        logger.debug("[Finish entity KG process]")
        processed_word_kg = self._word_kg_process()
        logger.debug("[Finish word KG process]")
        with open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8') as f:
            item_entity_ids = json.load(f)
        logger.debug('[Load movie entity ids]')

        side_data = {
            "entity_kg": processed_entity_kg,
            "word_kg": processed_word_kg,
            "item_entity_ids": item_entity_ids,
        }
        return side_data

    def _entity_kg_process(self):
        edge_list = []  # [(entity, entity, relation)]
        for line in self.entity_kg:
            triple = line.strip().split('\t')
            e0 = self.entity2id[triple[0]]
            e1 = self.entity2id[triple[2]]
            r = triple[1]
            edge_list.append((e0, e1, r))
            edge_list.append((e1, e0, r))
            edge_list.append((e0, e0, 'SELF_LOOP'))
            if e1 != e0:
                edge_list.append((e1, e1, 'SELF_LOOP'))

        relation2id, edges, entities = dict(), set(), set()
        for h, t, r in edge_list:
            if r not in relation2id:
                relation2id[r] = len(relation2id)
            edges.add((h, t, relation2id[r]))
            entities.add(self.id2entity[h])
            entities.add(self.id2entity[t])

        return {
            'edge': list(edges),
            'n_relation': len(relation2id),
            'entity': list(entities)
        }

    def _word_kg_process(self):
        edges = set()  # {(entity, entity)}
        entities = set()
        for line in self.word_kg:
            triple = line.strip().split('\t')
            entities.add(triple[0])
            entities.add(triple[2])
            e0 = self.word2id[triple[0]]
            e1 = self.word2id[triple[2]]
            edges.add((e0, e1))
            edges.add((e1, e0))
        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
        return {
            'edge': list(edges),
            'entity': list(entities)
        }


================================================
FILE: crslab/data/dataset/durecdial/resources.py
================================================
# -*- encoding: utf-8 -*-
# @Time    :   2020/12/22
# @Author  :   Xiaolei Wang
# @email   :   wxl1999@foxmail.com

# UPDATE
# @Time    :   2020/12/22
# @Author  :   Xiaolei Wang
# @email   :   wxl1999@foxmail.com

from crslab.download import DownloadableFile

resources = {
    'jieba': {
        'version': '0.3',
        'file': DownloadableFile(
            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1',
            'durecdial_jieba.zip',
            'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813',
        ),
        'special_token_idx': {
            'pad': 0,
            'start': 1,
            'end': 2,
            'unk': 3,
            'pad_entity': 0,
            'pad_word': 0,
        },
    },
    'bert': {
        'version': '0.3',
        'file': DownloadableFile(
            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1',
            'durecdial_bert.zip',
            '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5'
        ),
        'special_token_idx': {
            'pad': 0,
            'start': 101,
            'end': 102,
            'unk': 100,
            'sent_split': 2,
            'word_split': 3,
            'pad_entity': 0,
            'pad_word': 0,
            'pad_topic': 0
        },
    },
    'gpt2': {
        'version': '0.3',
        'file': DownloadableFile(
            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1',
            'durecdial_gpt2.zip',
            'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4'
        ),
        'special_token_idx': {
            'pad': 0,
            'start': 101,
            'end': 102,
            'unk': 100,
            'cls': 101,
            'sep': 102,
            'sent_split': 2,
            'word_split': 3,
            'pad_entity': 0,
            'pad_word': 0,
            'pad_topic': 0,
        },
    }
}


================================================
FILE: crslab/data/dataset/gorecdial/__init__.py
================================================
from .gorecdial import GoRecDialDataset


================================================
FILE: crslab/data/dataset/gorecdial/gorecdial.py
================================================
# @Time   : 2020/12/12
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/12/13, 2021/1/2, 2020/12/19
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail

r"""
GoRecDial
=========
References:
    Kang, Dongyeop, et al. `"Recommendation as a Communication Game: Self-Supervised Bot-Play for Goal-oriented Dialogue."`_ in EMNLP 2019.

.. _`"Recommendation as a Communication Game: Self-Supervised Bot-Play for Goal-oriented Dialogue."`:
   https://www.aclweb.org/anthology/D19-1203/

"""

import json
import os
from copy import copy

from loguru import logger
from tqdm import tqdm

from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources


class GoRecDialDataset(BaseDataset):
    """

    Attributes:
        train_data: train dataset.
        valid_data: valid dataset.
        test_data: test dataset.
        vocab (dict): ::

            {
                'tok2ind': map from token to index,
                'ind2tok': map from index to token,
                'entity2id': map from entity to index,
                'id2entity': map from index to entity,
                'word2id': map from word to index,
                'vocab_size': len(self.tok2ind),
                'n_entity': max(self.entity2id.values()) + 1,
                'n_word': max(self.word2id.values()) + 1,
            }

    Notes:
        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.

    """

    def __init__(self, opt, tokenize, restore=False, save=False):
        """Specify tokenized resource and init base dataset.

        Args:
            opt (Config or dict): config for dataset or the whole system.
            tokenize (str): how to tokenize dataset.
            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
            save (bool): whether to save dataset after processing. Defaults to False.

        """
        resource = resources[tokenize]
        self.special_token_idx = resource['special_token_idx']
        self.unk_token_idx = self.special_token_idx['unk']
        dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize)
        super().__init__(opt, dpath, resource, restore, save)

    def _load_data(self):
        train_data, valid_data, test_data = self._load_raw_data()
        self._load_vocab()
        self._load_other_data()

        vocab = {
            'tok2ind': self.tok2ind,
            'ind2tok': self.ind2tok,
            'entity2id': self.entity2id,
            'id2entity': self.id2entity,
            'word2id': self.word2id,
            'vocab_size': len(self.tok2ind),
            'n_entity': self.n_entity,
            'n_word': self.n_word,
        }
        vocab.update(self.special_token_idx)

        return train_data, valid_data, test_data, vocab

    def _load_raw_data(self):
        # load train/valid/test data
        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
            train_data = json.load(f)
            logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
            valid_data = json.load(f)
            logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
            test_data = json.load(f)
            logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")

        return train_data, valid_data, test_data

    def _load_vocab(self):
        self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}

        logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
        logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
        logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")

    def _load_other_data(self):
        # dbpedia
        self.entity2id = json.load(
            open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8'))  # {entity: entity_id}
        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
        self.n_entity = max(self.entity2id.values()) + 1
        # {head_entity_id: [(relation_id, tail_entity_id)]}
        self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8')
        logger.debug(
            f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]")

        # conceptnet
        # {concept: concept_id}
        self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))
        self.n_word = max(self.word2id.values()) + 1
        # {concept \t relation\t concept}
        self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), encoding='utf-8')
        logger.debug(
            f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]")

    def _data_preprocess(self, train_data, valid_data, test_data):
        processed_train_data = self._raw_data_process(train_data)
        logger.debug("[Finish train data process]")
        processed_valid_data = self._raw_data_process(valid_data)
        logger.debug("[Finish valid data process]")
        processed_test_data = self._raw_data_process(test_data)
        logger.debug("[Finish test data process]")
        processed_side_data = self._side_data_process()
        logger.debug("[Finish side data process]")
        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data

    def _raw_data_process(self, raw_data):
        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]
        augmented_conv_dicts = []
        for conv in tqdm(augmented_convs):
            augmented_conv_dicts.extend(self._augment_and_add(conv))
        return augmented_conv_dicts

    def _convert_to_id(self, conversation):
        augmented_convs = []
        last_role = None
        for utt in conversation['dialog']:
            assert utt['role'] != last_role

            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
            movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
            policy = utt['decide']

            augmented_convs.append({
                "role": utt["role"],
                "text": text_token_ids,
                "entity": entity_ids,
                "movie": movie_ids,
                "word": word_ids,
                'policy': policy
            })
            last_role = utt["role"]

        return augmented_convs

    def _augment_and_add(self, raw_conv_dict):
        augmented_conv_dicts = []
        context_tokens, context_entities, context_words, context_items = [], [], [], []
        entity_set, word_set = set(), set()
        for i, conv in enumerate(raw_conv_dict):
            text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \
                                                             conv['policy']
            if len(context_tokens) > 0 and len(text_tokens) > 0:
                conv_dict = {
                    'role': conv['role'],
                    "context_tokens": copy(context_tokens),
                    "response": text_tokens,
                    "context_entities": copy(context_entities),
                    "context_words": copy(context_words),
                    'context_items': copy(context_items),
                    "items": movies,
                    'policy': policies,
                }
                augmented_conv_dicts.append(conv_dict)

            if len(text_tokens) > 0:
                context_tokens.append(text_tokens)
                context_items += movies
                for entity in entities + movies:
                    if entity not in entity_set:
                        entity_set.add(entity)
                        context_entities.append(entity)
                for word in words:
                    if word not in word_set:
                        word_set.add(word)
                        context_words.append(word)

        return augmented_conv_dicts

    def _side_data_process(self):
        processed_entity_kg = self._entity_kg_process()
        logger.debug("[Finish entity KG process]")
        processed_word_kg = self._word_kg_process()
        logger.debug("[Finish word KG process]")
        movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))
        logger.debug('[Load movie entity ids]')

        side_data = {
            "entity_kg": processed_entity_kg,
            "word_kg": processed_word_kg,
            "item_entity_ids": movie_entity_ids,
        }
        return side_data

    def _entity_kg_process(self):
        edge_list = []  # [(entity, entity, relation)]
        for line in self.entity_kg:
            triple = line.strip().split('\t')
            e0 = self.entity2id[triple[0]]
            e1 = self.entity2id[triple[2]]
            r = triple[1]
            edge_list.append((e0, e1, r))
            edge_list.append((e1, e0, r))
            edge_list.append((e0, e0, 'SELF_LOOP'))
            if e1 != e0:
                edge_list.append((e1, e1, 'SELF_LOOP'))

        relation2id, edges, entities = dict(), set(), set()
        for h, t, r in edge_list:
            if r not in relation2id:
                relation2id[r] = len(relation2id)
            edges.add((h, t, relation2id[r]))
            entities.add(self.id2entity[h])
            entities.add(self.id2entity[t])

        return {
            'edge': list(edges),
            'n_relation': len(relation2id),
            'entity': list(entities)
        }

    def _word_kg_process(self):
        edges = set()  # {(entity, entity)}
        entities = set()
        for line in self.word_kg:
            triple = line.strip().split('\t')
            entities.add(triple[0])
            entities.add(triple[2])
            e0 = self.word2id[triple[0]]
            e1 = self.word2id[triple[2]]
            edges.add((e0, e1))
            edges.add((e1, e0))
        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
        return {
            'edge': list(edges),
            'entity': list(entities)
        }


================================================
FILE: crslab/data/dataset/gorecdial/resources.py
================================================
# -*- encoding: utf-8 -*-
# @Time    :   2020/12/14
# @Author  :   Xiaolei Wang
# @email   :   wxl1999@foxmail.com

# UPDATE
# @Time    :   2020/12/22
# @Author  :   Xiaolei Wang
# @email   :   wxl1999@foxmail.com

from crslab.download import DownloadableFile

resources = {
    'nltk': {
        'version': '0.31',
        'file': DownloadableFile(
            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESIqjwAg0ItAu7WGfukIt3cBXjzi7AZ9L_lcbFT1aS1qYQ?download=1',
            'gorecdial_nltk.zip',
            '58cd368f8f83c0c8555becc314a0017990545f71aefb7e93a52581c97d1b8e9b',
        ),
        'special_token_idx': {
            'pad': 0,
            'start': 1,
            'end': 2,
            'unk': 3,
            'pad_entity': 0,
            'pad_word': 0,
            'pad_topic': 0
        },
    },
    'bert': {
        'version': '0.31',
        'file': DownloadableFile(
            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed1HT8gzvRpDosVT83BEj5QBnzKpjR3Zbf5u49yyWP-k6Q?download=1',
            'gorecdial_bert.zip',
            '4fa10c3fe8ba538af0f393c99892739fcb376d832616aa7028334c594b3fec10'
        ),
        'special_token_idx': {
            'pad': 0,
            'start': 101,
            'end': 102,
            'unk': 100,
            'sent_split': 2,
            'word_split': 3,
            'pad_entity': 0,
            'pad_word': 0,
            'pad_topic': 0
        }
    },
    'gpt2': {
        'version': '0.31',
        'file': DownloadableFile(
            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EUJOHmX8v79DkZMq0x5r9d4B0UJlfw85v-VdciwKfAhpng?download=1',
            'gorecdial_gpt2.zip',
            '44a15637e014b2e6628102ff654e1aef7ec1cbfa34b7ada1a03f294f72ddd4b1'
        ),
        'special_token_idx': {
            'pad': 0,
            'start': 1,
            'end': 2,
            'unk': 3,
            'sent_split': 4,
            'word_split': 5,
            'pad_entity': 0,
            'pad_word': 0
        },
    }
}


================================================
FILE: crslab/data/dataset/inspired/__init__.py
================================================
from .inspired import InspiredDataset


================================================
FILE: crslab/data/dataset/inspired/inspired.py
================================================
# @Time   : 2020/12/19
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/12/20, 2021/1/2
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

r"""
Inspired
========
References:
    Hayati, Shirley Anugrah, et al. `"INSPIRED: Toward Sociable Recommendation Dialog Systems."`_ in EMNLP 2020.

.. _`"INSPIRED: Toward Sociable Recommendation Dialog Systems."`:
   https://www.aclweb.org/anthology/2020.emnlp-main.654/

"""

import json
import os
from copy import copy

from loguru import logger
from tqdm import tqdm

from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources


class InspiredDataset(BaseDataset):
    """

    Attributes:
        train_data: train dataset.
        valid_data: valid dataset.
        test_data: test dataset.
        vocab (dict): ::

            {
                'tok2ind': map from token to index,
                'ind2tok': map from index to token,
                'entity2id': map from entity to index,
                'id2entity': map from index to entity,
                'word2id': map from word to index,
                'vocab_size': len(self.tok2ind),
                'n_entity': max(self.entity2id.values()) + 1,
                'n_word': max(self.word2id.values()) + 1,
            }

    Notes:
        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.

    """

    def __init__(self, opt, tokenize, restore=False, save=False):
        """Specify tokenized resource and init base dataset.

        Args:
            opt (Config or dict): config for dataset or the whole system.
            tokenize (str): how to tokenize dataset.
            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
            save (bool): whether to save dataset after processing. Defaults to False.

        """
        resource = resources[tokenize]
        self.special_token_idx = resource['special_token_idx']
        self.unk_token_idx = self.special_token_idx['unk']
        dpath = os.path.join(DATASET_PATH, 'inspired', tokenize)
        super().__init__(opt, dpath, resource, restore, save)

    def _load_data(self):
        train_data, valid_data, test_data = self._load_raw_data()
        self._load_vocab()
        self._load_other_data()

        vocab = {
            'tok2ind': self.tok2ind,
            'ind2tok': self.ind2tok,
            'entity2id': self.entity2id,
            'id2entity': self.id2entity,
            'word2id': self.word2id,
            'vocab_size': len(self.tok2ind),
            'n_entity': self.n_entity,
            'n_word': self.n_word,
        }
        vocab.update(self.special_token_idx)

        return train_data, valid_data, test_data, vocab

    def _load_raw_data(self):
        # load train/valid/test data
        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
            train_data = json.load(f)
            logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
            valid_data = json.load(f)
            logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
            test_data = json.load(f)
            logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")

        return train_data, valid_data, test_data

    def _load_vocab(self):
        with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f:
            self.tok2ind = json.load(f)
        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}

        logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
        logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
        logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")

    def _load_other_data(self):
        # dbpedia
        with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f:
            self.entity2id = json.load(f)  # {entity: entity_id}
        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
        self.n_entity = max(self.entity2id.values()) + 1
        # {head_entity_id: [(relation_id, tail_entity_id)]}
        self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8')
        logger.debug(
            f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]")

        # conceptnet
        # {concept: concept_id}
        with open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8') as f:
            self.word2id = json.load(f)
        self.n_word = max(self.word2id.values()) + 1
        # {concept \t relation\t concept}
        self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8')
        logger.debug(
            f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]")

    def _data_preprocess(self, train_data, valid_data, test_data):
        processed_train_data = self._raw_data_process(train_data)
        logger.debug("[Finish train data process]")
        processed_valid_data = self._raw_data_process(valid_data)
        logger.debug("[Finish valid data process]")
        processed_test_data = self._raw_data_process(test_data)
        logger.debug("[Finish test data process]")
        processed_side_data = self._side_data_process()
        logger.debug("[Finish side data process]")
        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data

    def _raw_data_process(self, raw_data):
        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]
        augmented_conv_dicts = []
        for conv in tqdm(augmented_convs):
            augmented_conv_dicts.extend(self._augment_and_add(conv))
        return augmented_conv_dicts

    def _convert_to_id(self, conversation):
        augmented_convs = []
        last_role = None
        for utt in conversation['dialog']:
            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
            movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]

            if utt["role"] == last_role:
                augmented_convs[-1]["text"] += text_token_ids
                augmented_convs[-1]["movie"] += movie_ids
                augmented_convs[-1]["entity"] += entity_ids
                augmented_convs[-1]["word"] += word_ids
            else:
                augmented_convs.append({
                    "role": utt["role"],
                    "text": text_token_ids,
                    "entity": entity_ids,
                    "movie": movie_ids,
                    "word": word_ids
                })
            last_role = utt["role"]

        return augmented_convs

    def _augment_and_add(self, raw_conv_dict):
        augmented_conv_dicts = []
        context_tokens, context_entities, context_words, context_items = [], [], [], []
        entity_set, word_set = set(), set()
        for i, conv in enumerate(raw_conv_dict):
            text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"]
            if len(context_tokens) > 0:
                conv_dict = {
                    'role': conv['role'],
                    "context_tokens": copy(context_tokens),
                    "response": text_tokens,
                    "context_entities": copy(context_entities),
                    "context_words": copy(context_words),
                    'context_items': copy(context_items),
                    "items": movies,
                }
                augmented_conv_dicts.append(conv_dict)

            context_tokens.append(text_tokens)
            context_items += movies
            for entity in entities + movies:
                if entity not in entity_set:
                    entity_set.add(entity)
                    context_entities.append(entity)
            for word in words:
                if word not in word_set:
                    word_set.add(word)
                    context_words.append(word)

        return augmented_conv_dicts

    def _side_data_process(self):
        processed_entity_kg = self._entity_kg_process()
        logger.debug("[Finish entity KG process]")
        processed_word_kg = self._word_kg_process()
        logger.debug("[Finish word KG process]")
        with open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8') as f:
            movie_entity_ids = json.load(f)
        logger.debug('[Load movie entity ids]')

        side_data = {
            "entity_kg": processed_entity_kg,
            "word_kg": processed_word_kg,
Download .txt
gitextract_o03n7yor/

├── .gitattributes
├── .gitignore
├── .readthedocs.yml
├── LICENSE
├── README.md
├── README_CN.md
├── config/
│   ├── conversation/
│   │   ├── gpt2/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   └── transformer/
│   │       ├── durecdial.yaml
│   │       ├── gorecdial.yaml
│   │       ├── inspired.yaml
│   │       ├── opendialkg.yaml
│   │       ├── redial.yaml
│   │       └── tgredial.yaml
│   ├── crs/
│   │   ├── inspired/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   ├── kbrd/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   ├── kgsf/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   ├── ntrd/
│   │   │   └── tgredial.yaml
│   │   ├── redial/
│   │   │   ├── durecdial.yaml
│   │   │   ├── gorecdial.yaml
│   │   │   ├── inspired.yaml
│   │   │   ├── opendialkg.yaml
│   │   │   ├── redial.yaml
│   │   │   └── tgredial.yaml
│   │   └── tgredial/
│   │       ├── durecdial.yaml
│   │       ├── gorecdial.yaml
│   │       ├── inspired.yaml
│   │       ├── opendialkg.yaml
│   │       ├── redial.yaml
│   │       └── tgredial.yaml
│   ├── policy/
│   │   ├── conv_bert/
│   │   │   └── tgredial.yaml
│   │   ├── mgcg/
│   │   │   └── tgredial.yaml
│   │   ├── pmi/
│   │   │   └── tgredial.yaml
│   │   ├── profile_bert/
│   │   │   └── tgredial.yaml
│   │   └── topic_bert/
│   │       └── tgredial.yaml
│   └── recommendation/
│       ├── bert/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       ├── gru4rec/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       ├── popularity/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       ├── sasrec/
│       │   ├── durecdial.yaml
│       │   ├── gorecdial.yaml
│       │   ├── inspired.yaml
│       │   ├── opendialkg.yaml
│       │   ├── redial.yaml
│       │   └── tgredial.yaml
│       └── textcnn/
│           ├── durecdial.yaml
│           ├── gorecdial.yaml
│           ├── inspired.yaml
│           ├── opendialkg.yaml
│           ├── redial.yaml
│           └── tgredial.yaml
├── crslab/
│   ├── __init__.py
│   ├── config/
│   │   ├── __init__.py
│   │   └── config.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataloader/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── inspired.py
│   │   │   ├── kbrd.py
│   │   │   ├── kgsf.py
│   │   │   ├── ntrd.py
│   │   │   ├── redial.py
│   │   │   ├── tgredial.py
│   │   │   └── utils.py
│   │   └── dataset/
│   │       ├── __init__.py
│   │       ├── base.py
│   │       ├── durecdial/
│   │       │   ├── __init__.py
│   │       │   ├── durecdial.py
│   │       │   └── resources.py
│   │       ├── gorecdial/
│   │       │   ├── __init__.py
│   │       │   ├── gorecdial.py
│   │       │   └── resources.py
│   │       ├── inspired/
│   │       │   ├── __init__.py
│   │       │   ├── inspired.py
│   │       │   └── resources.py
│   │       ├── opendialkg/
│   │       │   ├── __init__.py
│   │       │   ├── opendialkg.py
│   │       │   └── resources.py
│   │       ├── redial/
│   │       │   ├── __init__.py
│   │       │   ├── redial.py
│   │       │   └── resources.py
│   │       └── tgredial/
│   │           ├── __init__.py
│   │           ├── resources.py
│   │           └── tgredial.py
│   ├── download.py
│   ├── evaluator/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── conv.py
│   │   ├── embeddings.py
│   │   ├── end2end.py
│   │   ├── metrics/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── gen.py
│   │   │   └── rec.py
│   │   ├── rec.py
│   │   ├── standard.py
│   │   └── utils.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── conversation/
│   │   │   ├── __init__.py
│   │   │   ├── gpt2/
│   │   │   │   ├── __init__.py
│   │   │   │   └── gpt2.py
│   │   │   └── transformer/
│   │   │       ├── __init__.py
│   │   │       └── transformer.py
│   │   ├── crs/
│   │   │   ├── __init__.py
│   │   │   ├── inspired/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── inspired_conv.py
│   │   │   │   ├── inspired_rec.py
│   │   │   │   └── modules.py
│   │   │   ├── kbrd/
│   │   │   │   ├── __init__.py
│   │   │   │   └── kbrd.py
│   │   │   ├── kgsf/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── kgsf.py
│   │   │   │   ├── modules.py
│   │   │   │   └── resources.py
│   │   │   ├── ntrd/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── modules.py
│   │   │   │   ├── ntrd.py
│   │   │   │   └── resources.py
│   │   │   ├── redial/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── modules.py
│   │   │   │   ├── redial_conv.py
│   │   │   │   └── redial_rec.py
│   │   │   └── tgredial/
│   │   │       ├── __init__.py
│   │   │       ├── tg_conv.py
│   │   │       ├── tg_policy.py
│   │   │       └── tg_rec.py
│   │   ├── policy/
│   │   │   ├── __init__.py
│   │   │   ├── conv_bert/
│   │   │   │   ├── __init__.py
│   │   │   │   └── conv_bert.py
│   │   │   ├── mgcg/
│   │   │   │   ├── __init__.py
│   │   │   │   └── mgcg.py
│   │   │   ├── pmi/
│   │   │   │   ├── __init__.py
│   │   │   │   └── pmi.py
│   │   │   ├── profile_bert/
│   │   │   │   ├── __init__.py
│   │   │   │   └── profile_bert.py
│   │   │   └── topic_bert/
│   │   │       ├── __init__.py
│   │   │       └── topic_bert.py
│   │   ├── pretrained_models.py
│   │   ├── recommendation/
│   │   │   ├── __init__.py
│   │   │   ├── bert/
│   │   │   │   ├── __init__.py
│   │   │   │   └── bert.py
│   │   │   ├── gru4rec/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── gru4rec.py
│   │   │   │   └── modules.py
│   │   │   ├── popularity/
│   │   │   │   ├── __init__.py
│   │   │   │   └── popularity.py
│   │   │   ├── sasrec/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── modules.py
│   │   │   │   └── sasrec.py
│   │   │   └── textcnn/
│   │   │       ├── __init__.py
│   │   │       └── textcnn.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── functions.py
│   │       └── modules/
│   │           ├── __init__.py
│   │           ├── attention.py
│   │           └── transformer.py
│   ├── quick_start/
│   │   ├── __init__.py
│   │   └── quick_start.py
│   └── system/
│       ├── __init__.py
│       ├── base.py
│       ├── inspired.py
│       ├── kbrd.py
│       ├── kgsf.py
│       ├── ntrd.py
│       ├── redial.py
│       ├── tgredial.py
│       └── utils/
│           ├── __init__.py
│           ├── functions.py
│           └── lr_scheduler.py
├── docs/
│   ├── Makefile
│   ├── make.bat
│   ├── requirements.txt
│   ├── requirements_geometric.txt
│   ├── requirements_sphinx.txt
│   ├── requirements_torch.txt
│   └── source/
│       ├── api/
│       │   ├── crslab.config.rst
│       │   ├── crslab.data.dataloader.rst
│       │   ├── crslab.data.dataset.durecdial.rst
│       │   ├── crslab.data.dataset.gorecdial.rst
│       │   ├── crslab.data.dataset.inspired.rst
│       │   ├── crslab.data.dataset.opendialkg.rst
│       │   ├── crslab.data.dataset.redial.rst
│       │   ├── crslab.data.dataset.rst
│       │   ├── crslab.data.dataset.tgredial.rst
│       │   ├── crslab.data.rst
│       │   ├── crslab.evaluator.metrics.rst
│       │   ├── crslab.evaluator.rst
│       │   ├── crslab.model.conversation.gpt2.rst
│       │   ├── crslab.model.conversation.rst
│       │   ├── crslab.model.conversation.transformer.rst
│       │   ├── crslab.model.crs.kbrd.rst
│       │   ├── crslab.model.crs.kgsf.rst
│       │   ├── crslab.model.crs.redial.rst
│       │   ├── crslab.model.crs.rst
│       │   ├── crslab.model.crs.tgredial.rst
│       │   ├── crslab.model.policy.conv_bert.rst
│       │   ├── crslab.model.policy.mgcg.rst
│       │   ├── crslab.model.policy.pmi.rst
│       │   ├── crslab.model.policy.profile_bert.rst
│       │   ├── crslab.model.policy.rst
│       │   ├── crslab.model.policy.topic_bert.rst
│       │   ├── crslab.model.recommendation.bert.rst
│       │   ├── crslab.model.recommendation.gru4rec.rst
│       │   ├── crslab.model.recommendation.popularity.rst
│       │   ├── crslab.model.recommendation.rst
│       │   ├── crslab.model.recommendation.sasrec.rst
│       │   ├── crslab.model.recommendation.textcnn.rst
│       │   ├── crslab.model.rst
│       │   ├── crslab.model.utils.modules.rst
│       │   ├── crslab.model.utils.rst
│       │   ├── crslab.quick_start.rst
│       │   ├── crslab.rst
│       │   ├── crslab.system.rst
│       │   └── modules.rst
│       ├── conf.py
│       └── index.md
├── requirements.txt
├── run_crslab.py
└── setup.py
Download .txt
SYMBOL INDEX (674 symbols across 72 files)

FILE: crslab/config/config.py
  class Config (line 21) | class Config:
    method __init__ (line 24) | def __init__(self, config_file, gpu='-1', debug=False):
    method load_yaml_configs (line 91) | def load_yaml_configs(filename):
    method __setitem__ (line 106) | def __setitem__(self, key, value):
    method __getitem__ (line 111) | def __getitem__(self, item):
    method get (line 117) | def get(self, item, default=None):
    method __contains__ (line 133) | def __contains__(self, key):
    method __str__ (line 138) | def __str__(self):
    method __repr__ (line 141) | def __repr__(self):

FILE: crslab/data/__init__.py
  function get_dataset (line 73) | def get_dataset(opt, tokenize, restore, save) -> BaseDataset:
  function get_dataloader (line 93) | def get_dataloader(opt, dataset, vocab) -> BaseDataLoader:

FILE: crslab/data/dataloader/base.py
  class BaseDataLoader (line 18) | class BaseDataLoader(ABC):
    method __init__ (line 26) | def __init__(self, opt, dataset):
    method get_data (line 38) | def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None):
    method get_conv_data (line 72) | def get_conv_data(self, batch_size, shuffle=True):
    method get_rec_data (line 87) | def get_rec_data(self, batch_size, shuffle=True):
    method get_policy_data (line 102) | def get_policy_data(self, batch_size, shuffle=True):
    method conv_process_fn (line 117) | def conv_process_fn(self):
    method conv_batchify (line 126) | def conv_batchify(self, batch):
    method rec_process_fn (line 137) | def rec_process_fn(self):
    method rec_batchify (line 146) | def rec_batchify(self, batch):
    method policy_process_fn (line 157) | def policy_process_fn(self):
    method policy_batchify (line 166) | def policy_batchify(self, batch):
    method retain_recommender_target (line 177) | def retain_recommender_target(self):
    method rec_interact (line 190) | def rec_interact(self, data):
    method conv_interact (line 201) | def conv_interact(self, data):

FILE: crslab/data/dataloader/inspired.py
  class InspiredDataLoader (line 14) | class InspiredDataLoader(BaseDataLoader):
    method __init__ (line 47) | def __init__(self, opt, dataset, vocab):
    method rec_process_fn (line 81) | def rec_process_fn(self, *args, **kwargs):
    method _process_rec_context (line 91) | def _process_rec_context(self, context_tokens):
    method rec_batchify (line 105) | def rec_batchify(self, batch):
    method conv_batchify (line 123) | def conv_batchify(self, batch):
    method policy_batchify (line 161) | def policy_batchify(self, batch):

FILE: crslab/data/dataloader/kbrd.py
  class KBRDDataLoader (line 17) | class KBRDDataLoader(BaseDataLoader):
    method __init__ (line 38) | def __init__(self, opt, dataset, vocab):
    method rec_process_fn (line 56) | def rec_process_fn(self):
    method rec_batchify (line 65) | def rec_batchify(self, batch):
    method conv_process_fn (line 77) | def conv_process_fn(self, *args, **kwargs):
    method conv_batchify (line 80) | def conv_batchify(self, batch):
    method policy_batchify (line 99) | def policy_batchify(self, *args, **kwargs):

FILE: crslab/data/dataloader/kgsf.py
  class KGSFDataLoader (line 19) | class KGSFDataLoader(BaseDataLoader):
    method __init__ (line 44) | def __init__(self, opt, dataset, vocab):
    method get_pretrain_data (line 65) | def get_pretrain_data(self, batch_size, shuffle=True):
    method pretrain_batchify (line 68) | def pretrain_batchify(self, batch):
    method rec_process_fn (line 79) | def rec_process_fn(self):
    method rec_batchify (line 89) | def rec_batchify(self, batch):
    method conv_process_fn (line 104) | def conv_process_fn(self, *args, **kwargs):
    method conv_batchify (line 107) | def conv_batchify(self, batch):
    method policy_batchify (line 128) | def policy_batchify(self, *args, **kwargs):

FILE: crslab/data/dataloader/ntrd.py
  class NTRDDataLoader (line 14) | class NTRDDataLoader(BaseDataLoader):
    method __init__ (line 15) | def __init__(self, opt, dataset, vocab):
    method get_pretrain_data (line 38) | def get_pretrain_data(self, batch_size, shuffle=True):
    method pretrain_batchify (line 41) | def pretrain_batchify(self, batch):
    method rec_process_fn (line 52) | def rec_process_fn(self):
    method rec_batchify (line 62) | def rec_batchify(self, batch):
    method conv_process_fn (line 77) | def conv_process_fn(self, *args, **kwargs):
    method conv_batchify (line 80) | def conv_batchify(self, batch):
    method policy_batchify (line 115) | def policy_batchify(self, *args, **kwargs):

FILE: crslab/data/dataloader/redial.py
  class ReDialDataLoader (line 22) | class ReDialDataLoader(BaseDataLoader):
    method __init__ (line 46) | def __init__(self, opt, dataset, vocab):
    method rec_process_fn (line 66) | def rec_process_fn(self, *args, **kwargs):
    method rec_batchify (line 75) | def rec_batchify(self, batch):
    method conv_process_fn (line 84) | def conv_process_fn(self):
    method conv_batchify (line 100) | def conv_batchify(self, batch):
    method policy_batchify (line 145) | def policy_batchify(self, batch):

FILE: crslab/data/dataloader/tgredial.py
  class TGReDialDataLoader (line 20) | class TGReDialDataLoader(BaseDataLoader):
    method __init__ (line 55) | def __init__(self, opt, dataset, vocab):
    method rec_process_fn (line 101) | def rec_process_fn(self, *args, **kwargs):
    method _process_rec_context (line 110) | def _process_rec_context(self, context_tokens):
    method _neg_sample (line 124) | def _neg_sample(self, item_set):
    method _process_history (line 130) | def _process_history(self, context_items, item_id=None):
    method rec_batchify (line 146) | def rec_batchify(self, batch):
    method rec_interact (line 198) | def rec_interact(self, data):
    method conv_batchify (line 228) | def conv_batchify(self, batch):
    method conv_interact (line 316) | def conv_interact(self, data):
    method policy_process_fn (line 336) | def policy_process_fn(self, *args, **kwargs):
    method policy_batchify (line 347) | def policy_batchify(self, batch):

FILE: crslab/data/dataloader/utils.py
  function padded_tensor (line 23) | def padded_tensor(
  function get_onehot (line 80) | def get_onehot(data_list, categories) -> torch.Tensor:
  function add_start_end_token_idx (line 100) | def add_start_end_token_idx(vec: list, start_token_idx: int = None, end_...
  function truncate (line 120) | def truncate(vec, max_length, truncate_tail=True):
  function merge_utt (line 144) | def merge_utt(conversation, split_token_idx=None, keep_split_in_tail=Fal...
  function merge_utt_replace (line 169) | def merge_utt_replace(conversation,detect_token=None,replace_token=None,...

FILE: crslab/data/dataset/base.py
  class BaseDataset (line 20) | class BaseDataset(ABC):
    method __init__ (line 28) | def __init__(self, opt, dpath, resource, restore=False, save=False):
    method _load_data (line 66) | def _load_data(self):
    method _data_preprocess (line 80) | def _data_preprocess(self, train_data, valid_data, test_data):
    method _load_from_restore (line 138) | def _load_from_restore(self, file_name="all_data.pkl"):
    method _save_to_one (line 152) | def _save_to_one(self, data, file_name="all_data.pkl"):

FILE: crslab/data/dataset/durecdial/durecdial.py
  class DuRecDialDataset (line 33) | class DuRecDialDataset(BaseDataset):
    method __init__ (line 58) | def __init__(self, opt, tokenize, restore=False, save=False):
    method _load_data (line 74) | def _load_data(self):
    method _load_raw_data (line 93) | def _load_raw_data(self):
    method _load_vocab (line 106) | def _load_vocab(self):
    method _load_other_data (line 114) | def _load_other_data(self):
    method _data_preprocess (line 135) | def _data_preprocess(self, train_data, valid_data, test_data):
    method _raw_data_process (line 146) | def _raw_data_process(self, raw_data):
    method _convert_to_id (line 153) | def _convert_to_id(self, conversation):
    method _augment_and_add (line 175) | def _augment_and_add(self, raw_conv_dict):
    method _side_data_process (line 206) | def _side_data_process(self):
    method _entity_kg_process (line 222) | def _entity_kg_process(self):
    method _word_kg_process (line 249) | def _word_kg_process(self):

FILE: crslab/data/dataset/gorecdial/gorecdial.py
  class GoRecDialDataset (line 33) | class GoRecDialDataset(BaseDataset):
    method __init__ (line 58) | def __init__(self, opt, tokenize, restore=False, save=False):
    method _load_data (line 74) | def _load_data(self):
    method _load_raw_data (line 93) | def _load_raw_data(self):
    method _load_vocab (line 107) | def _load_vocab(self):
    method _load_other_data (line 115) | def _load_other_data(self):
    method _data_preprocess (line 135) | def _data_preprocess(self, train_data, valid_data, test_data):
    method _raw_data_process (line 146) | def _raw_data_process(self, raw_data):
    method _convert_to_id (line 153) | def _convert_to_id(self, conversation):
    method _augment_and_add (line 177) | def _augment_and_add(self, raw_conv_dict):
    method _side_data_process (line 211) | def _side_data_process(self):
    method _entity_kg_process (line 226) | def _entity_kg_process(self):
    method _word_kg_process (line 253) | def _word_kg_process(self):

FILE: crslab/data/dataset/inspired/inspired.py
  class InspiredDataset (line 33) | class InspiredDataset(BaseDataset):
    method __init__ (line 58) | def __init__(self, opt, tokenize, restore=False, save=False):
    method _load_data (line 74) | def _load_data(self):
    method _load_raw_data (line 93) | def _load_raw_data(self):
    method _load_vocab (line 107) | def _load_vocab(self):
    method _load_other_data (line 116) | def _load_other_data(self):
    method _data_preprocess (line 137) | def _data_preprocess(self, train_data, valid_data, test_data):
    method _raw_data_process (line 148) | def _raw_data_process(self, raw_data):
    method _convert_to_id (line 155) | def _convert_to_id(self, conversation):
    method _augment_and_add (line 181) | def _augment_and_add(self, raw_conv_dict):
    method _side_data_process (line 212) | def _side_data_process(self):
    method _entity_kg_process (line 228) | def _entity_kg_process(self):
    method _word_kg_process (line 255) | def _word_kg_process(self):

FILE: crslab/data/dataset/opendialkg/opendialkg.py
  class OpenDialKGDataset (line 34) | class OpenDialKGDataset(BaseDataset):
    method __init__ (line 59) | def __init__(self, opt, tokenize, restore=False, save=False):
    method _load_data (line 75) | def _load_data(self):
    method _load_raw_data (line 94) | def _load_raw_data(self):
    method _load_vocab (line 108) | def _load_vocab(self):
    method _load_other_data (line 116) | def _load_other_data(self):
    method _data_preprocess (line 136) | def _data_preprocess(self, train_data, valid_data, test_data):
    method _raw_data_process (line 147) | def _raw_data_process(self, raw_data):
    method _convert_to_id (line 154) | def _convert_to_id(self, conversation):
    method _augment_and_add (line 180) | def _augment_and_add(self, raw_conv_dict):
    method _side_data_process (line 211) | def _side_data_process(self):
    method _entity_kg_process (line 226) | def _entity_kg_process(self):
    method _word_kg_process (line 258) | def _word_kg_process(self):

FILE: crslab/data/dataset/redial/redial.py
  class ReDialDataset (line 34) | class ReDialDataset(BaseDataset):
    method __init__ (line 59) | def __init__(self, opt, tokenize, restore=False, save=False):
    method _load_data (line 75) | def _load_data(self):
    method _load_raw_data (line 94) | def _load_raw_data(self):
    method _load_vocab (line 108) | def _load_vocab(self):
    method _load_other_data (line 116) | def _load_other_data(self):
    method _data_preprocess (line 136) | def _data_preprocess(self, train_data, valid_data, test_data):
    method _raw_data_process (line 147) | def _raw_data_process(self, raw_data):
    method _merge_conv_data (line 154) | def _merge_conv_data(self, dialog):
    method _augment_and_add (line 180) | def _augment_and_add(self, raw_conv_dict):
    method _side_data_process (line 211) | def _side_data_process(self):
    method _entity_kg_process (line 226) | def _entity_kg_process(self, SELF_LOOP_ID=185):
    method _word_kg_process (line 253) | def _word_kg_process(self):

FILE: crslab/data/dataset/tgredial/tgredial.py
  class TGReDialDataset (line 34) | class TGReDialDataset(BaseDataset):
    method __init__ (line 62) | def __init__(self, opt, tokenize, restore=False, save=False):
    method _load_data (line 87) | def _load_data(self):
    method _load_raw_data (line 109) | def _load_raw_data(self):
    method _load_vocab (line 123) | def _load_vocab(self):
    method _load_other_data (line 148) | def _load_other_data(self):
    method _data_preprocess (line 177) | def _data_preprocess(self, train_data, valid_data, test_data):
    method _raw_data_process (line 188) | def _raw_data_process(self, raw_data):
    method _convert_to_id (line 195) | def _convert_to_id(self, conversation):
    method _augment_and_add (line 241) | def _augment_and_add(self, raw_conv_dict):
    method _side_data_process (line 283) | def _side_data_process(self):
    method _entity_kg_process (line 298) | def _entity_kg_process(self):
    method _word_kg_process (line 327) | def _word_kg_process(self):

FILE: crslab/download.py
  class DownloadableFile (line 22) | class DownloadableFile:
    method __init__ (line 44) | def __init__(self, url, file_name, hashcode, zipped=True, from_google=...
    method checksum (line 51) | def checksum(self, dpath):
    method download_file (line 71) | def download_file(self, dpath):
  function download (line 83) | def download(url, path, fname, redownload=False, num_retries=5):
  function _get_confirm_token (line 155) | def _get_confirm_token(response):
  function download_from_google_drive (line 162) | def download_from_google_drive(gd_id, destination):
  function move (line 185) | def move(path1, path2):
  function untar (line 192) | def untar(path, fname, deleteTar=True):
  function make_dir (line 212) | def make_dir(path):
  function remove_dir (line 221) | def remove_dir(path):
  function check_build (line 228) | def check_build(path, version_string=None):
  function mark_done (line 247) | def mark_done(path, version_string=None):
  function build (line 266) | def build(dpath, dfile, version=None):

FILE: crslab/evaluator/__init__.py
  function get_evaluator (line 25) | def get_evaluator(evaluator_name, dataset, tensorboard=False):

FILE: crslab/evaluator/base.py
  class BaseEvaluator (line 13) | class BaseEvaluator(ABC):
    method rec_evaluate (line 16) | def rec_evaluate(self, preds, label):
    method gen_evaluate (line 19) | def gen_evaluate(self, preds, label):
    method policy_evaluate (line 22) | def policy_evaluate(self, preds, label):
    method report (line 26) | def report(self, epoch, mode):
    method reset_metrics (line 30) | def reset_metrics(self):

FILE: crslab/evaluator/conv.py
  class ConvEvaluator (line 26) | class ConvEvaluator(BaseEvaluator):
    method __init__ (line 37) | def __init__(self, tensorboard=False):
    method _load_embedding (line 48) | def _load_embedding(self, language):
    method _get_sent_embedding (line 57) | def _get_sent_embedding(self, sent):
    method gen_evaluate (line 60) | def gen_evaluate(self, hyp, refs):
    method report (line 78) | def report(self, epoch=-1, mode='test'):
    method reset_metrics (line 89) | def reset_metrics(self):

FILE: crslab/evaluator/metrics/base.py
  class Metric (line 21) | class Metric(ABC):
    method value (line 29) | def value(self) -> float:
    method __add__ (line 36) | def __add__(self, other: Any) -> 'Metric':
    method __iadd__ (line 39) | def __iadd__(self, other):
    method __radd__ (line 42) | def __radd__(self, other: Any):
    method __str__ (line 47) | def __str__(self) -> str:
    method __repr__ (line 50) | def __repr__(self) -> str:
    method __float__ (line 53) | def __float__(self) -> float:
    method __int__ (line 56) | def __int__(self) -> int:
    method __eq__ (line 59) | def __eq__(self, other: Any) -> bool:
    method __lt__ (line 65) | def __lt__(self, other: Any) -> bool:
    method __sub__ (line 71) | def __sub__(self, other: Any) -> float:
    method __rsub__ (line 79) | def __rsub__(self, other: Any) -> float:
    method as_number (line 90) | def as_number(cls, obj: TScalar) -> Union[int, float]:
    method as_float (line 99) | def as_float(cls, obj: TScalar) -> float:
    method as_int (line 103) | def as_int(cls, obj: TScalar) -> int:
    method many (line 107) | def many(cls, *objs: List[TVector]) -> List['Metric']:
  class SumMetric (line 119) | class SumMetric(Metric):
    method __init__ (line 129) | def __init__(self, sum_: TScalar = 0):
    method __add__ (line 136) | def __add__(self, other: Optional['SumMetric']) -> 'SumMetric':
    method value (line 145) | def value(self) -> float:
  class AverageMetric (line 149) | class AverageMetric(Metric):
    method __init__ (line 159) | def __init__(self, numer: TScalar, denom: TScalar = 1):
    method __add__ (line 163) | def __add__(self, other: Optional['AverageMetric']) -> 'AverageMetric':
    method value (line 173) | def value(self) -> float:
  function aggregate_unnamed_reports (line 182) | def aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[...
  class Metrics (line 193) | class Metrics(object):
    method __init__ (line 198) | def __init__(self):
    method __str__ (line 201) | def __str__(self):
    method __repr__ (line 204) | def __repr__(self):
    method get (line 207) | def get(self, key: str):
    method __getitem__ (line 213) | def __getitem__(self, item):
    method add (line 216) | def add(self, key: str, value: Optional[Metric]) -> None:
    method report (line 222) | def report(self):
    method clear (line 228) | def clear(self):

FILE: crslab/evaluator/metrics/gen.py
  class PPLMetric (line 27) | class PPLMetric(AverageMetric):
    method value (line 28) | def value(self):
  function normalize_answer (line 32) | def normalize_answer(s):
  class ExactMatchMetric (line 45) | class ExactMatchMetric(AverageMetric):
    method compute (line 47) | def compute(guess: str, answers: List[str]) -> 'ExactMatchMetric':
  class F1Metric (line 56) | class F1Metric(AverageMetric):
    method _prec_recall_f1_score (line 62) | def _prec_recall_f1_score(pred_items, gold_items):
    method compute (line 81) | def compute(guess: str, answers: List[str]) -> 'F1Metric':
  class BleuMetric (line 92) | class BleuMetric(AverageMetric):
    method compute (line 94) | def compute(guess: str, answers: List[str], k: int) -> Optional['BleuM...
  class DistMetric (line 109) | class DistMetric(SumMetric):
    method compute (line 111) | def compute(sent: str, k: int) -> 'DistMetric':
  class EmbeddingAverage (line 118) | class EmbeddingAverage(AverageMetric):
    method _avg_embedding (line 120) | def _avg_embedding(embedding):
    method compute (line 124) | def compute(hyp_embedding, ref_embeddings) -> 'EmbeddingAverage':
  class VectorExtrema (line 131) | class VectorExtrema(AverageMetric):
    method _extreme_embedding (line 133) | def _extreme_embedding(embedding):
    method compute (line 142) | def compute(hyp_embedding, ref_embeddings) -> 'VectorExtrema':
  class GreedyMatch (line 149) | class GreedyMatch(AverageMetric):
    method compute (line 151) | def compute(hyp_embedding, ref_embeddings) -> 'GreedyMatch':

FILE: crslab/evaluator/metrics/rec.py
  class HitMetric (line 14) | class HitMetric(AverageMetric):
    method compute (line 16) | def compute(ranks, label, k) -> 'HitMetric':
  class NDCGMetric (line 20) | class NDCGMetric(AverageMetric):
    method compute (line 22) | def compute(ranks, label, k) -> 'NDCGMetric':
  class MRRMetric (line 29) | class MRRMetric(AverageMetric):
    method compute (line 31) | def compute(ranks, label, k) -> 'MRRMetric':

FILE: crslab/evaluator/rec.py
  class RecEvaluator (line 20) | class RecEvaluator(BaseEvaluator):
    method __init__ (line 28) | def __init__(self, tensorboard=False):
    method rec_evaluate (line 37) | def rec_evaluate(self, ranks, label):
    method report (line 44) | def report(self, epoch=-1, mode='test'):
    method reset_metrics (line 52) | def reset_metrics(self):

FILE: crslab/evaluator/standard.py
  class StandardEvaluator (line 27) | class StandardEvaluator(BaseEvaluator):
    method __init__ (line 38) | def __init__(self, language, tensorboard=False):
    method _load_embedding (line 55) | def _load_embedding(self, language):
    method _get_sent_embedding (line 64) | def _get_sent_embedding(self, sent):
    method rec_evaluate (line 67) | def rec_evaluate(self, ranks, label):
    method gen_evaluate (line 74) | def gen_evaluate(self, hyp, refs):
    method report (line 90) | def report(self, epoch=-1, mode='test'):
    method reset_metrics (line 100) | def reset_metrics(self):

FILE: crslab/evaluator/utils.py
  function _line_width (line 23) | def _line_width():
  function float_formatter (line 32) | def float_formatter(f: Union[float, int]) -> str:
  function round_sigfigs (line 60) | def round_sigfigs(x: Union[float, 'torch.Tensor'], sigfigs=4) -> float:
  function _report_sort_key (line 86) | def _report_sort_key(report_key: str) -> Tuple[str, str]:
  function nice_report (line 103) | def nice_report(report) -> str:

FILE: crslab/model/__init__.py
  function get_model (line 48) | def get_model(config, model_name, device, vocab, side_data=None):

FILE: crslab/model/base.py
  class BaseModel (line 17) | class BaseModel(ABC, nn.Module):
    method __init__ (line 20) | def __init__(self, opt, device, dpath=None, resource=None):
    method build_model (line 33) | def build_model(self, *args, **kwargs):
    method recommend (line 37) | def recommend(self, batch, mode):
    method converse (line 46) | def converse(self, batch, mode):
    method guide (line 55) | def guide(self, batch, mode):

FILE: crslab/model/conversation/gpt2/gpt2.py
  class GPT2Model (line 33) | class GPT2Model(BaseModel):
    method __init__ (line 43) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 62) | def build_model(self):
    method forward (line 67) | def forward(self, batch, mode):
    method generate (line 85) | def generate(self, context):
    method calculate_loss (line 111) | def calculate_loss(self, logit, labels):
    method generate_bs (line 122) | def generate_bs(self, context, beam=4):

FILE: crslab/model/conversation/transformer/transformer.py
  class TransformerModel (line 31) | class TransformerModel(BaseModel):
    method __init__ (line 61) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 107) | def build_model(self):
    method _init_embeddings (line 111) | def _init_embeddings(self):
    method _build_conversation_layer (line 123) | def _build_conversation_layer(self):
    method _starts (line 158) | def _starts(self, batch_size):
    method _decode_forced_with_kg (line 162) | def _decode_forced_with_kg(self, token_encoding, response):
    method _decode_greedy_with_kg (line 173) | def _decode_greedy_with_kg(self, token_encoding):
    method _decode_beam_search_with_kg (line 193) | def _decode_beam_search_with_kg(self, token_encoding, beam=4):
    method forward (line 247) | def forward(self, batch, mode):

FILE: crslab/model/crs/inspired/inspired_conv.py
  class InspiredConvModel (line 17) | class InspiredConvModel(BaseModel):
    method __init__ (line 27) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 47) | def build_model(self):
    method converse (line 53) | def converse(self, batch, mode):
    method generate (line 101) | def generate(self, roles, context):
    method calculate_loss (line 136) | def calculate_loss(self, logit, labels):

FILE: crslab/model/crs/inspired/inspired_rec.py
  class InspiredRecModel (line 33) | class InspiredRecModel(BaseModel):
    method __init__ (line 41) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 58) | def build_model(self):
    method recommend (line 70) | def recommend(self, batch, mode='train'):

FILE: crslab/model/crs/inspired/modules.py
  class SequenceCrossEntropyLoss (line 10) | class SequenceCrossEntropyLoss(nn.Module):
    method __init__ (line 19) | def __init__(self, ignore_index=None, label_smoothing=-1):
    method forward (line 24) | def forward(self, logits, labels):

FILE: crslab/model/crs/kbrd/kbrd.py
  class KBRDModel (line 34) | class KBRDModel(BaseModel):
    method __init__ (line 64) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 109) | def build_model(self, *args, **kwargs):
    method _build_embedding (line 115) | def _build_embedding(self):
    method _build_kg_layer (line 126) | def _build_kg_layer(self):
    method _build_recommendation_layer (line 131) | def _build_recommendation_layer(self):
    method _build_conversation_layer (line 136) | def _build_conversation_layer(self):
    method encode_user (line 174) | def encode_user(self, entity_lists, kg_embedding):
    method recommend (line 185) | def recommend(self, batch, mode):
    method _starts (line 193) | def _starts(self, batch_size):
    method decode_forced (line 197) | def decode_forced(self, encoder_states, user_embedding, resp):
    method decode_greedy (line 209) | def decode_greedy(self, encoder_states, user_embedding):
    method decode_beam_search (line 231) | def decode_beam_search(self, encoder_states, user_embedding, beam=4):
    method converse (line 287) | def converse(self, batch, mode):
    method forward (line 303) | def forward(self, batch, mode, stage):
    method freeze_parameters (line 312) | def freeze_parameters(self):

FILE: crslab/model/crs/kgsf/kgsf.py
  class KGSFModel (line 39) | class KGSFModel(BaseModel):
    method __init__ (line 70) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 124) | def build_model(self):
    method _init_embeddings (line 131) | def _init_embeddings(self):
    method _build_kg_layer (line 147) | def _build_kg_layer(self):
    method _build_infomax_layer (line 161) | def _build_infomax_layer(self):
    method _build_recommendation_layer (line 168) | def _build_recommendation_layer(self):
    method _build_conversation_layer (line 174) | def _build_conversation_layer(self):
    method pretrain_infomax (line 218) | def pretrain_infomax(self, batch):
    method recommend (line 241) | def recommend(self, batch, mode):
    method freeze_parameters (line 277) | def freeze_parameters(self):
    method _starts (line 284) | def _starts(self, batch_size):
    method _decode_forced_with_kg (line 288) | def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_e...
    method _decode_greedy_with_kg (line 308) | def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_e...
    method _decode_beam_search_with_kg (line 335) | def _decode_beam_search_with_kg(self, token_encoding, entity_reps, ent...
    method converse (line 407) | def converse(self, batch, mode):
    method forward (line 444) | def forward(self, batch, stage, mode):

FILE: crslab/model/crs/kgsf/modules.py
  class GateLayer (line 10) | class GateLayer(nn.Module):
    method __init__ (line 11) | def __init__(self, input_dim):
    method forward (line 16) | def forward(self, input1, input2):
  class TransformerDecoderLayerKG (line 23) | class TransformerDecoderLayerKG(nn.Module):
    method __init__ (line 24) | def __init__(
    method forward (line 61) | def forward(self, x, encoder_output, encoder_mask, kg_encoder_output, ...
  class TransformerDecoderKG (line 115) | class TransformerDecoderKG(nn.Module):
    method __init__ (line 140) | def __init__(
    method forward (line 190) | def forward(self, input, encoder_state, kg_encoder_output, kg_encoder_...

FILE: crslab/model/crs/ntrd/modules.py
  class GateLayer (line 14) | class GateLayer(nn.Module):
    method __init__ (line 15) | def __init__(self, input_dim):
    method forward (line 20) | def forward(self, input1, input2):
  class TransformerDecoderLayerKG (line 27) | class TransformerDecoderLayerKG(nn.Module):
    method __init__ (line 28) | def __init__(
    method forward (line 65) | def forward(self, x, encoder_output, encoder_mask, kg_encoder_output, ...
  class TransformerDecoderLayerSelection (line 118) | class TransformerDecoderLayerSelection(nn.Module):
    method __init__ (line 119) | def __init__(
    method forward (line 147) | def forward(self, x, encoder_output, encoder_mask, movie_embed, movie_...
  class TransformerDecoderKG (line 179) | class TransformerDecoderKG(nn.Module):
    method __init__ (line 204) | def __init__(
    method forward (line 254) | def forward(self, input, encoder_state, kg_encoder_output, kg_encoder_...
  class TransformerDecoderSelection (line 273) | class TransformerDecoderSelection(nn.Module):
    method __init__ (line 274) | def __init__(
    method forward (line 313) | def forward(self, input, encoder_state,movie_embed,movie_embed_mask,in...

FILE: crslab/model/crs/ntrd/ntrd.py
  class NTRDModel (line 34) | class NTRDModel(BaseModel):
    method __init__ (line 35) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 97) | def build_model(self):
    method _init_embeddings (line 105) | def _init_embeddings(self):
    method _build_kg_layer (line 121) | def _build_kg_layer(self):
    method _build_infomax_layer (line 135) | def _build_infomax_layer(self):
    method _build_recommendation_layer (line 142) | def _build_recommendation_layer(self):
    method _build_conversation_layer (line 148) | def _build_conversation_layer(self):
    method pretrain_infomax (line 198) | def pretrain_infomax(self, batch):
    method _build_movie_selector (line 221) | def _build_movie_selector(self):
    method recommend (line 240) | def recommend(self, batch, mode):
    method freeze_parameters (line 276) | def freeze_parameters(self):
    method _starts (line 283) | def _starts(self, batch_size):
    method converse (line 287) | def converse(self, batch, mode):
    method _decode_greedy_with_kg (line 351) | def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_e...
    method _decode_forced_with_kg (line 381) | def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_e...
    method forward (line 404) | def forward(self, batch, stage, mode):

FILE: crslab/model/crs/redial/modules.py
  class HRNN (line 18) | class HRNN(nn.Module):
    method __init__ (line 19) | def __init__(self,
    method get_utterance_encoding (line 52) | def get_utterance_encoding(self, context, utterance_lengths):
    method forward (line 95) | def forward(self, context, utterance_lengths, dialog_lengths):
  class SwitchingDecoder (line 116) | class SwitchingDecoder(nn.Module):
    method __init__ (line 117) | def __init__(self, hidden_size, context_size, num_layers, vocab_size, ...
    method forward (line 134) | def forward(self, request, request_lengths, context_state):

FILE: crslab/model/crs/redial/redial_conv.py
  class ReDialConvModel (line 28) | class ReDialConvModel(BaseModel):
    method __init__ (line 50) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 83) | def build_model(self):
    method forward (line 111) | def forward(self, batch, mode):

FILE: crslab/model/crs/redial/redial_rec.py
  class ReDialRecModel (line 26) | class ReDialRecModel(BaseModel):
    method __init__ (line 36) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 52) | def build_model(self):
    method forward (line 79) | def forward(self, batch, mode):

FILE: crslab/model/crs/tgredial/tg_conv.py
  class TGConvModel (line 33) | class TGConvModel(BaseModel):
    method __init__ (line 43) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 62) | def build_model(self):
    method forward (line 67) | def forward(self, batch, mode):
    method generate (line 86) | def generate(self, context):
    method generate_bs (line 112) | def generate_bs(self, context, beam=4):
    method calculate_loss (line 154) | def calculate_loss(self, logit, labels):

FILE: crslab/model/crs/tgredial/tg_policy.py
  class TGPolicyModel (line 33) | class TGPolicyModel(BaseModel):
    method __init__ (line 34) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 52) | def build_model(self, *args, **kwargs):
    method forward (line 64) | def forward(self, batch, mode):

FILE: crslab/model/crs/tgredial/tg_rec.py
  class TGRecModel (line 35) | class TGRecModel(BaseModel):
    method __init__ (line 51) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 76) | def build_model(self):
    method forward (line 94) | def forward(self, batch, mode):

FILE: crslab/model/policy/conv_bert/conv_bert.py
  class ConvBERTModel (line 32) | class ConvBERTModel(BaseModel):
    method __init__ (line 40) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 56) | def build_model(self, *args, **kwargs):
    method forward (line 66) | def forward(self, batch, mode):

FILE: crslab/model/policy/mgcg/mgcg.py
  class MGCGModel (line 28) | class MGCGModel(BaseModel):
    method __init__ (line 42) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 62) | def build_model(self, *args, **kwargs):
    method get_length (line 87) | def get_length(self, input):
    method forward (line 90) | def forward(self, batch, mode):

FILE: crslab/model/policy/pmi/pmi.py
  class PMIModel (line 22) | class PMIModel(BaseModel):
    method __init__ (line 31) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 45) | def build_model(self, *args, **kwargs):
    method forward (line 51) | def forward(self, batch, mode):

FILE: crslab/model/policy/profile_bert/profile_bert.py
  class ProfileBERTModel (line 33) | class ProfileBERTModel(BaseModel):
    method __init__ (line 42) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 60) | def build_model(self, *args, **kwargs):
    method forward (line 70) | def forward(self, batch, mode):

FILE: crslab/model/policy/topic_bert/topic_bert.py
  class TopicBERTModel (line 32) | class TopicBERTModel(BaseModel):
    method __init__ (line 40) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 58) | def build_model(self, *args, **kwargs):
    method forward (line 68) | def forward(self, batch, mode):

FILE: crslab/model/recommendation/bert/bert.py
  class BERTModel (line 33) | class BERTModel(BaseModel):
    method __init__ (line 41) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 58) | def build_model(self):
    method forward (line 70) | def forward(self, batch, mode='train'):

FILE: crslab/model/recommendation/gru4rec/gru4rec.py
  class GRU4RECModel (line 30) | class GRU4RECModel(BaseModel):
    method __init__ (line 44) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 64) | def build_model(self):
    method reconstruct_input (line 74) | def reconstruct_input(self, input_ids):
    method cross_entropy (line 96) | def cross_entropy(self, seq_out, pos_ids, neg_ids, input_mask):
    method forward (line 121) | def forward(self, batch, mode):

FILE: crslab/model/recommendation/gru4rec/modules.py
  class Embedding (line 5) | class Embedding(nn.Module):
    method __init__ (line 6) | def __init__(self, item_size, embedding_dim):
    method forward (line 10) | def forward(self, input: torch.Tensor):
  class GRU4REC (line 14) | class GRU4REC(nn.Module):
    method __init__ (line 15) | def __init__(self, item_size, embedding_dim, hidden_size, num_layers, ...
    method cross_entropy (line 37) | def cross_entropy(self, seq_out, pos_ids, neg_ids):
    method forward (line 62) | def forward(self, input: torch.Tensor):

FILE: crslab/model/recommendation/popularity/popularity.py
  class PopularityModel (line 23) | class PopularityModel(BaseModel):
    method __init__ (line 31) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 44) | def build_model(self):
    method forward (line 48) | def forward(self, batch, mode):

FILE: crslab/model/recommendation/sasrec/modules.py
  class SASRec (line 18) | class SASRec(nn.Module):
    method __init__ (line 19) | def __init__(self, hidden_dropout_prob, device, initializer_range,
    method build_model (line 37) | def build_model(self):
    method init_model (line 49) | def init_model(self):
    method forward (line 52) | def forward(self,
    method init_sas_weights (line 92) | def init_sas_weights(self, module):
    method save_model (line 105) | def save_model(self, file_name):
    method load_model (line 109) | def load_model(self, path):
    method compute_loss (line 120) | def compute_loss(self, y_pred, y, subset='test'):
    method cross_entropy (line 123) | def cross_entropy(self, seq_out, pos_ids, neg_ids):
  function gelu (line 149) | def gelu(x):
  function swish (line 162) | def swish(x):
  class LayerNorm (line 169) | class LayerNorm(nn.Module):
    method __init__ (line 170) | def __init__(self, hidden_size, eps=1e-12):
    method forward (line 177) | def forward(self, x):
  class Embeddings (line 184) | class Embeddings(nn.Module):
    method __init__ (line 187) | def __init__(self, item_size, hidden_size, max_seq_length,
    method forward (line 197) | def forward(self, input_ids):
  class SelfAttention (line 214) | class SelfAttention(nn.Module):
    method __init__ (line 215) | def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob,
    method transpose_for_scores (line 236) | def transpose_for_scores(self, x):
    method forward (line 251) | def forward(self, input_tensor, attention_mask):
  class Intermediate (line 291) | class Intermediate(nn.Module):
    method __init__ (line 292) | def __init__(self, hidden_size, hidden_act, hidden_dropout_prob):
    method forward (line 304) | def forward(self, input_tensor):
  class Layer (line 315) | class Layer(nn.Module):
    method __init__ (line 316) | def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob,
    method forward (line 324) | def forward(self, hidden_states, attention_mask):
  class Encoder (line 330) | class Encoder(nn.Module):
    method __init__ (line 331) | def __init__(self, num_hidden_layers, hidden_size, num_attention_heads,
    method forward (line 340) | def forward(self,

FILE: crslab/model/recommendation/sasrec/sasrec.py
  class SASRECModel (line 29) | class SASRECModel(BaseModel):
    method __init__ (line 45) | def __init__(self, opt, device, vocab, side_data):
    method build_model (line 67) | def build_model(self):
    method forward (line 81) | def forward(self, batch, mode):

FILE: crslab/model/recommendation/textcnn/textcnn.py
  class TextCNNModel (line 29) | class TextCNNModel(BaseModel):
    method __init__ (line 41) | def __init__(self, opt, device, vocab, side_data):
    method conv_and_pool (line 58) | def conv_and_pool(self, x, conv):
    method build_model (line 63) | def build_model(self):
    method forward (line 76) | def forward(self, batch, mode):

FILE: crslab/model/utils/functions.py
  function edge_to_pyg_format (line 14) | def edge_to_pyg_format(edge, type='RGCN'):
  function sort_for_packed_sequence (line 27) | def sort_for_packed_sequence(lengths: torch.Tensor):

FILE: crslab/model/utils/modules/attention.py
  class SelfAttentionBatch (line 16) | class SelfAttentionBatch(nn.Module):
    method __init__ (line 17) | def __init__(self, dim, da, alpha=0.2, dropout=0.5):
    method forward (line 28) | def forward(self, h):
  class SelfAttentionSeq (line 35) | class SelfAttentionSeq(nn.Module):
    method __init__ (line 36) | def __init__(self, dim, da, alpha=0.2, dropout=0.5):
    method forward (line 47) | def forward(self, h, mask=None, return_logits=False):

FILE: crslab/model/utils/modules/transformer.py
  function neginf (line 22) | def neginf(dtype):
  function _create_selfattn_mask (line 30) | def _create_selfattn_mask(x):
  function create_position_codes (line 41) | def create_position_codes(n_pos, dim, out):
  function _normalize (line 53) | def _normalize(tensor, norm_layer):
  class MultiHeadAttention (line 59) | class MultiHeadAttention(nn.Module):
    method __init__ (line 60) | def __init__(self, n_heads, dim, dropout=.0):
    method forward (line 78) | def forward(self, query, key=None, value=None, mask=None):
  class TransformerFFN (line 143) | class TransformerFFN(nn.Module):
    method __init__ (line 144) | def __init__(self, dim, dim_hidden, relu_dropout=.0):
    method forward (line 153) | def forward(self, x):
  class TransformerEncoderLayer (line 160) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 161) | def __init__(
    method forward (line 182) | def forward(self, tensor, mask):
  class TransformerEncoder (line 191) | class TransformerEncoder(nn.Module):
    method __init__ (line 218) | def __init__(
    method forward (line 286) | def forward(self, input):
  class TransformerDecoderLayer (line 314) | class TransformerDecoderLayer(nn.Module):
    method __init__ (line 315) | def __init__(
    method forward (line 342) | def forward(self, x, encoder_output, encoder_mask):
    method _create_selfattn_mask (line 372) | def _create_selfattn_mask(self, x):
  class TransformerDecoder (line 383) | class TransformerDecoder(nn.Module):
    method __init__ (line 406) | def __init__(
    method forward (line 456) | def forward(self, input, encoder_state, incr_state=None):

FILE: crslab/quick_start/quick_start.py
  function run_crslab (line 16) | def run_crslab(config, save_data=False, restore_data=False, save_system=...

FILE: crslab/system/__init__.py
  function get_system (line 48) | def get_system(opt, train_dataloader, valid_dataloader, test_dataloader,...

FILE: crslab/system/base.py
  class BaseSystem (line 41) | class BaseSystem(ABC):
    method __init__ (line 44) | def __init__(self, opt, train_dataloader, valid_dataloader, test_datal...
    method init_optim (line 109) | def init_optim(self, opt, parameters):
    method build_optimizer (line 138) | def build_optimizer(self, parameters):
    method build_lr_scheduler (line 144) | def build_lr_scheduler(self):
    method reset_early_stop_state (line 161) | def reset_early_stop_state(self):
    method fit (line 174) | def fit(self):
    method step (line 179) | def step(self, batch, stage, mode):
    method backward (line 189) | def backward(self, loss):
    method _zero_grad (line 204) | def _zero_grad(self):
    method _update_params (line 210) | def _update_params(self):
    method adjust_lr (line 236) | def adjust_lr(self, metric=None):
    method early_stop (line 247) | def early_stop(self, metric):
    method save_model (line 261) | def save_model(self):
    method restore_model (line 277) | def restore_model(self):
    method interact (line 293) | def interact(self):
    method init_interact (line 296) | def init_interact(self):
    method update_context (line 312) | def update_context(self, stage, token_ids=None, entity_ids=None, item_...
    method get_input (line 328) | def get_input(self, language):
    method tokenize (line 343) | def tokenize(self, text, tokenizer, path=None):
    method nltk_tokenize (line 350) | def nltk_tokenize(self, text):
    method bert_tokenize (line 354) | def bert_tokenize(self, text, path):
    method gpt2_tokenize (line 360) | def gpt2_tokenize(self, text, path):
    method pkuseg_tokenize (line 366) | def pkuseg_tokenize(self, text):
    method link (line 372) | def link(self, tokens, entities):

FILE: crslab/system/inspired.py
  class InspiredSystem (line 16) | class InspiredSystem(BaseSystem):
    method __init__ (line 19) | def __init__(self, opt, train_dataloader, valid_dataloader, test_datal...
    method rec_evaluate (line 66) | def rec_evaluate(self, rec_predict, item_label):
    method conv_evaluate (line 76) | def conv_evaluate(self, prediction, response):
    method step (line 91) | def step(self, batch, stage, mode):
    method train_recommender (line 132) | def train_recommender(self):
    method train_conversation (line 173) | def train_conversation(self):
    method fit (line 202) | def fit(self):
    method interact (line 208) | def interact(self):

FILE: crslab/system/kbrd.py
  class KBRDSystem (line 22) | class KBRDSystem(BaseSystem):
    method __init__ (line 25) | def __init__(self, opt, train_dataloader, valid_dataloader, test_datal...
    method rec_evaluate (line 56) | def rec_evaluate(self, rec_predict, item_label):
    method conv_evaluate (line 66) | def conv_evaluate(self, prediction, response):
    method step (line 74) | def step(self, batch, stage, mode):
    method train_recommender (line 105) | def train_recommender(self):
    method train_conversation (line 134) | def train_conversation(self):
    method fit (line 169) | def fit(self):
    method interact (line 173) | def interact(self):

FILE: crslab/system/kgsf.py
  class KGSFSystem (line 21) | class KGSFSystem(BaseSystem):
    method __init__ (line 24) | def __init__(self, opt, train_dataloader, valid_dataloader, test_datal...
    method rec_evaluate (line 58) | def rec_evaluate(self, rec_predict, item_label):
    method conv_evaluate (line 68) | def conv_evaluate(self, prediction, response):
    method step (line 76) | def step(self, batch, stage, mode):
    method pretrain (line 115) | def pretrain(self):
    method train_recommender (line 125) | def train_recommender(self):
    method train_conversation (line 154) | def train_conversation(self):
    method fit (line 183) | def fit(self):
    method interact (line 188) | def interact(self):

FILE: crslab/system/ntrd.py
  class NTRDSystem (line 18) | class NTRDSystem(BaseSystem):
    method __init__ (line 20) | def __init__(self, opt, train_dataloader, valid_dataloader, test_datal...
    method rec_evaluate (line 43) | def rec_evaluate(self, rec_predict, item_label):
    method conv_evaluate (line 53) | def conv_evaluate(self, prediction,movie_prediction,response,movie_res...
    method step (line 73) | def step(self, batch, stage, mode):
    method pretrain (line 127) | def pretrain(self):
    method train_recommender (line 137) | def train_recommender(self):
    method train_conversation (line 166) | def train_conversation(self):
    method fit (line 195) | def fit(self):
    method interact (line 200) | def interact(self):

FILE: crslab/system/redial.py
  class ReDialSystem (line 20) | class ReDialSystem(BaseSystem):
    method __init__ (line 23) | def __init__(self, opt, train_dataloader, valid_dataloader, test_datal...
    method rec_evaluate (line 56) | def rec_evaluate(self, rec_predict, item_label):
    method conv_evaluate (line 66) | def conv_evaluate(self, prediction, response):
    method step (line 74) | def step(self, batch, stage, mode):
    method train_recommender (line 102) | def train_recommender(self):
    method train_conversation (line 131) | def train_conversation(self):
    method fit (line 160) | def fit(self):
    method interact (line 164) | def interact(self):

FILE: crslab/system/tgredial.py
  class TGReDialSystem (line 24) | class TGReDialSystem(BaseSystem):
    method __init__ (line 27) | def __init__(self, opt, train_dataloader, valid_dataloader, test_datal...
    method rec_evaluate (line 79) | def rec_evaluate(self, rec_predict, item_label):
    method policy_evaluate (line 89) | def policy_evaluate(self, rec_predict, movie_label):
    method conv_evaluate (line 97) | def conv_evaluate(self, prediction, response):
    method step (line 112) | def step(self, batch, stage, mode):
    method train_recommender (line 168) | def train_recommender(self):
    method train_conversation (line 212) | def train_conversation(self):
    method train_policy (line 241) | def train_policy(self):
    method fit (line 286) | def fit(self):
    method interact (line 294) | def interact(self):
    method process_input (line 329) | def process_input(self, input_text, stage):
    method convert_to_id (line 348) | def convert_to_id(self, text, stage):

FILE: crslab/system/utils/functions.py
  function compute_grad_norm (line 18) | def compute_grad_norm(parameters, norm_type=2.0):
  function ind2txt (line 41) | def ind2txt(inds, ind2tok, end_token_idx=None, unk_token='unk'):
  function ind2txt_with_slots (line 51) | def ind2txt_with_slots(inds,slots,ind2tok, end_token_idx=None, unk_token...
  function ind2slot (line 65) | def ind2slot(inds,ind2slot):

FILE: crslab/system/utils/lr_scheduler.py
  class LRScheduler (line 18) | class LRScheduler(ABC):
    method __init__ (line 30) | def __init__(self, optimizer, warmup_steps: int = 0):
    method _warmup_lr (line 45) | def _warmup_lr(self, step):
    method _init_warmup_scheduler (line 51) | def _init_warmup_scheduler(self, optimizer):
    method _is_lr_warming_up (line 57) | def _is_lr_warming_up(self):
    method train_step (line 67) | def train_step(self):
    method valid_step (line 80) | def valid_step(self, metric=None):
    method train_adjust (line 88) | def train_adjust(self):
    method valid_adjust (line 97) | def valid_adjust(self, metric):
  class ReduceLROnPlateau (line 110) | class ReduceLROnPlateau(LRScheduler):
    method __init__ (line 115) | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, ver...
    method train_adjust (line 123) | def train_adjust(self):
    method valid_adjust (line 126) | def valid_adjust(self, metric):
  class StepLR (line 130) | class StepLR(LRScheduler):
    method __init__ (line 135) | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, war...
    method train_adjust (line 139) | def train_adjust(self):
    method valid_adjust (line 142) | def valid_adjust(self, metric=None):
  class ConstantLR (line 146) | class ConstantLR(LRScheduler):
    method __init__ (line 147) | def __init__(self, optimizer, warmup_steps=0):
    method train_adjust (line 150) | def train_adjust(self):
    method valid_adjust (line 153) | def valid_adjust(self, metric):
  class InvSqrtLR (line 157) | class InvSqrtLR(LRScheduler):
    method __init__ (line 162) | def __init__(self, optimizer, invsqrt_lr_decay_gamma=-1, last_epoch=-1...
    method _invsqrt_lr (line 182) | def _invsqrt_lr(self, step):
    method train_adjust (line 185) | def train_adjust(self):
    method valid_adjust (line 188) | def valid_adjust(self, metric):
  class CosineAnnealingLR (line 193) | class CosineAnnealingLR(LRScheduler):
    method __init__ (line 198) | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, warmup_...
    method train_adjust (line 208) | def train_adjust(self):
    method valid_adjust (line 211) | def valid_adjust(self, metric):
  class CosineAnnealingWarmRestartsLR (line 215) | class CosineAnnealingWarmRestartsLR(LRScheduler):
    method __init__ (line 216) | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1,...
    method train_adjust (line 220) | def train_adjust(self):
    method valid_adjust (line 223) | def valid_adjust(self, metric):
  class TransformersLinearLR (line 227) | class TransformersLinearLR(LRScheduler):
    method __init__ (line 232) | def __init__(self, optimizer, training_steps, warmup_steps=0):
    method _linear_lr (line 242) | def _linear_lr(self, step):
    method train_adjust (line 245) | def train_adjust(self):
    method valid_adjust (line 248) | def valid_adjust(self, metric):
  class TransformersCosineLR (line 252) | class TransformersCosineLR(LRScheduler):
    method __init__ (line 253) | def __init__(self, optimizer, training_steps: int, num_cycles: float =...
    method _cosine_lr (line 260) | def _cosine_lr(self, step):
    method train_adjust (line 264) | def train_adjust(self):
    method valid_adjust (line 267) | def valid_adjust(self, metric):
  class TransformersCosineWithHardRestartsLR (line 271) | class TransformersCosineWithHardRestartsLR(LRScheduler):
    method __init__ (line 272) | def __init__(self, optimizer, training_steps: int, num_cycles: int = 1...
    method _cosine_with_hard_restarts_lr (line 279) | def _cosine_with_hard_restarts_lr(self, step):
    method train_adjust (line 285) | def train_adjust(self):
    method valid_adjust (line 288) | def valid_adjust(self, metric):
  class TransformersPolynomialDecayLR (line 292) | class TransformersPolynomialDecayLR(LRScheduler):
    method __init__ (line 293) | def __init__(self, optimizer, training_steps, lr_end=1e-7, power=1.0, ...
    method _polynomial_decay_lr (line 302) | def _polynomial_decay_lr(self, step):
    method train_adjust (line 312) | def train_adjust(self):
    method valid_adjust (line 315) | def valid_adjust(self, metric):

FILE: docs/source/conf.py
  function setup (line 76) | def setup(app):
Condensed preview — 253 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (645K chars).
[
  {
    "path": ".gitattributes",
    "chars": 88,
    "preview": "* text=auto eol=lf\n*.{cmd,[cC][mM][dD]} text eol=crlf\n*.{bat,[bB][aA][tT]} text eol=crlf"
  },
  {
    "path": ".gitignore",
    "chars": 4521,
    "preview": "# Created by .ignore support plugin (hsz.mobi)\n### Project\n\ndata\nlog\nsave\n!crslab/data\nruns\n\n### VisualStudioCode templa"
  },
  {
    "path": ".readthedocs.yml",
    "chars": 584,
    "preview": "# Required\nversion: 2\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n  configuration: docs/source/con"
  },
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2021 RUCAIBox\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README.md",
    "chars": 14828,
    "preview": "# CRSLab\n\n[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab)\n[![Release](ht"
  },
  {
    "path": "README_CN.md",
    "chars": 11331,
    "preview": "# CRSLab\n\n[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab)\n[![Release](ht"
  },
  {
    "path": "config/conversation/gpt2/durecdial.yaml",
    "chars": 363,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/conversation/gpt2/gorecdial.yaml",
    "chars": 363,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/conversation/gpt2/inspired.yaml",
    "chars": 362,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/conversation/gpt2/opendialkg.yaml",
    "chars": 364,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunc"
  },
  {
    "path": "config/conversation/gpt2/redial.yaml",
    "chars": 360,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate:"
  },
  {
    "path": "config/conversation/gpt2/tgredial.yaml",
    "chars": 414,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/conversation/transformer/durecdial.yaml",
    "chars": 588,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  conv: jieba\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: "
  },
  {
    "path": "config/conversation/transformer/gorecdial.yaml",
    "chars": 625,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1"
  },
  {
    "path": "config/conversation/transformer/inspired.yaml",
    "chars": 624,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n"
  },
  {
    "path": "config/conversation/transformer/opendialkg.yaml",
    "chars": 626,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: "
  },
  {
    "path": "config/conversation/transformer/redial.yaml",
    "chars": 584,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# "
  },
  {
    "path": "config/conversation/transformer/tgredial.yaml",
    "chars": 587,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  conv: pkuseg\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: "
  },
  {
    "path": "config/crs/inspired/durecdial.yaml",
    "chars": 795,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30"
  },
  {
    "path": "config/crs/inspired/gorecdial.yaml",
    "chars": 795,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30"
  },
  {
    "path": "config/crs/inspired/inspired.yaml",
    "chars": 697,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\n"
  },
  {
    "path": "config/crs/inspired/opendialkg.yaml",
    "chars": 796,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 3"
  },
  {
    "path": "config/crs/inspired/redial.yaml",
    "chars": 792,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nit"
  },
  {
    "path": "config/crs/inspired/tgredial.yaml",
    "chars": 794,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\n"
  },
  {
    "path": "config/crs/kbrd/durecdial.yaml",
    "chars": 671,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize: jieba\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# mode"
  },
  {
    "path": "config/crs/kbrd/gorecdial.yaml",
    "chars": 708,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model"
  },
  {
    "path": "config/crs/kbrd/inspired.yaml",
    "chars": 669,
    "preview": "# dataset\ndataset: Inspired\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\n"
  },
  {
    "path": "config/crs/kbrd/opendialkg.yaml",
    "chars": 671,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# mode"
  },
  {
    "path": "config/crs/kbrd/redial.yaml",
    "chars": 669,
    "preview": "# dataset\ndataset: ReDial\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nmo"
  },
  {
    "path": "config/crs/kbrd/tgredial.yaml",
    "chars": 738,
    "preview": "# dataset\ndataset: TGReDial\ntokenize: pkuseg\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# mode"
  },
  {
    "path": "config/crs/kgsf/durecdial.yaml",
    "chars": 801,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize: jieba\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncat"
  },
  {
    "path": "config/crs/kgsf/gorecdial.yaml",
    "chars": 795,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate"
  },
  {
    "path": "config/crs/kgsf/inspired.yaml",
    "chars": 799,
    "preview": "# dataset\ndataset: Inspired\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate:"
  },
  {
    "path": "config/crs/kgsf/opendialkg.yaml",
    "chars": 801,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncat"
  },
  {
    "path": "config/crs/kgsf/redial.yaml",
    "chars": 744,
    "preview": "# dataset\ndataset: ReDial\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 3"
  },
  {
    "path": "config/crs/kgsf/tgredial.yaml",
    "chars": 802,
    "preview": "# dataset\ndataset: TGReDial\ntokenize: pkuseg\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncat"
  },
  {
    "path": "config/crs/ntrd/tgredial.yaml",
    "chars": 860,
    "preview": "# dataset\ndataset: TGReDial\ntokenize: pkuseg\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncat"
  },
  {
    "path": "config/crs/redial/durecdial.yaml",
    "chars": 766,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: jieba\n  conv: jieba\n# dataloader\nutterance_truncate: 80\nconversation_trunc"
  },
  {
    "path": "config/crs/redial/gorecdial.yaml",
    "chars": 763,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_truncat"
  },
  {
    "path": "config/crs/redial/inspired.yaml",
    "chars": 763,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_truncate"
  },
  {
    "path": "config/crs/redial/opendialkg.yaml",
    "chars": 765,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_trunca"
  },
  {
    "path": "config/crs/redial/redial.yaml",
    "chars": 763,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_truncate: "
  },
  {
    "path": "config/crs/redial/tgredial.yaml",
    "chars": 766,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: pkuseg\n  conv: pkuseg\n# dataloader\nutterance_truncate: 80\nconversation_trun"
  },
  {
    "path": "config/crs/tgredial/durecdial.yaml",
    "chars": 820,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30"
  },
  {
    "path": "config/crs/tgredial/gorecdial.yaml",
    "chars": 820,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30"
  },
  {
    "path": "config/crs/tgredial/inspired.yaml",
    "chars": 819,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\n"
  },
  {
    "path": "config/crs/tgredial/opendialkg.yaml",
    "chars": 821,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 3"
  },
  {
    "path": "config/crs/tgredial/redial.yaml",
    "chars": 804,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nit"
  },
  {
    "path": "config/crs/tgredial/tgredial.yaml",
    "chars": 1017,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n  policy: bert\n# dataloader\ncontext_truncate: 256\nrespons"
  },
  {
    "path": "config/policy/conv_bert/tgredial.yaml",
    "chars": 335,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunc"
  },
  {
    "path": "config/policy/mgcg/tgredial.yaml",
    "chars": 414,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunc"
  },
  {
    "path": "config/policy/pmi/tgredial.yaml",
    "chars": 332,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunc"
  },
  {
    "path": "config/policy/profile_bert/tgredial.yaml",
    "chars": 349,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunc"
  },
  {
    "path": "config/policy/topic_bert/tgredial.yaml",
    "chars": 336,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunc"
  },
  {
    "path": "config/recommendation/bert/durecdial.yaml",
    "chars": 307,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/bert/gorecdial.yaml",
    "chars": 307,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/bert/inspired.yaml",
    "chars": 306,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/bert/opendialkg.yaml",
    "chars": 308,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/recommendation/bert/redial.yaml",
    "chars": 304,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: "
  },
  {
    "path": "config/recommendation/bert/tgredial.yaml",
    "chars": 358,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/gru4rec/durecdial.yaml",
    "chars": 415,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/gru4rec/gorecdial.yaml",
    "chars": 415,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/gru4rec/inspired.yaml",
    "chars": 414,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/gru4rec/opendialkg.yaml",
    "chars": 416,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/recommendation/gru4rec/redial.yaml",
    "chars": 412,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: "
  },
  {
    "path": "config/recommendation/gru4rec/tgredial.yaml",
    "chars": 466,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/popularity/durecdial.yaml",
    "chars": 316,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/popularity/gorecdial.yaml",
    "chars": 316,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/popularity/inspired.yaml",
    "chars": 315,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/popularity/opendialkg.yaml",
    "chars": 317,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/recommendation/popularity/redial.yaml",
    "chars": 313,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: "
  },
  {
    "path": "config/recommendation/popularity/tgredial.yaml",
    "chars": 314,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/sasrec/durecdial.yaml",
    "chars": 492,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/sasrec/gorecdial.yaml",
    "chars": 492,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/sasrec/inspired.yaml",
    "chars": 491,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/sasrec/opendialkg.yaml",
    "chars": 493,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/recommendation/sasrec/redial.yaml",
    "chars": 489,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: "
  },
  {
    "path": "config/recommendation/sasrec/tgredial.yaml",
    "chars": 544,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/textcnn/durecdial.yaml",
    "chars": 559,
    "preview": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: jieba\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/recommendation/textcnn/gorecdial.yaml",
    "chars": 558,
    "preview": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncat"
  },
  {
    "path": "config/recommendation/textcnn/inspired.yaml",
    "chars": 557,
    "preview": "# dataset\ndataset: Inspired\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate"
  },
  {
    "path": "config/recommendation/textcnn/opendialkg.yaml",
    "chars": 559,
    "preview": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "config/recommendation/textcnn/redial.yaml",
    "chars": 555,
    "preview": "# dataset\ndataset: ReDial\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: "
  },
  {
    "path": "config/recommendation/textcnn/tgredial.yaml",
    "chars": 611,
    "preview": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: sougou\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_trunca"
  },
  {
    "path": "crslab/__init__.py",
    "chars": 22,
    "preview": "__version__ = '0.0.1'\n"
  },
  {
    "path": "crslab/config/__init__.py",
    "chars": 1031,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/config/config.py",
    "chars": 4550,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2021"
  },
  {
    "path": "crslab/data/__init__.py",
    "chars": 3356,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020"
  },
  {
    "path": "crslab/data/dataloader/__init__.py",
    "chars": 251,
    "preview": "from .base import BaseDataLoader\nfrom .inspired import InspiredDataLoader\nfrom .kbrd import KBRDDataLoader\nfrom .kgsf im"
  },
  {
    "path": "crslab/data/dataloader/base.py",
    "chars": 6084,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2020"
  },
  {
    "path": "crslab/data/dataloader/inspired.py",
    "chars": 6239,
    "preview": "# @Time   : 2021/3/11\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nfrom copy import deepcopy\n\nimport"
  },
  {
    "path": "crslab/data/dataloader/kbrd.py",
    "chars": 3563,
    "preview": "# @Time   : 2020/11/27\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/2\n# @Auth"
  },
  {
    "path": "crslab/data/dataloader/kgsf.py",
    "chars": 5386,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2020"
  },
  {
    "path": "crslab/data/dataloader/ntrd.py",
    "chars": 5397,
    "preview": "# @Time   : 2021/10/06\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nfrom copy import deepcopy\n\nimport"
  },
  {
    "path": "crslab/data/dataloader/redial.py",
    "chars": 6299,
    "preview": "# @Time   : 2020/11/22\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE:\n# @Time   : 2020/12/16\n# @A"
  },
  {
    "path": "crslab/data/dataloader/tgredial.py",
    "chars": 18621,
    "preview": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2020/12/29, 2020/1"
  },
  {
    "path": "crslab/data/dataloader/utils.py",
    "chars": 5256,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/10\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/data/dataset/__init__.py",
    "chars": 262,
    "preview": "from .base import BaseDataset\nfrom .durecdial import DuRecDialDataset\nfrom .gorecdial import GoRecDialDataset\nfrom .insp"
  },
  {
    "path": "crslab/data/dataset/base.py",
    "chars": 6372,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2020"
  },
  {
    "path": "crslab/data/dataset/durecdial/__init__.py",
    "chars": 40,
    "preview": "from .durecdial import DuRecDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/durecdial/durecdial.py",
    "chars": 10549,
    "preview": "# @Time   : 2020/12/21\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/21, 2021"
  },
  {
    "path": "crslab/data/dataset/durecdial/resources.py",
    "chars": 2169,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/data/dataset/gorecdial/__init__.py",
    "chars": 40,
    "preview": "from .gorecdial import GoRecDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/gorecdial/gorecdial.py",
    "chars": 10979,
    "preview": "# @Time   : 2020/12/12\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/13, 2021"
  },
  {
    "path": "crslab/data/dataset/gorecdial/resources.py",
    "chars": 2113,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/14\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/data/dataset/inspired/__init__.py",
    "chars": 38,
    "preview": "from .inspired import InspiredDataset\n"
  },
  {
    "path": "crslab/data/dataset/inspired/inspired.py",
    "chars": 10942,
    "preview": "# @Time   : 2020/12/19\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/20, 2021"
  },
  {
    "path": "crslab/data/dataset/inspired/resources.py",
    "chars": 2054,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/data/dataset/opendialkg/__init__.py",
    "chars": 42,
    "preview": "from .opendialkg import OpenDialKGDataset\n"
  },
  {
    "path": "crslab/data/dataset/opendialkg/opendialkg.py",
    "chars": 11254,
    "preview": "# @Time   : 2020/12/19\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/20, 2021"
  },
  {
    "path": "crslab/data/dataset/opendialkg/resources.py",
    "chars": 2059,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/21\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/data/dataset/redial/__init__.py",
    "chars": 34,
    "preview": "from .redial import ReDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/redial/redial.py",
    "chars": 11294,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2021"
  },
  {
    "path": "crslab/data/dataset/redial/resources.py",
    "chars": 2054,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/1\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPD"
  },
  {
    "path": "crslab/data/dataset/tgredial/__init__.py",
    "chars": 38,
    "preview": "from .tgredial import TGReDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/tgredial/resources.py",
    "chars": 2194,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/4\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPD"
  },
  {
    "path": "crslab/data/dataset/tgredial/tgredial.py",
    "chars": 15248,
    "preview": "# @Time   : 2020/12/4\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/6, 2021/1"
  },
  {
    "path": "crslab/download.py",
    "chars": 8619,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/7\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPD"
  },
  {
    "path": "crslab/evaluator/__init__.py",
    "chars": 1108,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/evaluator/base.py",
    "chars": 574,
    "preview": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/11/30\n# @Aut"
  },
  {
    "path": "crslab/evaluator/conv.py",
    "chars": 3531,
    "preview": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/18\n# @Aut"
  },
  {
    "path": "crslab/evaluator/embeddings.py",
    "chars": 990,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/18\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/evaluator/end2end.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "crslab/evaluator/metrics/__init__.py",
    "chars": 246,
    "preview": "from .base import Metric, Metrics, aggregate_unnamed_reports, AverageMetric\nfrom .gen import BleuMetric, ExactMatchMetri"
  },
  {
    "path": "crslab/evaluator/metrics/base.py",
    "chars": 6620,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020"
  },
  {
    "path": "crslab/evaluator/metrics/gen.py",
    "chars": 5029,
    "preview": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/18\n# @Aut"
  },
  {
    "path": "crslab/evaluator/metrics/rec.py",
    "chars": 913,
    "preview": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/2\n# @Auth"
  },
  {
    "path": "crslab/evaluator/rec.py",
    "chars": 2018,
    "preview": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/17\n# @Aut"
  },
  {
    "path": "crslab/evaluator/standard.py",
    "chars": 4209,
    "preview": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/18\n# @Aut"
  },
  {
    "path": "crslab/evaluator/utils.py",
    "chars": 4382,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/17\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/model/__init__.py",
    "chars": 1963,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020"
  },
  {
    "path": "crslab/model/base.py",
    "chars": 1621,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020"
  },
  {
    "path": "crslab/model/conversation/__init__.py",
    "chars": 70,
    "preview": "from .gpt2 import GPT2Model\nfrom .transformer import TransformerModel\n"
  },
  {
    "path": "crslab/model/conversation/gpt2/__init__.py",
    "chars": 28,
    "preview": "from .gpt2 import GPT2Model\n"
  },
  {
    "path": "crslab/model/conversation/gpt2/gpt2.py",
    "chars": 5966,
    "preview": "# @Time   : 2020/12/14\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2021/1/7\n# @Author"
  },
  {
    "path": "crslab/model/conversation/transformer/__init__.py",
    "chars": 42,
    "preview": "from .transformer import TransformerModel\n"
  },
  {
    "path": "crslab/model/conversation/transformer/transformer.py",
    "chars": 11817,
    "preview": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1"
  },
  {
    "path": "crslab/model/crs/__init__.py",
    "chars": 130,
    "preview": "from .inspired import *\nfrom .kbrd import *\nfrom .kgsf import *\nfrom .redial import *\nfrom .tgredial import *\nfrom .ntrd"
  },
  {
    "path": "crslab/model/crs/inspired/__init__.py",
    "chars": 88,
    "preview": "from .inspired_conv import InspiredConvModel\nfrom .inspired_rec import InspiredRecModel\n"
  },
  {
    "path": "crslab/model/crs/inspired/inspired_conv.py",
    "chars": 5505,
    "preview": "# @Time   : 2021/3/10\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nimport os\n\nimport torch\nfrom tran"
  },
  {
    "path": "crslab/model/crs/inspired/inspired_rec.py",
    "chars": 2376,
    "preview": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4"
  },
  {
    "path": "crslab/model/crs/inspired/modules.py",
    "chars": 2326,
    "preview": "# @Time   : 2021/3/10\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nimport torch\nimport torch.nn as n"
  },
  {
    "path": "crslab/model/crs/kbrd/__init__.py",
    "chars": 28,
    "preview": "from .kbrd import KBRDModel\n"
  },
  {
    "path": "crslab/model/crs/kbrd/kbrd.py",
    "chars": 14269,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/4\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPD"
  },
  {
    "path": "crslab/model/crs/kgsf/__init__.py",
    "chars": 28,
    "preview": "from .kgsf import KGSFModel\n"
  },
  {
    "path": "crslab/model/crs/kgsf/kgsf.py",
    "chars": 22081,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020"
  },
  {
    "path": "crslab/model/crs/kgsf/modules.py",
    "chars": 7405,
    "preview": "import numpy as np\nimport torch\nfrom torch import nn as nn\n\nfrom crslab.model.utils.modules.transformer import MultiHead"
  },
  {
    "path": "crslab/model/crs/kgsf/resources.py",
    "chars": 2487,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/13\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/model/crs/ntrd/__init__.py",
    "chars": 27,
    "preview": "from .ntrd import NTRDModel"
  },
  {
    "path": "crslab/model/crs/ntrd/modules.py",
    "chars": 10871,
    "preview": "# @Time   : 2021/10/06\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nimport numpy as np\nimport torch\nf"
  },
  {
    "path": "crslab/model/crs/ntrd/ntrd.py",
    "chars": 19693,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2021/10/1\n# @Author  :   Zhipeng Zhao\n# @email   :   oran_official@outlook.com\n"
  },
  {
    "path": "crslab/model/crs/ntrd/resources.py",
    "chars": 2487,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/13\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/model/crs/redial/__init__.py",
    "chars": 80,
    "preview": "from .redial_conv import ReDialConvModel\nfrom .redial_rec import ReDialRecModel\n"
  },
  {
    "path": "crslab/model/crs/redial/modules.py",
    "chars": 8556,
    "preview": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE:\n# @Time   : 2020/12/16\n# @Au"
  },
  {
    "path": "crslab/model/crs/redial/redial_conv.py",
    "chars": 6014,
    "preview": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/"
  },
  {
    "path": "crslab/model/crs/redial/redial_rec.py",
    "chars": 3185,
    "preview": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/"
  },
  {
    "path": "crslab/model/crs/tgredial/__init__.py",
    "chars": 101,
    "preview": "from .tg_conv import TGConvModel\nfrom .tg_policy import TGPolicyModel\nfrom .tg_rec import TGRecModel\n"
  },
  {
    "path": "crslab/model/crs/tgredial/tg_conv.py",
    "chars": 6124,
    "preview": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/7, 2020/12/"
  },
  {
    "path": "crslab/model/crs/tgredial/tg_policy.py",
    "chars": 3098,
    "preview": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/7, 2020/12/"
  },
  {
    "path": "crslab/model/crs/tgredial/tg_rec.py",
    "chars": 4350,
    "preview": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/7, 2021/1/4"
  },
  {
    "path": "crslab/model/policy/__init__.py",
    "chars": 173,
    "preview": "from .conv_bert import ConvBERTModel\nfrom .mgcg import MGCGModel\nfrom .pmi import PMIModel\nfrom .profile_bert import Pro"
  },
  {
    "path": "crslab/model/policy/conv_bert/__init__.py",
    "chars": 37,
    "preview": "from .conv_bert import ConvBERTModel\n"
  },
  {
    "path": "crslab/model/policy/conv_bert/conv_bert.py",
    "chars": 2413,
    "preview": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @"
  },
  {
    "path": "crslab/model/policy/mgcg/__init__.py",
    "chars": 28,
    "preview": "from .mgcg import MGCGModel\n"
  },
  {
    "path": "crslab/model/policy/mgcg/mgcg.py",
    "chars": 5970,
    "preview": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n#"
  },
  {
    "path": "crslab/model/policy/pmi/__init__.py",
    "chars": 26,
    "preview": "from .pmi import PMIModel\n"
  },
  {
    "path": "crslab/model/policy/pmi/pmi.py",
    "chars": 3400,
    "preview": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n#"
  },
  {
    "path": "crslab/model/policy/profile_bert/__init__.py",
    "chars": 43,
    "preview": "from .profile_bert import ProfileBERTModel\n"
  },
  {
    "path": "crslab/model/policy/profile_bert/profile_bert.py",
    "chars": 2760,
    "preview": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @"
  },
  {
    "path": "crslab/model/policy/topic_bert/__init__.py",
    "chars": 39,
    "preview": "from .topic_bert import TopicBERTModel\n"
  },
  {
    "path": "crslab/model/policy/topic_bert/topic_bert.py",
    "chars": 2441,
    "preview": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @"
  },
  {
    "path": "crslab/model/pretrained_models.py",
    "chars": 2156,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2021/1/6\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDA"
  },
  {
    "path": "crslab/model/recommendation/__init__.py",
    "chars": 168,
    "preview": "from .bert import BERTModel\nfrom .gru4rec import GRU4RECModel\nfrom .popularity import PopularityModel\nfrom .sasrec impor"
  },
  {
    "path": "crslab/model/recommendation/bert/__init__.py",
    "chars": 28,
    "preview": "from .bert import BERTModel\n"
  },
  {
    "path": "crslab/model/recommendation/bert/bert.py",
    "chars": 2408,
    "preview": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4"
  },
  {
    "path": "crslab/model/recommendation/gru4rec/__init__.py",
    "chars": 34,
    "preview": "from .gru4rec import GRU4RECModel\n"
  },
  {
    "path": "crslab/model/recommendation/gru4rec/gru4rec.py",
    "chars": 5500,
    "preview": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1"
  },
  {
    "path": "crslab/model/recommendation/gru4rec/modules.py",
    "chars": 2277,
    "preview": "import torch\nfrom torch import nn\n\n\nclass Embedding(nn.Module):\n    def __init__(self, item_size, embedding_dim):\n      "
  },
  {
    "path": "crslab/model/recommendation/popularity/__init__.py",
    "chars": 40,
    "preview": "from .popularity import PopularityModel\n"
  },
  {
    "path": "crslab/model/recommendation/popularity/popularity.py",
    "chars": 1669,
    "preview": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1"
  },
  {
    "path": "crslab/model/recommendation/sasrec/__init__.py",
    "chars": 32,
    "preview": "from .sasrec import SASRECModel\n"
  },
  {
    "path": "crslab/model/recommendation/sasrec/modules.py",
    "chars": 14196,
    "preview": "# @Time   : 2020/12/13\n# @Author : Kun Zhou\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE\n# @Time   : 2020/12/13, 2021/1/4\n#"
  },
  {
    "path": "crslab/model/recommendation/sasrec/sasrec.py",
    "chars": 3714,
    "preview": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1"
  },
  {
    "path": "crslab/model/recommendation/textcnn/__init__.py",
    "chars": 34,
    "preview": "from .textcnn import TextCNNModel\n"
  },
  {
    "path": "crslab/model/recommendation/textcnn/textcnn.py",
    "chars": 2743,
    "preview": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1"
  },
  {
    "path": "crslab/model/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "crslab/model/utils/functions.py",
    "chars": 1195,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/11/26\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UP"
  },
  {
    "path": "crslab/model/utils/modules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "crslab/model/utils/modules/attention.py",
    "chars": 2472,
    "preview": "# -*- coding: utf-8 -*-\n# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @T"
  },
  {
    "path": "crslab/model/utils/modules/transformer.py",
    "chars": 16780,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24\n# @Au"
  },
  {
    "path": "crslab/quick_start/__init__.py",
    "chars": 36,
    "preview": "from .quick_start import run_crslab\n"
  },
  {
    "path": "crslab/quick_start/quick_start.py",
    "chars": 3103,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2021/1/8\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDA"
  },
  {
    "path": "crslab/system/__init__.py",
    "chars": 1967,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020"
  },
  {
    "path": "crslab/system/base.py",
    "chars": 14524,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2021"
  },
  {
    "path": "crslab/system/inspired.py",
    "chars": 9596,
    "preview": "# @Time   : 2021/3/1\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nimport torch\nfrom loguru import lo"
  },
  {
    "path": "crslab/system/kbrd.py",
    "chars": 7557,
    "preview": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/4\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPD"
  },
  {
    "path": "crslab/system/kgsf.py",
    "chars": 8504,
    "preview": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2021"
  },
  {
    "path": "crslab/system/ntrd.py",
    "chars": 9038,
    "preview": "# @Time   : 2021/10/05\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nimport os\nfrom crslab.evaluator.m"
  },
  {
    "path": "crslab/system/redial.py",
    "chars": 7438,
    "preview": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE\n# @Time   : 2021/1/3\n# @Autho"
  },
  {
    "path": "crslab/system/tgredial.py",
    "chars": 16894,
    "preview": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/3\n# @Author"
  }
]

// ... and 53 more files (download for full content)

About this extraction

This page contains the full source code of the RUCAIBox/CRSLab GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 253 files (590.2 KB), approximately 160.3k tokens, and a symbol index with 674 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!