[
  {
    "path": ".gitattributes",
    "content": "* text=auto eol=lf\n*.{cmd,[cC][mM][dD]} text eol=crlf\n*.{bat,[bB][aA][tT]} text eol=crlf"
  },
  {
    "path": ".gitignore",
    "content": "# Created by .ignore support plugin (hsz.mobi)\n### Project\n\ndata\nlog\nsave\n!crslab/data\nruns\n\n### VisualStudioCode template\n.vscode/*\n!.vscode/settings.json\n!.vscode/tasks.json\n!.vscode/launch.json\n!.vscode/extensions.json\n*.code-workspace\n\n# Local History for Visual Studio Code\n.history/\n\n### Python template\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n### JetBrains template\n# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider\n# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839\n\n# User-specific stuff\n.idea/**/workspace.xml\n.idea/**/tasks.xml\n.idea/**/usage.statistics.xml\n.idea/**/dictionaries\n.idea/**/shelf\n\n# Generated files\n.idea/**/contentModel.xml\n\n# Sensitive or high-churn files\n.idea/**/dataSources/\n.idea/**/dataSources.ids\n.idea/**/dataSources.local.xml\n.idea/**/sqlDataSources.xml\n.idea/**/dynamic.xml\n.idea/**/uiDesigner.xml\n.idea/**/dbnavigator.xml\n\n# Gradle\n.idea/**/gradle.xml\n.idea/**/libraries\n\n# Gradle and Maven with auto-import\n# When using Gradle or Maven with auto-import, you should exclude module files,\n# since they will be recreated, and may cause churn.  Uncomment if using\n# auto-import.\n# .idea/artifacts\n# .idea/compiler.xml\n# .idea/jarRepositories.xml\n# .idea/modules.xml\n# .idea/*.iml\n# .idea/modules\n# *.iml\n# *.ipr\n\n# CMake\ncmake-build-*/\n\n# Mongo Explorer plugin\n.idea/**/mongoSettings.xml\n\n# File-based project format\n*.iws\n\n# IntelliJ\n.idea\n*.iml\nout\ngen\n\n# mpeltonen/sbt-idea plugin\n.idea_modules/\n\n# JIRA plugin\natlassian-ide-plugin.xml\n\n# Cursive Clojure plugin\n.idea/replstate.xml\n\n# Crashlytics plugin (for Android Studio and IntelliJ)\ncom_crashlytics_export_strings.xml\ncrashlytics.properties\ncrashlytics-build.properties\nfabric.properties\n\n# Editor-based Rest Client\n.idea/httpRequests\n\n# Android studio 3.1+ serialized cache file\n.idea/caches/build_file_checksums.ser\n\n### JupyterNotebooks template\n# gitignore template for Jupyter Notebooks\n# website: http://jupyter.org/\n\n*/.ipynb_checkpoints/*\n\n# Remove previous ipynb_checkpoints\n#   git rm -r .ipynb_checkpoints/\n\n### macOS template\n# General\n.DS_Store\n.AppleDouble\n.LSOverride\n\n# Icon must end with two \\r\nIcon\n\n# Thumbnails\n._*\n\n# Files that might appear in the root of a volume\n.DocumentRevisions-V100\n.fseventsd\n.Spotlight-V100\n.TemporaryItems\n.Trashes\n.VolumeIcon.icns\n.com.apple.timemachine.donotpresent\n\n# Directories potentially created on remote AFP share\n.AppleDB\n.AppleDesktop\nNetwork Trash Folder\nTemporary Items\n.apdisk\n"
  },
  {
    "path": ".readthedocs.yml",
    "content": "# Required\nversion: 2\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n  configuration: docs/source/conf.py\n\n# Build documentation with MkDocs\n#mkdocs:\n#  configuration: mkdocs.yml\n\n# Optionally build your docs in additional formats such as PDF\nformats: all\n\n# Optionally set the version of Python and requirements required to build your docs\npython:\n  version: 3.6\n  install:\n    - requirements: docs/requirements_torch.txt\n    - requirements: docs/requirements_geometric.txt\n    - requirements: docs/requirements.txt\n    - requirements: docs/requirements_sphinx.txt"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 RUCAIBox\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# CRSLab\n\n[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab)\n[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases)\n[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE)\n[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939)\n[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest)\n\n[Paper](https://arxiv.org/pdf/2101.00939.pdf) | [Docs](https://crslab.readthedocs.io/en/latest/?badge=latest)\n| [中文版](./README_CN.md)\n\n**CRSLab** is an open-source toolkit for building Conversational Recommender System (CRS). It is developed based on\nPython and PyTorch. CRSLab has the following highlights:\n\n- **Comprehensive benchmark models and datasets**: We have integrated commonly-used 6 datasets and 18 models, including graph neural network and pre-training models such as R-GCN, BERT and GPT-2. We have preprocessed these datasets to support these models, and release for downloading.\n- **Extensive and standard evaluation protocols**: We support a series of widely-adopted evaluation protocols for testing and comparing different CRS.\n- **General and extensible structure**: We design a general and extensible structure to unify various conversational recommendation datasets and models, in which we integrate various built-in interfaces and functions for quickly development.\n- **Easy to get started**: We provide simple yet flexible configuration for new researchers to quickly start in our library. \n- **Human-machine interaction interfaces**: We provide flexible human-machine interaction interfaces for researchers to conduct qualitative analysis.\n\n<p align=\"center\">\n  <img src=\"https://i.loli.net/2020/12/30/6TPVG4pBg2rcDf9.png\" alt=\"RecBole v0.1 architecture\" width=\"400\">\n  <br>\n  <b>Figure 1</b>: The overall framework of CRSLab\n</p>\n\n\n\n\n- [Installation](#Installation)\n- [Quick-Start](#Quick-Start)\n- [Models](#Models)\n- [Datasets](#Datasets)\n- [Performance](#Performance)\n- [Releases](#Releases)\n- [Contributions](#Contributions)\n- [Citing](#Citing)\n- [Team](#Team)\n- [License](#License)\n\n\n\n## Installation\n\nCRSLab works with the following operating systems：\n\n- Linux\n- Windows 10\n- macOS X\n\nCRSLab requires Python version 3.7 or later.\n\nCRSLab requires torch version 1.8. If you want to use CRSLab with GPU, please ensure that CUDA or CUDAToolkit version is 10.2 or later. Please use the combinations shown in this [Link](https://pytorch-geometric.com/whl/) to ensure the normal operation of PyTorch Geometric.\n\n\n\n### Install PyTorch\n\nUse PyTorch [Locally Installation](https://pytorch.org/get-started/locally/) or [Previous Versions Installation](https://pytorch.org/get-started/previous-versions/) commands to install PyTorch. For example, on Linux and Windows 10:\n\n```bash\n# CUDA 10.2\nconda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch\n\n# CUDA 11.1\nconda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge\n\n# CPU Only\nconda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cpuonly -c pytorch\n```\n\nIf you want to use CRSLab with GPU, make sure the following command prints `True` after installation:\n\n```bash\n$ python -c \"import torch; print(torch.cuda.is_available())\"\n>>> True\n```\n\n\n\n### Install PyTorch Geometric\n\nEnsure that at least PyTorch 1.8.0 is installed:\n\n```bash\n$ python -c \"import torch; print(torch.__version__)\"\n>>> 1.8.0\n```\n\nFind the CUDA version PyTorch was installed with:\n\n```bash\n$ python -c \"import torch; print(torch.version.cuda)\"\n>>> 11.1\n```\n\nFor Linux:\n\nInstall the relevant packages:\n\n```\nconda install pyg -c pyg\n```\n\nFor others：\n\nCheck PyG [installation documents](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) to install the relevant packages.\n\n\n\n### Install CRSLab\n\nYou can install from pip:\n\n```bash\npip install crslab\n```\n\nOR install from source:\n\n```bash\ngit clone https://github.com/RUCAIBox/CRSLab && cd CRSLab\npip install -e .\n```\n\n\n\n## Quick-Start\n\nWith the source code, you can use the provided script for initial usage of our library with cpu by default:\n\n```bash\npython run_crslab.py --config config/crs/kgsf/redial.yaml\n```\n\nThe system will complete the data preprocessing, and training, validation, testing of each model in turn. Finally it will get the evaluation results of specified models.\n\nIf you want to save pre-processed datasets and training results of models, you can use the following command:\n\n```bash\npython run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system\n```\n\nIn summary, there are following arguments in `run_crslab.py`:\n\n- `--config` or `-c`: relative path for configuration file(yaml).\n- `--gpu` or `-g`: specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1).\n- `--save_data` or `-sd`: save pre-processed dataset.\n- `--restore_data` or `-rd`: restore pre-processed dataset from file.\n- `--save_system` or `-ss`: save trained system.\n- `--restore_system` or `-rs`: restore trained system from file.\n- `--debug` or `-d`: use validation dataset to debug your system.\n- `--interact` or `-i`: interact with your system instead of training.\n- `--tensorboard` or `-tb`: enable tensorboard to monitor train performance.\n\n\n\n## Models\n\nIn CRSLab, we unify the task description of conversational recommendation into three sub-tasks, namely recommendation (recommend user-preferred items), conversation (generate proper responses) and policy (select proper interactive action). The recommendation and conversation sub-tasks are the core of a CRS and have been studied in most of works. The policy sub-task is needed by recent works, by which the CRS can interact with users through purposeful strategy.\nAs the first release version, we have implemented 18 models in the four categories of CRS model, Recommendation model, Conversation model and Policy model.\n\n|       Category       |                            Model                             |      Graph Neural Network?      |       Pre-training Model?       |\n| :------------------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: |\n|      CRS Model       | [ReDial](https://arxiv.org/abs/1812.07617)<br/>[KBRD](https://arxiv.org/abs/1908.05391)<br/>[KGSF](https://arxiv.org/abs/2007.04032)<br/>[TG-ReDial](https://arxiv.org/abs/2010.04125)<br/>[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) |       ×<br/>√<br/>√<br/>×<br/>×       |       ×<br/>×<br/>×<br/>√<br/>√       |\n| Recommendation model | Popularity<br/>[GRU4Rec](https://arxiv.org/abs/1511.06939)<br/>[SASRec](https://arxiv.org/abs/1808.09781)<br/>[TextCNN](https://arxiv.org/abs/1408.5882)<br/>[R-GCN](https://arxiv.org/abs/1703.06103)<br/>[BERT](https://arxiv.org/abs/1810.04805) | ×<br/>×<br/>×<br/>×<br/>√<br/>× | ×<br/>×<br/>×<br/>×<br/>×<br/>√ |\n|  Conversation model  | [HERD](https://arxiv.org/abs/1507.04808)<br/>[Transformer](https://arxiv.org/abs/1706.03762)<br/>[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) |          ×<br/>×<br/>×          |          ×<br/>×<br/>√          |\n|     Policy model     | PMI<br/>[MGCG](https://arxiv.org/abs/2005.03954)<br/>[Conv-BERT](https://arxiv.org/abs/2010.04125)<br/>[Topic-BERT](https://arxiv.org/abs/2010.04125)<br/>[Profile-BERT](https://arxiv.org/abs/2010.04125) |    ×<br/>×<br/>×<br/>×<br/>×    |    ×<br/>×<br/>√<br/>√<br/>√    |\n\nAmong them, the four CRS models integrate the recommendation model and the conversation model to improve each other, while others only specify an individual task.\n\nFor Recommendation model and Conversation model, we have respectively implemented the following commonly-used automatic evaluation metrics:\n\n|        Category        |                           Metrics                            |\n| :--------------------: | :----------------------------------------------------------: |\n| Recommendation Metrics |      Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50}      |\n|  Conversation Metrics  | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} |\n|     Policy Metrics     |        Accuracy, Hit@{1,3,5}           |\n\n\n\n## Datasets\n\nWe have collected and preprocessed 6 commonly-used human-annotated datasets, and each dataset was matched with proper KGs as shown below:\n\n|                           Dataset                            | Dialogs | Utterances |   Domains    | Task Definition | Entity KG  |  Word KG   |\n| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: |\n|       [ReDial](https://redialdata.github.io/website/)        | 10,006  |  182,150   |    Movie     |       --        |  DBpedia   | ConceptNet |\n|      [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial)      | 10,000  |  129,392   |    Movie     |   Topic Guide   | CN-DBpedia |   HowNet   |\n|        [GoRecDial](https://arxiv.org/abs/1909.03922)         |  9,125  |  170,904   |    Movie     |  Action Choice  |  DBpedia   | ConceptNet |\n|        [DuRecDial](https://arxiv.org/abs/2005.03954)         | 10,200  |  156,000   | Movie, Music |    Goal Plan    | CN-DBpedia |   HowNet   |\n|      [INSPIRED](https://github.com/sweetpeach/Inspired)      |  1,001  |   35,811   |    Movie     | Social Strategy |  DBpedia   | ConceptNet |\n| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802  |   91,209   | Movie, Book  |  Path Generate  |  DBpedia   | ConceptNet |\n\n\n\n## Performance\n\nWe have trained and test the integrated models on the TG-Redial dataset, which is split into training, validation and test sets using a ratio of 8:1:1. For each conversation, we start from the first utterance, and generate reply utterances or recommendations in turn by our model. We perform the evaluation on the three sub-tasks.\n\n### Recommendation Task\n\n|   Model   |    Hit@1    |   Hit@10   |   Hit@50   |    MRR@1    |   MRR@10   |   MRR@50   |   NDCG@1    |  NDCG@10   |  NDCG@50   |\n| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: |\n|  SASRec   |  0.000446   |  0.00134   |   0.0160   |   0.000446  |  0.000576  |  0.00114   |  0.000445   |  0.00075   |  0.00380   |\n|  TextCNN  |   0.00267   |   0.0103   |   0.0236   |   0.00267   |  0.00434   |  0.00493   |   0.00267   |  0.00570   |  0.00860   |\n|   BERT    |   0.00722   |  0.00490   |   0.0281   |   0.00722   |   0.0106   |   0.0124   |   0.00490   |   0.0147   |   0.0239   |\n|   KBRD    |   0.00401   |   0.0254   |   0.0588   |   0.00401   |  0.00891   |   0.0103   |   0.00401   |   0.0127   |   0.0198   |\n|   KGSF    |   0.00535   | **0.0285** | **0.0771** |   0.00535   |   0.0114   | **0.0135** |   0.00535   | **0.0154** | **0.0259** |\n| TG-ReDial | **0.00793** |   0.0251   |   0.0524   | **0.00793** | **0.0122** |   0.0134   | **0.00793** |   0.0152   |   0.0211   |\n\n\n### Conversation Task\n\n|    Model    |  BLEU@1   |  BLEU@2   |   BLEU@3   |   BLEU@4   |  Dist@1  |  Dist@2  |  Dist@3  |  Dist@4  |  Average  |  Extreme  |  Greedy   |   PPL    |\n| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: |\n|    HERD     |   0.120   |  0.0141   |  0.00136   |  0.000350  |  0.181   |  0.369   |  0.847   |   1.30   |   0.697   |   0.382   |   0.639   |   472    |\n| Transformer |   0.266   |  0.0440   |   0.0145   |  0.00651   |  0.324   |  0.837   |   2.02   |   3.06   |   0.879   |   0.438   |   0.680   |   30.9   |\n|    GPT2     |  0.0858   |  0.0119   |  0.00377   |   0.0110   | **2.35** | **4.62** | **8.84** | **12.5** |   0.763   |   0.297   |   0.583   |   9.26   |\n|    KBRD     |   0.267   |  0.0458   |   0.0134   |  0.00579   |  0.469   |   1.50   |   3.40   |   4.90   |   0.863   |   0.398   |   0.710   |   52.5   |\n|    KGSF     | **0.383** | **0.115** | **0.0444** | **0.0200** |  0.340   |  0.910   |   3.50   |   6.20   | **0.888** | **0.477** | **0.767** |   50.1   |\n|  TG-ReDial  |   0.125   |  0.0204   |  0.00354   |  0.000803  |  0.881   |   1.75   |   7.00   |   12.0   |   0.810   |   0.332   |   0.598   | **7.41** |\n\n\n### Policy Task\n\n|   Model    |   Hit@1   |  Hit@10   |  Hit@50   |   MRR@1   |  MRR@10   |  MRR@50   |  NDCG@1   |  NDCG@10  |  NDCG@50  |\n| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |\n|    MGCG    |   0.591   |   0.818   |   0.883   |   0.591   |   0.680   |   0.683   |   0.591   |   0.712   |   0.729   |\n| Conv-BERT  |   0.597   |   0.814   |   0.881   |   0.597   |   0.684   |   0.687   |   0.597   |   0.716   |   0.731   |\n| Topic-BERT |   0.598   |   0.828   |   0.885   |   0.598   |   0.690   |   0.693   |   0.598   |   0.724   |   0.737   |\n| TG-ReDial  | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** |\n\nThe above results were obtained from our CRSLab in preliminary experiments. However, these algorithms were implemented and tuned based on our understanding and experiences, which may not achieve their optimal performance. If you could yield a better result for some specific algorithm, please kindly let us know. We will update this table after the results are verified.\n\n## Releases\n\n| Releases |     Date      |   Features   |\n| :------: | :-----------: | :----------: |\n|  v0.1.1  | 1 / 4 / 2021  | Basic CRSLab |\n|  v0.1.2  | 3 / 28 / 2021 |    CRSLab    |\n\n\n\n## Contributions\n\nPlease let us know if you encounter a bug or have any suggestions by [filing an issue](https://github.com/RUCAIBox/CRSLab/issues).\n\nWe welcome all contributions from bug fixes to new features and extensions.\n\nWe expect all contributions discussed in the issue tracker and going through PRs.\n\nWe thank the nice contributions through PRs from [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang).\n\n\n\n## Citing\n\nIf you find CRSLab useful for your research or development, please cite our [Paper](https://arxiv.org/pdf/2101.00939.pdf):\n\n```\n@article{crslab,\n    title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System},\n    author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen},\n    year={2021},\n    journal={arXiv preprint arXiv:2101.00939}\n}\n```\n\n\n\n## Team\n\n**CRSLab** was developed and maintained by [AI Box](http://aibox.ruc.edu.cn/) group in RUC.\n\n\n\n## License\n\n**CRSLab** uses [MIT License](./LICENSE).\n\n"
  },
  {
    "path": "README_CN.md",
    "content": "# CRSLab\n\n[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab)\n[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases)\n[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE)\n[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939)\n[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest)\n\n[论文](https://arxiv.org/pdf/2101.00939.pdf) | [文档](https://crslab.readthedocs.io/en/latest/?badge=latest)\n| [English Version](./README.md)\n\n**CRSLab** 是一个用于构建对话推荐系统（CRS）的开源工具包，其基于 PyTorch 实现、主要面向研究者使用，并具有如下特色：\n\n- **全面的基准模型和数据集**：我们集成了常用的 6 个数据集和 18 个模型，包括基于图神经网络和预训练模型，比如  GCN，BERT 和 GPT-2；我们还对数据集进行相关处理以支持这些模型，并提供预处理后的版本供大家下载。\n- **大规模的标准评测**：我们支持一系列被广泛认可的评估方式来测试和比较不同的 CRS。\n- **通用和可扩展的结构**：我们设计了通用和可扩展的结构来统一各种对话推荐数据集和模型，并集成了多种内置接口和函数以便于快速开发。\n- **便捷的使用方法**：我们为新手提供了简单而灵活的配置，方便其快速启动集成在 CRSLab 中的模型。\n- **人性化的人机交互接口**：我们提供了人性化的人机交互界面，以供研究者对比和测试不同的模型系统。\n\n<p align=\"center\">\n  <img src=\"https://i.loli.net/2020/12/30/6TPVG4pBg2rcDf9.png\" alt=\"RecBole v0.1 architecture\" width=\"400\">\n  <br>\n  <b>图片</b>: CRSLab 的总体架构\n</p>\n\n\n\n\n- [安装](#安装)\n- [快速上手](#快速上手)\n- [模型](#模型)\n- [数据集](#数据集)\n- [评测结果](#评测结果)\n- [发行版本](#发行版本)\n- [贡献](#贡献)\n- [引用](#引用)\n- [项目团队](#项目团队)\n- [免责声明](#免责声明)\n\n\n\n## 安装\n\nCRSLab 可以在以下几种系统上运行：\n\n- Linux\n- Windows 10\n- macOS X\n\nCRSLab 需要在 Python 3.7 或更高的环境下运行。\n\nCRSLab 要求 torch 版本为1.8，如果你想在 GPU 上运行 CRSLab，请确保你的 CUDA 版本或者 CUDAToolkit 版本在 10.2 及以上。为保证 PyTorch Geometric 库的正常运行，请使用[链接](https://pytorch-geometric.com/whl/)所示的安装方式。\n\n\n\n### 安装 PyTorch\n\n使用 PyTorch [本地安装](https://pytorch.org/get-started/locally/)命令或者[先前版本安装](https://pytorch.org/get-started/previous-versions/)命令安装 PyTorch，比如在 Linux 和 Windows 下：\n\n```bash\n# CUDA 10.2\nconda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch\n\n# CUDA 11.1\nconda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge\n\n# CPU Only\nconda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cpuonly -c pytorch\n```\n\n安装完成后，如果你想在 GPU 上运行 CRSLab，请确保如下命令输出`True`：\n\n```bash\n$ python -c \"import torch; print(torch.cuda.is_available())\"\n>>> True\n```\n\n\n\n### 安装 PyTorch Geometric\n\n确保安装的 PyTorch 版本至少为 1.8.0：\n\n```bash\n$ python -c \"import torch; print(torch.__version__)\"\n>>> 1.8.0\n```\n\n找到安装好的 PyTorch 对应的 CUDA 版本：\n\n```bash\n$ python -c \"import torch; print(torch.version.cuda)\"\n>>> 11.1\n```\n\n在Linux下：\n\n安装相关的包：\n\n```bash\nconda install pyg -c pyg\n```\n\n在其他系统下：\n\n查看PyG[官方下载文档](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)安装相关的包。\n\n### 安装 CRSLab\n\n你可以通过 pip 来安装：\n\n```bash\npip install crslab\n```\n\n也可以通过源文件进行进行安装：\n\n```bash\ngit clone https://github.com/RUCAIBox/CRSLab && cd CRSLab\npip install -e .\n```\n\n\n\n## 快速上手\n\n从 GitHub 下载 CRSLab 后，可以使用提供的脚本快速运行和测试，默认使用CPU：\n\n```bash\npython run_crslab.py --config config/crs/kgsf/redial.yaml\n```\n\n系统将依次完成数据的预处理，以及各模块的训练、验证和测试，并得到指定的模型评测结果。\n\n如果你希望保存数据预处理结果与模型训练结果，可以使用如下命令：\n\n```bash\npython run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system\n```\n\n总的来说，`run_crslab.py`有如下参数可供调用：\n\n- `--config` 或 `-c`：配置文件的相对路径，以指定运行的模型与数据集。\n- `--gpu` or `-g`：指定 GPU id，支持多 GPU，默认使用 CPU（-1）。\n- `--save_data` 或 `-sd`：保存预处理的数据。\n- `--restore_data` 或 `-rd`：从文件读取预处理的数据。\n- `--save_system` 或 `-ss`：保存训练好的 CRS 系统。\n- `--restore_system` 或 `-rs`：从文件载入提前训练好的系统。\n- `--debug` 或 `-d`：用验证集代替训练集以方便调试。\n- `--interact` 或 `-i`：与你的系统进行对话交互，而非进行训练。\n- `--tensorboard` or `-tb`：使用 tensorboardX 组件来监测训练表现。\n\n\n\n## 模型\n\n在第一个发行版中，我们实现了 4 类共 18 个模型。这里我们将对话推荐任务主要拆分成三个任务：推荐任务（生成推荐的商品），对话任务（生成对话的回复）和策略任务（规划对话推荐的策略）。其中所有的对话推荐系统都具有对话和推荐任务，他们是对话推荐系统的核心功能。而策略任务是一个辅助任务，其致力于更好的控制对话推荐系统，在不同的模型中的实现也可能不同（如 TG-ReDial 采用一个主题预测模型，DuRecDial 中采用一个对话规划模型等）：\n\n\n\n|   类别   |                             模型                             |      Graph Neural Network?      |       Pre-training Model?       |\n| :------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: |\n| CRS 模型 | [ReDial](https://arxiv.org/abs/1812.07617)<br/>[KBRD](https://arxiv.org/abs/1908.05391)<br/>[KGSF](https://arxiv.org/abs/2007.04032)<br/>[TG-ReDial](https://arxiv.org/abs/2010.04125)<br/>[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) |    ×<br/>√<br/>√<br/>×<br/>×    |    ×<br/>×<br/>×<br/>√<br/>√    |\n| 推荐模型 | Popularity<br/>[GRU4Rec](https://arxiv.org/abs/1511.06939)<br/>[SASRec](https://arxiv.org/abs/1808.09781)<br/>[TextCNN](https://arxiv.org/abs/1408.5882)<br/>[R-GCN](https://arxiv.org/abs/1703.06103)<br/>[BERT](https://arxiv.org/abs/1810.04805) | ×<br/>×<br/>×<br/>×<br/>√<br/>× | ×<br/>×<br/>×<br/>×<br/>×<br/>√ |\n| 对话模型 | [HERD](https://arxiv.org/abs/1507.04808)<br/>[Transformer](https://arxiv.org/abs/1706.03762)<br/>[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) |          ×<br/>×<br/>×          |          ×<br/>×<br/>√          |\n| 策略模型 | PMI<br/>[MGCG](https://arxiv.org/abs/2005.03954)<br/>[Conv-BERT](https://arxiv.org/abs/2010.04125)<br/>[Topic-BERT](https://arxiv.org/abs/2010.04125)<br/>[Profile-BERT](https://arxiv.org/abs/2010.04125) |    ×<br/>×<br/>×<br/>×<br/>×    |    ×<br/>×<br/>√<br/>√<br/>√    |\n\n\n其中，CRS 模型是指直接融合推荐模型和对话模型，以相互增强彼此的效果，故其内部往往已经包含了推荐、对话和策略模型。其他如推荐模型、对话模型、策略模型往往只关注以上任务中的某一个。\n\n我们对于这几类模型，我们还分别实现了如下的自动评测指标模块：\n\n|   类别   |                             指标                             |\n| :------: | :----------------------------------------------------------: |\n| 推荐指标 |      Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50}      |\n| 对话指标 | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} |\n| 策略指标 | Accuracy, Hit@{1,3,5} |\n\n\n\n\n\n## 数据集\n\n我们收集了 6 个常用的人工标注数据集，并对它们进行了预处理（包括引入外部知识图谱），以融入统一的 CRS 任务中。如下为相关数据集的统计数据：\n\n|                           Dataset                            | Dialogs | Utterances |   Domains    | Task Definition | Entity KG  |  Word KG   |\n| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: |\n|       [ReDial](https://redialdata.github.io/website/)        | 10,006  |  182,150   |    Movie     |       --        |  DBpedia   | ConceptNet |\n|      [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial)      | 10,000  |  129,392   |    Movie     |   Topic Guide   | CN-DBpedia |   HowNet   |\n|        [GoRecDial](https://arxiv.org/abs/1909.03922)         |  9,125  |  170,904   |    Movie     |  Action Choice  |  DBpedia   | ConceptNet |\n|        [DuRecDial](https://arxiv.org/abs/2005.03954)         | 10,200  |  156,000   | Movie, Music |    Goal Plan    | CN-DBpedia |   HowNet   |\n|      [INSPIRED](https://github.com/sweetpeach/Inspired)      |  1,001  |   35,811   |    Movie     | Social Strategy |  DBpedia   | ConceptNet |\n| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802  |   91,209   | Movie, Book  |  Path Generate  |  DBpedia   | ConceptNet |\n\n\n\n## 评测结果\n\n我们在 TG-ReDial 数据集上对模型进行了训练和测试，这里我们将数据集按照 8:1:1 切分。其中对于每条数据，我们从对话的第一轮开始，一轮一轮的进行推荐、策略生成、回复生成任务。下表记录了相关的评测结果。\n\n### 推荐任务\n\n|   模型    |    Hit@1    |   Hit@10   |   Hit@50   |    MRR@1    |   MRR@10   |   MRR@50   |   NDCG@1    |  NDCG@10   |  NDCG@50   |\n| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: |\n|  SASRec   |  0.000446   |  0.00134   |   0.0160   |  0.000446   |  0.000576  |  0.00114   |  0.000445   |  0.00075   |  0.00380   |\n|  TextCNN  |   0.00267   |   0.0103   |   0.0236   |   0.00267   |  0.00434   |  0.00493   |   0.00267   |  0.00570   |  0.00860   |\n|   BERT    |   0.00722   |  0.00490   |   0.0281   |   0.00722   |   0.0106   |   0.0124   |   0.00490   |   0.0147   |   0.0239   |\n|   KBRD    |   0.00401   |   0.0254   |   0.0588   |   0.00401   |  0.00891   |   0.0103   |   0.00401   |   0.0127   |   0.0198   |\n|   KGSF    |   0.00535   | **0.0285** | **0.0771** |   0.00535   |   0.0114   | **0.0135** |   0.00535   | **0.0154** | **0.0259** |\n| TG-ReDial | **0.00793** |   0.0251   |   0.0524   | **0.00793** | **0.0122** |   0.0134   | **0.00793** |   0.0152   |   0.0211   |\n\n\n\n### 对话任务\n\n|    模型     |  BLEU@1   |  BLEU@2   |   BLEU@3   |   BLEU@4   |  Dist@1  |  Dist@2  |  Dist@3  |  Dist@4  |  Average  |  Extreme  |  Greedy   |   PPL    |\n| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: |\n|    HERD     |   0.120   |  0.0141   |  0.00136   |  0.000350  |  0.181   |  0.369   |  0.847   |   1.30   |   0.697   |   0.382   |   0.639   |   472    |\n| Transformer |   0.266   |  0.0440   |   0.0145   |  0.00651   |  0.324   |  0.837   |   2.02   |   3.06   |   0.879   |   0.438   |   0.680   |   30.9   |\n|    GPT2     |  0.0858   |  0.0119   |  0.00377   |   0.0110   | **2.35** | **4.62** | **8.84** | **12.5** |   0.763   |   0.297   |   0.583   |   9.26   |\n|    KBRD     |   0.267   |  0.0458   |   0.0134   |  0.00579   |  0.469   |   1.50   |   3.40   |   4.90   |   0.863   |   0.398   |   0.710   |   52.5   |\n|    KGSF     | **0.383** | **0.115** | **0.0444** | **0.0200** |  0.340   |  0.910   |   3.50   |   6.20   | **0.888** | **0.477** | **0.767** |   50.1   |\n|  TG-ReDial  |   0.125   |  0.0204   |  0.00354   |  0.000803  |  0.881   |   1.75   |   7.00   |   12.0   |   0.810   |   0.332   |   0.598   | **7.41** |\n\n\n\n### 策略任务\n\n|    模型    |   Hit@1   |  Hit@10   |  Hit@50   |   MRR@1   |  MRR@10   |  MRR@50   |  NDCG@1   |  NDCG@10  |  NDCG@50  |\n| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |\n|    MGCG    |   0.591   |   0.818   |   0.883   |   0.591   |   0.680   |   0.683   |   0.591   |   0.712   |   0.729   |\n| Conv-BERT  |   0.597   |   0.814   |   0.881   |   0.597   |   0.684   |   0.687   |   0.597   |   0.716   |   0.731   |\n| Topic-BERT |   0.598   |   0.828   |   0.885   |   0.598   |   0.690   |   0.693   |   0.598   |   0.724   |   0.737   |\n| TG-ReDial  | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** |\n\n上述结果是我们使用 CRSLab 进行实验得到的。然而，这些算法是根据我们的经验和理解来实现和调参的，可能还没有达到它们的最佳性能。如果您能在某个具体算法上得到更好的结果，请告知我们。验证结果后，我们会更新该表。\n\n## 发行版本\n\n| 版本号 |   发行日期    |     特性     |\n| :----: | :-----------: | :----------: |\n| v0.1.1 | 1 / 4 / 2021  | Basic CRSLab |\n| v0.1.2 | 3 / 28 / 2021 |    CRSLab    |\n\n\n\n## 贡献\n\n如果您遇到错误或有任何建议，请通过 [Issue](https://github.com/RUCAIBox/CRSLab/issues) 进行反馈\n\n我们欢迎关于修复错误、添加新特性的任何贡献。\n\n如果想贡献代码，请先在 Issue 中提出问题，然后再提 PR。\n\n我们感谢 [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang) 通过 PR 为项目贡献的新特性。\n\n\n\n## 引用\n\n如果你觉得 CRSLab 对你的科研工作有帮助，请引用我们的[论文](https://arxiv.org/pdf/2101.00939.pdf)：\n\n```\n@article{crslab,\n    title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System},\n    author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen},\n    year={2021},\n    journal={arXiv preprint arXiv:2101.00939}\n}\n```\n\n\n\n## 项目团队\n\n**CRSLab** 由中国人民大学 [AI Box](http://aibox.ruc.edu.cn/) 小组开发和维护。\n\n\n\n## 免责声明\n\n**CRSLab** 基于 [MIT License](./LICENSE) 进行开发，本项目的所有数据和代码只能被用于学术目的。\n"
  },
  {
    "path": "config/conversation/gpt2/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nconv_model: GPT2\n# optim\nconv:\n  epoch: 1\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n"
  },
  {
    "path": "config/conversation/gpt2/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nconv_model: GPT2\n# optim\nconv:\n  epoch: 1\n  batch_size: 4\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n"
  },
  {
    "path": "config/conversation/gpt2/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nconv_model: GPT2\n# optim\nconv:\n  epoch: 1\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n"
  },
  {
    "path": "config/conversation/gpt2/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nconv_model: GPT2\n# optim\nconv:\n  epoch: 1\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n"
  },
  {
    "path": "config/conversation/gpt2/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nconv_model: GPT2\n# optim\nconv:\n  epoch: 1\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n"
  },
  {
    "path": "config/conversation/gpt2/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nconv_model: GPT2\n# optim\nconv:\n  epoch: 50\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  early_stop: true\n  stop_mode: min\n  impatience: 3\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000"
  },
  {
    "path": "config/conversation/transformer/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  conv: jieba\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nconv_model: Transformer\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\nconv:\n  epoch: 1\n  batch_size: 64\n  early_stop: True\n  stop_mode: min\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5"
  },
  {
    "path": "config/conversation/transformer/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nconv_model: Transformer\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\nconv:\n  epoch: 1\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n  early_stop: true\n  stop_mode: min\n  impatience: 3"
  },
  {
    "path": "config/conversation/transformer/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nconv_model: Transformer\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\nconv:\n  epoch: 1\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n  early_stop: true\n  stop_mode: min\n  impatience: 3"
  },
  {
    "path": "config/conversation/transformer/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nconv_model: Transformer\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\nconv:\n  epoch: 1\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n  early_stop: true\n  stop_mode: min\n  impatience: 3"
  },
  {
    "path": "config/conversation/transformer/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  conv: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nconv_model: Transformer\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\nconv:\n  epoch: 1\n  batch_size: 64\n  early_stop: True\n  stop_mode: min\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5"
  },
  {
    "path": "config/conversation/transformer/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  conv: pkuseg\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nconv_model: Transformer\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\nconv:\n  epoch: 50\n  batch_size: 64\n  early_stop: True\n  stop_mode: min\n  patience: 3\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    factor: 0.5"
  },
  {
    "path": "config/crs/inspired/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\n# rec\nrec_model: InspiredRec\n# conv\nconv_model: InspiredConv\n# embedding: word2vec\nembedding_dim: 300\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  early_stop: true\n  stop_mode: max\n  impatience: 3\n  lr_bert: !!float 1e-5\nconv:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 3e-5\n    eps: !!float 1e-06\n    weight_decay: !!float 0.01\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 100\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/inspired/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\n# rec\nrec_model: InspiredRec\n# conv\nconv_model: InspiredConv\n# embedding: word2vec\nembedding_dim: 300\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  early_stop: true\n  stop_mode: max\n  impatience: 3\n  lr_bert: !!float 1e-5\nconv:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 3e-5\n    eps: !!float 1e-06\n    weight_decay: !!float 0.01\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 100\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/inspired/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\n# rec\nrec_model: InspiredRec\n# conv\nconv_model: InspiredConv\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  early_stop: true\n  stop_mode: max\n  impatience: 3\n  lr_bert: !!float 1e-5\nconv:\n  epoch: 50\n  batch_size: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 3e-5\n    eps: !!float 1e-06\n    weight_decay: !!float 0.01\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 100\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n  label_smoothing: -1\n"
  },
  {
    "path": "config/crs/inspired/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\n# rec\nrec_model: InspiredRec\n# conv\nconv_model: InspiredConv\n# embedding: word2vec\nembedding_dim: 300\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  early_stop: true\n  stop_mode: max\n  impatience: 3\n  lr_bert: !!float 1e-5\nconv:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 3e-5\n    eps: !!float 1e-06\n    weight_decay: !!float 0.01\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 100\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/inspired/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\n# rec\nrec_model: InspiredRec\n# conv\nconv_model: InspiredConv\n# embedding: word2vec\nembedding_dim: 300\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  early_stop: true\n  stop_mode: max\n  impatience: 3\n  lr_bert: !!float 1e-5\nconv:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 3e-5\n    eps: !!float 1e-06\n    weight_decay: !!float 0.01\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 100\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/inspired/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\n# rec\nrec_model: InspiredRec\n# conv\nconv_model: InspiredConv\n# embedding: word2vec\nembedding_dim: 300\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  early_stop: true\n  stop_mode: max\n  impatience: 3\n  lr_bert: !!float 1e-5\nconv:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 3e-5\n    eps: !!float 1e-06\n    weight_decay: !!float 0.01\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 100\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/kbrd/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize: jieba\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nmodel: KBRD\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\nuser_proj_dim: 512\n# optim\nrec:\n  epoch: 1\n  batch_size: 4096\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nconv:\n  epoch: 1\n  batch_size: 64\n  early_stop: True\n  stop_mode: min\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5"
  },
  {
    "path": "config/crs/kbrd/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nmodel: KBRD\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\nuser_proj_dim: 512\n# optim\nrec:\n  epoch: 1\n  batch_size: 4096\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nconv:\n  epoch: 1\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n  early_stop: true\n  stop_mode: min\n  impatience: 3"
  },
  {
    "path": "config/crs/kbrd/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nmodel: KBRD\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\nuser_proj_dim: 512\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nconv:\n  epoch: 1\n  batch_size: 64\n  early_stop: True\n  stop_mode: min\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5"
  },
  {
    "path": "config/crs/kbrd/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nmodel: KBRD\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\nuser_proj_dim: 512\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nconv:\n  epoch: 1\n  batch_size: 64\n  early_stop: True\n  stop_mode: min\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5"
  },
  {
    "path": "config/crs/kbrd/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize: nltk\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nmodel: KBRD\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\nuser_proj_dim: 512\n# optim\nrec:\n  epoch: 10\n  batch_size: 4096\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nconv:\n  epoch: 10\n  batch_size: 32\n  early_stop: True\n  stop_mode: min\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5"
  },
  {
    "path": "config/crs/kbrd/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize: pkuseg\n# dataloader\ncontext_truncate: 1024\nresponse_truncate: 1024\nscale: 1\n# model\nmodel: KBRD\ntoken_emb_dim: 300\nn_relation: 56\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\nuser_proj_dim: 512\n# optim\nrec:\n  epoch: 100\n  batch_size: 64\n  early_stop: True\n  stop_mode: max\n  patience: 3\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nconv:\n  epoch: 100\n  batch_size: 16\n  early_stop: True\n  stop_mode: min\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5"
  },
  {
    "path": "config/crs/kgsf/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize: jieba\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nscale: 1\n# model\nmodel: KGSF\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\npretrain:\n  epoch: 1\n  batch_size: 4096\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  early_stop: true\n  stop_mode: max\n  impatience: 3\nconv:\n  epoch: 1\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n"
  },
  {
    "path": "config/crs/kgsf/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nscale: 1\n# model\nmodel: KGSF\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\npretrain:\n  epoch: 1\n  batch_size: 64\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nrec:\n  epoch: 1\n  batch_size: 64\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  early_stop: true\n  stop_mode: max\n  impatience: 3\nconv:\n  epoch: 1\n  batch_size: 64\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n"
  },
  {
    "path": "config/crs/kgsf/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nscale: 1\n# model\nmodel: KGSF\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\npretrain:\n  epoch: 1\n  batch_size: 4096\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  early_stop: true\n  stop_mode: max\n  impatience: 3\nconv:\n  epoch: 1\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n"
  },
  {
    "path": "config/crs/kgsf/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nscale: 1\n# model\nmodel: KGSF\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\npretrain:\n  epoch: 1\n  batch_size: 4096\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  early_stop: true\n  stop_mode: max\n  impatience: 3\nconv:\n  epoch: 1\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 3e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n"
  },
  {
    "path": "config/crs/kgsf/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize: nltk\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nscale: 1\n# model\nmodel: KGSF\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\npretrain:\n  epoch: 3\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\nrec:\n  epoch: 9\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\nconv:\n  epoch: 90\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n"
  },
  {
    "path": "config/crs/kgsf/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize: pkuseg\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nscale: 1\n# model\nmodel: KGSF\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\n# optim\npretrain:\n  epoch: 50\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\nrec:\n  epoch: 20\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  stop_mode: max\n  impatience: 3\nconv:\n  epoch: 10\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n"
  },
  {
    "path": "config/crs/ntrd/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize: pkuseg\nembedding: word2vec.npy\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nscale: 1\n# model\nmodel: NTRD\ntoken_emb_dim: 300\nkg_emb_dim: 128\nnum_bases: 8\nn_heads: 2\nn_layers: 2\nffn_size: 300\ndropout: 0.1\nattention_dropout: 0.0\nrelu_dropout: 0.1\nlearn_positional_embeddings: false\nembeddings_scale: true\nreduction: false\nn_positions: 1024\ngen_loss_weight: 5\nn_movies: 62287\nreplace_token: '[ITEM]'\n# optim\npretrain:\n  epoch: 50\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\nrec:\n  epoch: 20\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  stop_mode: max\n  impatience: 3\nconv:\n  epoch: 10\n  batch_size: 64\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    patience: 3\n    factor: 0.5\n  gradient_clip: 0.1\n"
  },
  {
    "path": "config/crs/redial/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: jieba\n  conv: jieba\n# dataloader\nutterance_truncate: 80\nconversation_truncate: 40\nscale: 1\n# model\n# rec\nrec_model: ReDialRec\nautorec_layer_sizes: [ 1000 ]\nautorec_f: sigmoid\nautorec_g: sigmoid\n# conv\nconv_model: ReDialConv\n# embedding: word2vec\nembedding_dim: 300\nutterance_encoder_hidden_size: 256\ndialog_encoder_hidden_size: 256\ndialog_encoder_num_layers: 1\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\nconv:\n  epoch: 1\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/redial/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_truncate: 40\nscale: 1\n# model\n# rec\nrec_model: ReDialRec\nautorec_layer_sizes: [ 1000 ]\nautorec_f: sigmoid\nautorec_g: sigmoid\n# conv\nconv_model: ReDialConv\n#embedding: word2vec\nembedding_dim: 300\nutterance_encoder_hidden_size: 256\ndialog_encoder_hidden_size: 256\ndialog_encoder_num_layers: 1\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\nconv:\n  epoch: 1\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/redial/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_truncate: 40\nscale: 1\n# model\n# rec\nrec_model: ReDialRec\nautorec_layer_sizes: [ 1000 ]\nautorec_f: sigmoid\nautorec_g: sigmoid\n# conv\nconv_model: ReDialConv\n# embedding: word2vec\nembedding_dim: 300\nutterance_encoder_hidden_size: 256\ndialog_encoder_hidden_size: 256\ndialog_encoder_num_layers: 1\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\nconv:\n  epoch: 1\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/redial/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_truncate: 40\nscale: 1\n# model\n# rec\nrec_model: ReDialRec\nautorec_layer_sizes: [ 1000 ]\nautorec_f: sigmoid\nautorec_g: sigmoid\n# conv\nconv_model: ReDialConv\n# embedding: word2vec\nembedding_dim: 300\nutterance_encoder_hidden_size: 256\ndialog_encoder_hidden_size: 256\ndialog_encoder_num_layers: 1\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\nconv:\n  epoch: 1\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/redial/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: nltk\n  conv: nltk\n# dataloader\nutterance_truncate: 80\nconversation_truncate: 40\nscale: 1\n# model\n# rec\nrec_model: ReDialRec\nautorec_layer_sizes: [ 1000 ]\nautorec_f: sigmoid\nautorec_g: sigmoid\n# conv\nconv_model: ReDialConv\n# embedding: word2vec\nembedding_dim: 300\nutterance_encoder_hidden_size: 256\ndialog_encoder_hidden_size: 256\ndialog_encoder_num_layers: 1\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 50\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\nconv:\n  epoch: 50\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/redial/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: pkuseg\n  conv: pkuseg\n# dataloader\nutterance_truncate: 80\nconversation_truncate: 40\nscale: 1\n# model\n# rec\nrec_model: ReDialRec\nautorec_layer_sizes: [ 1000 ]\nautorec_f: sigmoid\nautorec_g: sigmoid\n# conv\nconv_model: ReDialConv\n#embedding: word2vec\nembedding_dim: 300\nutterance_encoder_hidden_size: 256\ndialog_encoder_hidden_size: 256\ndialog_encoder_num_layers: 1\nuse_dropout: False\ndropout: 0.3\ndecoder_hidden_size: 256\ndecoder_num_layers: 1\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\nconv:\n  epoch: 1\n  batch_size: 128\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/tgredial/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TGRec\nconv_model: TGConv\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  impatience: 3\n  stop_mode: max\nconv:\n  epoch: 1\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/tgredial/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TGRec\nconv_model: TGConv\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  impatience: 3\n  stop_mode: max\nconv:\n  epoch: 1\n  batch_size: 4\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/tgredial/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TGRec\nconv_model: TGConv\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  impatience: 3\n  stop_mode: max\nconv:\n  epoch: 1\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/tgredial/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TGRec\nconv_model: TGConv\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  impatience: 3\n  stop_mode: max\nconv:\n  epoch: 1\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/tgredial/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TGRec\nconv_model: TGConv\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 10\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-4\n    weight_decay: 0\n  lr_bert: !!float 1e-5\n  early_stop: true\n  impatience: 3\n  stop_mode: max\nconv:\n  epoch: 10\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n  early_stop: true\n  impatience: 3\n  stop_mode: min\n"
  },
  {
    "path": "config/crs/tgredial/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n  conv: gpt2\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TGRec\nconv_model: TGConv\npolicy_model: TGPolicy\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 50\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  impatience: 3\n  stop_mode: max\nconv:\n  epoch: 50\n  batch_size: 8\n  gradient_clip: 1.0\n  update_freq: 1\n  optimizer:\n    name: AdamW\n    lr: !!float 1.5e-4\n  lr_scheduler:\n    name: TransformersLinearLR\n    warmup_steps: 2000\n  early_stop: true\n  impatience: 3\n  stop_mode: min\npolicy:\n  epoch: 50\n  batch_size: 8\n  weight_decay: 0.01\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/policy/conv_bert/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\npolicy_model: ConvBERT\n# optim\npolicy:\n  epoch: 50\n  batch_size: 8\n  weight_decay: 0.01\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/policy/mgcg/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\npolicy_model: MGCG\ndropout_hidden: 0\nnum_layers: 1\nhidden_size: 300\nembedding_dim: 300\nn_sent: 10\n# optim\npolicy:\n  epoch: 100\n  batch_size: 1024\n  weight_decay: 0.01\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-4\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/policy/pmi/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\npolicy_model: PMI\n# optim\npolicy:\n  epoch: 1\n  batch_size: 1024\n  weight_decay: 0.01\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/policy/profile_bert/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\npolicy_model: ProfileBERT\nn_sent: 10\n# optim\npolicy:\n  epoch: 50\n  batch_size: 8\n  weight_decay: 0.01\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/policy/topic_bert/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  policy: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\npolicy_model: TopicBERT\n# optim\npolicy:\n  epoch: 50\n  batch_size: 8\n  weight_decay: 0.01\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/recommendation/bert/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: BERT\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/bert/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: BERT\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/bert/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: BERT\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/bert/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: BERT\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/bert/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: BERT\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/bert/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: BERT\n# optim\nrec:\n  epoch: 20\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  early_stop: true\n  stop_mode: max\n  impatience: 3\n  lr_bert: !!float 1e-5"
  },
  {
    "path": "config/recommendation/gru4rec/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: GRU4REC\ngru_hidden_size: 50\nnum_layers: 3\nembedding_dim: 50\ndropout_input: 0\ndropout_hidden: 0.0\nhidden_size: 50\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/gru4rec/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: GRU4REC\ngru_hidden_size: 50\nnum_layers: 3\nembedding_dim: 50\ndropout_input: 0\ndropout_hidden: 0.0\nhidden_size: 50\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-2\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/gru4rec/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: GRU4REC\ngru_hidden_size: 50\nnum_layers: 3\nembedding_dim: 50\ndropout_input: 0\ndropout_hidden: 0.0\nhidden_size: 50\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/gru4rec/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: GRU4REC\ngru_hidden_size: 50\nnum_layers: 3\nembedding_dim: 50\ndropout_input: 0\ndropout_hidden: 0.0\nhidden_size: 50\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/gru4rec/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: GRU4REC\ngru_hidden_size: 50\nnum_layers: 3\nembedding_dim: 50\ndropout_input: 0\ndropout_hidden: 0.0\nhidden_size: 50\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-2\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/gru4rec/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: GRU4REC\ngru_hidden_size: 50\nnum_layers: 3\nembedding_dim: 50\ndropout_input: 0\ndropout_hidden: 0.0\nhidden_size: 50\n# optim\nrec:\n  epoch: 50\n  batch_size: 64\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/recommendation/popularity/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: Popularity\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/popularity/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: Popularity\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/popularity/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: Popularity\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/popularity/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: Popularity\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/popularity/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: Popularity\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/popularity/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: Popularity\n# optim\nrec:\n  epoch: 1\n  batch_size: 1024\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5"
  },
  {
    "path": "config/recommendation/sasrec/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: SASREC\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/sasrec/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: SASREC\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-2\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/sasrec/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: SASREC\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/sasrec/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: SASREC\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/sasrec/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: SASREC\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/sasrec/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: bert\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: SASREC\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\n# optim\nrec:\n  epoch: 50\n  batch_size: 256\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "config/recommendation/textcnn/durecdial.yaml",
    "content": "# dataset\ndataset: DuRecDial\ntokenize:\n  rec: jieba\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TextCNN\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\nnum_filters: 256\nembed: 300\nfilter_sizes: (2, 3, 4)\ndropout: 0.5\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/textcnn/gorecdial.yaml",
    "content": "# dataset\ndataset: GoRecDial\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TextCNN\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\nnum_filters: 256\nembed: 300\nfilter_sizes: (2, 3, 4)\ndropout: 0.5\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/textcnn/inspired.yaml",
    "content": "# dataset\ndataset: Inspired\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TextCNN\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\nnum_filters: 256\nembed: 300\nfilter_sizes: (2, 3, 4)\ndropout: 0.5\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/textcnn/opendialkg.yaml",
    "content": "# dataset\ndataset: OpenDialKG\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TextCNN\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\nnum_filters: 256\nembed: 300\nfilter_sizes: (2, 3, 4)\ndropout: 0.5\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/textcnn/redial.yaml",
    "content": "# dataset\ndataset: ReDial\ntokenize:\n  rec: nltk\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TextCNN\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\nnum_filters: 256\nembed: 300\nfilter_sizes: (2, 3, 4)\ndropout: 0.5\n# optim\nrec:\n  epoch: 1\n  batch_size: 8\n  optimizer:\n    name: AdamW\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n"
  },
  {
    "path": "config/recommendation/textcnn/tgredial.yaml",
    "content": "# dataset\ndataset: TGReDial\ntokenize:\n  rec: sougou\n# dataloader\ncontext_truncate: 256\nresponse_truncate: 30\nitem_truncate: 100\nscale: 1\n# model\nrec_model: TextCNN\nhidden_dropout_prob: 0.2\ninitializer_range: 0.02\nhidden_size: 50\nmax_history_items: 100\nnum_attention_heads: 1\nattention_probs_dropout_prob: 0.2\nhidden_act: gelu\nnum_hidden_layers: 2\nnum_filters: 256\nembed: 300\nfilter_sizes: (2, 3, 4)\ndropout: 0.5\n# optim\nrec:\n  epoch: 50\n  batch_size: 64\n  optimizer:\n    name: Adam\n    lr: !!float 1e-3\n    weight_decay: !!float 0.0000\n  lr_bert: !!float 1e-5\n  early_stop: true\n  stop_mode: max\n  impatience: 3"
  },
  {
    "path": "crslab/__init__.py",
    "content": "__version__ = '0.0.1'\n"
  },
  {
    "path": "crslab/config/__init__.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/29\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n\"\"\"Config module which loads parameters for the whole system.\n\nAttributes:\n    SAVE_PATH (str): where system to save.\n    DATASET_PATH (str): where dataset to save.\n    MODEL_PATH (str): where model related data to save.\n    PRETRAIN_PATH (str): where pretrained model to save.\n    EMBEDDING_PATH (str): where pretrained embedding to save, used for evaluate embedding related metrics.\n\"\"\"\n\nimport os\nfrom os.path import dirname, realpath\n\nfrom .config import Config\n\nROOT_PATH = dirname(dirname(dirname(realpath(__file__))))\nSAVE_PATH = os.path.join(ROOT_PATH, 'save')\nDATA_PATH = os.path.join(ROOT_PATH, 'data')\nDATASET_PATH = os.path.join(DATA_PATH, 'dataset')\nMODEL_PATH = os.path.join(DATA_PATH, 'model')\nPRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain')\nEMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding')\n"
  },
  {
    "path": "crslab/config/config.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2021/1/9\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nimport json\nimport os\nimport time\nfrom pprint import pprint\n\nimport yaml\nimport torch\nfrom loguru import logger\nfrom tqdm import tqdm\n\n\nclass Config:\n    \"\"\"Configurator module that load the defined parameters.\"\"\"\n\n    def __init__(self, config_file, gpu='-1', debug=False):\n        \"\"\"Load parameters and set log level.\n\n        Args:\n            config_file (str): path to the config file, which should be in ``yaml`` format.\n                You can use default config provided in the `Github repo`_, or write it by yourself.\n            debug (bool, optional): whether to enable debug function during running. Defaults to False.\n\n        .. _Github repo:\n            https://github.com/RUCAIBox/CRSLab\n\n        \"\"\"\n\n        self.opt = self.load_yaml_configs(config_file)\n        # gpu\n        os.environ['CUDA_VISIBLE_DEVICES'] = gpu\n        if gpu != '-1':\n            self.opt['gpu'] = [i for i in range(len(gpu.split(',')))]\n        else:\n            self.opt['gpu'] = [-1]\n        # dataset\n        dataset = self.opt['dataset']\n        tokenize = self.opt['tokenize']\n        if isinstance(tokenize, dict):\n            tokenize = ', '.join(tokenize.values())\n        # model\n        model = self.opt.get('model', None)\n        rec_model = self.opt.get('rec_model', None)\n        conv_model = self.opt.get('conv_model', None)\n        policy_model = self.opt.get('policy_model', None)\n        if model:\n            model_name = model\n        else:\n            models = []\n            if rec_model:\n                models.append(rec_model)\n            if conv_model:\n                models.append(conv_model)\n            if policy_model:\n                models.append(policy_model)\n            model_name = '_'.join(models)\n        self.opt['model_name'] = model_name\n        # log\n        log_name = self.opt.get(\"log_name\", dataset + '_' + model_name + '_' + time.strftime(\"%Y-%m-%d-%H-%M-%S\",\n                                                                                             time.localtime())) + \".log\"\n        if not os.path.exists(\"log\"):\n            os.makedirs(\"log\")\n        logger.remove()\n        if debug:\n            level = 'DEBUG'\n        else:\n            level = 'INFO'\n        logger.add(os.path.join(\"log\", log_name), level=level)\n        logger.add(lambda msg: tqdm.write(msg, end=''), colorize=True, level=level)\n\n        logger.info(f\"[Dataset: {dataset} tokenized in {tokenize}]\")\n        if model:\n            logger.info(f'[Model: {model}]')\n        if rec_model:\n            logger.info(f'[Recommendation Model: {rec_model}]')\n        if conv_model:\n            logger.info(f'[Conversation Model: {conv_model}]')\n        if policy_model:\n            logger.info(f'[Policy Model: {policy_model}]')\n        logger.info(\"[Config]\" + '\\n' + json.dumps(self.opt, indent=4))\n\n    @staticmethod\n    def load_yaml_configs(filename):\n        \"\"\"This function reads ``yaml`` file to build config dictionary\n\n        Args:\n            filename (str): path to ``yaml`` config\n\n        Returns:\n            dict: config\n\n        \"\"\"\n        config_dict = dict()\n        with open(filename, 'r', encoding='utf-8') as f:\n            config_dict.update(yaml.safe_load(f.read()))\n        return config_dict\n\n    def __setitem__(self, key, value):\n        if not isinstance(key, str):\n            raise TypeError(\"index must be a str.\")\n        self.opt[key] = value\n\n    def __getitem__(self, item):\n        if item in self.opt:\n            return self.opt[item]\n        else:\n            return None\n\n    def get(self, item, default=None):\n        \"\"\"Get value of corrsponding item in config\n\n        Args:\n            item (str): key to query in config\n            default (optional): default value for item if not found in config. Defaults to None.\n\n        Returns:\n            value of corrsponding item in config\n\n        \"\"\"\n        if item in self.opt:\n            return self.opt[item]\n        else:\n            return default\n\n    def __contains__(self, key):\n        if not isinstance(key, str):\n            raise TypeError(\"index must be a str.\")\n        return key in self.opt\n\n    def __str__(self):\n        return str(self.opt)\n\n    def __repr__(self):\n        return self.__str__()\n\n\nif __name__ == '__main__':\n    opt_dict = Config('../../config/crs/kbrd/redial.yaml')\n    pprint(opt_dict)\n"
  },
  {
    "path": "crslab/data/__init__.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020/12/29, 2020/12/17\n# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com\n\n# @Time   : 2021/10/06\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\n\"\"\"Data module which reads, processes and batches data for the whole system\n\nAttributes:\n    dataset_register_table (dict): record all supported dataset\n    dataset_language_map (dict): record all dataset corresponding language\n    dataloader_register_table (dict): record all model corresponding dataloader\n\n\"\"\"\n\nfrom crslab.data.dataloader import *\nfrom crslab.data.dataset import *\n\ndataset_register_table = {\n    'ReDial': ReDialDataset,\n    'TGReDial': TGReDialDataset,\n    'GoRecDial': GoRecDialDataset,\n    'OpenDialKG': OpenDialKGDataset,\n    'Inspired': InspiredDataset,\n    'DuRecDial': DuRecDialDataset\n}\n\ndataset_language_map = {\n    'ReDial': 'en',\n    'TGReDial': 'zh',\n    'GoRecDial': 'en',\n    'OpenDialKG': 'en',\n    'Inspired': 'en',\n    'DuRecDial': 'zh'\n}\n\ndataloader_register_table = {\n    'KGSF': KGSFDataLoader,\n    'KBRD': KBRDDataLoader,\n    'TGReDial': TGReDialDataLoader,\n    'TGRec': TGReDialDataLoader,\n    'TGConv': TGReDialDataLoader,\n    'TGPolicy': TGReDialDataLoader,\n    'TGRec_TGConv': TGReDialDataLoader,\n    'TGRec_TGConv_TGPolicy': TGReDialDataLoader,\n    'ReDialRec': ReDialDataLoader,\n    'ReDialConv': ReDialDataLoader,\n    'ReDialRec_ReDialConv': ReDialDataLoader,\n    'InspiredRec_InspiredConv': InspiredDataLoader,\n    'BERT': TGReDialDataLoader,\n    'SASREC': TGReDialDataLoader,\n    'TextCNN': TGReDialDataLoader,\n    'GRU4REC': TGReDialDataLoader,\n    'Popularity': TGReDialDataLoader,\n    'Transformer': KGSFDataLoader,\n    'GPT2': TGReDialDataLoader,\n    'ConvBERT': TGReDialDataLoader,\n    'TopicBERT': TGReDialDataLoader,\n    'ProfileBERT': TGReDialDataLoader,\n    'MGCG': TGReDialDataLoader,\n    'PMI': TGReDialDataLoader,\n    'NTRD': NTRDDataLoader\n}\n\n\ndef get_dataset(opt, tokenize, restore, save) -> BaseDataset:\n    \"\"\"get and process dataset\n\n    Args:\n        opt (Config or dict): config for dataset or the whole system.\n        tokenize (str): how to tokenize the dataset.\n        restore (bool): whether to restore saved dataset which has been processed.\n        save (bool): whether to save dataset after processing.\n\n    Returns:\n        processed dataset\n\n    \"\"\"\n    dataset = opt['dataset']\n    if dataset in dataset_register_table:\n        return dataset_register_table[dataset](opt, tokenize, restore, save)\n    else:\n        raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented')\n\n\ndef get_dataloader(opt, dataset, vocab) -> BaseDataLoader:\n    \"\"\"get dataloader to batchify dataset\n\n    Args:\n        opt (Config or dict): config for dataloader or the whole system.\n        dataset: processed raw data, no side data.\n        vocab (dict): all kinds of useful size, idx and map between token and idx.\n\n    Returns:\n        dataloader\n\n    \"\"\"\n    model_name = opt['model_name']\n    if model_name in dataloader_register_table:\n        return dataloader_register_table[model_name](opt, dataset, vocab)\n    else:\n        raise NotImplementedError(f'The dataloader [{model_name}] has not been implemented')\n"
  },
  {
    "path": "crslab/data/dataloader/__init__.py",
    "content": "from .base import BaseDataLoader\nfrom .inspired import InspiredDataLoader\nfrom .kbrd import KBRDDataLoader\nfrom .kgsf import KGSFDataLoader\nfrom .redial import ReDialDataLoader\nfrom .tgredial import TGReDialDataLoader\nfrom .ntrd import NTRDDataLoader\n"
  },
  {
    "path": "crslab/data/dataloader/base.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2020/12/29\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nimport random\nfrom abc import ABC\n\nfrom loguru import logger\nfrom math import ceil\nfrom tqdm import tqdm\n\n\nclass BaseDataLoader(ABC):\n    \"\"\"Abstract class of dataloader\n\n    Notes:\n        ``'scale'`` can be set in config to limit the size of dataset.\n\n    \"\"\"\n\n    def __init__(self, opt, dataset):\n        \"\"\"\n        Args:\n            opt (Config or dict): config for dataloader or the whole system.\n            dataset: dataset\n\n        \"\"\"\n        self.opt = opt\n        self.dataset = dataset\n        self.scale = opt.get('scale', 1)\n        assert 0 < self.scale <= 1\n\n    def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None):\n        \"\"\"Collate batch data for system to fit\n\n        Args:\n            batch_fn (func): function to collate data\n            batch_size (int):\n            shuffle (bool, optional): Defaults to True.\n            process_fn (func, optional): function to process dataset before batchify. Defaults to None.\n\n        Yields:\n            tuple or dict of torch.Tensor: batch data for system to fit\n\n        \"\"\"\n        dataset = self.dataset\n        if process_fn is not None:\n            dataset = process_fn()\n            logger.info('[Finish dataset process before batchify]')\n        dataset = dataset[:ceil(len(dataset) * self.scale)]\n        logger.debug(f'[Dataset size: {len(dataset)}]')\n\n        batch_num = ceil(len(dataset) / batch_size)\n        idx_list = list(range(len(dataset)))\n        if shuffle:\n            random.shuffle(idx_list)\n\n        for start_idx in tqdm(range(batch_num)):\n            batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size]\n            batch = [dataset[idx] for idx in batch_idx]\n            batch = batch_fn(batch)\n            if batch == False:\n                continue\n            else:\n                yield(batch) \n\n    def get_conv_data(self, batch_size, shuffle=True):\n        \"\"\"get_data wrapper for conversation.\n\n        You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``.\n\n        Args:\n            batch_size (int):\n            shuffle (bool, optional): Defaults to True.\n\n        Yields:\n            tuple or dict of torch.Tensor: batch data for conversation.\n\n        \"\"\"\n        return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn)\n\n    def get_rec_data(self, batch_size, shuffle=True):\n        \"\"\"get_data wrapper for recommendation.\n\n        You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``.\n\n        Args:\n            batch_size (int):\n            shuffle (bool, optional): Defaults to True.\n\n        Yields:\n            tuple or dict of torch.Tensor: batch data for recommendation.\n\n        \"\"\"\n        return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn)\n\n    def get_policy_data(self, batch_size, shuffle=True):\n        \"\"\"get_data wrapper for policy.\n\n        You can implement your own process_fn in ``self.policy_process_fn``, batch_fn in ``policy_batchify``.\n\n        Args:\n            batch_size (int):\n            shuffle (bool, optional): Defaults to True.\n\n        Yields:\n            tuple or dict of torch.Tensor: batch data for policy.\n\n        \"\"\"\n        return self.get_data(self.policy_batchify, batch_size, shuffle, self.policy_process_fn)\n\n    def conv_process_fn(self):\n        \"\"\"Process whole data for conversation before batch_fn.\n\n        Returns:\n            processed dataset. Defaults to return the same as `self.dataset`.\n\n        \"\"\"\n        return self.dataset\n\n    def conv_batchify(self, batch):\n        \"\"\"batchify data for conversation after process.\n\n        Args:\n            batch (list): processed batch dataset.\n\n        Returns:\n            batch data for the system to train conversation part.\n        \"\"\"\n        raise NotImplementedError('dataloader must implement conv_batchify() method')\n\n    def rec_process_fn(self):\n        \"\"\"Process whole data for recommendation before batch_fn.\n\n        Returns:\n            processed dataset. Defaults to return the same as `self.dataset`.\n\n        \"\"\"\n        return self.dataset\n\n    def rec_batchify(self, batch):\n        \"\"\"batchify data for recommendation after process.\n\n        Args:\n            batch (list): processed batch dataset.\n\n        Returns:\n            batch data for the system to train recommendation part.\n        \"\"\"\n        raise NotImplementedError('dataloader must implement rec_batchify() method')\n\n    def policy_process_fn(self):\n        \"\"\"Process whole data for policy before batch_fn.\n\n        Returns:\n            processed dataset. Defaults to return the same as `self.dataset`.\n\n        \"\"\"\n        return self.dataset\n\n    def policy_batchify(self, batch):\n        \"\"\"batchify data for policy after process.\n\n        Args:\n            batch (list): processed batch dataset.\n\n        Returns:\n            batch data for the system to train policy part.\n        \"\"\"\n        raise NotImplementedError('dataloader must implement policy_batchify() method')\n\n    def retain_recommender_target(self):\n        \"\"\"keep data whose role is recommender.\n\n        Returns:\n            Recommender part of ``self.dataset``.\n\n        \"\"\"\n        dataset = []\n        for conv_dict in tqdm(self.dataset):\n            if conv_dict['role'] == 'Recommender':\n                dataset.append(conv_dict)\n        return dataset\n\n    def rec_interact(self, data):\n        \"\"\"process user input data for system to recommend.\n\n        Args:\n            data: user input data.\n\n        Returns:\n            data for system to recommend.\n        \"\"\"\n        pass\n\n    def conv_interact(self, data):\n        \"\"\"Process user input data for system to converse.\n\n        Args:\n            data: user input data.\n\n        Returns:\n            data for system in converse.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "crslab/data/dataloader/inspired.py",
    "content": "# @Time   : 2021/3/11\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nfrom copy import deepcopy\n\nimport torch\nfrom tqdm import tqdm\n\nfrom crslab.data.dataloader.base import BaseDataLoader\nfrom crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt\n\n\nclass InspiredDataLoader(BaseDataLoader):\n    \"\"\"Dataloader for model Inspired.\n\n    Notes:\n        You can set the following parameters in config:\n\n        - ``'context_truncate'``: the maximum length of context.\n        - ``'response_truncate'``: the maximum length of response.\n        - ``'entity_truncate'``: the maximum length of mentioned entities in context.\n        - ``'word_truncate'``: the maximum length of mentioned words in context.\n        - ``'item_truncate'``: the maximum length of mentioned items in context.\n\n        The following values must be specified in ``vocab``:\n\n        - ``'pad'``\n        - ``'start'``\n        - ``'end'``\n        - ``'unk'``\n        - ``'pad_entity'``\n        - ``'pad_word'``\n\n        the above values specify the id of needed special token.\n\n        - ``'ind2tok'``: map from index to token.\n        - ``'tok2ind'``: map from token to index.\n        - ``'vocab_size'``: size of vocab.\n        - ``'id2entity'``: map from index to entity.\n        - ``'n_entity'``: number of entities in the entity KG of dataset.\n        - ``'sent_split'`` (optional): token used to split sentence. Defaults to ``'end'``.\n        - ``'word_split'`` (optional): token used to split word. Defaults to ``'end'``.\n\n    \"\"\"\n\n    def __init__(self, opt, dataset, vocab):\n        \"\"\"\n\n        Args:\n            opt (Config or dict): config for dataloader or the whole system.\n            dataset: data for model.\n            vocab (dict): all kinds of useful size, idx and map between token and idx.\n\n        \"\"\"\n        super().__init__(opt, dataset)\n\n        self.n_entity = vocab['n_entity']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.unk_token_idx = vocab['unk']\n        self.conv_bos_id = vocab['start']\n        self.cls_id = vocab['start']\n        self.sep_id = vocab['end']\n        if 'sent_split' in vocab:\n            self.sent_split_idx = vocab['sent_split']\n        else:\n            self.sent_split_idx = vocab['end']\n\n        self.pad_entity_idx = vocab['pad_entity']\n        self.pad_word_idx = vocab['pad_word']\n\n        self.tok2ind = vocab['tok2ind']\n        self.ind2tok = vocab['ind2tok']\n        self.id2entity = vocab['id2entity']\n\n        self.context_truncate = opt.get('context_truncate', None)\n        self.response_truncate = opt.get('response_truncate', None)\n\n    def rec_process_fn(self, *args, **kwargs):\n        augment_dataset = []\n        for conv_dict in tqdm(self.dataset):\n            if conv_dict['role'] == 'Recommender':\n                for movie in conv_dict['items']:\n                    augment_conv_dict = deepcopy(conv_dict)\n                    augment_conv_dict['item'] = movie\n                    augment_dataset.append(augment_conv_dict)\n        return augment_dataset\n\n    def _process_rec_context(self, context_tokens):\n        compact_context = []\n        for i, utterance in enumerate(context_tokens):\n            if i != 0:\n                utterance.insert(0, self.sent_split_idx)\n            compact_context.append(utterance)\n        compat_context = truncate(merge_utt(compact_context),\n                                  self.context_truncate - 2,\n                                  truncate_tail=False)\n        compat_context = add_start_end_token_idx(compat_context,\n                                                 self.start_token_idx,\n                                                 self.end_token_idx)\n        return compat_context\n\n    def rec_batchify(self, batch):\n        batch_context = []\n        batch_movie_id = []\n\n        for conv_dict in batch:\n            context = self._process_rec_context(conv_dict['context_tokens'])\n            batch_context.append(context)\n\n            item_id = conv_dict['item']\n            batch_movie_id.append(item_id)\n\n        batch_context = padded_tensor(batch_context,\n                                      self.pad_token_idx,\n                                      max_len=self.context_truncate)\n        batch_mask = (batch_context != self.pad_token_idx).long()\n\n        return (batch_context, batch_mask, torch.tensor(batch_movie_id))\n\n    def conv_batchify(self, batch):\n        \"\"\"get batch and corresponding roles\n        \"\"\"\n        batch_roles = []\n        batch_context_tokens = []\n        batch_response = []\n\n        for conv_dict in batch:\n            batch_roles.append(0 if conv_dict['role'] == 'Seeker' else 1)\n            context_tokens = [utter + [self.conv_bos_id] for utter in conv_dict['context_tokens']]\n            context_tokens[-1] = context_tokens[-1][:-1]\n            batch_context_tokens.append(\n                truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False),\n            )\n            batch_response.append(\n                add_start_end_token_idx(\n                    truncate(conv_dict['response'], max_length=self.response_truncate - 2),\n                    start_token_idx=self.start_token_idx,\n                    end_token_idx=self.end_token_idx\n                )\n            )\n\n        batch_context_tokens = padded_tensor(items=batch_context_tokens,\n                                             pad_idx=self.pad_token_idx,\n                                             max_len=self.context_truncate,\n                                             pad_tail=False)\n        batch_response = padded_tensor(batch_response,\n                                       pad_idx=self.pad_token_idx,\n                                       max_len=self.response_truncate,\n                                       pad_tail=True)\n        batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1)\n        batch_roles = torch.tensor(batch_roles)\n\n        return (batch_roles,\n                batch_input_ids,\n                batch_context_tokens,\n                batch_response)\n\n    def policy_batchify(self, batch):\n        pass\n"
  },
  {
    "path": "crslab/data/dataloader/kbrd.py",
    "content": "# @Time   : 2020/11/27\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/2\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nimport torch\nfrom tqdm import tqdm\n\nfrom crslab.data.dataloader.base import BaseDataLoader\nfrom crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt\n\n\nclass KBRDDataLoader(BaseDataLoader):\n    \"\"\"Dataloader for model KBRD.\n\n    Notes:\n        You can set the following parameters in config:\n\n        - ``'context_truncate'``: the maximum length of context.\n        - ``'response_truncate'``: the maximum length of response.\n        - ``'entity_truncate'``: the maximum length of mentioned entities in context.\n\n        The following values must be specified in ``vocab``:\n\n        - ``'pad'``\n        - ``'start'``\n        - ``'end'``\n        - ``'pad_entity'``\n\n        the above values specify the id of needed special token.\n\n    \"\"\"\n\n    def __init__(self, opt, dataset, vocab):\n        \"\"\"\n\n        Args:\n            opt (Config or dict): config for dataloader or the whole system.\n            dataset: data for model.\n            vocab (dict): all kinds of useful size, idx and map between token and idx.\n\n        \"\"\"\n        super().__init__(opt, dataset)\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.pad_entity_idx = vocab['pad_entity']\n        self.context_truncate = opt.get('context_truncate', None)\n        self.response_truncate = opt.get('response_truncate', None)\n        self.entity_truncate = opt.get('entity_truncate', None)\n\n    def rec_process_fn(self):\n        augment_dataset = []\n        for conv_dict in tqdm(self.dataset):\n            if conv_dict['role'] == 'Recommender':\n                for movie in conv_dict['items']:\n                    augment_conv_dict = {'context_entities': conv_dict['context_entities'], 'item': movie}\n                    augment_dataset.append(augment_conv_dict)\n        return augment_dataset\n\n    def rec_batchify(self, batch):\n        batch_context_entities = []\n        batch_movies = []\n        for conv_dict in batch:\n            batch_context_entities.append(conv_dict['context_entities'])\n            batch_movies.append(conv_dict['item'])\n\n        return {\n            \"context_entities\": batch_context_entities,\n            \"item\": torch.tensor(batch_movies, dtype=torch.long)\n        }\n\n    def conv_process_fn(self, *args, **kwargs):\n        return self.retain_recommender_target()\n\n    def conv_batchify(self, batch):\n        batch_context_tokens = []\n        batch_context_entities = []\n        batch_response = []\n        for conv_dict in batch:\n            batch_context_tokens.append(\n                truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False))\n            batch_context_entities.append(conv_dict['context_entities'])\n            batch_response.append(\n                add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),\n                                        start_token_idx=self.start_token_idx,\n                                        end_token_idx=self.end_token_idx))\n\n        return {\n            \"context_tokens\": padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),\n            \"context_entities\": batch_context_entities,\n            \"response\": padded_tensor(batch_response, self.pad_token_idx)\n        }\n\n    def policy_batchify(self, *args, **kwargs):\n        pass\n"
  },
  {
    "path": "crslab/data/dataloader/kgsf.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2020/12/2\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nfrom copy import deepcopy\n\nimport torch\nfrom tqdm import tqdm\n\nfrom crslab.data.dataloader.base import BaseDataLoader\nfrom crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, get_onehot, truncate, merge_utt\n\n\nclass KGSFDataLoader(BaseDataLoader):\n    \"\"\"Dataloader for model KGSF.\n\n    Notes:\n        You can set the following parameters in config:\n\n        - ``'context_truncate'``: the maximum length of context.\n        - ``'response_truncate'``: the maximum length of response.\n        - ``'entity_truncate'``: the maximum length of mentioned entities in context.\n        - ``'word_truncate'``: the maximum length of mentioned words in context.\n\n        The following values must be specified in ``vocab``:\n\n        - ``'pad'``\n        - ``'start'``\n        - ``'end'``\n        - ``'pad_entity'``\n        - ``'pad_word'``\n\n        the above values specify the id of needed special token.\n\n        - ``'n_entity'``: the number of entities in the entity KG of dataset.\n\n    \"\"\"\n\n    def __init__(self, opt, dataset, vocab):\n        \"\"\"\n\n        Args:\n            opt (Config or dict): config for dataloader or the whole system.\n            dataset: data for model.\n            vocab (dict): all kinds of useful size, idx and map between token and idx.\n\n        \"\"\"\n        super().__init__(opt, dataset)\n        self.n_entity = vocab['n_entity']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.pad_entity_idx = vocab['pad_entity']\n        self.pad_word_idx = vocab['pad_word']\n        self.context_truncate = opt.get('context_truncate', None)\n        self.response_truncate = opt.get('response_truncate', None)\n        self.entity_truncate = opt.get('entity_truncate', None)\n        self.word_truncate = opt.get('word_truncate', None)\n\n    def get_pretrain_data(self, batch_size, shuffle=True):\n        return self.get_data(self.pretrain_batchify, batch_size, shuffle, self.retain_recommender_target)\n\n    def pretrain_batchify(self, batch):\n        batch_context_entities = []\n        batch_context_words = []\n        for conv_dict in batch:\n            batch_context_entities.append(\n                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))\n            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))\n\n        return (padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),\n                get_onehot(batch_context_entities, self.n_entity))\n\n    def rec_process_fn(self):\n        augment_dataset = []\n        for conv_dict in tqdm(self.dataset):\n            if conv_dict['role'] == 'Recommender':\n                for movie in conv_dict['items']:\n                    augment_conv_dict = deepcopy(conv_dict)\n                    augment_conv_dict['item'] = movie\n                    augment_dataset.append(augment_conv_dict)\n        return augment_dataset\n\n    def rec_batchify(self, batch):\n        batch_context_entities = []\n        batch_context_words = []\n        batch_item = []\n        for conv_dict in batch:\n            batch_context_entities.append(\n                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))\n            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))\n            batch_item.append(conv_dict['item'])\n\n        return (padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),\n                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),\n                get_onehot(batch_context_entities, self.n_entity),\n                torch.tensor(batch_item, dtype=torch.long))\n\n    def conv_process_fn(self, *args, **kwargs):\n        return self.retain_recommender_target()\n\n    def conv_batchify(self, batch):\n        batch_context_tokens = []\n        batch_context_entities = []\n        batch_context_words = []\n        batch_response = []\n        for conv_dict in batch:\n            batch_context_tokens.append(\n                truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False))\n            batch_context_entities.append(\n                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))\n            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))\n            batch_response.append(\n                add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),\n                                        start_token_idx=self.start_token_idx,\n                                        end_token_idx=self.end_token_idx))\n\n        return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),\n                padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),\n                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),\n                padded_tensor(batch_response, self.pad_token_idx))\n\n    def policy_batchify(self, *args, **kwargs):\n        pass\n"
  },
  {
    "path": "crslab/data/dataloader/ntrd.py",
    "content": "# @Time   : 2021/10/06\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nfrom copy import deepcopy\n\nimport torch\nfrom tqdm import tqdm\n\nfrom crslab.data.dataloader.base import BaseDataLoader\nfrom crslab.data.dataloader.utils import add_start_end_token_idx, merge_utt_replace, padded_tensor, get_onehot, truncate, merge_utt\n\n\nclass NTRDDataLoader(BaseDataLoader):\n    def __init__(self, opt, dataset, vocab):\n        \"\"\"\n\n        Args:\n            opt (Config or dict): config for dataloader or the whole system.\n            dataset: data for model.\n            vocab (dict): all kinds of useful size, idx and map between token and idx.\n\n        \"\"\"\n        super().__init__(opt, dataset)\n        self.n_entity = vocab['n_entity']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.pad_entity_idx = vocab['pad_entity']\n        self.pad_word_idx = vocab['pad_word']\n        self.context_truncate = opt.get('context_truncate', None)\n        self.response_truncate = opt.get('response_truncate', None)\n        self.entity_truncate = opt.get('entity_truncate', None)\n        self.word_truncate = opt.get('word_truncate', None)\n        self.replace_token = opt.get('replace_token',None)\n        self.replace_token_idx = vocab[self.replace_token]\n\n    def get_pretrain_data(self, batch_size, shuffle=True):\n        return self.get_data(self.pretrain_batchify, batch_size, shuffle, self.retain_recommender_target)\n\n    def pretrain_batchify(self, batch):\n        batch_context_entities = []\n        batch_context_words = []\n        for conv_dict in batch:\n            batch_context_entities.append(\n                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))\n            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))\n\n        return (padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),\n                get_onehot(batch_context_entities, self.n_entity))\n\n    def rec_process_fn(self):\n        augment_dataset = []\n        for conv_dict in tqdm(self.dataset):\n            if conv_dict['role'] == 'Recommender':\n                for movie in conv_dict['items']:\n                    augment_conv_dict = deepcopy(conv_dict)\n                    augment_conv_dict['item'] = movie\n                    augment_dataset.append(augment_conv_dict)\n        return augment_dataset\n\n    def rec_batchify(self, batch):\n        batch_context_entities = []\n        batch_context_words = []\n        batch_item = []\n        for conv_dict in batch:\n            batch_context_entities.append(\n                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))\n            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))\n            batch_item.append(conv_dict['item'])\n\n        return (padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),\n                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),\n                get_onehot(batch_context_entities, self.n_entity),\n                torch.tensor(batch_item, dtype=torch.long))\n\n    def conv_process_fn(self, *args, **kwargs):\n        return self.retain_recommender_target()\n\n    def conv_batchify(self, batch):\n        batch_context_tokens = []\n        batch_context_entities = []\n        batch_context_words = []\n        batch_response = []\n        flag = False\n        batch_all_movies = [] \n        for conv_dict in batch:\n            temp = add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),\n                                        start_token_idx=self.start_token_idx,\n                                        end_token_idx=self.end_token_idx)\n\n            if temp.count(self.replace_token_idx) != 0:\n                flag = True\n            batch_context_tokens.append(\n                truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False))\n            batch_context_entities.append(\n                truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False))\n            batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False))\n            batch_response.append(\n                add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2),\n                                        start_token_idx=self.start_token_idx,\n                                        end_token_idx=self.end_token_idx))\n            \n            batch_all_movies.append(\n                truncate(conv_dict['items'], temp.count(self.replace_token_idx), truncate_tail=False)) #only use movies, not all entities.\n        if flag == False:# zero slot in a batch\n            return False\n\n        return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),\n                padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False),\n                padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False),\n                padded_tensor(batch_response, self.pad_token_idx),\n                padded_tensor(batch_all_movies, self.pad_entity_idx, pad_tail=False)) \n\n    def policy_batchify(self, *args, **kwargs):\n        pass"
  },
  {
    "path": "crslab/data/dataloader/redial.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE:\n# @Time   : 2020/12/16\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nimport re\nfrom copy import copy\n\nimport torch\nfrom tqdm import tqdm\n\nfrom crslab.data.dataloader.base import BaseDataLoader\nfrom crslab.data.dataloader.utils import padded_tensor, get_onehot, truncate\n\nmovie_pattern = re.compile(r'^@\\d{5,6}$')\n\n\nclass ReDialDataLoader(BaseDataLoader):\n    \"\"\"Dataloader for model ReDial.\n\n    Notes:\n        You can set the following parameters in config:\n\n        - ``'utterance_truncate'``: the maximum length of a single utterance.\n        - ``'conversation_truncate'``: the maximum length of the whole conversation.\n\n        The following values must be specified in ``vocab``:\n\n        - ``'pad'``\n        - ``'start'``\n        - ``'end'``\n        - ``'unk'``\n\n        the above values specify the id of needed special token.\n\n        - ``'ind2tok'``: map from index to token.\n        - ``'n_entity'``: number of entities in the entity KG of dataset.\n        - ``'vocab_size'``: size of vocab.\n\n    \"\"\"\n\n    def __init__(self, opt, dataset, vocab):\n        \"\"\"\n\n        Args:\n            opt (Config or dict): config for dataloader or the whole system.\n            dataset: data for model.\n            vocab (dict): all kinds of useful size, idx and map between token and idx.\n\n        \"\"\"\n        super().__init__(opt, dataset)\n        self.ind2tok = vocab['ind2tok']\n        self.n_entity = vocab['n_entity']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.unk_token_idx = vocab['unk']\n        self.item_token_idx = vocab['vocab_size']\n        self.conversation_truncate = self.opt.get('conversation_truncate', None)\n        self.utterance_truncate = self.opt.get('utterance_truncate', None)\n\n    def rec_process_fn(self, *args, **kwargs):\n        dataset = []\n        for conversation in self.dataset:\n            if conversation['role'] == 'Recommender':\n                for item in conversation['items']:\n                    context_entities = conversation['context_entities']\n                    dataset.append({'context_entities': context_entities, 'item': item})\n        return dataset\n\n    def rec_batchify(self, batch):\n        batch_context_entities = []\n        batch_item = []\n        for conversation in batch:\n            batch_context_entities.append(conversation['context_entities'])\n            batch_item.append(conversation['item'])\n        context_entities = get_onehot(batch_context_entities, self.n_entity)\n        return {'context_entities': context_entities, 'item': torch.tensor(batch_item, dtype=torch.long)}\n\n    def conv_process_fn(self):\n        dataset = []\n        for conversation in tqdm(self.dataset):\n            if conversation['role'] != 'Recommender':\n                continue\n            context_tokens = [truncate(utterance, self.utterance_truncate, truncate_tail=True) for utterance in\n                              conversation['context_tokens']]\n            context_tokens = truncate(context_tokens, self.conversation_truncate, truncate_tail=True)\n            context_length = len(context_tokens)\n            utterance_lengths = [len(utterance) for utterance in context_tokens]\n            request = context_tokens[-1]\n            response = truncate(conversation['response'], self.utterance_truncate, truncate_tail=True)\n            dataset.append({'context_tokens': context_tokens, 'context_length': context_length,\n                            'utterance_lengths': utterance_lengths, 'request': request, 'response': response})\n        return dataset\n\n    def conv_batchify(self, batch):\n        max_utterance_length = max([max(conversation['utterance_lengths']) for conversation in batch])\n        max_response_length = max([len(conversation['response']) for conversation in batch])\n        max_utterance_length = max(max_utterance_length, max_response_length)\n        max_context_length = max([conversation['context_length'] for conversation in batch])\n        batch_context = []\n        batch_context_length = []\n        batch_utterance_lengths = []\n        batch_request = []  # tensor\n        batch_request_length = []\n        batch_response = []\n\n        for conversation in batch:\n            padded_context = padded_tensor(conversation['context_tokens'], pad_idx=self.pad_token_idx,\n                                           pad_tail=True, max_len=max_utterance_length)\n            if len(conversation['context_tokens']) < max_context_length:\n                pad_tensor = padded_context.new_full(\n                    (max_context_length - len(conversation['context_tokens']), max_utterance_length), self.pad_token_idx\n                )\n                padded_context = torch.cat((padded_context, pad_tensor), 0)\n            batch_context.append(padded_context)\n            batch_context_length.append(conversation['context_length'])\n            batch_utterance_lengths.append(conversation['utterance_lengths'] +\n                                           [0] * (max_context_length - len(conversation['context_tokens'])))\n\n            request = conversation['request']\n            batch_request_length.append(len(request))\n            batch_request.append(request)\n\n            response = copy(conversation['response'])\n            # replace '^\\d{5,6}$' by '__item__'\n            for i in range(len(response)):\n                if movie_pattern.match(self.ind2tok[response[i]]):\n                    response[i] = self.item_token_idx\n            batch_response.append(response)\n\n        context = torch.stack(batch_context, dim=0)\n        request = padded_tensor(batch_request, self.pad_token_idx, pad_tail=True, max_len=max_utterance_length)\n        response = padded_tensor(batch_response, self.pad_token_idx, pad_tail=True,\n                                 max_len=max_utterance_length)  # (bs, utt_len)\n\n        return {'context': context, 'context_lengths': torch.tensor(batch_context_length),\n                'utterance_lengths': torch.tensor(batch_utterance_lengths), 'request': request,\n                'request_lengths': torch.tensor(batch_request_length), 'response': response}\n\n    def policy_batchify(self, batch):\n        pass\n"
  },
  {
    "path": "crslab/data/dataloader/tgredial.py",
    "content": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2020/12/29, 2020/12/15\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @Email  : wxl1999@foxmail.com, sdzyh002@gmail\n\nimport random\nfrom copy import deepcopy\n\nimport torch\nfrom tqdm import tqdm\n\nfrom crslab.data.dataloader.base import BaseDataLoader\nfrom crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt\n\n\nclass TGReDialDataLoader(BaseDataLoader):\n    \"\"\"Dataloader for model TGReDial.\n\n    Notes:\n        You can set the following parameters in config:\n\n        - ``'context_truncate'``: the maximum length of context.\n        - ``'response_truncate'``: the maximum length of response.\n        - ``'entity_truncate'``: the maximum length of mentioned entities in context.\n        - ``'word_truncate'``: the maximum length of mentioned words in context.\n        - ``'item_truncate'``: the maximum length of mentioned items in context.\n\n        The following values must be specified in ``vocab``:\n\n        - ``'pad'``\n        - ``'start'``\n        - ``'end'``\n        - ``'unk'``\n        - ``'pad_entity'``\n        - ``'pad_word'``\n\n        the above values specify the id of needed special token.\n\n        - ``'ind2tok'``: map from index to token.\n        - ``'tok2ind'``: map from token to index.\n        - ``'vocab_size'``: size of vocab.\n        - ``'id2entity'``: map from index to entity.\n        - ``'n_entity'``: number of entities in the entity KG of dataset.\n        - ``'sent_split'`` (optional): token used to split sentence. Defaults to ``'end'``.\n        - ``'word_split'`` (optional): token used to split word. Defaults to ``'end'``.\n        - ``'pad_topic'`` (optional): token used to pad topic.\n        - ``'ind2topic'`` (optional): map from index to topic.\n\n    \"\"\"\n\n    def __init__(self, opt, dataset, vocab):\n        \"\"\"\n\n        Args:\n            opt (Config or dict): config for dataloader or the whole system.\n            dataset: data for model.\n            vocab (dict): all kinds of useful size, idx and map between token and idx.\n\n        \"\"\"\n        super().__init__(opt, dataset)\n\n        self.n_entity = vocab['n_entity']\n        self.item_size = self.n_entity\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.unk_token_idx = vocab['unk']\n        self.conv_bos_id = vocab['start']\n        self.cls_id = vocab['start']\n        self.sep_id = vocab['end']\n        if 'sent_split' in vocab:\n            self.sent_split_idx = vocab['sent_split']\n        else:\n            self.sent_split_idx = vocab['end']\n        if 'word_split' in vocab:\n            self.word_split_idx = vocab['word_split']\n        else:\n            self.word_split_idx = vocab['end']\n\n        self.pad_entity_idx = vocab['pad_entity']\n        self.pad_word_idx = vocab['pad_word']\n        if 'pad_topic' in vocab:\n            self.pad_topic_idx = vocab['pad_topic']\n\n        self.tok2ind = vocab['tok2ind']\n        self.ind2tok = vocab['ind2tok']\n        self.id2entity = vocab['id2entity']\n        if 'ind2topic' in vocab:\n            self.ind2topic = vocab['ind2topic']\n\n        self.context_truncate = opt.get('context_truncate', None)\n        self.response_truncate = opt.get('response_truncate', None)\n        self.entity_truncate = opt.get('entity_truncate', None)\n        self.word_truncate = opt.get('word_truncate', None)\n        self.item_truncate = opt.get('item_truncate', None)\n\n    def rec_process_fn(self, *args, **kwargs):\n        augment_dataset = []\n        for conv_dict in tqdm(self.dataset):\n            for movie in conv_dict['items']:\n                augment_conv_dict = deepcopy(conv_dict)\n                augment_conv_dict['item'] = movie\n                augment_dataset.append(augment_conv_dict)\n        return augment_dataset\n\n    def _process_rec_context(self, context_tokens):\n        compact_context = []\n        for i, utterance in enumerate(context_tokens):\n            if i != 0:\n                utterance.insert(0, self.sent_split_idx)\n            compact_context.append(utterance)\n        compat_context = truncate(merge_utt(compact_context),\n                                  self.context_truncate - 2,\n                                  truncate_tail=False)\n        compat_context = add_start_end_token_idx(compat_context,\n                                                 self.start_token_idx,\n                                                 self.end_token_idx)\n        return compat_context\n\n    def _neg_sample(self, item_set):\n        item = random.randint(1, self.item_size)\n        while item in item_set:\n            item = random.randint(1, self.item_size)\n        return item\n\n    def _process_history(self, context_items, item_id=None):\n        input_ids = truncate(context_items,\n                             max_length=self.item_truncate,\n                             truncate_tail=False)\n        input_mask = [1] * len(input_ids)\n        sample_negs = []\n        seq_set = set(input_ids)\n        for _ in input_ids:\n            sample_negs.append(self._neg_sample(seq_set))\n\n        if item_id is not None:\n            target_pos = input_ids[1:] + [item_id]\n            return input_ids, target_pos, input_mask, sample_negs\n        else:\n            return input_ids, input_mask, sample_negs\n\n    def rec_batchify(self, batch):\n        batch_context = []\n        batch_movie_id = []\n        batch_input_ids = []\n        batch_target_pos = []\n        batch_input_mask = []\n        batch_sample_negs = []\n\n        for conv_dict in batch:\n            context = self._process_rec_context(conv_dict['context_tokens'])\n            batch_context.append(context)\n\n            item_id = conv_dict['item']\n            batch_movie_id.append(item_id)\n\n            if 'interaction_history' in conv_dict:\n                context_items = conv_dict['interaction_history'] + conv_dict[\n                    'context_items']\n            else:\n                context_items = conv_dict['context_items']\n\n            input_ids, target_pos, input_mask, sample_negs = self._process_history(\n                context_items, item_id)\n            batch_input_ids.append(input_ids)\n            batch_target_pos.append(target_pos)\n            batch_input_mask.append(input_mask)\n            batch_sample_negs.append(sample_negs)\n\n        batch_context = padded_tensor(batch_context,\n                                      self.pad_token_idx,\n                                      max_len=self.context_truncate)\n        batch_mask = (batch_context != self.pad_token_idx).long()\n\n        return (batch_context, batch_mask,\n                padded_tensor(batch_input_ids,\n                              pad_idx=self.pad_token_idx,\n                              pad_tail=False,\n                              max_len=self.item_truncate),\n                padded_tensor(batch_target_pos,\n                              pad_idx=self.pad_token_idx,\n                              pad_tail=False,\n                              max_len=self.item_truncate),\n                padded_tensor(batch_input_mask,\n                              pad_idx=self.pad_token_idx,\n                              pad_tail=False,\n                              max_len=self.item_truncate),\n                padded_tensor(batch_sample_negs,\n                              pad_idx=self.pad_token_idx,\n                              pad_tail=False,\n                              max_len=self.item_truncate),\n                torch.tensor(batch_movie_id))\n\n    def rec_interact(self, data):\n        context = [self._process_rec_context(data['context_tokens'])]\n        if 'interaction_history' in data:\n            context_items = data['interaction_history'] + data['context_items']\n        else:\n            context_items = data['context_items']\n        input_ids, input_mask, sample_negs = self._process_history(context_items)\n        input_ids, input_mask, sample_negs = [input_ids], [input_mask], [sample_negs]\n\n        context = padded_tensor(context,\n                                self.pad_token_idx,\n                                max_len=self.context_truncate)\n        mask = (context != self.pad_token_idx).long()\n\n        return (context, mask,\n                padded_tensor(input_ids,\n                              pad_idx=self.pad_token_idx,\n                              pad_tail=False,\n                              max_len=self.item_truncate),\n                None,\n                padded_tensor(input_mask,\n                              pad_idx=self.pad_token_idx,\n                              pad_tail=False,\n                              max_len=self.item_truncate),\n                padded_tensor(sample_negs,\n                              pad_idx=self.pad_token_idx,\n                              pad_tail=False,\n                              max_len=self.item_truncate),\n                None)\n\n    def conv_batchify(self, batch):\n        batch_context_tokens = []\n        batch_enhanced_context_tokens = []\n        batch_response = []\n        batch_context_entities = []\n        batch_context_words = []\n        for conv_dict in batch:\n            context_tokens = [utter + [self.conv_bos_id] for utter in conv_dict['context_tokens']]\n            context_tokens[-1] = context_tokens[-1][:-1]\n            batch_context_tokens.append(\n                truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False),\n            )\n            batch_response.append(\n                add_start_end_token_idx(\n                    truncate(conv_dict['response'], max_length=self.response_truncate - 2),\n                    start_token_idx=self.start_token_idx,\n                    end_token_idx=self.end_token_idx\n                )\n            )\n            batch_context_entities.append(\n                truncate(conv_dict['context_entities'],\n                         self.entity_truncate,\n                         truncate_tail=False))\n            batch_context_words.append(\n                truncate(conv_dict['context_words'],\n                         self.word_truncate,\n                         truncate_tail=False))\n\n            enhanced_topic = []\n            if 'target' in conv_dict:\n                for target_policy in conv_dict['target']:\n                    topic_variable = target_policy[1]\n                    if isinstance(topic_variable, list):\n                        for topic in topic_variable:\n                            enhanced_topic.append(topic)\n                enhanced_topic = [[\n                    self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id]\n                ] for topic_id in enhanced_topic]\n                enhanced_topic = merge_utt(enhanced_topic, self.word_split_idx, False, self.sent_split_idx)\n\n            enhanced_movie = []\n            if 'items' in conv_dict:\n                for movie_id in conv_dict['items']:\n                    enhanced_movie.append(movie_id)\n                enhanced_movie = [\n                    [self.tok2ind.get(token, self.unk_token_idx) for token in self.id2entity[movie_id].split('（')[0]]\n                    for movie_id in enhanced_movie]\n                enhanced_movie = truncate(merge_utt(enhanced_movie, self.word_split_idx, self.sent_split_idx),\n                                          self.item_truncate, truncate_tail=False)\n\n            if len(enhanced_movie) != 0:\n                enhanced_context_tokens = enhanced_movie + truncate(batch_context_tokens[-1],\n                                                                    max_length=self.context_truncate - len(\n                                                                        enhanced_movie), truncate_tail=False)\n            elif len(enhanced_topic) != 0:\n                enhanced_context_tokens = enhanced_topic + truncate(batch_context_tokens[-1],\n                                                                    max_length=self.context_truncate - len(\n                                                                        enhanced_topic), truncate_tail=False)\n            else:\n                enhanced_context_tokens = batch_context_tokens[-1]\n            batch_enhanced_context_tokens.append(\n                enhanced_context_tokens\n            )\n\n        batch_context_tokens = padded_tensor(items=batch_context_tokens,\n                                             pad_idx=self.pad_token_idx,\n                                             max_len=self.context_truncate,\n                                             pad_tail=False)\n        batch_response = padded_tensor(batch_response,\n                                       pad_idx=self.pad_token_idx,\n                                       max_len=self.response_truncate,\n                                       pad_tail=True)\n        batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1)\n        batch_enhanced_context_tokens = padded_tensor(items=batch_enhanced_context_tokens,\n                                                      pad_idx=self.pad_token_idx,\n                                                      max_len=self.context_truncate,\n                                                      pad_tail=False)\n        batch_enhanced_input_ids = torch.cat((batch_enhanced_context_tokens, batch_response), dim=1)\n\n        return (batch_enhanced_input_ids, batch_enhanced_context_tokens,\n                batch_input_ids, batch_context_tokens,\n                padded_tensor(batch_context_entities,\n                              self.pad_entity_idx,\n                              pad_tail=False),\n                padded_tensor(batch_context_words,\n                              self.pad_word_idx,\n                              pad_tail=False), batch_response)\n\n    def conv_interact(self, data):\n        context_tokens = [utter + [self.conv_bos_id] for utter in data['context_tokens']]\n        context_tokens[-1] = context_tokens[-1][:-1]\n        context_tokens = [truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False)]\n        context_tokens = padded_tensor(items=context_tokens,\n                                       pad_idx=self.pad_token_idx,\n                                       max_len=self.context_truncate,\n                                       pad_tail=False)\n        context_entities = [truncate(data['context_entities'], self.entity_truncate, truncate_tail=False)]\n        context_words = [truncate(data['context_words'], self.word_truncate, truncate_tail=False)]\n\n        return (context_tokens, context_tokens,\n                context_tokens, context_tokens,\n                padded_tensor(context_entities,\n                              self.pad_entity_idx,\n                              pad_tail=False),\n                padded_tensor(context_words,\n                              self.pad_word_idx,\n                              pad_tail=False), None)\n\n    def policy_process_fn(self, *args, **kwargs):\n        augment_dataset = []\n        for conv_dict in tqdm(self.dataset):\n            for target_policy in conv_dict['target']:\n                topic_variable = target_policy[1]\n                for topic in topic_variable:\n                    augment_conv_dict = deepcopy(conv_dict)\n                    augment_conv_dict['target_topic'] = topic\n                    augment_dataset.append(augment_conv_dict)\n        return augment_dataset\n\n    def policy_batchify(self, batch):\n        batch_context = []\n        batch_context_policy = []\n        batch_user_profile = []\n        batch_target = []\n\n        for conv_dict in batch:\n            final_topic = conv_dict['final']\n            final_topic = [[\n                self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id]\n            ] for topic_id in final_topic[1]]\n            final_topic = merge_utt(final_topic, self.word_split_idx, False, self.sep_id)\n\n            context = conv_dict['context_tokens']\n            context = merge_utt(context,\n                                self.sent_split_idx,\n                                False,\n                                self.sep_id)\n            context += final_topic\n            context = add_start_end_token_idx(\n                truncate(context, max_length=self.context_truncate - 1, truncate_tail=False),\n                start_token_idx=self.cls_id)\n            batch_context.append(context)\n\n            # [topic, topic, ..., topic]\n            context_policy = []\n            for policies_one_turn in conv_dict['context_policy']:\n                if len(policies_one_turn) != 0:\n                    for policy in policies_one_turn:\n                        for topic_id in policy[1]:\n                            if topic_id != self.pad_topic_idx:\n                                policy = []\n                                for token in self.ind2topic[topic_id]:\n                                    policy.append(self.tok2ind.get(token, self.unk_token_idx))\n                                context_policy.append(policy)\n            context_policy = merge_utt(context_policy, self.word_split_idx, False)\n            context_policy = add_start_end_token_idx(\n                context_policy,\n                start_token_idx=self.cls_id,\n                end_token_idx=self.sep_id)\n            context_policy += final_topic\n            batch_context_policy.append(context_policy)\n\n            batch_user_profile.extend(conv_dict['user_profile'])\n\n            batch_target.append(conv_dict['target_topic'])\n\n        batch_context = padded_tensor(batch_context,\n                                      pad_idx=self.pad_token_idx,\n                                      pad_tail=True,\n                                      max_len=self.context_truncate)\n        batch_cotnext_mask = (batch_context != self.pad_token_idx).long()\n        batch_context_policy = padded_tensor(batch_context_policy,\n                                             pad_idx=self.pad_token_idx,\n                                             pad_tail=True)\n        batch_context_policy_mask = (batch_context_policy != 0).long()\n        batch_user_profile = padded_tensor(batch_user_profile,\n                                           pad_idx=self.pad_token_idx,\n                                           pad_tail=True)\n        batch_user_profile_mask = (batch_user_profile != 0).long()\n        batch_target = torch.tensor(batch_target, dtype=torch.long)\n\n        return (batch_context, batch_cotnext_mask, batch_context_policy,\n                batch_context_policy_mask, batch_user_profile,\n                batch_user_profile_mask, batch_target)\n"
  },
  {
    "path": "crslab/data/dataloader/utils.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/10\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/20, 2020/12/15\n# @Author  :   Xiaolei Wang, Yuanhang Zhou\n# @email   :   wxl1999@foxmail.com, sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2021/10/06\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\n\nfrom copy import copy\n\nimport torch\nfrom typing import List, Union, Optional\n\n\ndef padded_tensor(\n        items: List[Union[List[int], torch.LongTensor]],\n        pad_idx: int = 0,\n        pad_tail: bool = True,\n        max_len: Optional[int] = None,\n) -> torch.LongTensor:\n    \"\"\"Create a padded matrix from an uneven list of lists.\n\n    Returns padded matrix.\n\n    Matrix is right-padded (filled to the right) by default, but can be\n    left padded if the flag is set to True.\n\n    Matrix can also be placed on cuda automatically.\n\n    :param list[iter[int]] items: List of items\n    :param int pad_idx: the value to use for padding\n    :param bool pad_tail:\n    :param int max_len: if None, the max length is the maximum item length\n\n    :returns: padded tensor.\n    :rtype: Tensor[int64]\n\n    \"\"\"\n    # number of items\n    n = len(items)\n    # length of each item\n    lens: List[int] = [len(item) for item in items]  # type: ignore\n    # max in time dimension\n    t = max(lens) if max_len is None else max_len\n    # if input tensors are empty, we should expand to nulls\n    t = max(t, 1)\n\n    if isinstance(items[0], torch.Tensor):\n        # keep type of input tensors, they may already be cuda ones\n        output = items[0].new(n, t)  # type: ignore\n    else:\n        output = torch.LongTensor(n, t)  # type: ignore\n    output.fill_(pad_idx)\n\n    for i, (item, length) in enumerate(zip(items, lens)):\n        if length == 0:\n            # skip empty items\n            continue\n        if not isinstance(item, torch.Tensor):\n            # put non-tensors into a tensor\n            item = torch.tensor(item, dtype=torch.long)  # type: ignore\n        if pad_tail:\n            # place at beginning\n            output[i, :length] = item\n        else:\n            # place at end\n            output[i, t - length:] = item\n\n    return output\n\n\ndef get_onehot(data_list, categories) -> torch.Tensor:\n    \"\"\"Transform lists of label into one-hot.\n\n    Args:\n        data_list (list of list of int): source data.\n        categories (int): #label class.\n\n    Returns:\n        torch.Tensor: one-hot labels.\n\n    \"\"\"\n    onehot_labels = []\n    for label_list in data_list:\n        onehot_label = torch.zeros(categories)\n        for label in label_list:\n            onehot_label[label] = 1.0 / len(label_list)\n        onehot_labels.append(onehot_label)\n    return torch.stack(onehot_labels, dim=0)\n\n\ndef add_start_end_token_idx(vec: list, start_token_idx: int = None, end_token_idx: int = None):\n    \"\"\"Can choose to add start token in the beginning and end token in the end.\n\n    Args:\n        vec: source list composed of indexes.\n        start_token_idx: index of start token.\n        end_token_idx: index of end token.\n\n    Returns:\n        list: list added start or end token index.\n\n    \"\"\"\n    res = copy(vec)\n    if start_token_idx:\n        res.insert(0, start_token_idx)\n    if end_token_idx:\n        res.append(end_token_idx)\n    return res\n\n\ndef truncate(vec, max_length, truncate_tail=True):\n    \"\"\"truncate vec to make its length no more than max length.\n\n    Args:\n        vec (list): source list.\n        max_length (int)\n        truncate_tail (bool, optional): Defaults to True.\n\n    Returns:\n        list: truncated vec.\n\n    \"\"\"\n    if max_length is None:\n        return vec\n    if len(vec) <= max_length:\n        return vec\n    if max_length == 0:\n        return []\n    if truncate_tail:\n        return vec[:max_length]\n    else:\n        return vec[-max_length:]\n\n\ndef merge_utt(conversation, split_token_idx=None, keep_split_in_tail=False, final_token_idx=None):\n    \"\"\"merge utterances in one conversation.\n\n    Args:\n        conversation (list of list of int): conversation consist of utterances consist of tokens.\n        split_token_idx (int): index of split token. Defaults to None.\n        keep_split_in_tail (bool): split in tail or head. Defaults to False.\n        final_token_idx (int): index of final token. Defaults to None.\n\n    Returns:\n        list: tokens of all utterances in one list.\n\n    \"\"\"\n    merged_conv = []\n    for utt in conversation:\n        for token in utt:\n            merged_conv.append(token)\n        if split_token_idx:\n            merged_conv.append(split_token_idx)\n    if split_token_idx and not keep_split_in_tail:\n        merged_conv = merged_conv[:-1]\n    if final_token_idx:\n        merged_conv.append(final_token_idx)\n    return merged_conv\n\ndef merge_utt_replace(conversation,detect_token=None,replace_token=None,method=\"in\"):\n    if method == 'in': \n        replaced_conv = []\n        for utt in conversation:\n            for token in utt:\n                if detect_token in token:\n                    replaced_conv.append(replace_token)\n                else:\n                    replaced_conv.append(token)\n        return replaced_conv\n    else:\n        return [token.replace(detect_token,replace_token) for utt in conversation for token in utt]\n"
  },
  {
    "path": "crslab/data/dataset/__init__.py",
    "content": "from .base import BaseDataset\nfrom .durecdial import DuRecDialDataset\nfrom .gorecdial import GoRecDialDataset\nfrom .inspired import InspiredDataset\nfrom .opendialkg import OpenDialKGDataset\nfrom .redial import ReDialDataset\nfrom .tgredial import TGReDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/base.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2020/12/13\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nimport os\nimport pickle as pkl\nfrom abc import ABC, abstractmethod\n\nimport numpy as np\nfrom loguru import logger\n\nfrom crslab.download import build\n\n\nclass BaseDataset(ABC):\n    \"\"\"Abstract class of dataset\n\n    Notes:\n        ``'embedding'`` can be specified in config to use pretrained word embedding.\n\n    \"\"\"\n\n    def __init__(self, opt, dpath, resource, restore=False, save=False):\n        \"\"\"Download resource, load, process data. Support restore and save processed dataset.\n\n        Args:\n            opt (Config or dict): config for dataset or the whole system.\n            dpath (str): where to store dataset.\n            resource (dict): version, download file and special token idx of tokenized dataset.\n            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.\n            save (bool): whether to save dataset after processing. Defaults to False.\n\n        \"\"\"\n        self.opt = opt\n        self.dpath = dpath\n\n        # download\n        dfile = resource['file']\n        build(dpath, dfile, version=resource['version'])\n\n        if not restore:\n            # load and process\n            train_data, valid_data, test_data, self.vocab = self._load_data()\n            logger.info('[Finish data load]')\n            self.train_data, self.valid_data, self.test_data, self.side_data = self._data_preprocess(train_data,\n                                                                                                     valid_data,\n                                                                                                     test_data)\n            embedding = opt.get('embedding', None)\n            if embedding:\n                self.side_data[\"embedding\"] = np.load(os.path.join(self.dpath, embedding))\n                logger.debug(f'[Load pretrained embedding {embedding}]')\n            logger.info('[Finish data preprocess]')\n        else:\n            self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab = self._load_from_restore()\n\n        if save:\n            data = (self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab)\n            self._save_to_one(data)\n\n    @abstractmethod\n    def _load_data(self):\n        \"\"\"Load dataset.\n\n        Returns:\n            (any, any, any, dict):\n\n            raw train, valid and test data.\n\n            vocab: all kinds of useful size, idx and map between token and idx.\n\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def _data_preprocess(self, train_data, valid_data, test_data):\n        \"\"\"Process raw train, valid, test data.\n\n        Args:\n            train_data: train dataset.\n            valid_data: valid dataset.\n            test_data: test dataset.\n\n        Returns:\n            (list of dict, dict):\n\n            train/valid/test_data, each dict is in the following format::\n\n                 {\n                    'role' (str):\n                        'Seeker' or 'Recommender',\n                    'user_profile' (list of list of int):\n                        id of tokens of sentences of user profile,\n                    'context_tokens' (list of list int):\n                        token ids of preprocessed contextual dialogs,\n                    'response' (list of int):\n                        token ids of the ground-truth response,\n                    'interaction_history' (list of int):\n                        id of items which have interaction of the user in current turn,\n                    'context_items' (list of int):\n                        item ids mentioned in context,\n                    'items' (list of int):\n                        item ids mentioned in current turn, we only keep\n                        those in entity kg for comparison,\n                    'context_entities' (list of int):\n                        if necessary, id of entities in context,\n                    'context_words' (list of int):\n                        if necessary, id of words in context,\n                    'context_policy' (list of list of list):\n                        policy of each context turn, one turn may have several policies,\n                        where first is action and second is keyword,\n                    'target' (list): policy of current turn,\n                    'final' (list): final goal for current turn\n                }\n\n            side_data, which is in the following format::\n\n                {\n                    'entity_kg': {\n                        'edge' (list of tuple): (head_entity_id, tail_entity_id, relation_id),\n                        'n_relation' (int): number of distinct relations,\n                        'entity' (list of str): str of entities, used for entity linking\n                    }\n                    'word_kg': {\n                        'edge' (list of tuple): (head_entity_id, tail_entity_id),\n                        'entity' (list of str): str of entities, used for entity linking\n                    }\n                    'item_entity_ids' (list of int): entity id of each item;\n                }\n\n        \"\"\"\n        pass\n\n    def _load_from_restore(self, file_name=\"all_data.pkl\"):\n        \"\"\"Restore saved dataset.\n\n        Args:\n            file_name (str): file of saved dataset. Defaults to \"all_data.pkl\".\n\n        \"\"\"\n        if not os.path.exists(os.path.join(self.dpath, file_name)):\n            raise ValueError(f'Saved dataset [{file_name}] does not exist')\n        with open(os.path.join(self.dpath, file_name), 'rb') as f:\n            dataset = pkl.load(f)\n        logger.info(f'Restore dataset from [{file_name}]')\n        return dataset\n\n    def _save_to_one(self, data, file_name=\"all_data.pkl\"):\n        \"\"\"Save all processed dataset and vocab into one file.\n\n        Args:\n            data (tuple): all dataset and vocab.\n            file_name (str, optional): file to save dataset. Defaults to \"all_data.pkl\".\n\n        \"\"\"\n        if not os.path.exists(self.dpath):\n            os.makedirs(self.dpath)\n        save_path = os.path.join(self.dpath, file_name)\n        with open(save_path, 'wb') as f:\n            pkl.dump(data, f)\n        logger.info(f'[Save dataset to {file_name}]')\n"
  },
  {
    "path": "crslab/data/dataset/durecdial/__init__.py",
    "content": "from .durecdial import DuRecDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/durecdial/durecdial.py",
    "content": "# @Time   : 2020/12/21\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/21, 2021/1/2\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nr\"\"\"\nDuRecDial\n=========\nReferences:\n    Liu, Zeming, et al. `\"Towards Conversational Recommendation over Multi-Type Dialogs.\"`_ in ACL 2020.\n\n.. _\"Towards Conversational Recommendation over Multi-Type Dialogs.\":\n   https://www.aclweb.org/anthology/2020.acl-main.98/\n\n\"\"\"\n\nimport json\nimport os\nfrom copy import copy\n\nfrom loguru import logger\nfrom tqdm import tqdm\n\nfrom crslab.config import DATASET_PATH\nfrom crslab.data.dataset.base import BaseDataset\nfrom .resources import resources\n\n\nclass DuRecDialDataset(BaseDataset):\n    \"\"\"\n\n    Attributes:\n        train_data: train dataset.\n        valid_data: valid dataset.\n        test_data: test dataset.\n        vocab (dict): ::\n\n            {\n                'tok2ind': map from token to index,\n                'ind2tok': map from index to token,\n                'entity2id': map from entity to index,\n                'id2entity': map from index to entity,\n                'word2id': map from word to index,\n                'vocab_size': len(self.tok2ind),\n                'n_entity': max(self.entity2id.values()) + 1,\n                'n_word': max(self.word2id.values()) + 1,\n            }\n\n    Notes:\n        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.\n\n    \"\"\"\n\n    def __init__(self, opt, tokenize, restore=False, save=False):\n        \"\"\"\n\n        Args:\n            opt (Config or dict): config for dataset or the whole system.\n            tokenize (str): how to tokenize dataset.\n            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.\n            save (bool): whether to save dataset after processing. Defaults to False.\n\n        \"\"\"\n        resource = resources[tokenize]\n        self.special_token_idx = resource['special_token_idx']\n        self.unk_token_idx = self.special_token_idx['unk']\n        dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize)\n        super().__init__(opt, dpath, resource, restore, save)\n\n    def _load_data(self):\n        train_data, valid_data, test_data = self._load_raw_data()\n        self._load_vocab()\n        self._load_other_data()\n\n        vocab = {\n            'tok2ind': self.tok2ind,\n            'ind2tok': self.ind2tok,\n            'entity2id': self.entity2id,\n            'id2entity': self.id2entity,\n            'word2id': self.word2id,\n            'vocab_size': len(self.tok2ind),\n            'n_entity': self.n_entity,\n            'n_word': self.n_word,\n        }\n        vocab.update(self.special_token_idx)\n\n        return train_data, valid_data, test_data, vocab\n\n    def _load_raw_data(self):\n        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:\n            train_data = json.load(f)\n            logger.debug(f\"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]\")\n        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:\n            valid_data = json.load(f)\n            logger.debug(f\"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]\")\n        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:\n            test_data = json.load(f)\n            logger.debug(f\"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]\")\n\n        return train_data, valid_data, test_data\n\n    def _load_vocab(self):\n        self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))\n        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}\n\n        logger.debug(f\"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]\")\n        logger.debug(f\"[The size of token2index dictionary is {len(self.tok2ind)}]\")\n        logger.debug(f\"[The size of index2token dictionary is {len(self.ind2tok)}]\")\n\n    def _load_other_data(self):\n        # entity kg\n        with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f:\n            self.entity2id = json.load(f)  # {entity: entity_id}\n        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}\n        self.n_entity = max(self.entity2id.values()) + 1\n        # {head_entity_id: [(relation_id, tail_entity_id)]}\n        self.entity_kg = open(os.path.join(self.dpath, 'entity_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]\")\n\n        # hownet\n        # {concept: concept_id}\n        with open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8') as f:\n            self.word2id = json.load(f)\n        self.n_word = max(self.word2id.values()) + 1\n        # {concept \\t relation\\t concept}\n        self.word_kg = open(os.path.join(self.dpath, 'hownet_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet_subkg.txt')}]\")\n\n    def _data_preprocess(self, train_data, valid_data, test_data):\n        processed_train_data = self._raw_data_process(train_data)\n        logger.debug(\"[Finish train data process]\")\n        processed_valid_data = self._raw_data_process(valid_data)\n        logger.debug(\"[Finish valid data process]\")\n        processed_test_data = self._raw_data_process(test_data)\n        logger.debug(\"[Finish test data process]\")\n        processed_side_data = self._side_data_process()\n        logger.debug(\"[Finish side data process]\")\n        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data\n\n    def _raw_data_process(self, raw_data):\n        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]\n        augmented_conv_dicts = []\n        for conv in tqdm(augmented_convs):\n            augmented_conv_dicts.extend(self._augment_and_add(conv))\n        return augmented_conv_dicts\n\n    def _convert_to_id(self, conversation):\n        augmented_convs = []\n        last_role = None\n        for utt in conversation['dialog']:\n            assert utt['role'] != last_role, print(utt)\n\n            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt[\"text\"]]\n            item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id]\n            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]\n            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]\n\n            augmented_convs.append({\n                \"role\": utt[\"role\"],\n                \"text\": text_token_ids,\n                \"entity\": entity_ids,\n                \"movie\": item_ids,\n                \"word\": word_ids\n            })\n            last_role = utt[\"role\"]\n\n        return augmented_convs\n\n    def _augment_and_add(self, raw_conv_dict):\n        augmented_conv_dicts = []\n        context_tokens, context_entities, context_words, context_items = [], [], [], []\n        entity_set, word_set = set(), set()\n        for i, conv in enumerate(raw_conv_dict):\n            text_tokens, entities, movies, words = conv[\"text\"], conv[\"entity\"], conv[\"movie\"], conv[\"word\"]\n            if len(context_tokens) > 0:\n                conv_dict = {\n                    'role': conv['role'],\n                    \"context_tokens\": copy(context_tokens),\n                    \"response\": text_tokens,\n                    \"context_entities\": copy(context_entities),\n                    \"context_words\": copy(context_words),\n                    'context_items': copy(context_items),\n                    \"items\": movies\n                }\n                augmented_conv_dicts.append(conv_dict)\n\n            context_tokens.append(text_tokens)\n            context_items += movies\n            for entity in entities + movies:\n                if entity not in entity_set:\n                    entity_set.add(entity)\n                    context_entities.append(entity)\n            for word in words:\n                if word not in word_set:\n                    word_set.add(word)\n                    context_words.append(word)\n\n        return augmented_conv_dicts\n\n    def _side_data_process(self):\n        processed_entity_kg = self._entity_kg_process()\n        logger.debug(\"[Finish entity KG process]\")\n        processed_word_kg = self._word_kg_process()\n        logger.debug(\"[Finish word KG process]\")\n        with open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8') as f:\n            item_entity_ids = json.load(f)\n        logger.debug('[Load movie entity ids]')\n\n        side_data = {\n            \"entity_kg\": processed_entity_kg,\n            \"word_kg\": processed_word_kg,\n            \"item_entity_ids\": item_entity_ids,\n        }\n        return side_data\n\n    def _entity_kg_process(self):\n        edge_list = []  # [(entity, entity, relation)]\n        for line in self.entity_kg:\n            triple = line.strip().split('\\t')\n            e0 = self.entity2id[triple[0]]\n            e1 = self.entity2id[triple[2]]\n            r = triple[1]\n            edge_list.append((e0, e1, r))\n            edge_list.append((e1, e0, r))\n            edge_list.append((e0, e0, 'SELF_LOOP'))\n            if e1 != e0:\n                edge_list.append((e1, e1, 'SELF_LOOP'))\n\n        relation2id, edges, entities = dict(), set(), set()\n        for h, t, r in edge_list:\n            if r not in relation2id:\n                relation2id[r] = len(relation2id)\n            edges.add((h, t, relation2id[r]))\n            entities.add(self.id2entity[h])\n            entities.add(self.id2entity[t])\n\n        return {\n            'edge': list(edges),\n            'n_relation': len(relation2id),\n            'entity': list(entities)\n        }\n\n    def _word_kg_process(self):\n        edges = set()  # {(entity, entity)}\n        entities = set()\n        for line in self.word_kg:\n            triple = line.strip().split('\\t')\n            entities.add(triple[0])\n            entities.add(triple[2])\n            e0 = self.word2id[triple[0]]\n            e1 = self.word2id[triple[2]]\n            edges.add((e0, e1))\n            edges.add((e1, e0))\n        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]\n        return {\n            'edge': list(edges),\n            'entity': list(entities)\n        }\n"
  },
  {
    "path": "crslab/data/dataset/durecdial/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'jieba': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1',\n            'durecdial_jieba.zip',\n            'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813',\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n        },\n    },\n    'bert': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1',\n            'durecdial_bert.zip',\n            '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n            'pad_topic': 0\n        },\n    },\n    'gpt2': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1',\n            'durecdial_gpt2.zip',\n            'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'cls': 101,\n            'sep': 102,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n            'pad_topic': 0,\n        },\n    }\n}\n"
  },
  {
    "path": "crslab/data/dataset/gorecdial/__init__.py",
    "content": "from .gorecdial import GoRecDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/gorecdial/gorecdial.py",
    "content": "# @Time   : 2020/12/12\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/13, 2021/1/2, 2020/12/19\n# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail\n\nr\"\"\"\nGoRecDial\n=========\nReferences:\n    Kang, Dongyeop, et al. `\"Recommendation as a Communication Game: Self-Supervised Bot-Play for Goal-oriented Dialogue.\"`_ in EMNLP 2019.\n\n.. _`\"Recommendation as a Communication Game: Self-Supervised Bot-Play for Goal-oriented Dialogue.\"`:\n   https://www.aclweb.org/anthology/D19-1203/\n\n\"\"\"\n\nimport json\nimport os\nfrom copy import copy\n\nfrom loguru import logger\nfrom tqdm import tqdm\n\nfrom crslab.config import DATASET_PATH\nfrom crslab.data.dataset.base import BaseDataset\nfrom .resources import resources\n\n\nclass GoRecDialDataset(BaseDataset):\n    \"\"\"\n\n    Attributes:\n        train_data: train dataset.\n        valid_data: valid dataset.\n        test_data: test dataset.\n        vocab (dict): ::\n\n            {\n                'tok2ind': map from token to index,\n                'ind2tok': map from index to token,\n                'entity2id': map from entity to index,\n                'id2entity': map from index to entity,\n                'word2id': map from word to index,\n                'vocab_size': len(self.tok2ind),\n                'n_entity': max(self.entity2id.values()) + 1,\n                'n_word': max(self.word2id.values()) + 1,\n            }\n\n    Notes:\n        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.\n\n    \"\"\"\n\n    def __init__(self, opt, tokenize, restore=False, save=False):\n        \"\"\"Specify tokenized resource and init base dataset.\n\n        Args:\n            opt (Config or dict): config for dataset or the whole system.\n            tokenize (str): how to tokenize dataset.\n            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.\n            save (bool): whether to save dataset after processing. Defaults to False.\n\n        \"\"\"\n        resource = resources[tokenize]\n        self.special_token_idx = resource['special_token_idx']\n        self.unk_token_idx = self.special_token_idx['unk']\n        dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize)\n        super().__init__(opt, dpath, resource, restore, save)\n\n    def _load_data(self):\n        train_data, valid_data, test_data = self._load_raw_data()\n        self._load_vocab()\n        self._load_other_data()\n\n        vocab = {\n            'tok2ind': self.tok2ind,\n            'ind2tok': self.ind2tok,\n            'entity2id': self.entity2id,\n            'id2entity': self.id2entity,\n            'word2id': self.word2id,\n            'vocab_size': len(self.tok2ind),\n            'n_entity': self.n_entity,\n            'n_word': self.n_word,\n        }\n        vocab.update(self.special_token_idx)\n\n        return train_data, valid_data, test_data, vocab\n\n    def _load_raw_data(self):\n        # load train/valid/test data\n        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:\n            train_data = json.load(f)\n            logger.debug(f\"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]\")\n        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:\n            valid_data = json.load(f)\n            logger.debug(f\"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]\")\n        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:\n            test_data = json.load(f)\n            logger.debug(f\"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]\")\n\n        return train_data, valid_data, test_data\n\n    def _load_vocab(self):\n        self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))\n        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}\n\n        logger.debug(f\"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]\")\n        logger.debug(f\"[The size of token2index dictionary is {len(self.tok2ind)}]\")\n        logger.debug(f\"[The size of index2token dictionary is {len(self.ind2tok)}]\")\n\n    def _load_other_data(self):\n        # dbpedia\n        self.entity2id = json.load(\n            open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8'))  # {entity: entity_id}\n        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}\n        self.n_entity = max(self.entity2id.values()) + 1\n        # {head_entity_id: [(relation_id, tail_entity_id)]}\n        self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]\")\n\n        # conceptnet\n        # {concept: concept_id}\n        self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))\n        self.n_word = max(self.word2id.values()) + 1\n        # {concept \\t relation\\t concept}\n        self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]\")\n\n    def _data_preprocess(self, train_data, valid_data, test_data):\n        processed_train_data = self._raw_data_process(train_data)\n        logger.debug(\"[Finish train data process]\")\n        processed_valid_data = self._raw_data_process(valid_data)\n        logger.debug(\"[Finish valid data process]\")\n        processed_test_data = self._raw_data_process(test_data)\n        logger.debug(\"[Finish test data process]\")\n        processed_side_data = self._side_data_process()\n        logger.debug(\"[Finish side data process]\")\n        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data\n\n    def _raw_data_process(self, raw_data):\n        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]\n        augmented_conv_dicts = []\n        for conv in tqdm(augmented_convs):\n            augmented_conv_dicts.extend(self._augment_and_add(conv))\n        return augmented_conv_dicts\n\n    def _convert_to_id(self, conversation):\n        augmented_convs = []\n        last_role = None\n        for utt in conversation['dialog']:\n            assert utt['role'] != last_role\n\n            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt[\"text\"]]\n            movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]\n            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]\n            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]\n            policy = utt['decide']\n\n            augmented_convs.append({\n                \"role\": utt[\"role\"],\n                \"text\": text_token_ids,\n                \"entity\": entity_ids,\n                \"movie\": movie_ids,\n                \"word\": word_ids,\n                'policy': policy\n            })\n            last_role = utt[\"role\"]\n\n        return augmented_convs\n\n    def _augment_and_add(self, raw_conv_dict):\n        augmented_conv_dicts = []\n        context_tokens, context_entities, context_words, context_items = [], [], [], []\n        entity_set, word_set = set(), set()\n        for i, conv in enumerate(raw_conv_dict):\n            text_tokens, entities, movies, words, policies = conv[\"text\"], conv[\"entity\"], conv[\"movie\"], conv[\"word\"], \\\n                                                             conv['policy']\n            if len(context_tokens) > 0 and len(text_tokens) > 0:\n                conv_dict = {\n                    'role': conv['role'],\n                    \"context_tokens\": copy(context_tokens),\n                    \"response\": text_tokens,\n                    \"context_entities\": copy(context_entities),\n                    \"context_words\": copy(context_words),\n                    'context_items': copy(context_items),\n                    \"items\": movies,\n                    'policy': policies,\n                }\n                augmented_conv_dicts.append(conv_dict)\n\n            if len(text_tokens) > 0:\n                context_tokens.append(text_tokens)\n                context_items += movies\n                for entity in entities + movies:\n                    if entity not in entity_set:\n                        entity_set.add(entity)\n                        context_entities.append(entity)\n                for word in words:\n                    if word not in word_set:\n                        word_set.add(word)\n                        context_words.append(word)\n\n        return augmented_conv_dicts\n\n    def _side_data_process(self):\n        processed_entity_kg = self._entity_kg_process()\n        logger.debug(\"[Finish entity KG process]\")\n        processed_word_kg = self._word_kg_process()\n        logger.debug(\"[Finish word KG process]\")\n        movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))\n        logger.debug('[Load movie entity ids]')\n\n        side_data = {\n            \"entity_kg\": processed_entity_kg,\n            \"word_kg\": processed_word_kg,\n            \"item_entity_ids\": movie_entity_ids,\n        }\n        return side_data\n\n    def _entity_kg_process(self):\n        edge_list = []  # [(entity, entity, relation)]\n        for line in self.entity_kg:\n            triple = line.strip().split('\\t')\n            e0 = self.entity2id[triple[0]]\n            e1 = self.entity2id[triple[2]]\n            r = triple[1]\n            edge_list.append((e0, e1, r))\n            edge_list.append((e1, e0, r))\n            edge_list.append((e0, e0, 'SELF_LOOP'))\n            if e1 != e0:\n                edge_list.append((e1, e1, 'SELF_LOOP'))\n\n        relation2id, edges, entities = dict(), set(), set()\n        for h, t, r in edge_list:\n            if r not in relation2id:\n                relation2id[r] = len(relation2id)\n            edges.add((h, t, relation2id[r]))\n            entities.add(self.id2entity[h])\n            entities.add(self.id2entity[t])\n\n        return {\n            'edge': list(edges),\n            'n_relation': len(relation2id),\n            'entity': list(entities)\n        }\n\n    def _word_kg_process(self):\n        edges = set()  # {(entity, entity)}\n        entities = set()\n        for line in self.word_kg:\n            triple = line.strip().split('\\t')\n            entities.add(triple[0])\n            entities.add(triple[2])\n            e0 = self.word2id[triple[0]]\n            e1 = self.word2id[triple[2]]\n            edges.add((e0, e1))\n            edges.add((e1, e0))\n        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]\n        return {\n            'edge': list(edges),\n            'entity': list(entities)\n        }\n"
  },
  {
    "path": "crslab/data/dataset/gorecdial/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/14\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'nltk': {\n        'version': '0.31',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESIqjwAg0ItAu7WGfukIt3cBXjzi7AZ9L_lcbFT1aS1qYQ?download=1',\n            'gorecdial_nltk.zip',\n            '58cd368f8f83c0c8555becc314a0017990545f71aefb7e93a52581c97d1b8e9b',\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n            'pad_topic': 0\n        },\n    },\n    'bert': {\n        'version': '0.31',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed1HT8gzvRpDosVT83BEj5QBnzKpjR3Zbf5u49yyWP-k6Q?download=1',\n            'gorecdial_bert.zip',\n            '4fa10c3fe8ba538af0f393c99892739fcb376d832616aa7028334c594b3fec10'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n            'pad_topic': 0\n        }\n    },\n    'gpt2': {\n        'version': '0.31',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EUJOHmX8v79DkZMq0x5r9d4B0UJlfw85v-VdciwKfAhpng?download=1',\n            'gorecdial_gpt2.zip',\n            '44a15637e014b2e6628102ff654e1aef7ec1cbfa34b7ada1a03f294f72ddd4b1'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'sent_split': 4,\n            'word_split': 5,\n            'pad_entity': 0,\n            'pad_word': 0\n        },\n    }\n}\n"
  },
  {
    "path": "crslab/data/dataset/inspired/__init__.py",
    "content": "from .inspired import InspiredDataset\n"
  },
  {
    "path": "crslab/data/dataset/inspired/inspired.py",
    "content": "# @Time   : 2020/12/19\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/20, 2021/1/2\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nr\"\"\"\nInspired\n========\nReferences:\n    Hayati, Shirley Anugrah, et al. `\"INSPIRED: Toward Sociable Recommendation Dialog Systems.\"`_ in EMNLP 2020.\n\n.. _`\"INSPIRED: Toward Sociable Recommendation Dialog Systems.\"`:\n   https://www.aclweb.org/anthology/2020.emnlp-main.654/\n\n\"\"\"\n\nimport json\nimport os\nfrom copy import copy\n\nfrom loguru import logger\nfrom tqdm import tqdm\n\nfrom crslab.config import DATASET_PATH\nfrom crslab.data.dataset.base import BaseDataset\nfrom .resources import resources\n\n\nclass InspiredDataset(BaseDataset):\n    \"\"\"\n\n    Attributes:\n        train_data: train dataset.\n        valid_data: valid dataset.\n        test_data: test dataset.\n        vocab (dict): ::\n\n            {\n                'tok2ind': map from token to index,\n                'ind2tok': map from index to token,\n                'entity2id': map from entity to index,\n                'id2entity': map from index to entity,\n                'word2id': map from word to index,\n                'vocab_size': len(self.tok2ind),\n                'n_entity': max(self.entity2id.values()) + 1,\n                'n_word': max(self.word2id.values()) + 1,\n            }\n\n    Notes:\n        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.\n\n    \"\"\"\n\n    def __init__(self, opt, tokenize, restore=False, save=False):\n        \"\"\"Specify tokenized resource and init base dataset.\n\n        Args:\n            opt (Config or dict): config for dataset or the whole system.\n            tokenize (str): how to tokenize dataset.\n            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.\n            save (bool): whether to save dataset after processing. Defaults to False.\n\n        \"\"\"\n        resource = resources[tokenize]\n        self.special_token_idx = resource['special_token_idx']\n        self.unk_token_idx = self.special_token_idx['unk']\n        dpath = os.path.join(DATASET_PATH, 'inspired', tokenize)\n        super().__init__(opt, dpath, resource, restore, save)\n\n    def _load_data(self):\n        train_data, valid_data, test_data = self._load_raw_data()\n        self._load_vocab()\n        self._load_other_data()\n\n        vocab = {\n            'tok2ind': self.tok2ind,\n            'ind2tok': self.ind2tok,\n            'entity2id': self.entity2id,\n            'id2entity': self.id2entity,\n            'word2id': self.word2id,\n            'vocab_size': len(self.tok2ind),\n            'n_entity': self.n_entity,\n            'n_word': self.n_word,\n        }\n        vocab.update(self.special_token_idx)\n\n        return train_data, valid_data, test_data, vocab\n\n    def _load_raw_data(self):\n        # load train/valid/test data\n        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:\n            train_data = json.load(f)\n            logger.debug(f\"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]\")\n        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:\n            valid_data = json.load(f)\n            logger.debug(f\"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]\")\n        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:\n            test_data = json.load(f)\n            logger.debug(f\"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]\")\n\n        return train_data, valid_data, test_data\n\n    def _load_vocab(self):\n        with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f:\n            self.tok2ind = json.load(f)\n        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}\n\n        logger.debug(f\"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]\")\n        logger.debug(f\"[The size of token2index dictionary is {len(self.tok2ind)}]\")\n        logger.debug(f\"[The size of index2token dictionary is {len(self.ind2tok)}]\")\n\n    def _load_other_data(self):\n        # dbpedia\n        with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f:\n            self.entity2id = json.load(f)  # {entity: entity_id}\n        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}\n        self.n_entity = max(self.entity2id.values()) + 1\n        # {head_entity_id: [(relation_id, tail_entity_id)]}\n        self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]\")\n\n        # conceptnet\n        # {concept: concept_id}\n        with open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8') as f:\n            self.word2id = json.load(f)\n        self.n_word = max(self.word2id.values()) + 1\n        # {concept \\t relation\\t concept}\n        self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]\")\n\n    def _data_preprocess(self, train_data, valid_data, test_data):\n        processed_train_data = self._raw_data_process(train_data)\n        logger.debug(\"[Finish train data process]\")\n        processed_valid_data = self._raw_data_process(valid_data)\n        logger.debug(\"[Finish valid data process]\")\n        processed_test_data = self._raw_data_process(test_data)\n        logger.debug(\"[Finish test data process]\")\n        processed_side_data = self._side_data_process()\n        logger.debug(\"[Finish side data process]\")\n        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data\n\n    def _raw_data_process(self, raw_data):\n        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]\n        augmented_conv_dicts = []\n        for conv in tqdm(augmented_convs):\n            augmented_conv_dicts.extend(self._augment_and_add(conv))\n        return augmented_conv_dicts\n\n    def _convert_to_id(self, conversation):\n        augmented_convs = []\n        last_role = None\n        for utt in conversation['dialog']:\n            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt[\"text\"]]\n            movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]\n            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]\n            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]\n\n            if utt[\"role\"] == last_role:\n                augmented_convs[-1][\"text\"] += text_token_ids\n                augmented_convs[-1][\"movie\"] += movie_ids\n                augmented_convs[-1][\"entity\"] += entity_ids\n                augmented_convs[-1][\"word\"] += word_ids\n            else:\n                augmented_convs.append({\n                    \"role\": utt[\"role\"],\n                    \"text\": text_token_ids,\n                    \"entity\": entity_ids,\n                    \"movie\": movie_ids,\n                    \"word\": word_ids\n                })\n            last_role = utt[\"role\"]\n\n        return augmented_convs\n\n    def _augment_and_add(self, raw_conv_dict):\n        augmented_conv_dicts = []\n        context_tokens, context_entities, context_words, context_items = [], [], [], []\n        entity_set, word_set = set(), set()\n        for i, conv in enumerate(raw_conv_dict):\n            text_tokens, entities, movies, words = conv[\"text\"], conv[\"entity\"], conv[\"movie\"], conv[\"word\"]\n            if len(context_tokens) > 0:\n                conv_dict = {\n                    'role': conv['role'],\n                    \"context_tokens\": copy(context_tokens),\n                    \"response\": text_tokens,\n                    \"context_entities\": copy(context_entities),\n                    \"context_words\": copy(context_words),\n                    'context_items': copy(context_items),\n                    \"items\": movies,\n                }\n                augmented_conv_dicts.append(conv_dict)\n\n            context_tokens.append(text_tokens)\n            context_items += movies\n            for entity in entities + movies:\n                if entity not in entity_set:\n                    entity_set.add(entity)\n                    context_entities.append(entity)\n            for word in words:\n                if word not in word_set:\n                    word_set.add(word)\n                    context_words.append(word)\n\n        return augmented_conv_dicts\n\n    def _side_data_process(self):\n        processed_entity_kg = self._entity_kg_process()\n        logger.debug(\"[Finish entity KG process]\")\n        processed_word_kg = self._word_kg_process()\n        logger.debug(\"[Finish word KG process]\")\n        with open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8') as f:\n            movie_entity_ids = json.load(f)\n        logger.debug('[Load movie entity ids]')\n\n        side_data = {\n            \"entity_kg\": processed_entity_kg,\n            \"word_kg\": processed_word_kg,\n            \"item_entity_ids\": movie_entity_ids,\n        }\n        return side_data\n\n    def _entity_kg_process(self):\n        edge_list = []  # [(entity, entity, relation)]\n        for line in self.entity_kg:\n            triple = line.strip().split('\\t')\n            e0 = self.entity2id[triple[0]]\n            e1 = self.entity2id[triple[2]]\n            r = triple[1]\n            edge_list.append((e0, e1, r))\n            edge_list.append((e1, e0, r))\n            edge_list.append((e0, e0, 'SELF_LOOP'))\n            if e1 != e0:\n                edge_list.append((e1, e1, 'SELF_LOOP'))\n\n        relation2id, edges, entities = dict(), set(), set()\n        for h, t, r in edge_list:\n            if r not in relation2id:\n                relation2id[r] = len(relation2id)\n            edges.add((h, t, relation2id[r]))\n            entities.add(self.id2entity[h])\n            entities.add(self.id2entity[t])\n\n        return {\n            'edge': list(edges),\n            'n_relation': len(relation2id),\n            'entity': list(entities)\n        }\n\n    def _word_kg_process(self):\n        edges = set()  # {(entity, entity)}\n        entities = set()\n        for line in self.word_kg:\n            triple = line.strip().split('\\t')\n            entities.add(triple[0])\n            entities.add(triple[2])\n            e0 = self.word2id[triple[0]]\n            e1 = self.word2id[triple[2]]\n            edges.add((e0, e1))\n            edges.add((e1, e0))\n        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]\n        return {\n            'edge': list(edges),\n            'entity': list(entities)\n        }\n"
  },
  {
    "path": "crslab/data/dataset/inspired/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'nltk': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1',\n            'inspired_nltk.zip',\n            '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4',\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n        },\n    },\n    'bert': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1',\n            'inspired_bert.zip',\n            '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n        },\n    },\n    'gpt2': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1',\n            'inspired_gpt2.zip',\n            '261ad7e5325258d5cb8ffef0751925a58270fb6d9f17490f8552f6b86ef1eed2'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'sent_split': 4,\n            'word_split': 5,\n            'pad_entity': 0,\n            'pad_word': 0\n        },\n    }\n}\n"
  },
  {
    "path": "crslab/data/dataset/opendialkg/__init__.py",
    "content": "from .opendialkg import OpenDialKGDataset\n"
  },
  {
    "path": "crslab/data/dataset/opendialkg/opendialkg.py",
    "content": "# @Time   : 2020/12/19\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/20, 2021/1/2\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nr\"\"\"\nOpenDialKG\n==========\nReferences:\n    Moon, Seungwhan, et al. `\"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs.\"`_ in ACL 2019.\n\n.. _`\"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs.\"`:\n   https://www.aclweb.org/anthology/P19-1081/\n\n\"\"\"\n\nimport json\nimport os\nfrom collections import defaultdict\nfrom copy import copy\n\nfrom loguru import logger\nfrom tqdm import tqdm\n\nfrom crslab.config import DATASET_PATH\nfrom crslab.data.dataset.base import BaseDataset\nfrom .resources import resources\n\n\nclass OpenDialKGDataset(BaseDataset):\n    \"\"\"\n\n    Attributes:\n        train_data: train dataset.\n        valid_data: valid dataset.\n        test_data: test dataset.\n        vocab (dict): ::\n\n            {\n                'tok2ind': map from token to index,\n                'ind2tok': map from index to token,\n                'entity2id': map from entity to index,\n                'id2entity': map from index to entity,\n                'word2id': map from word to index,\n                'vocab_size': len(self.tok2ind),\n                'n_entity': max(self.entity2id.values()) + 1,\n                'n_word': max(self.word2id.values()) + 1,\n            }\n\n    Notes:\n        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.\n\n    \"\"\"\n\n    def __init__(self, opt, tokenize, restore=False, save=False):\n        \"\"\"Specify tokenized resource and init base dataset.\n\n        Args:\n            opt (Config or dict): config for dataset or the whole system.\n            tokenize (str): how to tokenize dataset.\n            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.\n            save (bool): whether to save dataset after processing. Defaults to False.\n\n        \"\"\"\n        resource = resources[tokenize]\n        self.special_token_idx = resource['special_token_idx']\n        self.unk_token_idx = self.special_token_idx['unk']\n        dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize)\n        super().__init__(opt, dpath, resource, restore, save)\n\n    def _load_data(self):\n        train_data, valid_data, test_data = self._load_raw_data()\n        self._load_vocab()\n        self._load_other_data()\n\n        vocab = {\n            'tok2ind': self.tok2ind,\n            'ind2tok': self.ind2tok,\n            'entity2id': self.entity2id,\n            'id2entity': self.id2entity,\n            'word2id': self.word2id,\n            'vocab_size': len(self.tok2ind),\n            'n_entity': self.n_entity,\n            'n_word': self.n_word,\n        }\n        vocab.update(self.special_token_idx)\n\n        return train_data, valid_data, test_data, vocab\n\n    def _load_raw_data(self):\n        # load train/valid/test data\n        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:\n            train_data = json.load(f)\n            logger.debug(f\"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]\")\n        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:\n            valid_data = json.load(f)\n            logger.debug(f\"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]\")\n        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:\n            test_data = json.load(f)\n            logger.debug(f\"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]\")\n\n        return train_data, valid_data, test_data\n\n    def _load_vocab(self):\n        self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))\n        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}\n\n        logger.debug(f\"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]\")\n        logger.debug(f\"[The size of token2index dictionary is {len(self.tok2ind)}]\")\n        logger.debug(f\"[The size of index2token dictionary is {len(self.ind2tok)}]\")\n\n    def _load_other_data(self):\n        # opendialkg\n        self.entity2id = json.load(\n            open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8'))  # {entity: entity_id}\n        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}\n        self.n_entity = max(self.entity2id.values()) + 1\n        # {head_entity_id: [(relation_id, tail_entity_id)]}\n        self.entity_kg = open(os.path.join(self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load entity dictionary and KG from {os.path.join(self.dpath, 'opendialkg_subkg.json')} and {os.path.join(self.dpath, 'opendialkg_triples.txt')}]\")\n\n        # conceptnet\n        # {concept: concept_id}\n        self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))\n        self.n_word = max(self.word2id.values()) + 1\n        # {concept \\t relation\\t concept}\n        self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]\")\n\n    def _data_preprocess(self, train_data, valid_data, test_data):\n        processed_train_data = self._raw_data_process(train_data)\n        logger.debug(\"[Finish train data process]\")\n        processed_valid_data = self._raw_data_process(valid_data)\n        logger.debug(\"[Finish valid data process]\")\n        processed_test_data = self._raw_data_process(test_data)\n        logger.debug(\"[Finish test data process]\")\n        processed_side_data = self._side_data_process()\n        logger.debug(\"[Finish side data process]\")\n        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data\n\n    def _raw_data_process(self, raw_data):\n        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]\n        augmented_conv_dicts = []\n        for conv in tqdm(augmented_convs):\n            augmented_conv_dicts.extend(self._augment_and_add(conv))\n        return augmented_conv_dicts\n\n    def _convert_to_id(self, conversation):\n        augmented_convs = []\n        last_role = None\n        for utt in conversation['dialog']:\n            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt[\"text\"]]\n            item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id]\n            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]\n            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]\n\n            if utt[\"role\"] == last_role:\n                augmented_convs[-1][\"text\"] += text_token_ids\n                augmented_convs[-1][\"item\"] += item_ids\n                augmented_convs[-1][\"entity\"] += entity_ids\n                augmented_convs[-1][\"word\"] += word_ids\n            else:\n                augmented_convs.append({\n                    \"role\": utt[\"role\"],\n                    \"text\": text_token_ids,\n                    \"entity\": entity_ids,\n                    \"item\": item_ids,\n                    \"word\": word_ids\n                })\n            last_role = utt[\"role\"]\n\n        return augmented_convs\n\n    def _augment_and_add(self, raw_conv_dict):\n        augmented_conv_dicts = []\n        context_tokens, context_entities, context_words, context_items = [], [], [], []\n        entity_set, word_set = set(), set()\n        for i, conv in enumerate(raw_conv_dict):\n            text_tokens, entities, items, words = conv[\"text\"], conv[\"entity\"], conv[\"item\"], conv[\"word\"]\n            if len(context_tokens) > 0:\n                conv_dict = {\n                    'role': conv['role'],\n                    \"context_tokens\": copy(context_tokens),\n                    \"response\": text_tokens,\n                    \"context_entities\": copy(context_entities),\n                    \"context_words\": copy(context_words),\n                    'context_items': copy(context_items),\n                    \"items\": items,\n                }\n                augmented_conv_dicts.append(conv_dict)\n\n            context_tokens.append(text_tokens)\n            context_items += items\n            for entity in entities + items:\n                if entity not in entity_set:\n                    entity_set.add(entity)\n                    context_entities.append(entity)\n            for word in words:\n                if word not in word_set:\n                    word_set.add(word)\n                    context_words.append(word)\n\n        return augmented_conv_dicts\n\n    def _side_data_process(self):\n        processed_entity_kg = self._entity_kg_process()\n        logger.debug(\"[Finish entity KG process]\")\n        processed_word_kg = self._word_kg_process()\n        logger.debug(\"[Finish word KG process]\")\n        item_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8'))\n        logger.debug('[Load item entity ids]')\n\n        side_data = {\n            \"entity_kg\": processed_entity_kg,\n            \"word_kg\": processed_word_kg,\n            \"item_entity_ids\": item_entity_ids,\n        }\n        return side_data\n\n    def _entity_kg_process(self):\n        edge_list = []  # [(entity, entity, relation)]\n        for line in self.entity_kg:\n            triple = line.strip().split('\\t')\n            if len(triple) != 3 or triple[0] not in self.entity2id or triple[2] not in self.entity2id:\n                continue\n            e0 = self.entity2id[triple[0]]\n            e1 = self.entity2id[triple[2]]\n            r = triple[1]\n            edge_list.append((e0, e1, r))\n            # edge_list.append((e1, e0, r))\n            edge_list.append((e0, e0, 'SELF_LOOP'))\n            if e1 != e0:\n                edge_list.append((e1, e1, 'SELF_LOOP'))\n\n        relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()\n        for h, t, r in edge_list:\n            relation_cnt[r] += 1\n        for h, t, r in edge_list:\n            if relation_cnt[r] > 20000:\n                if r not in relation2id:\n                    relation2id[r] = len(relation2id)\n                edges.add((h, t, relation2id[r]))\n                entities.add(self.id2entity[h])\n                entities.add(self.id2entity[t])\n\n        return {\n            'edge': list(edges),\n            'n_relation': len(relation2id),\n            'entity': list(entities)\n        }\n\n    def _word_kg_process(self):\n        edges = set()  # {(entity, entity)}\n        entities = set()\n        for line in self.word_kg:\n            triple = line.strip().split('\\t')\n            entities.add(triple[0])\n            entities.add(triple[2])\n            e0 = self.word2id[triple[0]]\n            e1 = self.word2id[triple[2]]\n            edges.add((e0, e1))\n            edges.add((e1, e0))\n        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]\n        return {\n            'edge': list(edges),\n            'entity': list(entities)\n        }\n"
  },
  {
    "path": "crslab/data/dataset/opendialkg/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/21\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'nltk': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1',\n            'opendialkg_nltk.zip',\n            '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n        },\n    },\n    'bert': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1',\n            'opendialkg_bert.zip',\n            '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n        },\n    },\n    'gpt2': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1',\n            'opendialkg_gpt2.zip',\n            'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'sent_split': 4,\n            'word_split': 5,\n            'pad_entity': 0,\n            'pad_word': 0\n        },\n    }\n}\n"
  },
  {
    "path": "crslab/data/dataset/redial/__init__.py",
    "content": "from .redial import ReDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/redial/redial.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/23, 2021/1/3, 2020/12/19\n# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail\n\nr\"\"\"\nReDial\n======\nReferences:\n    Li, Raymond, et al. `\"Towards deep conversational recommendations.\"`_ in NeurIPS 2018.\n\n.. _`\"Towards deep conversational recommendations.\"`:\n   https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html\n\n\"\"\"\n\nimport json\nimport os\nfrom collections import defaultdict\nfrom copy import copy\n\nfrom loguru import logger\nfrom tqdm import tqdm\n\nfrom crslab.config import DATASET_PATH\nfrom crslab.data.dataset.base import BaseDataset\nfrom .resources import resources\n\n\nclass ReDialDataset(BaseDataset):\n    \"\"\"\n\n    Attributes:\n        train_data: train dataset.\n        valid_data: valid dataset.\n        test_data: test dataset.\n        vocab (dict): ::\n\n            {\n                'tok2ind': map from token to index,\n                'ind2tok': map from index to token,\n                'entity2id': map from entity to index,\n                'id2entity': map from index to entity,\n                'word2id': map from word to index,\n                'vocab_size': len(self.tok2ind),\n                'n_entity': max(self.entity2id.values()) + 1,\n                'n_word': max(self.word2id.values()) + 1,\n            }\n\n    Notes:\n        ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.\n\n    \"\"\"\n\n    def __init__(self, opt, tokenize, restore=False, save=False):\n        \"\"\"Specify tokenized resource and init base dataset.\n\n        Args:\n            opt (Config or dict): config for dataset or the whole system.\n            tokenize (str): how to tokenize dataset.\n            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.\n            save (bool): whether to save dataset after processing. Defaults to False.\n\n        \"\"\"\n        resource = resources[tokenize]\n        self.special_token_idx = resource['special_token_idx']\n        self.unk_token_idx = self.special_token_idx['unk']\n        dpath = os.path.join(DATASET_PATH, \"redial\", tokenize)\n        super().__init__(opt, dpath, resource, restore, save)\n\n    def _load_data(self):\n        train_data, valid_data, test_data = self._load_raw_data()\n        self._load_vocab()\n        self._load_other_data()\n\n        vocab = {\n            'tok2ind': self.tok2ind,\n            'ind2tok': self.ind2tok,\n            'entity2id': self.entity2id,\n            'id2entity': self.id2entity,\n            'word2id': self.word2id,\n            'vocab_size': len(self.tok2ind),\n            'n_entity': self.n_entity,\n            'n_word': self.n_word,\n        }\n        vocab.update(self.special_token_idx)\n\n        return train_data, valid_data, test_data, vocab\n\n    def _load_raw_data(self):\n        # load train/valid/test data\n        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:\n            train_data = json.load(f)\n            logger.debug(f\"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]\")\n        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:\n            valid_data = json.load(f)\n            logger.debug(f\"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]\")\n        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:\n            test_data = json.load(f)\n            logger.debug(f\"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]\")\n\n        return train_data, valid_data, test_data\n\n    def _load_vocab(self):\n        self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))\n        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}\n\n        logger.debug(f\"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]\")\n        logger.debug(f\"[The size of token2index dictionary is {len(self.tok2ind)}]\")\n        logger.debug(f\"[The size of index2token dictionary is {len(self.ind2tok)}]\")\n\n    def _load_other_data(self):\n        # dbpedia\n        self.entity2id = json.load(\n            open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8'))  # {entity: entity_id}\n        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}\n        self.n_entity = max(self.entity2id.values()) + 1\n        # {head_entity_id: [(relation_id, tail_entity_id)]}\n        self.entity_kg = json.load(open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8'))\n        logger.debug(\n            f\"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]\")\n\n        # conceptNet\n        # {concept: concept_id}\n        self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8'))\n        self.n_word = max(self.word2id.values()) + 1\n        # {relation\\t concept \\t concept}\n        self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8')\n        logger.debug(\n            f\"[Load word dictionary and KG from {os.path.join(self.dpath, 'concept2id.json')} and {os.path.join(self.dpath, 'conceptnet_subkg.txt')}]\")\n\n    def _data_preprocess(self, train_data, valid_data, test_data):\n        processed_train_data = self._raw_data_process(train_data)\n        logger.debug(\"[Finish train data process]\")\n        processed_valid_data = self._raw_data_process(valid_data)\n        logger.debug(\"[Finish valid data process]\")\n        processed_test_data = self._raw_data_process(test_data)\n        logger.debug(\"[Finish test data process]\")\n        processed_side_data = self._side_data_process()\n        logger.debug(\"[Finish side data process]\")\n        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data\n\n    def _raw_data_process(self, raw_data):\n        augmented_convs = [self._merge_conv_data(conversation[\"dialog\"]) for conversation in tqdm(raw_data)]\n        augmented_conv_dicts = []\n        for conv in tqdm(augmented_convs):\n            augmented_conv_dicts.extend(self._augment_and_add(conv))\n        return augmented_conv_dicts\n\n    def _merge_conv_data(self, dialog):\n        augmented_convs = []\n        last_role = None\n        for utt in dialog:\n            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt[\"text\"]]\n            movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]\n            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]\n            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]\n\n            if utt[\"role\"] == last_role:\n                augmented_convs[-1][\"text\"] += text_token_ids\n                augmented_convs[-1][\"movie\"] += movie_ids\n                augmented_convs[-1][\"entity\"] += entity_ids\n                augmented_convs[-1][\"word\"] += word_ids\n            else:\n                augmented_convs.append({\n                    \"role\": utt[\"role\"],\n                    \"text\": text_token_ids,\n                    \"entity\": entity_ids,\n                    \"movie\": movie_ids,\n                    \"word\": word_ids\n                })\n            last_role = utt[\"role\"]\n\n        return augmented_convs\n\n    def _augment_and_add(self, raw_conv_dict):\n        augmented_conv_dicts = []\n        context_tokens, context_entities, context_words, context_items = [], [], [], []\n        entity_set, word_set = set(), set()\n        for i, conv in enumerate(raw_conv_dict):\n            text_tokens, entities, movies, words = conv[\"text\"], conv[\"entity\"], conv[\"movie\"], conv[\"word\"]\n            if len(context_tokens) > 0:\n                conv_dict = {\n                    \"role\": conv['role'],\n                    \"context_tokens\": copy(context_tokens),\n                    \"response\": text_tokens,\n                    \"context_entities\": copy(context_entities),\n                    \"context_words\": copy(context_words),\n                    \"context_items\": copy(context_items),\n                    \"items\": movies,\n                }\n                augmented_conv_dicts.append(conv_dict)\n\n            context_tokens.append(text_tokens)\n            context_items += movies\n            for entity in entities + movies:\n                if entity not in entity_set:\n                    entity_set.add(entity)\n                    context_entities.append(entity)\n            for word in words:\n                if word not in word_set:\n                    word_set.add(word)\n                    context_words.append(word)\n\n        return augmented_conv_dicts\n\n    def _side_data_process(self):\n        processed_entity_kg = self._entity_kg_process()\n        logger.debug(\"[Finish entity KG process]\")\n        processed_word_kg = self._word_kg_process()\n        logger.debug(\"[Finish word KG process]\")\n        movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))\n        logger.debug('[Load movie entity ids]')\n\n        side_data = {\n            \"entity_kg\": processed_entity_kg,\n            \"word_kg\": processed_word_kg,\n            \"item_entity_ids\": movie_entity_ids,\n        }\n        return side_data\n\n    def _entity_kg_process(self, SELF_LOOP_ID=185):\n        edge_list = []  # [(entity, entity, relation)]\n        for entity in range(self.n_entity):\n            if str(entity) not in self.entity_kg:\n                continue\n            edge_list.append((entity, entity, SELF_LOOP_ID))  # add self loop\n            for tail_and_relation in self.entity_kg[str(entity)]:\n                if entity != tail_and_relation[1] and tail_and_relation[0] != SELF_LOOP_ID:\n                    edge_list.append((entity, tail_and_relation[1], tail_and_relation[0]))\n                    edge_list.append((tail_and_relation[1], entity, tail_and_relation[0]))\n\n        relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()\n        for h, t, r in edge_list:\n            relation_cnt[r] += 1\n        for h, t, r in edge_list:\n            if relation_cnt[r] > 1000:\n                if r not in relation2id:\n                    relation2id[r] = len(relation2id)\n                edges.add((h, t, relation2id[r]))\n                entities.add(self.id2entity[h])\n                entities.add(self.id2entity[t])\n        return {\n            'edge': list(edges),\n            'n_relation': len(relation2id),\n            'entity': list(entities)\n        }\n\n    def _word_kg_process(self):\n        edges = set()  # {(entity, entity)}\n        entities = set()\n        for line in self.word_kg:\n            kg = line.strip().split('\\t')\n            entities.add(kg[1].split('/')[0])\n            entities.add(kg[2].split('/')[0])\n            e0 = self.word2id[kg[1].split('/')[0]]\n            e1 = self.word2id[kg[2].split('/')[0]]\n            edges.add((e0, e1))\n            edges.add((e1, e0))\n        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]\n        return {\n            'edge': list(edges),\n            'entity': list(entities)\n        }\n"
  },
  {
    "path": "crslab/data/dataset/redial/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/1\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'nltk': {\n        'version': '0.31',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1',\n            'redial_nltk.zip',\n            '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424',\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'pad_entity': 0,\n            'pad_word': 0\n        },\n    },\n    'bert': {\n        'version': '0.31',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1',\n            'redial_bert.zip',\n            'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd',\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n        },\n    },\n    'gpt2': {\n        'version': '0.31',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1',\n            'redial_gpt2.zip',\n            '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8',\n        ),\n        'special_token_idx': {\n            'pad': -100,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'sent_split': 4,\n            'word_split': 5,\n            'pad_entity': 0,\n            'pad_word': 0\n        },\n    }\n}\n"
  },
  {
    "path": "crslab/data/dataset/tgredial/__init__.py",
    "content": "from .tgredial import TGReDialDataset\n"
  },
  {
    "path": "crslab/data/dataset/tgredial/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/4\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'pkuseg': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1',\n            'tgredial_pkuseg.zip',\n            '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6',\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 1,\n            'end': 2,\n            'unk': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n            'pad_topic': 0\n        },\n    },\n    'bert': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1',\n            'tgredial_bert.zip',\n            'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n            'pad_topic': 0\n        },\n    },\n    'gpt2': {\n        'version': '0.3',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1',\n            'tgredial_gpt2.zip',\n            '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9'\n        ),\n        'special_token_idx': {\n            'pad': 0,\n            'start': 101,\n            'end': 102,\n            'unk': 100,\n            'cls': 101,\n            'sep': 102,\n            'sent_split': 2,\n            'word_split': 3,\n            'pad_entity': 0,\n            'pad_word': 0,\n            'pad_topic': 0,\n        },\n    }\n}\n"
  },
  {
    "path": "crslab/data/dataset/tgredial/tgredial.py",
    "content": "# @Time   : 2020/12/4\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/12/6, 2021/1/2, 2020/12/19\n# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou\n# @Email  : francis_kun_zhou@163.com, sdzyh002@gmail\n\nr\"\"\"\nTGReDial\n========\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport json\nimport os\nfrom collections import defaultdict\nfrom copy import copy\nimport numpy as np\nfrom loguru import logger\nfrom tqdm import tqdm\n\nfrom crslab.config import DATASET_PATH\nfrom crslab.data.dataset.base import BaseDataset\nfrom .resources import resources\n\n\nclass TGReDialDataset(BaseDataset):\n    \"\"\"\n\n    Attributes:\n        train_data: train dataset.\n        valid_data: valid dataset.\n        test_data: test dataset.\n        vocab (dict): ::\n\n            {\n                'tok2ind': map from token to index,\n                'ind2tok': map from index to token,\n                'topic2ind': map from topic to index,\n                'ind2topic': map from index to topic,\n                'entity2id': map from entity to index,\n                'id2entity': map from index to entity,\n                'word2id': map from word to index,\n                'vocab_size': len(self.tok2ind),\n                'n_topic': len(self.topic2ind) + 1,\n                'n_entity': max(self.entity2id.values()) + 1,\n                'n_word': max(self.word2id.values()) + 1,\n            }\n\n    Notes:\n        ``'unk'`` and ``'pad_topic'`` must be specified in ``'special_token_idx'`` in ``resources.py``.\n\n    \"\"\"\n\n    def __init__(self, opt, tokenize, restore=False, save=False):\n        \"\"\"Specify tokenized resource and init base dataset.\n\n        Args:\n            opt (Config or dict): config for dataset or the whole system.\n            tokenize (str): how to tokenize dataset.\n            restore (bool): whether to restore saved dataset which has been processed. Defaults to False.\n            save (bool): whether to save dataset after processing. Defaults to False.\n\n        \"\"\"\n        resource = resources[tokenize]\n        self.special_token_idx = resource['special_token_idx']\n        self.unk_token_idx = self.special_token_idx['unk']\n        self.pad_topic_idx = self.special_token_idx['pad_topic']\n        dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize)\n        self.replace_token = opt.get('replace_token',None)\n        self.replace_token_idx = opt.get('replace_token_idx',None)\n        super().__init__(opt, dpath, resource, restore, save)\n        if self.replace_token:\n            if self.replace_token_idx:\n                self.side_data[\"embedding\"][self.replace_token_idx] = self.side_data['embedding'][0]\n            else:\n                self.side_data[\"embedding\"] = np.insert(self.side_data[\"embedding\"],len(self.side_data[\"embedding\"]),self.side_data['embedding'][0],axis=0)\n        \n\n    def _load_data(self):\n        train_data, valid_data, test_data = self._load_raw_data()\n        self._load_vocab()\n        self._load_other_data()\n\n        vocab = {\n            'tok2ind': self.tok2ind,\n            'ind2tok': self.ind2tok,\n            'topic2ind': self.topic2ind,\n            'ind2topic': self.ind2topic,\n            'entity2id': self.entity2id,\n            'id2entity': self.id2entity,\n            'word2id': self.word2id,\n            'vocab_size': len(self.tok2ind),\n            'n_topic': len(self.topic2ind) + 1,\n            'n_entity': self.n_entity,\n            'n_word': self.n_word,\n        }\n        vocab.update(self.special_token_idx)\n\n        return train_data, valid_data, test_data, vocab\n\n    def _load_raw_data(self):\n        # load train/valid/test data\n        with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:\n            train_data = json.load(f)\n            logger.debug(f\"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]\")\n        with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:\n            valid_data = json.load(f)\n            logger.debug(f\"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]\")\n        with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:\n            test_data = json.load(f)\n            logger.debug(f\"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]\")\n\n        return train_data, valid_data, test_data\n\n    def _load_vocab(self):\n        self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))\n        self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}\n        # add special tokens\n        if self.replace_token:\n            if self.replace_token not in self.tok2ind:\n                if self.replace_token_idx:\n                    self.ind2tok[self.replace_token_idx] = self.replace_token\n                    self.tok2ind[self.replace_token] = self.replace_token_idx\n                    self.special_token_idx[self.replace_token] = self.replace_token_idx\n                else:\n                    self.ind2tok[len(self.tok2ind)] = self.replace_token\n                    self.tok2ind[self.replace_token] = len(self.tok2ind)\n                    self.special_token_idx[self.replace_token] = len(self.tok2ind)-1 \n        logger.debug(f\"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]\")\n        logger.debug(f\"[The size of token2index dictionary is {len(self.tok2ind)}]\")\n        logger.debug(f\"[The size of index2token dictionary is {len(self.ind2tok)}]\")\n\n        self.topic2ind = json.load(open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8'))\n        self.ind2topic = {idx: word for word, idx in self.topic2ind.items()}\n\n        logger.debug(f\"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]\")\n        logger.debug(f\"[The size of token2index dictionary is {len(self.topic2ind)}]\")\n        logger.debug(f\"[The size of index2token dictionary is {len(self.ind2topic)}]\")\n\n    def _load_other_data(self):\n        # cn-dbpedia\n        self.entity2id = json.load(\n            open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8'))  # {entity: entity_id}\n        self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}\n        self.n_entity = max(self.entity2id.values()) + 1\n        # {head_entity_id: [(relation_id, tail_entity_id)]}\n        self.entity_kg = open(os.path.join(self.dpath, 'cn-dbpedia.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'cn-dbpedia.txt')}]\")\n\n        # hownet\n        # {concept: concept_id}\n        self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))\n        self.n_word = max(self.word2id.values()) + 1\n        # {relation\\t concept \\t concept}\n        self.word_kg = open(os.path.join(self.dpath, 'hownet.txt'), encoding='utf-8')\n        logger.debug(\n            f\"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet.txt')}]\")\n\n        # user interaction history dictionary\n        self.conv2history = json.load(open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8'))\n        logger.debug(f\"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]\")\n\n        # user profile\n        self.user2profile = json.load(open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8'))\n        logger.debug(f\"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}\")\n\n\n    def _data_preprocess(self, train_data, valid_data, test_data):\n        processed_train_data = self._raw_data_process(train_data)\n        logger.debug(\"[Finish train data process]\")\n        processed_valid_data = self._raw_data_process(valid_data)\n        logger.debug(\"[Finish valid data process]\")\n        processed_test_data = self._raw_data_process(test_data)\n        logger.debug(\"[Finish test data process]\")\n        processed_side_data = self._side_data_process()\n        logger.debug(\"[Finish side data process]\")\n        return processed_train_data, processed_valid_data, processed_test_data, processed_side_data\n\n    def _raw_data_process(self, raw_data):\n        augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]\n        augmented_conv_dicts = []\n        for conv in tqdm(augmented_convs):\n            augmented_conv_dicts.extend(self._augment_and_add(conv))\n        return augmented_conv_dicts\n\n    def _convert_to_id(self, conversation):\n        augmented_convs = []\n        last_role = None\n        for utt in conversation['messages']:\n            assert utt['role'] != last_role\n            # change movies into slots\n            if self.replace_token:\n                if len(utt['movie']) != 0:\n                    while  '《' in utt['text'] :\n                        begin = utt['text'].index(\"《\")\n                        end = utt['text'].index(\"》\")\n                        utt['text'] = utt['text'][:begin] + [self.replace_token] + utt['text'][end+1:]\n            text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt[\"text\"]]\n            movie_ids = [self.entity2id[movie] for movie in utt['movie'] if movie in self.entity2id]\n            entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]\n            word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]\n            policy = []\n            for action, kw in zip(utt['target'][1::2], utt['target'][2::2]):\n                if kw is None or action == '推荐电影':\n                    continue\n                if isinstance(kw, str):\n                    kw = [kw]\n                kw = [self.topic2ind.get(k, self.pad_topic_idx) for k in kw]\n                policy.append([action, kw])\n            final_kws = [self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]]\n            final = [utt['final'][0], final_kws]\n            conv_utt_id = str(conversation['conv_id']) + '/' + str(utt['local_id'])\n            interaction_history = self.conv2history.get(conv_utt_id, [])\n            user_profile = self.user2profile[conversation['user_id']]\n            user_profile = [[self.tok2ind.get(token, self.unk_token_idx) for token in sent] for sent in user_profile]\n\n            augmented_convs.append({\n                \"role\": utt[\"role\"],\n                \"text\": text_token_ids,\n                \"entity\": entity_ids,\n                \"movie\": movie_ids,\n                \"word\": word_ids,\n                'policy': policy,\n                'final': final,\n                'interaction_history': interaction_history,\n                'user_profile': user_profile\n            })\n            last_role = utt[\"role\"]\n\n        return augmented_convs\n\n    def _augment_and_add(self, raw_conv_dict):\n        augmented_conv_dicts = []\n        context_tokens, context_entities, context_words, context_policy, context_items = [], [], [], [], []\n        entity_set, word_set = set(), set()\n        for i, conv in enumerate(raw_conv_dict):\n            text_tokens, entities, movies, words, policies = conv[\"text\"], conv[\"entity\"], conv[\"movie\"], conv[\"word\"], \\\n                                                             conv['policy']\n            if self.replace_token is not None: \n                if text_tokens.count(30000) != len(movies):\n                    continue # the number of slots doesn't equal to the number of movies\n                \n            if len(context_tokens) > 0:\n                conv_dict = {\n                    'role': conv['role'],\n                    'user_profile': conv['user_profile'],\n                    \"context_tokens\": copy(context_tokens),\n                    \"response\": text_tokens,\n                    \"context_entities\": copy(context_entities),\n                    \"context_words\": copy(context_words),\n                    'interaction_history': conv['interaction_history'],\n                    'context_items': copy(context_items),\n                    \"items\": movies,\n                    'context_policy': copy(context_policy),\n                    'target': policies,\n                    'final': conv['final'],\n                }\n                augmented_conv_dicts.append(conv_dict)\n\n            context_tokens.append(text_tokens)\n            context_policy.append(policies)\n            context_items += movies\n            for entity in entities + movies:\n                if entity not in entity_set:\n                    entity_set.add(entity)\n                    context_entities.append(entity)\n            for word in words:\n                if word not in word_set:\n                    word_set.add(word)\n                    context_words.append(word)\n\n        return augmented_conv_dicts\n\n    def _side_data_process(self):\n        processed_entity_kg = self._entity_kg_process()\n        logger.debug(\"[Finish entity KG process]\")\n        processed_word_kg = self._word_kg_process()\n        logger.debug(\"[Finish word KG process]\")\n        movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))\n        logger.debug('[Load movie entity ids]')\n\n        side_data = {\n            \"entity_kg\": processed_entity_kg,\n            \"word_kg\": processed_word_kg,\n            \"item_entity_ids\": movie_entity_ids,\n        }\n        return side_data\n\n    def _entity_kg_process(self):\n        edge_list = []  # [(entity, entity, relation)]\n        for line in self.entity_kg:\n            triple = line.strip().split('\\t')\n            e0 = self.entity2id[triple[0]]\n            e1 = self.entity2id[triple[2]]\n            r = triple[1]\n            edge_list.append((e0, e1, r))\n            edge_list.append((e1, e0, r))\n            edge_list.append((e0, e0, 'SELF_LOOP'))\n            if e1 != e0:\n                edge_list.append((e1, e1, 'SELF_LOOP'))\n\n        relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()\n        for h, t, r in edge_list:\n            relation_cnt[r] += 1\n        for h, t, r in edge_list:\n            if r not in relation2id:\n                relation2id[r] = len(relation2id)\n            edges.add((h, t, relation2id[r]))\n            entities.add(self.id2entity[h])\n            entities.add(self.id2entity[t])\n\n        return {\n            'edge': list(edges),\n            'n_relation': len(relation2id),\n            'entity': list(entities)\n        }\n\n    def _word_kg_process(self):\n        edges = set()  # {(entity, entity)}\n        entities = set()\n        for line in self.word_kg:\n            triple = line.strip().split('\\t')\n            entities.add(triple[0])\n            entities.add(triple[2])\n            e0 = self.word2id[triple[0]]\n            e1 = self.word2id[triple[2]]\n            edges.add((e0, e1))\n            edges.add((e1, e0))\n        # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]\n        return {\n            'edge': list(edges),\n            'entity': list(entities)\n        }\n"
  },
  {
    "path": "crslab/download.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/7\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/7\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nimport hashlib\nimport os\nimport shutil\nimport time\n\nimport datetime\nimport requests\nimport tqdm\nfrom loguru import logger\n\n\nclass DownloadableFile:\n    \"\"\"\n    A class used to abstract any file that has to be downloaded online.\n\n    Any task that needs to download a file needs to have a list RESOURCES\n    that have objects of this class as elements.\n\n    This class provides the following functionality:\n\n    - Download a file from a URL\n    - Untar the file if zipped\n    - Checksum for the downloaded file\n\n    An object of this class needs to be created with:\n\n    - url <string> : URL or Google Drive id to download from\n    - file_name <string> : File name that the file should be named\n    - hashcode <string> : SHA256 hashcode of the downloaded file\n    - zipped <boolean> : False if the file is not compressed\n    - from_google <boolean> : True if the file is from Google Drive\n    \"\"\"\n\n    def __init__(self, url, file_name, hashcode, zipped=True, from_google=False):\n        self.url = url\n        self.file_name = file_name\n        self.hashcode = hashcode\n        self.zipped = zipped\n        self.from_google = from_google\n\n    def checksum(self, dpath):\n        \"\"\"\n        Checksum on a given file.\n\n        :param dpath: path to the downloaded file.\n        \"\"\"\n        sha256_hash = hashlib.sha256()\n        with open(os.path.join(dpath, self.file_name), \"rb\") as f:\n            for byte_block in iter(lambda: f.read(65536), b\"\"):\n                sha256_hash.update(byte_block)\n            if sha256_hash.hexdigest() != self.hashcode:\n                # remove_dir(dpath)\n                raise AssertionError(\n                    f\"[ Checksum for {self.file_name} from \\n{self.url}\\n\"\n                    \"does not match the expected checksum. Please try again. ]\"\n                )\n            else:\n                logger.debug(\"Checksum Successful\")\n                pass\n\n    def download_file(self, dpath):\n        if self.from_google:\n            download_from_google_drive(self.url, os.path.join(dpath, self.file_name))\n        else:\n            download(self.url, dpath, self.file_name)\n\n        self.checksum(dpath)\n\n        if self.zipped:\n            untar(dpath, self.file_name)\n\n\ndef download(url, path, fname, redownload=False, num_retries=5):\n    \"\"\"\n    Download file using `requests`.\n    If ``redownload`` is set to false, then will not download tar file again if it is\n    present (default ``False``).\n    \"\"\"\n    outfile = os.path.join(path, fname)\n    download = not os.path.exists(outfile) or redownload\n    logger.info(f\"Downloading {url} to {outfile}\")\n    retry = num_retries\n    exp_backoff = [2 ** r for r in reversed(range(retry))]\n\n    pbar = tqdm.tqdm(unit='B', unit_scale=True, desc='Downloading {}'.format(fname))\n\n    while download and retry > 0:\n        response = None\n        try:\n            headers = {\n                'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.60',\n            }\n            response = requests.get(url, stream=True, headers=headers)\n\n            # negative reply could be 'none' or just missing\n            CHUNK_SIZE = 32768\n            total_size = int(response.headers.get('Content-Length', -1))\n            # server returns remaining size if resuming, so adjust total\n            pbar.total = total_size\n            done = 0\n\n            with open(outfile, 'wb') as f:\n                for chunk in response.iter_content(CHUNK_SIZE):\n                    if chunk:  # filter out keep-alive new chunks\n                        f.write(chunk)\n                    if total_size > 0:\n                        done += len(chunk)\n                        if total_size < done:\n                            # don't freak out if content-length was too small\n                            total_size = done\n                            pbar.total = total_size\n                        pbar.update(len(chunk))\n                break\n        except (\n                requests.exceptions.ConnectionError,\n                requests.exceptions.ReadTimeout,\n        ):\n            retry -= 1\n            pbar.clear()\n            if retry > 0:\n                pl = 'y' if retry == 1 else 'ies'\n                logger.debug(\n                    f'Connection error, retrying. ({retry} retr{pl} left)'\n                )\n                time.sleep(exp_backoff[retry])\n            else:\n                logger.error('Retried too many times, stopped retrying.')\n        finally:\n            if response:\n                response.close()\n    if retry <= 0:\n        raise RuntimeError('Connection broken too many times. Stopped retrying.')\n\n    if download and retry > 0:\n        pbar.update(done - pbar.n)\n        if done < total_size:\n            raise RuntimeError(\n                f'Received less data than specified in Content-Length header for '\n                f'{url}. There may be a download problem.'\n            )\n\n    pbar.close()\n\n\ndef _get_confirm_token(response):\n    for key, value in response.cookies.items():\n        if key.startswith('download_warning'):\n            return value\n    return None\n\n\ndef download_from_google_drive(gd_id, destination):\n    \"\"\"\n    Use the requests package to download a file from Google Drive.\n    \"\"\"\n    URL = 'https://docs.google.com/uc?export=download'\n\n    with requests.Session() as session:\n        response = session.get(URL, params={'id': gd_id}, stream=True)\n        token = _get_confirm_token(response)\n\n        if token:\n            response.close()\n            params = {'id': gd_id, 'confirm': token}\n            response = session.get(URL, params=params, stream=True)\n\n        CHUNK_SIZE = 32768\n        with open(destination, 'wb') as f:\n            for chunk in response.iter_content(CHUNK_SIZE):\n                if chunk:  # filter out keep-alive new chunks\n                    f.write(chunk)\n        response.close()\n\n\ndef move(path1, path2):\n    \"\"\"\n    Rename the given file.\n    \"\"\"\n    shutil.move(path1, path2)\n\n\ndef untar(path, fname, deleteTar=True):\n    \"\"\"\n    Unpack the given archive file to the same directory.\n\n    :param str path:\n        The folder containing the archive. Will contain the contents.\n\n    :param str fname:\n        The filename of the archive file.\n\n    :param bool deleteTar:\n        If true, the archive will be deleted after extraction.\n    \"\"\"\n    logger.debug(f'unpacking {fname}')\n    fullpath = os.path.join(path, fname)\n    shutil.unpack_archive(fullpath, path)\n    if deleteTar:\n        os.remove(fullpath)\n\n\ndef make_dir(path):\n    \"\"\"\n    Make the directory and any nonexistent parent directories (`mkdir -p`).\n    \"\"\"\n    # the current working directory is a fine path\n    if path != '':\n        os.makedirs(path, exist_ok=True)\n\n\ndef remove_dir(path):\n    \"\"\"\n    Remove the given directory, if it exists.\n    \"\"\"\n    shutil.rmtree(path, ignore_errors=True)\n\n\ndef check_build(path, version_string=None):\n    \"\"\"\n    Check if '.built' flag has been set for that task.\n\n    If a version_string is provided, this has to match, or the version is regarded as\n    not built.\n    \"\"\"\n    if version_string:\n        fname = os.path.join(path, '.built')\n        if not os.path.isfile(fname):\n            return False\n        else:\n            with open(fname, 'r') as read:\n                text = read.read().split('\\n')\n            return len(text) > 1 and text[1] == version_string\n    else:\n        return os.path.isfile(os.path.join(path, '.built'))\n\n\ndef mark_done(path, version_string=None):\n    \"\"\"\n    Mark this path as prebuilt.\n\n    Marks the path as done by adding a '.built' file with the current timestamp\n    plus a version description string if specified.\n\n    :param str path:\n        The file path to mark as built.\n\n    :param str version_string:\n        The version of this dataset.\n    \"\"\"\n    with open(os.path.join(path, '.built'), 'w') as write:\n        write.write(str(datetime.datetime.today()))\n        if version_string:\n            write.write('\\n' + version_string)\n\n\ndef build(dpath, dfile, version=None):\n    if not check_build(dpath, version):\n        logger.info('[Building data: ' + dpath + ']')\n        if check_build(dpath):\n            remove_dir(dpath)\n        make_dir(dpath)\n        # Download the data.\n        downloadable_file = dfile\n        downloadable_file.download_file(dpath)\n        mark_done(dpath, version)\n"
  },
  {
    "path": "crslab/evaluator/__init__.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/22\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom loguru import logger\n\nfrom .conv import ConvEvaluator\nfrom .rec import RecEvaluator\nfrom .standard import StandardEvaluator\nfrom ..data import dataset_language_map\n\nEvaluator_register_table = {\n    'rec': RecEvaluator,\n    'conv': ConvEvaluator,\n    'standard': StandardEvaluator\n}\n\n\ndef get_evaluator(evaluator_name, dataset, tensorboard=False):\n    if evaluator_name in Evaluator_register_table:\n        if evaluator_name in ('conv', 'standard'):\n            language = dataset_language_map[dataset]\n            evaluator = Evaluator_register_table[evaluator_name](language, tensorboard=tensorboard)\n        else:\n            evaluator = Evaluator_register_table[evaluator_name](tensorboard=tensorboard)\n        logger.info(f'[Build evaluator {evaluator_name}]')\n        return evaluator\n    else:\n        raise NotImplementedError(f'Model [{evaluator_name}] has not been implemented')\n"
  },
  {
    "path": "crslab/evaluator/base.py",
    "content": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nfrom abc import ABC, abstractmethod\n\n\nclass BaseEvaluator(ABC):\n    \"\"\"Base class for evaluator\"\"\"\n\n    def rec_evaluate(self, preds, label):\n        pass\n\n    def gen_evaluate(self, preds, label):\n        pass\n\n    def policy_evaluate(self, preds, label):\n        pass\n\n    @abstractmethod\n    def report(self, epoch, mode):\n        pass\n\n    @abstractmethod\n    def reset_metrics(self):\n        pass\n"
  },
  {
    "path": "crslab/evaluator/conv.py",
    "content": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/18\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\nimport os\nimport time\nfrom collections import defaultdict\n\nimport fasttext\nfrom loguru import logger\nfrom nltk import ngrams\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom crslab.evaluator.base import BaseEvaluator\nfrom crslab.evaluator.utils import nice_report\nfrom .embeddings import resources\nfrom .metrics import *\nfrom ..config import EMBEDDING_PATH\nfrom ..download import build\n\n\nclass ConvEvaluator(BaseEvaluator):\n    \"\"\"The evaluator specially for conversational model\n    \n    Args:\n        dist_set: the set to record dist n-gram\n        dist_cnt: the count of dist n-gram evaluation\n        gen_metrics: the metrics to evaluate conversational model, including bleu, dist, embedding metrics, f1\n        optim_metrics: the metrics to optimize in training\n\n    \"\"\"\n\n    def __init__(self, tensorboard=False):\n        super(ConvEvaluator, self).__init__()\n        self.dist_set = defaultdict(set)\n        self.dist_cnt = 0\n        self.gen_metrics = Metrics()\n        self.optim_metrics = Metrics()\n        self.tensorboard = tensorboard\n        if self.tensorboard:\n            self.writer = SummaryWriter(log_dir='runs/' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime()))\n            self.reports_name = ['Generation Metrics', 'Optimization Metrics']\n\n    def _load_embedding(self, language):\n        resource = resources[language]\n        dpath = os.path.join(EMBEDDING_PATH, language)\n        build(dpath, resource['file'], resource['version'])\n\n        model_file = os.path.join(dpath, f'cc.{language}.300.bin')\n        self.ft = fasttext.load_model(model_file)\n        logger.info(f'[Load {model_file} for embedding metric')\n\n    def _get_sent_embedding(self, sent):\n        return [self.ft[token] for token in sent.split()]\n\n    def gen_evaluate(self, hyp, refs):\n        if hyp:\n            self.gen_metrics.add(\"f1\", F1Metric.compute(hyp, refs))\n\n            for k in range(1, 5):\n                self.gen_metrics.add(f\"bleu@{k}\", BleuMetric.compute(hyp, refs, k))\n                # split sentence to tokens here\n                hyp_token = hyp.split()\n                for token in ngrams(hyp_token, k):\n                    self.dist_set[f\"dist@{k}\"].add(token)\n            self.dist_cnt += 1\n\n            hyp_emb = self._get_sent_embedding(hyp)\n            ref_embs = [self._get_sent_embedding(ref) for ref in refs]\n            self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs))\n            self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs))\n            self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs))\n\n    def report(self, epoch=-1, mode='test'):\n        for k, v in self.dist_set.items():\n            self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt))\n        reports = [self.gen_metrics.report(), self.optim_metrics.report()]\n        if self.tensorboard and mode != 'test':\n            for idx, task_report in enumerate(reports):\n                for each_metric, value in task_report.items():\n                    self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch)\n\n        logger.info('\\n' + nice_report(aggregate_unnamed_reports(reports)))\n\n    def reset_metrics(self):\n        self.gen_metrics.clear()\n        self.dist_cnt = 0\n        self.dist_set.clear()\n        self.optim_metrics.clear()\n"
  },
  {
    "path": "crslab/evaluator/embeddings.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/18\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/18\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'zh': {\n        'version': '0.2',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1',\n            'cc.zh.300.zip',\n            'effd9806809a1db106b5166b817aaafaaf3f005846f730d4c49f88c7a28a0ac3'\n        )\n    },\n    'en': {\n        'version': '0.2',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1',\n            'cc.en.300.zip',\n            '96a06a77da70325997eaa52bfd9acb1359a7c3754cb1c1aed2fc27c04936d53e'\n        )\n    }\n}\n"
  },
  {
    "path": "crslab/evaluator/end2end.py",
    "content": ""
  },
  {
    "path": "crslab/evaluator/metrics/__init__.py",
    "content": "from .base import Metric, Metrics, aggregate_unnamed_reports, AverageMetric\nfrom .gen import BleuMetric, ExactMatchMetric, F1Metric, DistMetric, EmbeddingAverage, VectorExtrema, \\\n    GreedyMatch\nfrom .rec import HitMetric, NDCGMetric, MRRMetric\n"
  },
  {
    "path": "crslab/evaluator/metrics/base.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020/12/2\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nimport functools\nfrom abc import ABC, abstractmethod\n\nimport torch\nfrom typing import Any, Union, List, Optional, Dict\n\nTScalar = Union[int, float, torch.Tensor]\nTVector = Union[List[TScalar], torch.Tensor]\n\n\n@functools.total_ordering\nclass Metric(ABC):\n    \"\"\"\n    Base class for storing metrics.\n\n    Subclasses should define .value(). Examples are provided for each subclass.\n    \"\"\"\n\n    @abstractmethod\n    def value(self) -> float:\n        \"\"\"\n        Return the value of the metric as a float.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def __add__(self, other: Any) -> 'Metric':\n        raise NotImplementedError\n\n    def __iadd__(self, other):\n        return self.__radd__(other)\n\n    def __radd__(self, other: Any):\n        if other is None:\n            return self\n        return self.__add__(other)\n\n    def __str__(self) -> str:\n        return f'{self.value():.4g}'\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.value():.4g})'\n\n    def __float__(self) -> float:\n        return float(self.value())\n\n    def __int__(self) -> int:\n        return int(self.value())\n\n    def __eq__(self, other: Any) -> bool:\n        if isinstance(other, Metric):\n            return self.value() == other.value()\n        else:\n            return self.value() == other\n\n    def __lt__(self, other: Any) -> bool:\n        if isinstance(other, Metric):\n            return self.value() < other.value()\n        else:\n            return self.value() < other\n\n    def __sub__(self, other: Any) -> float:\n        \"\"\"\n        Used heavily for assertAlmostEqual.\n        \"\"\"\n        if not isinstance(other, float):\n            raise TypeError('Metrics.__sub__ is intentionally limited to floats.')\n        return self.value() - other\n\n    def __rsub__(self, other: Any) -> float:\n        \"\"\"\n        Used heavily for assertAlmostEqual.\n\n        NOTE: This is not necessary in python 3.7+.\n        \"\"\"\n        if not isinstance(other, float):\n            raise TypeError('Metrics.__rsub__ is intentionally limited to floats.')\n        return other - self.value()\n\n    @classmethod\n    def as_number(cls, obj: TScalar) -> Union[int, float]:\n        if isinstance(obj, torch.Tensor):\n            obj_as_number: Union[int, float] = obj.item()\n        else:\n            obj_as_number = obj  # type: ignore\n        assert isinstance(obj_as_number, int) or isinstance(obj_as_number, float)\n        return obj_as_number\n\n    @classmethod\n    def as_float(cls, obj: TScalar) -> float:\n        return float(cls.as_number(obj))\n\n    @classmethod\n    def as_int(cls, obj: TScalar) -> int:\n        return int(cls.as_number(obj))\n\n    @classmethod\n    def many(cls, *objs: List[TVector]) -> List['Metric']:\n        \"\"\"\n        Construct many of a Metric from the base parts.\n\n        Useful if you separately compute numerators and denomenators, etc.\n        \"\"\"\n        lengths = [len(o) for o in objs]\n        if len(set(lengths)) != 1:\n            raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')\n        return [cls(*items) for items in zip(*objs)]\n\n\nclass SumMetric(Metric):\n    \"\"\"\n    Class that keeps a running sum of some metric.\n\n    Examples of SumMetric include things like \"exs\", the number of examples seen since\n    the last report, which depends exactly on a teacher.\n    \"\"\"\n\n    __slots__ = ('_sum',)\n\n    def __init__(self, sum_: TScalar = 0):\n        if isinstance(sum_, torch.Tensor):\n            self._sum = sum_.item()\n        else:\n            assert isinstance(sum_, (int, float))\n            self._sum = sum_\n\n    def __add__(self, other: Optional['SumMetric']) -> 'SumMetric':\n        # NOTE: hinting can be cleaned up with \"from __future__ import annotations\" when\n        # we drop Python 3.6\n        if other is None:\n            return self\n        full_sum = self._sum + other._sum\n        # always keep the same return type\n        return type(self)(sum_=full_sum)\n\n    def value(self) -> float:\n        return self._sum\n\n\nclass AverageMetric(Metric):\n    \"\"\"\n    Class that keeps a running average of some metric.\n\n    Examples of AverageMetrics include hits@1, F1, accuracy, etc. These metrics all have\n    per-example values that can be directly mapped back to a teacher.\n    \"\"\"\n\n    __slots__ = ('_numer', '_denom')\n\n    def __init__(self, numer: TScalar, denom: TScalar = 1):\n        self._numer = self.as_number(numer)\n        self._denom = self.as_number(denom)\n\n    def __add__(self, other: Optional['AverageMetric']) -> 'AverageMetric':\n        # NOTE: hinting can be cleaned up with \"from __future__ import annotations\" when\n        # we drop Python 3.6\n        if other is None:\n            return self\n        full_numer: TScalar = self._numer + other._numer\n        full_denom: TScalar = self._denom + other._denom\n        # always keep the same return type\n        return type(self)(numer=full_numer, denom=full_denom)\n\n    def value(self) -> float:\n        if self._numer == 0 and self._denom == 0:\n            # don't nan out if we haven't counted anything\n            return 0.0\n        if self._denom == 0:\n            return float('nan')\n        return self._numer / self._denom\n\n\ndef aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[str, Metric]:\n    \"\"\"\n    Combines metrics without regard for tracking provenence.\n    \"\"\"\n    m: Dict[str, Metric] = {}\n    for task_report in reports:\n        for each_metric, value in task_report.items():\n            m[each_metric] = m.get(each_metric) + value\n    return m\n\n\nclass Metrics(object):\n    \"\"\"\n    Metrics aggregator.\n    \"\"\"\n\n    def __init__(self):\n        self._data = {}\n\n    def __str__(self):\n        return str(self._data)\n\n    def __repr__(self):\n        return f'Metrics({repr(self._data)})'\n\n    def get(self, key: str):\n        if key in self._data.keys():\n            return self._data[key].value()\n        else:\n            raise\n\n    def __getitem__(self, item):\n        return self.get(item)\n\n    def add(self, key: str, value: Optional[Metric]) -> None:\n        \"\"\"\n        Record an accumulation to a metric.\n        \"\"\"\n        self._data[key] = self._data.get(key) + value\n\n    def report(self):\n        \"\"\"\n        Report the metrics over all data seen so far.\n        \"\"\"\n        return {k: v for k, v in self._data.items()}\n\n    def clear(self):\n        \"\"\"\n        Clear all the metrics.\n        \"\"\"\n        self._data.clear()\n"
  },
  {
    "path": "crslab/evaluator/metrics/gen.py",
    "content": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/18\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nimport re\nfrom collections import Counter\n\nimport math\nimport numpy as np\nfrom nltk import ngrams\nfrom nltk.translate.bleu_score import sentence_bleu\nfrom sklearn.metrics.pairwise import cosine_similarity\nfrom typing import List, Optional\n\nfrom crslab.evaluator.metrics.base import AverageMetric, SumMetric\n\nre_art = re.compile(r'\\b(a|an|the)\\b')\nre_punc = re.compile(r'[!\"#$%&()*+,-./:;<=>?@\\[\\]\\\\^`{|}~_\\']')\nre_space = re.compile(r'\\s+')\n\n\nclass PPLMetric(AverageMetric):\n    def value(self):\n        return math.exp(super().value())\n\n\ndef normalize_answer(s):\n    \"\"\"\n    Lower text and remove punctuation, articles and extra whitespace.\n    \"\"\"\n\n    s = s.lower()\n    s = re_punc.sub(' ', s)\n    s = re_art.sub(' ', s)\n    s = re_space.sub(' ', s)\n    # s = ' '.join(s.split())\n    return s\n\n\nclass ExactMatchMetric(AverageMetric):\n    @staticmethod\n    def compute(guess: str, answers: List[str]) -> 'ExactMatchMetric':\n        if guess is None or answers is None:\n            return None\n        for a in answers:\n            if guess == a:\n                return ExactMatchMetric(1)\n        return ExactMatchMetric(0)\n\n\nclass F1Metric(AverageMetric):\n    \"\"\"\n    Helper class which computes token-level F1.\n    \"\"\"\n\n    @staticmethod\n    def _prec_recall_f1_score(pred_items, gold_items):\n        \"\"\"\n        Compute precision, recall and f1 given a set of gold and prediction items.\n\n        :param pred_items: iterable of predicted values\n        :param gold_items: iterable of gold values\n\n        :return: tuple (p, r, f1) for precision, recall, f1\n        \"\"\"\n        common = Counter(gold_items) & Counter(pred_items)\n        num_same = sum(common.values())\n        if num_same == 0:\n            return 0\n        precision = 1.0 * num_same / len(pred_items)\n        recall = 1.0 * num_same / len(gold_items)\n        f1 = (2 * precision * recall) / (precision + recall)\n        return f1\n\n    @staticmethod\n    def compute(guess: str, answers: List[str]) -> 'F1Metric':\n        if guess is None or answers is None:\n            return AverageMetric(0, 0)\n        g_tokens = guess.split()\n        scores = [\n            F1Metric._prec_recall_f1_score(g_tokens, a.split())\n            for a in answers\n        ]\n        return F1Metric(max(scores), 1)\n\n\nclass BleuMetric(AverageMetric):\n    @staticmethod\n    def compute(guess: str, answers: List[str], k: int) -> Optional['BleuMetric']:\n        \"\"\"\n        Compute approximate BLEU score between guess and a set of answers.\n        \"\"\"\n\n        weights = [0] * 4\n        weights[k - 1] = 1\n        score = sentence_bleu(\n            [a.split(\" \") for a in answers],\n            guess.split(\" \"),\n            weights=weights,\n        )\n        return BleuMetric(score)\n\n\nclass DistMetric(SumMetric):\n    @staticmethod\n    def compute(sent: str, k: int) -> 'DistMetric':\n        token_set = set()\n        for token in ngrams(sent.split(), k):\n            token_set.add(token)\n        return DistMetric(len(token_set))\n\n\nclass EmbeddingAverage(AverageMetric):\n    @staticmethod\n    def _avg_embedding(embedding):\n        return np.sum(embedding, axis=0) / (np.linalg.norm(np.sum(embedding, axis=0)) + 1e-12)\n\n    @staticmethod\n    def compute(hyp_embedding, ref_embeddings) -> 'EmbeddingAverage':\n        hyp_avg_emb = EmbeddingAverage._avg_embedding(hyp_embedding).reshape(1, -1)\n        ref_avg_embs = [EmbeddingAverage._avg_embedding(emb) for emb in ref_embeddings]\n        ref_avg_embs = np.array(ref_avg_embs)\n        return EmbeddingAverage(float(cosine_similarity(hyp_avg_emb, ref_avg_embs).max()))\n\n\nclass VectorExtrema(AverageMetric):\n    @staticmethod\n    def _extreme_embedding(embedding):\n        max_emb = np.max(embedding, axis=0)\n        min_emb = np.min(embedding, axis=0)\n        extreme_emb = np.fromiter(\n            map(lambda x, y: x if ((x > y or x < -y) and y > 0) or ((x < y or x > -y) and y < 0) else y, max_emb,\n                min_emb), dtype=float)\n        return extreme_emb\n\n    @staticmethod\n    def compute(hyp_embedding, ref_embeddings) -> 'VectorExtrema':\n        hyp_ext_emb = VectorExtrema._extreme_embedding(hyp_embedding).reshape(1, -1)\n        ref_ext_embs = [VectorExtrema._extreme_embedding(emb) for emb in ref_embeddings]\n        ref_ext_embs = np.asarray(ref_ext_embs)\n        return VectorExtrema(float(cosine_similarity(hyp_ext_emb, ref_ext_embs).max()))\n\n\nclass GreedyMatch(AverageMetric):\n    @staticmethod\n    def compute(hyp_embedding, ref_embeddings) -> 'GreedyMatch':\n        hyp_emb = np.asarray(hyp_embedding)\n        ref_embs = (np.asarray(ref_embedding) for ref_embedding in ref_embeddings)\n        score_max = 0\n        for ref_emb in ref_embs:\n            sim_mat = cosine_similarity(hyp_emb, ref_emb)\n            score_max = max(score_max, (sim_mat.max(axis=0).mean() + sim_mat.max(axis=1).mean()) / 2)\n        return GreedyMatch(score_max)\n"
  },
  {
    "path": "crslab/evaluator/metrics/rec.py",
    "content": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/2\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\nimport math\n\nfrom crslab.evaluator.metrics.base import AverageMetric\n\n\nclass HitMetric(AverageMetric):\n    @staticmethod\n    def compute(ranks, label, k) -> 'HitMetric':\n        return HitMetric(int(label in ranks[:k]))\n\n\nclass NDCGMetric(AverageMetric):\n    @staticmethod\n    def compute(ranks, label, k) -> 'NDCGMetric':\n        if label in ranks[:k]:\n            label_rank = ranks.index(label)\n            return NDCGMetric(1 / math.log2(label_rank + 2))\n        return NDCGMetric(0)\n\n\nclass MRRMetric(AverageMetric):\n    @staticmethod\n    def compute(ranks, label, k) -> 'MRRMetric':\n        if label in ranks[:k]:\n            label_rank = ranks.index(label)\n            return MRRMetric(1 / (label_rank + 1))\n        return MRRMetric(0)\n"
  },
  {
    "path": "crslab/evaluator/rec.py",
    "content": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/17\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nimport time\n\nfrom loguru import logger\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom crslab.evaluator.base import BaseEvaluator\nfrom crslab.evaluator.utils import nice_report\nfrom .metrics import *\n\n\nclass RecEvaluator(BaseEvaluator):\n    \"\"\"The evaluator specially for reommender model\n    \n    Args:\n        rec_metrics: the metrics to evaluate recommender model, including hit@K, ndcg@K and mrr@K\n        optim_metrics: the metrics to optimize in training\n    \"\"\"\n\n    def __init__(self, tensorboard=False):\n        super(RecEvaluator, self).__init__()\n        self.rec_metrics = Metrics()\n        self.optim_metrics = Metrics()\n        self.tensorboard = tensorboard\n        if self.tensorboard:\n            self.writer = SummaryWriter(log_dir='runs/' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime()))\n            self.reports_name = ['Recommendation Metrics', 'Optimization Metrics']\n\n    def rec_evaluate(self, ranks, label):\n        for k in [1, 10, 50]:\n            if len(ranks) >= k:\n                self.rec_metrics.add(f\"hit@{k}\", HitMetric.compute(ranks, label, k))\n                self.rec_metrics.add(f\"ndcg@{k}\", NDCGMetric.compute(ranks, label, k))\n                self.rec_metrics.add(f\"mrr@{k}\", MRRMetric.compute(ranks, label, k))\n\n    def report(self, epoch=-1, mode='test'):\n        reports = [self.rec_metrics.report(), self.optim_metrics.report()]\n        if self.tensorboard and mode != 'test':\n            for idx, task_report in enumerate(reports):\n                for each_metric, value in task_report.items():\n                    self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch)\n        logger.info('\\n' + nice_report(aggregate_unnamed_reports(reports)))\n\n    def reset_metrics(self):\n        self.rec_metrics.clear()\n        self.optim_metrics.clear()\n"
  },
  {
    "path": "crslab/evaluator/standard.py",
    "content": "# @Time   : 2020/11/30\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2020/12/18\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nimport os\nimport time\nfrom collections import defaultdict\n\nimport fasttext\nfrom loguru import logger\nfrom nltk import ngrams\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom crslab.evaluator.base import BaseEvaluator\nfrom crslab.evaluator.utils import nice_report\nfrom .embeddings import resources\nfrom .metrics import *\nfrom ..config import EMBEDDING_PATH\nfrom ..download import build\n\n\nclass StandardEvaluator(BaseEvaluator):\n    \"\"\"The evaluator for all kind of model(recommender, conversation, policy)\n    \n    Args:\n        rec_metrics: the metrics to evaluate recommender model, including hit@K, ndcg@K and mrr@K\n        dist_set: the set to record dist n-gram\n        dist_cnt: the count of dist n-gram evaluation\n        gen_metrics: the metrics to evaluate conversational model, including bleu, dist, embedding metrics, f1\n        optim_metrics: the metrics to optimize in training\n    \"\"\"\n\n    def __init__(self, language, tensorboard=False):\n        super(StandardEvaluator, self).__init__()\n        # rec\n        self.rec_metrics = Metrics()\n        # gen\n        self.dist_set = defaultdict(set)\n        self.dist_cnt = 0\n        self.gen_metrics = Metrics()\n        self._load_embedding(language)\n        # optim\n        self.optim_metrics = Metrics()\n        # tensorboard\n        self.tensorboard = tensorboard\n        if self.tensorboard:\n            self.writer = SummaryWriter(log_dir='runs/' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime()))\n            self.reports_name = ['Recommendation Metrics', 'Generation Metrics', 'Optimization Metrics']\n\n    def _load_embedding(self, language):\n        resource = resources[language]\n        dpath = os.path.join(EMBEDDING_PATH, language)\n        build(dpath, resource['file'], resource['version'])\n\n        model_file = os.path.join(dpath, f'cc.{language}.300.bin')\n        self.ft = fasttext.load_model(model_file)\n        logger.info(f'[Load {model_file} for embedding metric')\n\n    def _get_sent_embedding(self, sent):\n        return [self.ft[token] for token in sent.split()]\n\n    def rec_evaluate(self, ranks, label):\n        for k in [1, 10, 50]:\n            if len(ranks) >= k:\n                self.rec_metrics.add(f\"hit@{k}\", HitMetric.compute(ranks, label, k))\n                self.rec_metrics.add(f\"ndcg@{k}\", NDCGMetric.compute(ranks, label, k))\n                self.rec_metrics.add(f\"mrr@{k}\", MRRMetric.compute(ranks, label, k))\n\n    def gen_evaluate(self, hyp, refs):\n        if hyp:\n            self.gen_metrics.add(\"f1\", F1Metric.compute(hyp, refs))\n\n            for k in range(1, 5):\n                self.gen_metrics.add(f\"bleu@{k}\", BleuMetric.compute(hyp, refs, k))\n                for token in ngrams(hyp, k):\n                    self.dist_set[f\"dist@{k}\"].add(token)\n            self.dist_cnt += 1\n\n            hyp_emb = self._get_sent_embedding(hyp)\n            ref_embs = [self._get_sent_embedding(ref) for ref in refs]\n            self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs))\n            self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs))\n            self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs))\n\n    def report(self, epoch=-1, mode='test'):\n        for k, v in self.dist_set.items():\n            self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt))\n        reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()]\n        if self.tensorboard and mode != 'test':\n            for idx, task_report in enumerate(reports):\n                for each_metric, value in task_report.items():\n                    self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch)\n        logger.info('\\n' + nice_report(aggregate_unnamed_reports(reports)))\n\n    def reset_metrics(self):\n        # rec\n        self.rec_metrics.clear()\n        # conv\n        self.gen_metrics.clear()\n        self.dist_cnt = 0\n        self.dist_set.clear()\n        # optim\n        self.optim_metrics.clear()\n"
  },
  {
    "path": "crslab/evaluator/utils.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/17\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/17\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nimport json\nimport re\nimport shutil\nfrom collections import OrderedDict\n\nimport math\nimport torch\nfrom typing import Union, Tuple\n\nfrom .metrics import Metric\n\n\ndef _line_width():\n    try:\n        # if we're in an interactive ipython notebook, hardcode a longer width\n        __IPYTHON__\n        return 128\n    except NameError:\n        return shutil.get_terminal_size((88, 24)).columns\n\n\ndef float_formatter(f: Union[float, int]) -> str:\n    \"\"\"\n    Format a float as a pretty string.\n    \"\"\"\n    if f != f:\n        # instead of returning nan, return \"\" so it shows blank in table\n        return \"\"\n    if isinstance(f, int):\n        # don't do any rounding of integers, leave them alone\n        return str(f)\n    if f >= 1000:\n        # numbers > 1000 just round to the nearest integer\n        s = f'{f:.0f}'\n    else:\n        # otherwise show 4 significant figures, regardless of decimal spot\n        s = f'{f:.4g}'\n    # replace leading 0's with blanks for easier reading\n    # example:  -0.32 to -.32\n    s = s.replace('-0.', '-.')\n    if s.startswith('0.'):\n        s = s[1:]\n    # Add the trailing 0's to always show 4 digits\n    # example: .32 to .3200\n    if s[0] == '.' and len(s) < 5:\n        s += '0' * (5 - len(s))\n    return s\n\n\ndef round_sigfigs(x: Union[float, 'torch.Tensor'], sigfigs=4) -> float:\n    \"\"\"\n    Round value to specified significant figures.\n\n    :param x: input number\n    :param sigfigs: number of significant figures to return\n\n    :returns: float number rounded to specified sigfigs\n    \"\"\"\n    x_: float\n    if isinstance(x, torch.Tensor):\n        x_ = x.item()\n    else:\n        x_ = x  # type: ignore\n\n    try:\n        if x_ == 0:\n            return 0\n        return round(x_, -math.floor(math.log10(abs(x_)) - sigfigs + 1))\n    except (ValueError, OverflowError) as ex:\n        if x_ in [float('inf'), float('-inf')] or x_ != x_:  # inf or nan\n            return x_\n        else:\n            raise ex\n\n\ndef _report_sort_key(report_key: str) -> Tuple[str, str]:\n    \"\"\"\n    Sorting name for reports.\n\n    Sorts by main metric alphabetically, then by task.\n    \"\"\"\n    # if metric is on its own, like \"f1\", we will return ('', 'f1')\n    # if metric is from multitask, we denote it.\n    # e.g. \"convai2/f1\" -> ('convai2', 'f1')\n    # we handle multiple cases of / because sometimes teacher IDs have\n    # filenames.\n    fields = report_key.split(\"/\")\n    main_key = fields.pop(-1)\n    sub_key = '/'.join(fields)\n    return (sub_key or 'all', main_key)\n\n\ndef nice_report(report) -> str:\n    \"\"\"\n    Render an agent Report as a beautiful string.\n\n    If pandas is installed,  we will use it to render as a table. Multitask\n    metrics will be shown per row, e.g.\n\n    .. code-block:\n                 f1   ppl\n       all     .410  27.0\n       task1   .400  32.0\n       task2   .420  22.0\n\n    If pandas is not available, we will use a dict with like-metrics placed\n    next to each other.\n    \"\"\"\n    if not report:\n        return \"\"\n\n    try:\n        import pandas as pd\n\n        use_pandas = True\n    except ImportError:\n        use_pandas = False\n\n    sorted_keys = sorted(report.keys(), key=_report_sort_key)\n    output: OrderedDict[Union[str, Tuple[str, str]], float] = OrderedDict()\n    for k in sorted_keys:\n        v = report[k]\n        if isinstance(v, Metric):\n            v = v.value()\n        if use_pandas:\n            output[_report_sort_key(k)] = v\n        else:\n            output[k] = v\n\n    if use_pandas:\n        line_width = _line_width()\n\n        df = pd.DataFrame([output])\n        df.columns = pd.MultiIndex.from_tuples(df.columns)\n        df = df.stack().transpose().droplevel(0, axis=1)\n        result = \"   \" + df.to_string(\n            na_rep=\"\",\n            line_width=line_width - 3,  # -3 for the extra spaces we add\n            float_format=float_formatter,\n            index=df.shape[0] > 1,\n        ).replace(\"\\n\\n\", \"\\n\").replace(\"\\n\", \"\\n   \")\n        result = re.sub(r\"\\s+$\", \"\", result)\n        return result\n    else:\n        return json.dumps(\n            {\n                k: round_sigfigs(v, 4) if isinstance(v, float) else v\n                for k, v in output.items()\n            }\n        )\n"
  },
  {
    "path": "crslab/model/__init__.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020/12/24\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\n# @Time   : 2021/10/06\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nimport torch\nfrom loguru import logger\n\nfrom .conversation import *\nfrom .crs import *\nfrom .policy import *\nfrom .recommendation import *\n\nModel_register_table = {\n    'KGSF': KGSFModel,\n    'KBRD': KBRDModel,\n    'TGRec': TGRecModel,\n    'TGConv': TGConvModel,\n    'TGPolicy': TGPolicyModel,\n    'ReDialRec': ReDialRecModel,\n    'ReDialConv': ReDialConvModel,\n    'InspiredRec': InspiredRecModel,\n    'InspiredConv': InspiredConvModel,\n    'GPT2': GPT2Model,\n    'Transformer': TransformerModel,\n    'ConvBERT': ConvBERTModel,\n    'ProfileBERT': ProfileBERTModel,\n    'TopicBERT': TopicBERTModel,\n    'PMI': PMIModel,\n    'MGCG': MGCGModel,\n    'BERT': BERTModel,\n    'SASREC': SASRECModel,\n    'GRU4REC': GRU4RECModel,\n    'Popularity': PopularityModel,\n    'TextCNN': TextCNNModel,\n    'NTRD': NTRDModel\n}\n\n\ndef get_model(config, model_name, device, vocab, side_data=None):\n    if model_name in Model_register_table:\n        model = Model_register_table[model_name](config, device, vocab, side_data)\n        logger.info(f'[Build model {model_name}]')\n        if config.opt[\"gpu\"] == [-1]:\n            return model\n        else:\n            if len(config.opt[\"gpu\"]) > 1:\n                if model_name == 'PMI' or model_name == 'KBRD':\n                    logger.info(f'[PMI/KBRD model does not support multi GPUs yet, using single GPU now]')\n                    return model.to(device)\n                else:\n                    return torch.nn.DataParallel(model, device_ids=config[\"gpu\"])\n            else:\n                return model.to(device)\n\n    else:\n        raise NotImplementedError('Model [{}] has not been implemented'.format(model_name))\n"
  },
  {
    "path": "crslab/model/base.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020/12/29\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nfrom abc import ABC, abstractmethod\n\nfrom torch import nn\n\nfrom crslab.download import build\n\n\nclass BaseModel(ABC, nn.Module):\n    \"\"\"Base class for all models\"\"\"\n\n    def __init__(self, opt, device, dpath=None, resource=None):\n        super(BaseModel, self).__init__()\n        self.opt = opt\n        self.device = device\n\n        if resource is not None:\n            self.dpath = dpath\n            dfile = resource['file']\n            build(dpath, dfile, version=resource['version'])\n\n        self.build_model()\n\n    @abstractmethod\n    def build_model(self, *args, **kwargs):\n        \"\"\"build model\"\"\"\n        pass\n\n    def recommend(self, batch, mode):\n        \"\"\"calculate loss and prediction of recommendation for batch under certain mode\n\n        Args:\n            batch (dict or tuple): batch data\n            mode (str, optional): train/valid/test.\n        \"\"\"\n        pass\n\n    def converse(self, batch, mode):\n        \"\"\"calculate loss and prediction of conversation for batch under certain mode\n\n        Args:\n            batch (dict or tuple): batch data\n            mode (str, optional): train/valid/test.\n        \"\"\"\n        pass\n\n    def guide(self, batch, mode):\n        \"\"\"calculate loss and prediction of guidance for batch under certain mode\n\n        Args:\n            batch (dict or tuple): batch data\n            mode (str, optional): train/valid/test.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "crslab/model/conversation/__init__.py",
    "content": "from .gpt2 import GPT2Model\nfrom .transformer import TransformerModel\n"
  },
  {
    "path": "crslab/model/conversation/gpt2/__init__.py",
    "content": "from .gpt2 import GPT2Model\n"
  },
  {
    "path": "crslab/model/conversation/gpt2/gpt2.py",
    "content": "# @Time   : 2020/12/14\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2021/1/7\n# @Author : Xiaolei Wang\n# @email  : wxl1999@foxmail.com\n\nr\"\"\"\nGPT2\n====\nReferences:\n    Radford, Alec, et al. `\"Language Models are Unsupervised Multitask Learners.\"`_.\n\n.. _`\"Language Models are Unsupervised Multitask Learners.\"`:\n   https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf\n\n\"\"\"\n\nimport os\n\nimport torch\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import GPT2LMHeadModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\n\n\nclass GPT2Model(BaseModel):\n    \"\"\"\n        \n    Attributes:\n        context_truncate: A integer indicating the length of dialogue context.\n        response_truncate: A integer indicating the length of dialogue response.\n        pad_id: A integer indicating the id of padding token.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.context_truncate = opt['context_truncate']\n        self.response_truncate = opt['response_truncate']\n        self.pad_id = vocab['pad']\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['gpt2'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"gpt2\", language)\n        super(GPT2Model, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self):\n        \"\"\"build model\"\"\"\n        self.model = GPT2LMHeadModel.from_pretrained(self.dpath)\n        self.loss = CrossEntropyLoss(ignore_index=self.pad_id)\n\n    def forward(self, batch, mode):\n        _, _, input_ids, context, _, _, y = batch\n        if mode != 'test':\n            # torch.tensor's shape = (bs, seq_len, v_s); tuple's length = 12\n            lm_logits = self.model(input_ids).logits\n\n            # index from 1 to self.reponse_truncate is valid response\n            loss = self.calculate_loss(\n                lm_logits[:, -self.response_truncate:-1, :],\n                input_ids[:, -self.response_truncate + 1:])\n\n            pred = torch.max(lm_logits, dim=2)[1]  # [bs, seq_len]\n            pred = pred[:, -self.response_truncate:]\n\n            return loss, pred\n        else:\n            return self.generate(context)\n\n    def generate(self, context):\n        \"\"\"\n        Args:\n            context: torch.tensor, shape=(bs, context_turncate)\n\n        Returns:\n            generated_response: torch.tensor, shape=(bs, reponse_turncate-1)\n        \"\"\"\n        generated_response = []\n        former_hidden_state = None\n        context = context[..., -self.response_truncate + 1:]\n\n        for i in range(self.response_truncate - 1):\n            outputs = self.model(context, former_hidden_state)  # (bs, c_t, v_s),\n            last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values\n\n            next_token_logits = last_hidden_state[:, -1, :]  # (bs, v_s)\n            preds = next_token_logits.argmax(dim=-1).long()  # (bs)\n\n            context = preds.unsqueeze(1)\n            generated_response.append(preds)\n\n        generated_response = torch.stack(generated_response).T\n\n        return generated_response\n\n    def calculate_loss(self, logit, labels):\n        \"\"\"\n        Args:\n            preds: torch.FloatTensor, shape=(bs, response_truncate, vocab_size)\n            labels: torch.LongTensor, shape=(bs, response_truncate)\n\n        \"\"\"\n\n        loss = self.loss(logit.reshape(-1, logit.size(-1)), labels.reshape(-1))\n        return loss\n\n    def generate_bs(self, context, beam=4):\n        context = context[..., -self.response_truncate + 1:]\n        context_former = context\n        batch_size = context.shape[0]\n        sequences = [[[list(), 1.0]]] * batch_size\n        for i in range(self.response_truncate - 1):\n            if sequences != [[[list(), 1.0]]] * batch_size:\n                context = []\n                for i in range(batch_size):\n                    for cand in sequences[i]:\n                        text = torch.cat(\n                            (context_former[i], torch.tensor(cand[0]).to(self.device)))  # 由于取消了state，与之前的context拼接\n                        context.append(text)\n                context = torch.stack(context)\n            with torch.no_grad():\n                outputs = self.model(context)\n            last_hidden_state, state = outputs.logits, outputs.past_key_values\n            next_token_logits = last_hidden_state[:, -1, :]\n            next_token_probs = torch.nn.functional.softmax(next_token_logits)\n            topk = torch.topk(next_token_probs, beam, dim=-1)\n            probs = topk.values.reshape([batch_size, -1, beam])  # (bs, candidate, beam)\n            preds = topk.indices.reshape([batch_size, -1, beam])  # (bs, candidate, beam)\n\n            for j in range(batch_size):\n                all_candidates = []\n                for n in range(len(sequences[j])):\n                    for k in range(beam):\n                        seq = sequences[j][n][0]\n                        prob = sequences[j][n][1]\n                        seq_tmp = seq.copy()\n                        seq_tmp.append(preds[j][n][k])\n                        candidate = [seq_tmp, prob * probs[j][n][k]]\n                        all_candidates.append(candidate)\n                ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)\n                sequences[j] = ordered[:beam]\n\n        res = []\n        for i in range(batch_size):\n            res.append(torch.stack(sequences[i][0][0]))\n        res = torch.stack(res)\n        return res\n"
  },
  {
    "path": "crslab/model/conversation/transformer/__init__.py",
    "content": "from .transformer import TransformerModel\n"
  },
  {
    "path": "crslab/model/conversation/transformer/transformer.py",
    "content": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nTransformer\n===========\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger\nfrom torch import nn\n\nfrom crslab.model.base import BaseModel\nfrom crslab.model.utils.functions import edge_to_pyg_format\nfrom crslab.model.utils.modules.transformer import TransformerEncoder, TransformerDecoder\n\n\nclass TransformerModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        vocab_size: A integer indicating the vocabulary size.\n        pad_token_idx: A integer indicating the id of padding token.\n        start_token_idx: A integer indicating the id of start token.\n        end_token_idx: A integer indicating the id of end token.\n        token_emb_dim: A integer indicating the dimension of token embedding layer.\n        pretrain_embedding: A string indicating the path of pretrained embedding.\n        n_word: A integer indicating the number of words.\n        n_entity: A integer indicating the number of entities.\n        pad_word_idx: A integer indicating the id of word padding.\n        pad_entity_idx: A integer indicating the id of entity padding.\n        num_bases: A integer indicating the number of bases.\n        kg_emb_dim: A integer indicating the dimension of kg embedding.\n        n_heads: A integer indicating the number of heads.\n        n_layers: A integer indicating the number of layer.\n        ffn_size: A integer indicating the size of ffn hidden.\n        dropout: A float indicating the drouput rate.\n        attention_dropout: A integer indicating the drouput rate of attention layer.\n        relu_dropout: A integer indicating the drouput rate of relu layer.\n        learn_positional_embeddings: A boolean indicating if we learn the positional embedding.\n        embeddings_scale: A boolean indicating if we use the embeddings scale.\n        reduction: A boolean indicating if we use the reduction.\n        n_positions: A integer indicating the number of position.\n        longest_label: A integer indicating the longest length for response generation.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        # vocab\n        self.vocab_size = vocab['vocab_size']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.token_emb_dim = opt['token_emb_dim']\n        self.pretrain_embedding = side_data.get('embedding', None)\n        # kg\n        self.n_word = vocab['n_word']\n        self.n_entity = vocab['n_entity']\n        self.pad_word_idx = vocab['pad_word']\n        self.pad_entity_idx = vocab['pad_entity']\n        entity_kg = side_data['entity_kg']\n        self.n_relation = entity_kg['n_relation']\n        entity_edges = entity_kg['edge']\n        self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN')\n        self.entity_edge_idx = self.entity_edge_idx.to(device)\n        self.entity_edge_type = self.entity_edge_type.to(device)\n        word_edges = side_data['word_kg']['edge']\n        self.word_edges = edge_to_pyg_format(word_edges, 'GCN').to(device)\n        self.num_bases = opt['num_bases']\n        self.kg_emb_dim = opt['kg_emb_dim']\n        # transformer\n        self.n_heads = opt['n_heads']\n        self.n_layers = opt['n_layers']\n        self.ffn_size = opt['ffn_size']\n        self.dropout = opt['dropout']\n        self.attention_dropout = opt['attention_dropout']\n        self.relu_dropout = opt['relu_dropout']\n        self.learn_positional_embeddings = opt['learn_positional_embeddings']\n        self.embeddings_scale = opt['embeddings_scale']\n        self.reduction = opt['reduction']\n        self.n_positions = opt['n_positions']\n        self.longest_label = opt.get('longest_label', 1)\n        super(TransformerModel, self).__init__(opt, device)\n\n    def build_model(self):\n        self._init_embeddings()\n        self._build_conversation_layer()\n\n    def _init_embeddings(self):\n        if self.pretrain_embedding is not None:\n            self.token_embedding = nn.Embedding.from_pretrained(\n                torch.as_tensor(self.pretrain_embedding, dtype=torch.float), freeze=False,\n                padding_idx=self.pad_token_idx)\n        else:\n            self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx)\n            nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)\n            nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0)\n\n        logger.debug('[Finish init embeddings]')\n\n    def _build_conversation_layer(self):\n        self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long))\n        self.conv_encoder = TransformerEncoder(\n            n_heads=self.n_heads,\n            n_layers=self.n_layers,\n            embedding_size=self.token_emb_dim,\n            ffn_size=self.ffn_size,\n            vocabulary_size=self.vocab_size,\n            embedding=self.token_embedding,\n            dropout=self.dropout,\n            attention_dropout=self.attention_dropout,\n            relu_dropout=self.relu_dropout,\n            padding_idx=self.pad_token_idx,\n            learn_positional_embeddings=self.learn_positional_embeddings,\n            embeddings_scale=self.embeddings_scale,\n            reduction=self.reduction,\n            n_positions=self.n_positions,\n        )\n\n        self.conv_decoder = TransformerDecoder(\n            self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size,\n            embedding=self.token_embedding,\n            dropout=self.dropout,\n            attention_dropout=self.attention_dropout,\n            relu_dropout=self.relu_dropout,\n            embeddings_scale=self.embeddings_scale,\n            learn_positional_embeddings=self.learn_positional_embeddings,\n            padding_idx=self.pad_token_idx,\n            n_positions=self.n_positions\n        )\n\n        self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)\n\n        logger.debug('[Finish build conv layer]')\n\n    def _starts(self, batch_size):\n        \"\"\"Return bsz start tokens.\"\"\"\n        return self.START.detach().expand(batch_size, 1)\n\n    def _decode_forced_with_kg(self, token_encoding, response):\n        batch_size, seq_len = response.shape\n        start = self._starts(batch_size)\n        inputs = torch.cat((start, response[:, :-1]), dim=-1).long()\n\n        dialog_latent, _ = self.conv_decoder(inputs, token_encoding)  # (bs, seq_len, dim)\n\n        gen_logits = F.linear(dialog_latent, self.token_embedding.weight)  # (bs, seq_len, vocab_size)\n        preds = gen_logits.argmax(dim=-1)\n        return gen_logits, preds\n\n    def _decode_greedy_with_kg(self, token_encoding):\n        batch_size = token_encoding[0].shape[0]\n        inputs = self._starts(batch_size).long()\n        incr_state = None\n        logits = []\n        for _ in range(self.longest_label):\n            dialog_latent, incr_state = self.conv_decoder(inputs, token_encoding, incr_state)\n            dialog_latent = dialog_latent[:, -1:, :]  # (bs, 1, dim)\n\n            gen_logits = F.linear(dialog_latent, self.token_embedding.weight)\n            preds = gen_logits.argmax(dim=-1).long()\n            logits.append(gen_logits)\n            inputs = torch.cat((inputs, preds), dim=1)\n\n            finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size\n            if finished:\n                break\n        logits = torch.cat(logits, dim=1)\n        return logits, inputs\n\n    def _decode_beam_search_with_kg(self, token_encoding, beam=4):\n        batch_size = token_encoding[0].shape[0]\n        xs = self._starts(batch_size).long().reshape(1, batch_size, -1)\n        incr_state = None\n        sequences = [[[list(), list(), 1.0]]] * batch_size\n        for i in range(self.longest_label):\n            # at beginning there is 1 candidate, when i!=0 there are 4 candidates\n            if i == 1:\n                token_encoding = (token_encoding[0].repeat(beam, 1, 1),\n                                  token_encoding[1].repeat(beam, 1, 1))\n            if i != 0:\n                xs = []\n                for d in range(len(sequences[0])):\n                    for j in range(batch_size):\n                        text = sequences[j][d][0]\n                        xs.append(text)\n                xs = torch.stack(xs).reshape(beam, batch_size, -1)  # (beam, batch_size, _)\n\n            dialog_latent, incr_state = self.conv_decoder(xs.reshape(len(sequences[0]) * batch_size, -1),\n                                                          token_encoding,\n                                                          incr_state)\n            dialog_latent = dialog_latent[:, -1:, :]  # (bs, 1, dim)\n            gen_logits = F.linear(dialog_latent, self.token_embedding.weight)\n\n            logits = gen_logits.reshape(len(sequences[0]), batch_size, 1, -1)\n            # turn into probabilities,in case of negative numbers\n            probs, preds = torch.nn.functional.softmax(logits).topk(beam, dim=-1)\n\n            # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam\n\n            for j in range(batch_size):\n                all_candidates = []\n                for n in range(len(sequences[j])):\n                    for k in range(beam):\n                        prob = sequences[j][n][2]\n                        logit = sequences[j][n][1]\n                        if logit == []:\n                            logit_tmp = logits[n][j][0].unsqueeze(0)\n                        else:\n                            logit_tmp = torch.cat((logit, logits[n][j][0].unsqueeze(0)), dim=0)\n                        seq_tmp = torch.cat((xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1)))\n                        candidate = [seq_tmp, logit_tmp, prob * probs[n][j][0][k]]\n                        all_candidates.append(candidate)\n                ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True)\n                sequences[j] = ordered[:beam]\n\n            # check if everyone has generated an end token\n            all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == batch_size\n            if all_finished:\n                break\n        logits = torch.stack([seq[0][1] for seq in sequences])\n        xs = torch.stack([seq[0][0] for seq in sequences])\n        return logits, xs\n\n    def forward(self, batch, mode):\n        context_tokens, context_entities, context_words, response = batch\n\n        # encoder-decoder\n        tokens_encoding = self.conv_encoder(context_tokens)\n        if mode != 'test':\n            self.longest_label = max(self.longest_label, response.shape[1])\n            logits, preds = self._decode_forced_with_kg(tokens_encoding,\n                                                        response)\n\n            logits = logits.view(-1, logits.shape[-1])\n            response = response.view(-1)\n            loss = self.conv_loss(logits, response)\n            return loss, preds\n        else:\n            logits, preds = self._decode_greedy_with_kg(tokens_encoding)\n            return preds\n"
  },
  {
    "path": "crslab/model/crs/__init__.py",
    "content": "from .inspired import *\nfrom .kbrd import *\nfrom .kgsf import *\nfrom .redial import *\nfrom .tgredial import *\nfrom .ntrd import *\n"
  },
  {
    "path": "crslab/model/crs/inspired/__init__.py",
    "content": "from .inspired_conv import InspiredConvModel\nfrom .inspired_rec import InspiredRecModel\n"
  },
  {
    "path": "crslab/model/crs/inspired/inspired_conv.py",
    "content": "# @Time   : 2021/3/10\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nimport os\n\nimport torch\nfrom transformers import GPT2LMHeadModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\nfrom .modules import SequenceCrossEntropyLoss\n\n\nclass InspiredConvModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        context_truncate: A integer indicating the length of dialogue context.\n        response_truncate: A integer indicating the length of dialogue response.\n        pad_id: A integer indicating the id of padding token.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.context_truncate = opt['context_truncate']\n        self.response_truncate = opt['response_truncate']\n        self.pad_id = vocab['pad']\n        self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['gpt2'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"gpt2\", language)\n        super(InspiredConvModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self):\n        \"\"\"build model for seeker and recommender separately\"\"\"\n        self.model_sk = GPT2LMHeadModel.from_pretrained(self.dpath)\n        self.model_rm = GPT2LMHeadModel.from_pretrained(self.dpath)\n        self.loss = SequenceCrossEntropyLoss(self.pad_id, self.label_smoothing)\n\n    def converse(self, batch, mode):\n        \"\"\"\n        Args:\n            batch: ::\n\n                {\n                    'roles': (batch_size),\n                    'input_ids': (batch_size, max_seq_length),\n                    'context': (batch_size, context_truncate)\n                }\n\n        \"\"\"\n        roles, input_ids, context, _ = batch\n        input_ids_iters = input_ids.unsqueeze(1)\n\n        past = None\n        lm_logits_all = []\n\n        if mode != 'test':\n            for turn, iter in enumerate(input_ids_iters):\n                if (roles[turn] == 0):\n                    # considering that gpt2 only supports up to 1024 tokens\n                    if past is not None and past[0].shape[3] + iter.shape[1] > 1024:\n                        past = None\n                    outputs = self.model_sk(iter, past_key_values=past)\n                    lm_logits, past = outputs.logits, outputs.past_key_values\n                    lm_logits_all.append(lm_logits)\n                else:\n                    if past is not None and past[0].shape[3] + iter.shape[1] > 1024:\n                        past = None\n                    outputs = self.model_rm(iter, past_key_values=past)\n                    lm_logits, past = outputs.logits, outputs.past_key_values\n                    lm_logits_all.append(lm_logits)\n\n            lm_logits_all = torch.cat(lm_logits_all, dim=0)  # (b_s, seq_len, vocab_size)\n\n            # index from 1 to self.reponse_truncate is valid response\n            loss = self.calculate_loss(\n                lm_logits_all[:, -self.response_truncate:-1, :],\n                input_ids[:, -self.response_truncate + 1:])\n\n            pred = torch.max(lm_logits_all, dim=2)[1]  # (b_s, seq_len)\n            pred = pred[:, -self.response_truncate:]\n\n            return loss, pred\n        else:\n            return self.generate(roles, context)\n\n    def generate(self, roles, context):\n        \"\"\"\n        Args:\n            roles: the role of each speak corresponding to the utterance in batch, shape=(b_s)\n            context: torch.tensor, shape=(b_s, context_turncate)\n\n        Returns:\n            generated_response: torch.tensor, shape=(b_s, reponse_turncate-1)\n        \"\"\"\n        generated_response = []\n        former_hidden_state = None\n        context = context[..., -self.response_truncate + 1:]\n\n        for i in range(self.response_truncate - 1):\n            last_hidden_state_all = []\n            context_iters = context.unsqueeze(1)\n            for turn, iter in enumerate(context_iters):\n                if roles[turn] == 0:\n                    outputs = self.model_sk(iter, former_hidden_state)  # (1, s_l, v_s),\n                else:\n                    outputs = self.model_rm(iter, former_hidden_state)  # (1, s_l, v_s),\n                last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values\n                last_hidden_state_all.append(last_hidden_state)\n\n            last_hidden_state_all = torch.cat(last_hidden_state_all, dim=0)\n            next_token_logits = last_hidden_state_all[:, -1, :]  # (b_s, v_s)\n            preds = next_token_logits.argmax(dim=-1).long()  # (b_s)\n\n            context = preds.unsqueeze(1)\n            generated_response.append(preds)\n\n        generated_response = torch.stack(generated_response).T\n\n        return generated_response\n\n    def calculate_loss(self, logit, labels):\n        \"\"\"\n\n        Args:\n            preds: torch.FloatTensor, shape=(b_s, response_truncate, vocab_size)\n            labels: torch.LongTensor, shape=(b_s, response_truncate)\n\n        \"\"\"\n\n        loss = self.loss(logit, labels)\n        return loss\n"
  },
  {
    "path": "crslab/model/crs/inspired/inspired_rec.py",
    "content": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nBERT\n====\nReferences:\n    Devlin, Jacob, et al. `\"BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.\"`_ in NAACL 2019.\n\n.. _`\"BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.\"`:\n   https://www.aclweb.org/anthology/N19-1423/\n\n\"\"\"\n\nimport os\n\nfrom loguru import logger\nfrom torch import nn\nfrom transformers import BertModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\n\n\nclass InspiredRecModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        item_size: A integer indicating the number of items.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.item_size = vocab['n_entity']\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['bert'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"bert\", language)\n        super(InspiredRecModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self):\n        # build BERT layer, give the architecture, load pretrained parameters\n        self.bert = BertModel.from_pretrained(self.dpath)\n        # print(self.item_size)\n        self.bert_hidden_size = self.bert.config.hidden_size\n        self.mlp = nn.Linear(self.bert_hidden_size, self.item_size)\n\n        # this loss may conduct to some weakness\n        self.rec_loss = nn.CrossEntropyLoss()\n\n        logger.debug('[Finish build rec layer]')\n\n    def recommend(self, batch, mode='train'):\n        context, mask, y = batch\n\n        bert_embed = self.bert(context, attention_mask=mask).pooler_output\n\n        rec_scores = self.mlp(bert_embed)  # bs, item_size\n\n        rec_loss = self.rec_loss(rec_scores, y)\n\n        return rec_loss, rec_scores\n"
  },
  {
    "path": "crslab/model/crs/inspired/modules.py",
    "content": "# @Time   : 2021/3/10\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SequenceCrossEntropyLoss(nn.Module):\n    \"\"\"\n\n    Attributes:\n        ignore_index: indices corresponding tokens which should be ignored in calculating loss.\n        label_smoothing: determine smoothing value in cross entropy loss. should be less than 1.0.\n\n    \"\"\"\n\n    def __init__(self, ignore_index=None, label_smoothing=-1):\n        super().__init__()\n        self.ignore_index = ignore_index\n        self.label_smoothing = label_smoothing\n\n    def forward(self, logits, labels):\n        \"\"\"\n\n        Args:\n            logits: (batch_size, max_seq_len, vocal_size)\n            labels: (batch_size, max_seq_len)\n\n        \"\"\"\n        if self.label_smoothing > 1.0:\n            raise ValueError('The param label_smoothing should be in the range of 0.0 to 1.0.')\n        if self.ignore_index == None:\n            mask = torch.ones_like(labels, dtype=torch.float)\n        else:\n            mask = (labels != self.ignore_index).float()\n        logits_flat = logits.reshape(-1, logits.size(-1))  # (b_s * s_l, num_classes)\n        log_probs_flat = F.log_softmax(logits_flat, dim=-1)\n        labels_flat = labels.reshape(-1, 1).long()  # (b_s * s_l, 1)\n\n        if self.label_smoothing > 0.0:\n            num_classes = logits.size(-1)\n            smoothing_value = self.label_smoothing / float(num_classes)\n            one_hot_labels = torch.zeros_like(log_probs_flat).scatter_(-1, labels_flat,\n                                                                       1.0 - self.label_smoothing)  # fill all the correct indices with 1 - smoothing value.\n            smoothed_labels = one_hot_labels + smoothing_value\n            negative_log_likelihood_flat = -log_probs_flat * smoothed_labels\n            negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)\n        else:\n            negative_log_likelihood_flat = -torch.gather(log_probs_flat, dim=1, index=labels_flat)  # (b_s * s_l, 1)\n\n        negative_log_likelihood = negative_log_likelihood_flat.view(-1, logits.shape[1])  # (b_s, s_l)\n        loss = negative_log_likelihood * mask\n\n        loss = loss.sum(1) / (mask.sum(1) + 1e-13)\n        loss = loss.mean()\n\n        return loss\n"
  },
  {
    "path": "crslab/model/crs/kbrd/__init__.py",
    "content": "from .kbrd import KBRDModel\n"
  },
  {
    "path": "crslab/model/crs/kbrd/kbrd.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/4\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time   : 2020/1/3, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nKBRD\n====\nReferences:\n    Chen, Qibin, et al. `\"Towards Knowledge-Based Recommender Dialog System.\"`_ in EMNLP 2019.\n\n.. _`\"Towards Knowledge-Based Recommender Dialog System.\"`:\n   https://www.aclweb.org/anthology/D19-1189/\n\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger\nfrom torch import nn\nfrom torch_geometric.nn import RGCNConv\n\nfrom crslab.model.base import BaseModel\nfrom crslab.model.utils.functions import edge_to_pyg_format\nfrom crslab.model.utils.modules.attention import SelfAttentionBatch\nfrom crslab.model.utils.modules.transformer import TransformerDecoder, TransformerEncoder\n\n\nclass KBRDModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        vocab_size: A integer indicating the vocabulary size.\n        pad_token_idx: A integer indicating the id of padding token.\n        start_token_idx: A integer indicating the id of start token.\n        end_token_idx: A integer indicating the id of end token.\n        token_emb_dim: A integer indicating the dimension of token embedding layer.\n        pretrain_embedding: A string indicating the path of pretrained embedding.\n        n_entity: A integer indicating the number of entities.\n        n_relation: A integer indicating the number of relation in KG.\n        num_bases: A integer indicating the number of bases.\n        kg_emb_dim: A integer indicating the dimension of kg embedding.\n        user_emb_dim: A integer indicating the dimension of user embedding.\n        n_heads: A integer indicating the number of heads.\n        n_layers: A integer indicating the number of layer.\n        ffn_size: A integer indicating the size of ffn hidden.\n        dropout: A float indicating the dropout rate.\n        attention_dropout: A integer indicating the dropout rate of attention layer.\n        relu_dropout: A integer indicating the dropout rate of relu layer.\n        learn_positional_embeddings: A boolean indicating if we learn the positional embedding.\n        embeddings_scale: A boolean indicating if we use the embeddings scale.\n        reduction: A boolean indicating if we use the reduction.\n        n_positions: A integer indicating the number of position.\n        longest_label: A integer indicating the longest length for response generation.\n        user_proj_dim: A integer indicating dim to project for user embedding.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.device = device\n        self.gpu = opt.get(\"gpu\", [-1])\n        # vocab\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.vocab_size = vocab['vocab_size']\n        self.token_emb_dim = opt.get('token_emb_dim', 300)\n        self.pretrain_embedding = side_data.get('embedding', None)\n        # kg\n        self.n_entity = vocab['n_entity']\n        entity_kg = side_data['entity_kg']\n        self.n_relation = entity_kg['n_relation']\n        self.edge_idx, self.edge_type = edge_to_pyg_format(entity_kg['edge'], 'RGCN')\n        self.edge_idx = self.edge_idx.to(device)\n        self.edge_type = self.edge_type.to(device)\n        self.num_bases = opt.get('num_bases', 8)\n        self.kg_emb_dim = opt.get('kg_emb_dim', 300)\n        self.user_emb_dim = self.kg_emb_dim\n        # transformer\n        self.n_heads = opt.get('n_heads', 2)\n        self.n_layers = opt.get('n_layers', 2)\n        self.ffn_size = opt.get('ffn_size', 300)\n        self.dropout = opt.get('dropout', 0.1)\n        self.attention_dropout = opt.get('attention_dropout', 0.0)\n        self.relu_dropout = opt.get('relu_dropout', 0.1)\n        self.embeddings_scale = opt.get('embedding_scale', True)\n        self.learn_positional_embeddings = opt.get('learn_positional_embeddings', False)\n        self.reduction = opt.get('reduction', False)\n        self.n_positions = opt.get('n_positions', 1024)\n        self.longest_label = opt.get('longest_label', 1)\n        self.user_proj_dim = opt.get('user_proj_dim', 512)\n\n        super(KBRDModel, self).__init__(opt, device)\n\n    def build_model(self, *args, **kwargs):\n        self._build_embedding()\n        self._build_kg_layer()\n        self._build_recommendation_layer()\n        self._build_conversation_layer()\n\n    def _build_embedding(self):\n        if self.pretrain_embedding is not None:\n            self.token_embedding = nn.Embedding.from_pretrained(\n                torch.as_tensor(self.pretrain_embedding, dtype=torch.float), freeze=False,\n                padding_idx=self.pad_token_idx)\n        else:\n            self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx)\n            nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)\n            nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0)\n        logger.debug('[Build embedding]')\n\n    def _build_kg_layer(self):\n        self.kg_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases)\n        self.kg_attn = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim)\n        logger.debug('[Build kg layer]')\n\n    def _build_recommendation_layer(self):\n        self.rec_bias = nn.Linear(self.kg_emb_dim, self.n_entity)\n        self.rec_loss = nn.CrossEntropyLoss()\n        logger.debug('[Build recommendation layer]')\n\n    def _build_conversation_layer(self):\n        self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long))\n        self.dialog_encoder = TransformerEncoder(\n            self.n_heads,\n            self.n_layers,\n            self.token_emb_dim,\n            self.ffn_size,\n            self.vocab_size,\n            self.token_embedding,\n            self.dropout,\n            self.attention_dropout,\n            self.relu_dropout,\n            self.pad_token_idx,\n            self.learn_positional_embeddings,\n            self.embeddings_scale,\n            self.reduction,\n            self.n_positions\n        )\n        self.decoder = TransformerDecoder(\n            self.n_heads,\n            self.n_layers,\n            self.token_emb_dim,\n            self.ffn_size,\n            self.vocab_size,\n            self.token_embedding,\n            self.dropout,\n            self.attention_dropout,\n            self.relu_dropout,\n            self.embeddings_scale,\n            self.learn_positional_embeddings,\n            self.pad_token_idx,\n            self.n_positions\n        )\n        self.user_proj_1 = nn.Linear(self.user_emb_dim, self.user_proj_dim)\n        self.user_proj_2 = nn.Linear(self.user_proj_dim, self.vocab_size)\n        self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)\n        logger.debug('[Build conversation layer]')\n\n    def encode_user(self, entity_lists, kg_embedding):\n        user_repr_list = []\n        for entity_list in entity_lists:\n            if entity_list is None:\n                user_repr_list.append(torch.zeros(self.user_emb_dim, device=self.device))\n                continue\n            user_repr = kg_embedding[entity_list]\n            user_repr = self.kg_attn(user_repr)\n            user_repr_list.append(user_repr)\n        return torch.stack(user_repr_list, dim=0)  # (bs, dim)\n\n    def recommend(self, batch, mode):\n        context_entities, item = batch['context_entities'], batch['item']\n        kg_embedding = self.kg_encoder(None, self.edge_idx, self.edge_type)\n        user_embedding = self.encode_user(context_entities, kg_embedding)\n        scores = F.linear(user_embedding, kg_embedding, self.rec_bias.bias)\n        loss = self.rec_loss(scores, item)\n        return loss, scores\n\n    def _starts(self, batch_size):\n        \"\"\"Return bsz start tokens.\"\"\"\n        return self.START.detach().expand(batch_size, 1)\n\n    def decode_forced(self, encoder_states, user_embedding, resp):\n        bsz = resp.size(0)\n        seqlen = resp.size(1)\n        inputs = resp.narrow(1, 0, seqlen - 1)\n        inputs = torch.cat([self._starts(bsz), inputs], 1)\n        latent, _ = self.decoder(inputs, encoder_states)\n        token_logits = F.linear(latent, self.token_embedding.weight)\n        user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)\n        sum_logits = token_logits + user_logits\n        _, preds = sum_logits.max(dim=-1)\n        return sum_logits, preds\n\n    def decode_greedy(self, encoder_states, user_embedding):\n\n        bsz = encoder_states[0].shape[0]\n        xs = self._starts(bsz)\n        incr_state = None\n        logits = []\n        for i in range(self.longest_label):\n            scores, incr_state = self.decoder(xs, encoder_states, incr_state)  # incr_state is always None\n            scores = scores[:, -1:, :]\n            token_logits = F.linear(scores, self.token_embedding.weight)\n            user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)\n            sum_logits = token_logits + user_logits\n            probs, preds = sum_logits.max(dim=-1)\n            logits.append(scores)\n            xs = torch.cat([xs, preds], dim=1)\n            # check if everyone has generated an end token\n            all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz\n            if all_finished:\n                break\n        logits = torch.cat(logits, 1)\n        return logits, xs\n\n    def decode_beam_search(self, encoder_states, user_embedding, beam=4):\n        bsz = encoder_states[0].shape[0]\n        xs = self._starts(bsz).reshape(1, bsz, -1)  # (batch_size, _)\n        sequences = [[[list(), list(), 1.0]]] * bsz\n        for i in range(self.longest_label):\n            # at beginning there is 1 candidate, when i!=0 there are 4 candidates\n            if i != 0:\n                xs = []\n                for d in range(len(sequences[0])):\n                    for j in range(bsz):\n                        text = sequences[j][d][0]\n                        xs.append(text)\n                xs = torch.stack(xs).reshape(beam, bsz, -1)  # (beam, batch_size, _)\n\n            with torch.no_grad():\n                if i == 1:\n                    user_embedding = user_embedding.repeat(beam, 1)\n                    encoder_states = (encoder_states[0].repeat(beam, 1, 1),\n                                      encoder_states[1].repeat(beam, 1, 1))\n\n                scores, _ = self.decoder(xs.reshape(len(sequences[0]) * bsz, -1), encoder_states)\n                scores = scores[:, -1:, :]\n                token_logits = F.linear(scores, self.token_embedding.weight)\n                user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)\n                sum_logits = token_logits + user_logits\n\n            logits = sum_logits.reshape(len(sequences[0]), bsz, 1, -1)\n            scores = scores.reshape(len(sequences[0]), bsz, 1, -1)\n            logits = torch.nn.functional.softmax(logits)  # turn into probabilities,in case of negative numbers\n            probs, preds = logits.topk(beam, dim=-1)\n            # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam\n\n            for j in range(bsz):\n                all_candidates = []\n                for n in range(len(sequences[j])):\n                    for k in range(beam):\n                        prob = sequences[j][n][2]\n                        score = sequences[j][n][1]\n                        if score == []:\n                            score_tmp = scores[n][j][0].unsqueeze(0)\n                        else:\n                            score_tmp = torch.cat((score, scores[n][j][0].unsqueeze(0)), dim=0)\n                        seq_tmp = torch.cat((xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1)))\n                        candidate = [seq_tmp, score_tmp, prob * probs[n][j][0][k]]\n                        all_candidates.append(candidate)\n                ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True)\n                sequences[j] = ordered[:beam]\n\n            # check if everyone has generated an end token\n            all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz\n            if all_finished:\n                break\n        logits = torch.stack([seq[0][1] for seq in sequences])\n        xs = torch.stack([seq[0][0] for seq in sequences])\n        return logits, xs\n\n    def converse(self, batch, mode):\n        context_tokens, context_entities, response = batch['context_tokens'], batch['context_entities'], batch[\n            'response']\n        kg_embedding = self.kg_encoder(None, self.edge_idx, self.edge_type)\n        user_embedding = self.encode_user(context_entities, kg_embedding)\n        encoder_state = self.dialog_encoder(context_tokens)\n        if mode != 'test':\n            self.longest_label = max(self.longest_label, response.shape[1])\n            logits, preds = self.decode_forced(encoder_state, user_embedding, response)\n            logits = logits.view(-1, logits.shape[-1])\n            labels = response.view(-1)\n            return self.conv_loss(logits, labels), preds\n        else:\n            _, preds = self.decode_greedy(encoder_state, user_embedding)\n            return preds\n\n    def forward(self, batch, mode, stage):\n        if len(self.gpu) >= 2:\n            self.edge_idx = self.edge_idx.cuda(torch.cuda.current_device())\n            self.edge_type = self.edge_type.cuda(torch.cuda.current_device())\n        if stage == \"conv\":\n            return self.converse(batch, mode)\n        if stage == \"rec\":\n            return self.recommend(batch, mode)\n\n    def freeze_parameters(self):\n        freeze_models = [self.kg_encoder, self.kg_attn, self.rec_bias]\n        for model in freeze_models:\n            for p in model.parameters():\n                p.requires_grad = False"
  },
  {
    "path": "crslab/model/crs/kgsf/__init__.py",
    "content": "from .kgsf import KGSFModel\n"
  },
  {
    "path": "crslab/model/crs/kgsf/kgsf.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020/12/29, 2021/1/4\n# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nKGSF\n====\nReferences:\n    Zhou, Kun, et al. `\"Improving Conversational Recommender Systems via Knowledge Graph based Semantic Fusion.\"`_ in KDD 2020.\n\n.. _`\"Improving Conversational Recommender Systems via Knowledge Graph based Semantic Fusion.\"`:\n   https://dl.acm.org/doi/abs/10.1145/3394486.3403143\n\n\"\"\"\n\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger\nfrom torch import nn\nfrom torch_geometric.nn import GCNConv, RGCNConv\n\nfrom crslab.config import MODEL_PATH\nfrom crslab.model.base import BaseModel\nfrom crslab.model.utils.functions import edge_to_pyg_format\nfrom crslab.model.utils.modules.attention import SelfAttentionSeq\nfrom crslab.model.utils.modules.transformer import TransformerEncoder\nfrom .modules import GateLayer, TransformerDecoderKG\nfrom .resources import resources\n\n\nclass KGSFModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        vocab_size: A integer indicating the vocabulary size.\n        pad_token_idx: A integer indicating the id of padding token.\n        start_token_idx: A integer indicating the id of start token.\n        end_token_idx: A integer indicating the id of end token.\n        token_emb_dim: A integer indicating the dimension of token embedding layer.\n        pretrain_embedding: A string indicating the path of pretrained embedding.\n        n_word: A integer indicating the number of words.\n        n_entity: A integer indicating the number of entities.\n        pad_word_idx: A integer indicating the id of word padding.\n        pad_entity_idx: A integer indicating the id of entity padding.\n        num_bases: A integer indicating the number of bases.\n        kg_emb_dim: A integer indicating the dimension of kg embedding.\n        n_heads: A integer indicating the number of heads.\n        n_layers: A integer indicating the number of layer.\n        ffn_size: A integer indicating the size of ffn hidden.\n        dropout: A float indicating the dropout rate.\n        attention_dropout: A integer indicating the dropout rate of attention layer.\n        relu_dropout: A integer indicating the dropout rate of relu layer.\n        learn_positional_embeddings: A boolean indicating if we learn the positional embedding.\n        embeddings_scale: A boolean indicating if we use the embeddings scale.\n        reduction: A boolean indicating if we use the reduction.\n        n_positions: A integer indicating the number of position.\n        response_truncate = A integer indicating the longest length for response generation.\n        pretrained_embedding: A string indicating the path of pretrained embedding.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.device = device\n        self.gpu = opt.get(\"gpu\", [-1])\n        # vocab\n        self.vocab_size = vocab['vocab_size']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.token_emb_dim = opt['token_emb_dim']\n        self.pretrained_embedding = side_data.get('embedding', None)\n        # kg\n        self.n_word = vocab['n_word']\n        self.n_entity = vocab['n_entity']\n        self.pad_word_idx = vocab['pad_word']\n        self.pad_entity_idx = vocab['pad_entity']\n        entity_kg = side_data['entity_kg']\n        self.n_relation = entity_kg['n_relation']\n        entity_edges = entity_kg['edge']\n        self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN')\n        self.entity_edge_idx = self.entity_edge_idx.to(device)\n        self.entity_edge_type = self.entity_edge_type.to(device)\n        word_edges = side_data['word_kg']['edge']\n\n        self.word_edges = edge_to_pyg_format(word_edges, 'GCN').to(device)\n\n        self.num_bases = opt['num_bases']\n        self.kg_emb_dim = opt['kg_emb_dim']\n        # transformer\n        self.n_heads = opt['n_heads']\n        self.n_layers = opt['n_layers']\n        self.ffn_size = opt['ffn_size']\n        self.dropout = opt['dropout']\n        self.attention_dropout = opt['attention_dropout']\n        self.relu_dropout = opt['relu_dropout']\n        self.learn_positional_embeddings = opt['learn_positional_embeddings']\n        self.embeddings_scale = opt['embeddings_scale']\n        self.reduction = opt['reduction']\n        self.n_positions = opt['n_positions']\n        self.response_truncate = opt.get('response_truncate', 20)\n        # copy mask\n        dataset = opt['dataset']\n        dpath = os.path.join(MODEL_PATH, \"kgsf\", dataset)\n        resource = resources[dataset]\n        super(KGSFModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self):\n        self._init_embeddings()\n        self._build_kg_layer()\n        self._build_infomax_layer()\n        self._build_recommendation_layer()\n        self._build_conversation_layer()\n\n    def _init_embeddings(self):\n        if self.pretrained_embedding is not None:\n            self.token_embedding = nn.Embedding.from_pretrained(\n                torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False,\n                padding_idx=self.pad_token_idx)\n        else:\n            self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx)\n            nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)\n            nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0)\n\n        self.word_kg_embedding = nn.Embedding(self.n_word, self.kg_emb_dim, self.pad_word_idx)\n        nn.init.normal_(self.word_kg_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)\n        nn.init.constant_(self.word_kg_embedding.weight[self.pad_word_idx], 0)\n\n        logger.debug('[Finish init embeddings]')\n\n    def _build_kg_layer(self):\n        # db encoder\n        self.entity_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases)\n        self.entity_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim)\n\n        # concept encoder\n        self.word_encoder = GCNConv(self.kg_emb_dim, self.kg_emb_dim)\n        self.word_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim)\n\n        # gate mechanism\n        self.gate_layer = GateLayer(self.kg_emb_dim)\n\n        logger.debug('[Finish build kg layer]')\n\n    def _build_infomax_layer(self):\n        self.infomax_norm = nn.Linear(self.kg_emb_dim, self.kg_emb_dim)\n        self.infomax_bias = nn.Linear(self.kg_emb_dim, self.n_entity)\n        self.infomax_loss = nn.MSELoss(reduction='sum')\n\n        logger.debug('[Finish build infomax layer]')\n\n    def _build_recommendation_layer(self):\n        self.rec_bias = nn.Linear(self.kg_emb_dim, self.n_entity)\n        self.rec_loss = nn.CrossEntropyLoss()\n\n        logger.debug('[Finish build rec layer]')\n\n    def _build_conversation_layer(self):\n        self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long))\n        self.conv_encoder = TransformerEncoder(\n            n_heads=self.n_heads,\n            n_layers=self.n_layers,\n            embedding_size=self.token_emb_dim,\n            ffn_size=self.ffn_size,\n            vocabulary_size=self.vocab_size,\n            embedding=self.token_embedding,\n            dropout=self.dropout,\n            attention_dropout=self.attention_dropout,\n            relu_dropout=self.relu_dropout,\n            padding_idx=self.pad_token_idx,\n            learn_positional_embeddings=self.learn_positional_embeddings,\n            embeddings_scale=self.embeddings_scale,\n            reduction=self.reduction,\n            n_positions=self.n_positions,\n        )\n\n        self.conv_entity_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n        self.conv_entity_attn_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n        self.conv_word_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n        self.conv_word_attn_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n\n        self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim)\n        self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size)\n        self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, \"copy_mask.npy\")).astype(bool),\n                                         ).to(self.device)\n\n        self.conv_decoder = TransformerDecoderKG(\n            self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size,\n            embedding=self.token_embedding,\n            dropout=self.dropout,\n            attention_dropout=self.attention_dropout,\n            relu_dropout=self.relu_dropout,\n            embeddings_scale=self.embeddings_scale,\n            learn_positional_embeddings=self.learn_positional_embeddings,\n            padding_idx=self.pad_token_idx,\n            n_positions=self.n_positions\n        )\n        self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)\n\n        logger.debug('[Finish build conv layer]')\n\n    def pretrain_infomax(self, batch):\n        \"\"\"\n        words: (batch_size, word_length)\n        entity_labels: (batch_size, n_entity)\n        \"\"\"\n        words, entity_labels = batch\n\n        loss_mask = torch.sum(entity_labels)\n        if loss_mask.item() == 0:\n            return None\n\n        entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type)\n        word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges)\n\n        word_representations = word_graph_representations[words]\n        word_padding_mask = words.eq(self.pad_word_idx)  # (bs, seq_len)\n\n        word_attn_rep = self.word_self_attn(word_representations, word_padding_mask)\n        word_info_rep = self.infomax_norm(word_attn_rep)  # (bs, dim)\n        info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias)  # (bs, #entity)\n        loss = self.infomax_loss(info_predict, entity_labels) / loss_mask\n        return loss\n\n    def recommend(self, batch, mode):\n        \"\"\"\n        context_entities: (batch_size, entity_length)\n        context_words: (batch_size, word_length)\n        movie: (batch_size)\n        \"\"\"\n        context_entities, context_words, entities, movie = batch\n\n        entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type)\n        word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges)\n\n        entity_padding_mask = context_entities.eq(self.pad_entity_idx)  # (bs, entity_len)\n        word_padding_mask = context_words.eq(self.pad_word_idx)  # (bs, word_len)\n\n        entity_representations = entity_graph_representations[context_entities]\n        word_representations = word_graph_representations[context_words]\n\n        entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask)\n        word_attn_rep = self.word_self_attn(word_representations, word_padding_mask)\n\n        user_rep = self.gate_layer(entity_attn_rep, word_attn_rep)\n        rec_scores = F.linear(user_rep, entity_graph_representations, self.rec_bias.bias)  # (bs, #entity)\n\n        rec_loss = self.rec_loss(rec_scores, movie)\n\n        info_loss_mask = torch.sum(entities)\n        if info_loss_mask.item() == 0:\n            info_loss = None\n        else:\n            word_info_rep = self.infomax_norm(word_attn_rep)  # (bs, dim)\n            info_predict = F.linear(word_info_rep, entity_graph_representations,\n                                    self.infomax_bias.bias)  # (bs, #entity)\n            info_loss = self.infomax_loss(info_predict, entities) / info_loss_mask\n\n        return rec_loss, info_loss, rec_scores\n\n    def freeze_parameters(self):\n        freeze_models = [self.word_kg_embedding, self.entity_encoder, self.entity_self_attn, self.word_encoder,\n                         self.word_self_attn, self.gate_layer, self.infomax_bias, self.infomax_norm, self.rec_bias]\n        for model in freeze_models:\n            for p in model.parameters():\n                p.requires_grad = False\n\n    def _starts(self, batch_size):\n        \"\"\"Return bsz start tokens.\"\"\"\n        return self.START.detach().expand(batch_size, 1)\n\n    def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask,\n                               word_reps, word_emb_attn, word_mask, response):\n        batch_size, seq_len = response.shape\n        start = self._starts(batch_size)\n        inputs = torch.cat((start, response[:, :-1]), dim=-1).long()\n\n        dialog_latent, _ = self.conv_decoder(inputs, token_encoding, word_reps, word_mask,\n                                             entity_reps, entity_mask)  # (bs, seq_len, dim)\n        entity_latent = entity_emb_attn.unsqueeze(1).expand(-1, seq_len, -1)\n        word_latent = word_emb_attn.unsqueeze(1).expand(-1, seq_len, -1)\n        copy_latent = self.copy_norm(\n            torch.cat((entity_latent, word_latent, dialog_latent), dim=-1))  # (bs, seq_len, dim)\n\n        copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(\n            0)  # (bs, seq_len, vocab_size)\n        gen_logits = F.linear(dialog_latent, self.token_embedding.weight)  # (bs, seq_len, vocab_size)\n        sum_logits = copy_logits + gen_logits\n        preds = sum_logits.argmax(dim=-1)\n        return sum_logits, preds\n\n    def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask,\n                               word_reps, word_emb_attn, word_mask):\n        batch_size = token_encoding[0].shape[0]\n        inputs = self._starts(batch_size).long()\n        incr_state = None\n        logits = []\n        for _ in range(self.response_truncate):\n            dialog_latent, incr_state = self.conv_decoder(inputs, token_encoding, word_reps, word_mask,\n                                                          entity_reps, entity_mask, incr_state)\n            dialog_latent = dialog_latent[:, -1:, :]  # (bs, 1, dim)\n            db_latent = entity_emb_attn.unsqueeze(1)\n            concept_latent = word_emb_attn.unsqueeze(1)\n            copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1))\n\n            copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0)\n            gen_logits = F.linear(dialog_latent, self.token_embedding.weight)\n            sum_logits = copy_logits + gen_logits\n            preds = sum_logits.argmax(dim=-1).long()\n            logits.append(sum_logits)\n            inputs = torch.cat((inputs, preds), dim=1)\n\n            finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size\n            if finished:\n                break\n        logits = torch.cat(logits, dim=1)\n        return logits, inputs\n\n    def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask,\n                                    word_reps, word_emb_attn, word_mask, beam=4):\n        batch_size = token_encoding[0].shape[0]\n        inputs = self._starts(batch_size).long().reshape(1, batch_size, -1)\n        incr_state = None\n\n        sequences = [[[list(), list(), 1.0]]] * batch_size\n        for i in range(self.response_truncate):\n            if i == 1:\n                token_encoding = (token_encoding[0].repeat(beam, 1, 1),\n                                  token_encoding[1].repeat(beam, 1, 1))\n                entity_reps = entity_reps.repeat(beam, 1, 1)\n                entity_emb_attn = entity_emb_attn.repeat(beam, 1)\n                entity_mask = entity_mask.repeat(beam, 1)\n                word_reps = word_reps.repeat(beam, 1, 1)\n                word_emb_attn = word_emb_attn.repeat(beam, 1)\n                word_mask = word_mask.repeat(beam, 1)\n\n            # at beginning there is 1 candidate, when i!=0 there are 4 candidates\n            if i != 0:\n                inputs = []\n                for d in range(len(sequences[0])):\n                    for j in range(batch_size):\n                        text = sequences[j][d][0]\n                        inputs.append(text)\n                inputs = torch.stack(inputs).reshape(beam, batch_size, -1)  # (beam, batch_size, _)\n\n            with torch.no_grad():\n                dialog_latent, incr_state = self.conv_decoder(\n                    inputs.reshape(len(sequences[0]) * batch_size, -1),\n                    token_encoding, word_reps, word_mask,\n                    entity_reps, entity_mask, incr_state\n                )\n                dialog_latent = dialog_latent[:, -1:, :]  # (bs, 1, dim)\n                db_latent = entity_emb_attn.unsqueeze(1)\n                concept_latent = word_emb_attn.unsqueeze(1)\n                copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1))\n\n                copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0)\n                gen_logits = F.linear(dialog_latent, self.token_embedding.weight)\n                sum_logits = copy_logits + gen_logits\n\n            logits = sum_logits.reshape(len(sequences[0]), batch_size, 1, -1)\n            # turn into probabilities,in case of negative numbers\n            probs, preds = torch.nn.functional.softmax(logits).topk(beam, dim=-1)\n\n            # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam\n\n            for j in range(batch_size):\n                all_candidates = []\n                for n in range(len(sequences[j])):\n                    for k in range(beam):\n                        prob = sequences[j][n][2]\n                        logit = sequences[j][n][1]\n                        if logit == []:\n                            logit_tmp = logits[n][j][0].unsqueeze(0)\n                        else:\n                            logit_tmp = torch.cat((logit, logits[n][j][0].unsqueeze(0)), dim=0)\n                        seq_tmp = torch.cat((inputs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1)))\n                        candidate = [seq_tmp, logit_tmp, prob * probs[n][j][0][k]]\n                        all_candidates.append(candidate)\n                ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True)\n                sequences[j] = ordered[:beam]\n\n            # check if everyone has generated an end token\n            all_finished = ((inputs == self.end_token_idx).sum(dim=1) > 0).sum().item() == batch_size\n            if all_finished:\n                break\n        logits = torch.stack([seq[0][1] for seq in sequences])\n        inputs = torch.stack([seq[0][0] for seq in sequences])\n        return logits, inputs\n\n    def converse(self, batch, mode):\n        context_tokens, context_entities, context_words, response = batch\n\n        entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type)\n        word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges)\n\n        entity_padding_mask = context_entities.eq(self.pad_entity_idx)  # (bs, entity_len)\n        word_padding_mask = context_words.eq(self.pad_word_idx)  # (bs, seq_len)\n\n        entity_representations = entity_graph_representations[context_entities]\n        word_representations = word_graph_representations[context_words]\n\n        entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask)\n        word_attn_rep = self.word_self_attn(word_representations, word_padding_mask)\n\n        # encoder-decoder\n        tokens_encoding = self.conv_encoder(context_tokens)\n        conv_entity_emb = self.conv_entity_attn_norm(entity_attn_rep)\n        conv_word_emb = self.conv_word_attn_norm(word_attn_rep)\n        conv_entity_reps = self.conv_entity_norm(entity_representations)\n        conv_word_reps = self.conv_word_norm(word_representations)\n        if mode != 'test':\n            logits, preds = self._decode_forced_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb,\n                                                        entity_padding_mask,\n                                                        conv_word_reps, conv_word_emb, word_padding_mask,\n                                                        response)\n\n            logits = logits.view(-1, logits.shape[-1])\n            response = response.view(-1)\n            loss = self.conv_loss(logits, response)\n            return loss, preds\n        else:\n            logits, preds = self._decode_greedy_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb,\n                                                        entity_padding_mask,\n                                                        conv_word_reps, conv_word_emb, word_padding_mask)\n            return preds\n\n    def forward(self, batch, stage, mode):\n        if len(self.gpu) >= 2:\n            #  forward function operates on different gpus, the weight of graph network need to be copied to other gpu\n            self.entity_edge_idx = self.entity_edge_idx.cuda(torch.cuda.current_device())\n            self.entity_edge_type = self.entity_edge_type.cuda(torch.cuda.current_device())\n            self.word_edges = self.word_edges.cuda(torch.cuda.current_device())\n            self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, \"copy_mask.npy\")).astype(bool),\n                                             ).cuda(torch.cuda.current_device())\n        if stage == \"pretrain\":\n            loss = self.pretrain_infomax(batch)\n        elif stage == \"rec\":\n            loss = self.recommend(batch, mode)\n        elif stage == \"conv\":\n            loss = self.converse(batch, mode)\n        return loss\n"
  },
  {
    "path": "crslab/model/crs/kgsf/modules.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn as nn\n\nfrom crslab.model.utils.modules.transformer import MultiHeadAttention, TransformerFFN, _create_selfattn_mask, \\\n    _normalize, \\\n    create_position_codes\n\n\nclass GateLayer(nn.Module):\n    def __init__(self, input_dim):\n        super(GateLayer, self).__init__()\n        self._norm_layer1 = nn.Linear(input_dim * 2, input_dim)\n        self._norm_layer2 = nn.Linear(input_dim, 1)\n\n    def forward(self, input1, input2):\n        norm_input = self._norm_layer1(torch.cat([input1, input2], dim=-1))\n        gate = torch.sigmoid(self._norm_layer2(norm_input))  # (bs, 1)\n        gated_emb = gate * input1 + (1 - gate) * input2  # (bs, dim)\n        return gated_emb\n\n\nclass TransformerDecoderLayerKG(nn.Module):\n    def __init__(\n            self,\n            n_heads,\n            embedding_size,\n            ffn_size,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = embedding_size\n        self.ffn_dim = ffn_size\n        self.dropout = nn.Dropout(p=dropout)\n\n        self.self_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm1 = nn.LayerNorm(embedding_size)\n\n        self.encoder_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2 = nn.LayerNorm(embedding_size)\n\n        self.encoder_db_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2_db = nn.LayerNorm(embedding_size)\n\n        self.encoder_kg_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2_kg = nn.LayerNorm(embedding_size)\n\n        self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)\n        self.norm3 = nn.LayerNorm(embedding_size)\n\n    def forward(self, x, encoder_output, encoder_mask, kg_encoder_output, kg_encoder_mask, db_encoder_output,\n                db_encoder_mask):\n        decoder_mask = _create_selfattn_mask(x)\n        # first self attn\n        residual = x\n        # don't peak into the future!\n        x = self.self_attention(query=x, mask=decoder_mask)\n        x = self.dropout(x)  # --dropout\n        x = x + residual\n        x = _normalize(x, self.norm1)\n\n        residual = x\n        x = self.encoder_db_attention(\n            query=x,\n            key=db_encoder_output,\n            value=db_encoder_output,\n            mask=db_encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2_db)\n\n        residual = x\n        x = self.encoder_kg_attention(\n            query=x,\n            key=kg_encoder_output,\n            value=kg_encoder_output,\n            mask=kg_encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2_kg)\n\n        residual = x\n        x = self.encoder_attention(\n            query=x,\n            key=encoder_output,\n            value=encoder_output,\n            mask=encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2)\n\n        # finally the ffn\n        residual = x\n        x = self.ffn(x)\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm3)\n\n        return x\n\n\nclass TransformerDecoderKG(nn.Module):\n    \"\"\"\n    Transformer Decoder layer.\n\n    :param int n_heads: the number of multihead attention heads.\n    :param int n_layers: number of transformer layers.\n    :param int embedding_size: the embedding sizes. Must be a multiple of n_heads.\n    :param int ffn_size: the size of the hidden layer in the FFN\n    :param embedding: an embedding matrix for the bottom layer of the transformer.\n        If none, one is created for this encoder.\n    :param float dropout: Dropout used around embeddings and before layer\n        layer normalizations. This is used in Vaswani 2017 and works well on\n        large datasets.\n    :param float attention_dropout: Dropout performed after the multhead attention\n        softmax. This is not used in Vaswani 2017.\n    :param float relu_dropout: Dropout used after the ReLU in the FFN. Not used\n        in Vaswani 2017, but used in Tensor2Tensor.\n    :param int padding_idx: Reserved padding index in the embeddings matrix.\n    :param bool learn_positional_embeddings: If off, sinusoidal embeddings are\n        used. If on, position embeddings are learned from scratch.\n    :param bool embeddings_scale: Scale embeddings relative to their dimensionality.\n        Found useful in fairseq.\n    :param int n_positions: Size of the position embeddings matrix.\n    \"\"\"\n\n    def __init__(\n            self,\n            n_heads,\n            n_layers,\n            embedding_size,\n            ffn_size,\n            vocabulary_size,\n            embedding,\n            dropout=0.0,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            embeddings_scale=True,\n            learn_positional_embeddings=False,\n            padding_idx=None,\n            n_positions=1024,\n    ):\n        super().__init__()\n        self.embedding_size = embedding_size\n        self.ffn_size = ffn_size\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.dim = embedding_size\n        self.embeddings_scale = embeddings_scale\n        self.dropout = nn.Dropout(dropout)  # --dropout\n\n        self.out_dim = embedding_size\n        assert embedding_size % n_heads == 0, \\\n            'Transformer embedding size must be a multiple of n_heads'\n\n        self.embeddings = embedding\n\n        # create the positional embeddings\n        self.position_embeddings = nn.Embedding(n_positions, embedding_size)\n        if not learn_positional_embeddings:\n            create_position_codes(\n                n_positions, embedding_size, out=self.position_embeddings.weight\n            )\n        else:\n            nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)\n\n        # build the model\n        self.layers = nn.ModuleList()\n        for _ in range(self.n_layers):\n            self.layers.append(TransformerDecoderLayerKG(\n                n_heads, embedding_size, ffn_size,\n                attention_dropout=attention_dropout,\n                relu_dropout=relu_dropout,\n                dropout=dropout,\n            ))\n\n    def forward(self, input, encoder_state, kg_encoder_output, kg_encoder_mask,\n                db_encoder_output, db_encoder_mask, incr_state=None):\n        encoder_output, encoder_mask = encoder_state\n\n        seq_len = input.size(1)\n        positions = input.new(seq_len).long()  # (seq_len)\n        positions = torch.arange(seq_len, out=positions).unsqueeze(0)  # (1, seq_len)\n        tensor = self.embeddings(input)  # (bs, seq_len, embed_dim)\n        if self.embeddings_scale:\n            tensor = tensor * np.sqrt(self.dim)\n        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)\n        tensor = self.dropout(tensor)  # --dropout\n\n        for layer in self.layers:\n            tensor = layer(tensor, encoder_output, encoder_mask, kg_encoder_output, kg_encoder_mask, db_encoder_output,\n                           db_encoder_mask)\n\n        return tensor, None\n"
  },
  {
    "path": "crslab/model/crs/kgsf/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/13\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/15\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'ReDial': {\n        'version': '0.2',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1',\n            'kgsf_redial.zip',\n            'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548',\n        ),\n    },\n    'TGReDial': {\n        'version': '0.2',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1',\n            'kgsf_tgredial.zip',\n            'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1',\n        ),\n    },\n    'GoRecDial': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ER5u2yMmgDNFvHuW6lKZLEkBKZkOkxMtZGK0bBQ-jvfLNw?download=1',\n            'kgsf_gorecdial.zip',\n            'f2f57ebb8f688f38a98ee41fe3a87e9362aed945ec9078869407f799da322633',\n        )\n    },\n    'OpenDialKG': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1',\n            'kgsf_opendialkg.zip',\n            '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61'\n        )\n    },\n    'Inspired': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1',\n            'kgsf_inspired.zip',\n            '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d'\n        )\n    },\n    'DuRecDial': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1',\n            'kgsf_durecdial.zip',\n            'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef'\n        )\n    }\n}\n"
  },
  {
    "path": "crslab/model/crs/ntrd/__init__.py",
    "content": "from .ntrd import NTRDModel"
  },
  {
    "path": "crslab/model/crs/ntrd/modules.py",
    "content": "# @Time   : 2021/10/06\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nimport numpy as np\nimport torch\nfrom torch import nn as nn\n\nfrom crslab.model.utils.modules.transformer import MultiHeadAttention, TransformerFFN, _create_selfattn_mask, \\\n    _normalize, \\\n    create_position_codes\n\n\nclass GateLayer(nn.Module):\n    def __init__(self, input_dim):\n        super(GateLayer, self).__init__()\n        self._norm_layer1 = nn.Linear(input_dim * 2, input_dim)\n        self._norm_layer2 = nn.Linear(input_dim, 1)\n\n    def forward(self, input1, input2):\n        norm_input = self._norm_layer1(torch.cat([input1, input2], dim=-1))\n        gate = torch.sigmoid(self._norm_layer2(norm_input))  # (bs, 1)\n        gated_emb = gate * input1 + (1 - gate) * input2  # (bs, dim)\n        return gated_emb\n\n\nclass TransformerDecoderLayerKG(nn.Module):\n    def __init__(\n            self,\n            n_heads,\n            embedding_size,\n            ffn_size,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = embedding_size\n        self.ffn_dim = ffn_size\n        self.dropout = nn.Dropout(p=dropout)\n\n        self.self_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm1 = nn.LayerNorm(embedding_size)\n\n        self.encoder_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2 = nn.LayerNorm(embedding_size)\n\n        self.encoder_db_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2_db = nn.LayerNorm(embedding_size)\n\n        self.encoder_kg_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2_kg = nn.LayerNorm(embedding_size)\n\n        self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)\n        self.norm3 = nn.LayerNorm(embedding_size)\n\n    def forward(self, x, encoder_output, encoder_mask, kg_encoder_output, kg_encoder_mask, db_encoder_output,\n                db_encoder_mask):\n        decoder_mask = _create_selfattn_mask(x)\n        # first self attn\n        residual = x\n        # don't peak into the future!\n        x = self.self_attention(query=x, mask=decoder_mask)\n        x = self.dropout(x)  # --dropout\n        x = x + residual\n        x = _normalize(x, self.norm1)\n\n        residual = x\n        x = self.encoder_db_attention(\n            query=x,\n            key=db_encoder_output,\n            value=db_encoder_output,\n            mask=db_encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2_db)\n\n        residual = x\n        x = self.encoder_kg_attention(\n            query=x,\n            key=kg_encoder_output,\n            value=kg_encoder_output,\n            mask=kg_encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2_kg)\n\n        residual = x\n        x = self.encoder_attention(\n            query=x,\n            key=encoder_output,\n            value=encoder_output,\n            mask=encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2)\n\n        # finally the ffn\n        residual = x\n        x = self.ffn(x)\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm3)\n\n        return x\n\nclass TransformerDecoderLayerSelection(nn.Module):\n    def __init__(\n        self,\n        n_heads,\n        embedding_size,\n        ffn_size,\n        attention_dropout=0.0,\n        relu_dropout=0.0,\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = embedding_size\n        self.ffn_dim = ffn_size\n        self.dropout = nn.Dropout(p=dropout)\n\n\n        self.encoder_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm1 = nn.LayerNorm(embedding_size)\n\n        self.movie_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2 = nn.LayerNorm(embedding_size)\n\n        self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)\n        self.norm3 = nn.LayerNorm(embedding_size)\n\n    def forward(self, x, encoder_output, encoder_mask, movie_embed, movie_embed_mask):\n        residual = x\n        x = self.movie_attention(\n            query=x,\n            key=movie_embed,\n            value=movie_embed,\n            mask=movie_embed_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm1)\n\n        residual = x\n        x = self.encoder_attention(\n            query=x,\n            key=encoder_output,\n            value=encoder_output,\n            mask=encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2)\n        \n        # finally the ffn\n        residual = x\n        x = self.ffn(x)\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm3)\n\n        return x\n\nclass TransformerDecoderKG(nn.Module):\n    \"\"\"\n    Transformer Decoder layer.\n\n    :param int n_heads: the number of multihead attention heads.\n    :param int n_layers: number of transformer layers.\n    :param int embedding_size: the embedding sizes. Must be a multiple of n_heads.\n    :param int ffn_size: the size of the hidden layer in the FFN\n    :param embedding: an embedding matrix for the bottom layer of the transformer.\n        If none, one is created for this encoder.\n    :param float dropout: Dropout used around embeddings and before layer\n        layer normalizations. This is used in Vaswani 2017 and works well on\n        large datasets.\n    :param float attention_dropout: Dropout performed after the multhead attention\n        softmax. This is not used in Vaswani 2017.\n    :param float relu_dropout: Dropout used after the ReLU in the FFN. Not used\n        in Vaswani 2017, but used in Tensor2Tensor.\n    :param int padding_idx: Reserved padding index in the embeddings matrix.\n    :param bool learn_positional_embeddings: If off, sinusoidal embeddings are\n        used. If on, position embeddings are learned from scratch.\n    :param bool embeddings_scale: Scale embeddings relative to their dimensionality.\n        Found useful in fairseq.\n    :param int n_positions: Size of the position embeddings matrix.\n    \"\"\"\n\n    def __init__(\n            self,\n            n_heads,\n            n_layers,\n            embedding_size,\n            ffn_size,\n            vocabulary_size,\n            embedding,\n            dropout=0.0,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            embeddings_scale=True,\n            learn_positional_embeddings=False,\n            padding_idx=None,\n            n_positions=1024,\n    ):\n        super().__init__()\n        self.embedding_size = embedding_size\n        self.ffn_size = ffn_size\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.dim = embedding_size\n        self.embeddings_scale = embeddings_scale\n        self.dropout = nn.Dropout(dropout)  # --dropout\n\n        self.out_dim = embedding_size\n        assert embedding_size % n_heads == 0, \\\n            'Transformer embedding size must be a multiple of n_heads'\n\n        self.embeddings = embedding\n\n        # create the positional embeddings\n        self.position_embeddings = nn.Embedding(n_positions, embedding_size)\n        if not learn_positional_embeddings:\n            create_position_codes(\n                n_positions, embedding_size, out=self.position_embeddings.weight\n            )\n        else:\n            nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)\n\n        # build the model\n        self.layers = nn.ModuleList()\n        for _ in range(self.n_layers):\n            self.layers.append(TransformerDecoderLayerKG(\n                n_heads, embedding_size, ffn_size,\n                attention_dropout=attention_dropout,\n                relu_dropout=relu_dropout,\n                dropout=dropout,\n            ))\n\n    def forward(self, input, encoder_state, kg_encoder_output, kg_encoder_mask,\n                db_encoder_output, db_encoder_mask, incr_state=None):\n        encoder_output, encoder_mask = encoder_state\n\n        seq_len = input.size(1)\n        positions = input.new(seq_len).long()  # (seq_len)\n        positions = torch.arange(seq_len, out=positions).unsqueeze(0)  # (1, seq_len)\n        tensor = self.embeddings(input)  # (bs, seq_len, embed_dim)\n        if self.embeddings_scale:\n            tensor = tensor * np.sqrt(self.dim)\n        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)\n        tensor = self.dropout(tensor)  # --dropout\n\n        for layer in self.layers:\n            tensor = layer(tensor, encoder_output, encoder_mask, kg_encoder_output, kg_encoder_mask, db_encoder_output,\n                           db_encoder_mask)\n\n        return tensor, None\n\nclass TransformerDecoderSelection(nn.Module):\n    def __init__(\n            self,\n            n_heads,\n            n_layers,\n            embedding_size,\n            ffn_size,\n            vocabulary_size,\n            # embedding,\n            dropout=0.0,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            embeddings_scale=True,\n            learn_positional_embeddings=False,\n            padding_idx=None,\n            n_positions=1024,\n    ):\n        super().__init__()\n        self.embedding_size = embedding_size\n        self.ffn_size = ffn_size\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.dim = embedding_size\n        self.embeddings_scale = embeddings_scale\n        self.dropout = nn.Dropout(p=dropout)  # --dropout\n\n        self.out_dim = embedding_size\n        assert embedding_size % n_heads == 0, \\\n            'Transformer embedding size must be a multiple of n_heads'\n        \n        # build the model\n        self.layers = nn.ModuleList()\n        for _ in range(self.n_layers):\n            self.layers.append(TransformerDecoderLayerSelection(\n                n_heads, embedding_size, ffn_size,\n                attention_dropout=attention_dropout,\n                relu_dropout=relu_dropout,\n                dropout=dropout,\n            ))\n    \n    def forward(self, input, encoder_state,movie_embed,movie_embed_mask,incr_state=None):\n        encoder_output,encoder_mask = encoder_state\n\n\n        tensor = input # -- No dropout\n\n        for layer in self.layers:\n            tensor = layer(tensor,encoder_output,encoder_mask,movie_embed,movie_embed_mask)\n        \n        return tensor, None"
  },
  {
    "path": "crslab/model/crs/ntrd/ntrd.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2021/10/1\n# @Author  :   Zhipeng Zhao\n# @email   :   oran_official@outlook.com\n\n\nr\"\"\"\nNTRD\n====\nReferences:\n    Liang, Zujie, et al. `\"Learning Neural Templates for Recommender Dialogue System.\"`_ in EMNLP 2021.\n\n.. _`\"Learning Neural Templates for Recommender Dialogue System.\"`:\n   https://arxiv.org/pdf/2109.12302.pdf\n\n\"\"\"\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger\nfrom torch import nn\nfrom torch_geometric.nn import GCNConv, RGCNConv\n\nfrom crslab.config import MODEL_PATH\nfrom crslab.model.base import BaseModel\nfrom crslab.model.utils.functions import edge_to_pyg_format\nfrom crslab.model.utils.modules.attention import SelfAttentionSeq\nfrom crslab.model.utils.modules.transformer import TransformerEncoder\nfrom .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection\nfrom .resources import resources\n\nclass NTRDModel(BaseModel):\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.device = device\n        self.gpu = opt.get(\"gpu\", [-1])\n        # vocab\n        self.vocab_size = vocab['vocab_size']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.token_emb_dim = opt['token_emb_dim']\n        self.pretrained_embedding = side_data.get('embedding', None)\n        self.replace_token = opt.get('replace_token',None)\n        self.replace_token_idx = vocab[self.replace_token]\n        # kg\n        self.n_word = vocab['n_word']\n        self.n_entity = vocab['n_entity']\n        self.pad_word_idx = vocab['pad_word']\n        self.pad_entity_idx = vocab['pad_entity']\n        entity_kg = side_data['entity_kg']\n        self.n_relation = entity_kg['n_relation']\n        entity_edges = entity_kg['edge']\n        self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN')\n        self.entity_edge_idx = self.entity_edge_idx.to(device)\n        self.entity_edge_type = self.entity_edge_type.to(device)\n        word_edges = side_data['word_kg']['edge']\n\n        self.word_edges = edge_to_pyg_format(word_edges, 'GCN').to(device)\n\n        self.num_bases = opt['num_bases']\n        self.kg_emb_dim = opt['kg_emb_dim']\n        # transformer\n        self.n_heads = opt['n_heads']\n        self.n_layers = opt['n_layers']\n        self.ffn_size = opt['ffn_size']\n        self.dropout = opt['dropout']\n        self.attention_dropout = opt['attention_dropout']\n        self.relu_dropout = opt['relu_dropout']\n        self.learn_positional_embeddings = opt['learn_positional_embeddings']\n        self.embeddings_scale = opt['embeddings_scale']\n        self.reduction = opt['reduction']\n        self.n_positions = opt['n_positions']\n        self.response_truncate = opt.get('response_truncate', 20)\n        # selector \n        self.n_movies = opt['n_movies']\n        # self.n_movies_label = opt['n_movies_label']\n        self.n_movies_label = 64362 # the number of entity2id\n        # copy mask\n        dataset = opt['dataset']\n        dpath = os.path.join(MODEL_PATH, \"kgsf\", dataset)\n        resource = resources[dataset]\n        # loss weight\n        self.gen_loss_weight = opt['gen_loss_weight']\n        super(NTRDModel, self).__init__(opt, device, dpath, resource)\n    \n    def build_model(self):\n        self._init_embeddings()\n        self._build_kg_layer()\n        self._build_infomax_layer()\n        self._build_recommendation_layer()\n        self._build_conversation_layer()\n        self._build_movie_selector()\n    \n    def _init_embeddings(self):\n        if self.pretrained_embedding is not None:\n            self.token_embedding = nn.Embedding.from_pretrained(\n                torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False,\n                padding_idx=self.pad_token_idx)\n        else:\n            self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx)\n            nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)\n            nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0)\n\n        self.word_kg_embedding = nn.Embedding(self.n_word, self.kg_emb_dim, self.pad_word_idx)\n        nn.init.normal_(self.word_kg_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)\n        nn.init.constant_(self.word_kg_embedding.weight[self.pad_word_idx], 0)\n\n        logger.debug('[Finish init embeddings]')\n\n    def _build_kg_layer(self):\n        # db encoder\n        self.entity_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases)\n        self.entity_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim)\n\n        # concept encoder\n        self.word_encoder = GCNConv(self.kg_emb_dim, self.kg_emb_dim)\n        self.word_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim)\n\n        # gate mechanism\n        self.gate_layer = GateLayer(self.kg_emb_dim)\n\n        logger.debug('[Finish build kg layer]')\n\n    def _build_infomax_layer(self):\n        self.infomax_norm = nn.Linear(self.kg_emb_dim, self.kg_emb_dim)\n        self.infomax_bias = nn.Linear(self.kg_emb_dim, self.n_entity)\n        self.infomax_loss = nn.MSELoss(reduction='sum')\n\n        logger.debug('[Finish build infomax layer]')\n\n    def _build_recommendation_layer(self):\n        self.rec_bias = nn.Linear(self.kg_emb_dim, self.n_entity)\n        self.rec_loss = nn.CrossEntropyLoss()\n\n        logger.debug('[Finish build rec layer]')\n\n    def _build_conversation_layer(self):\n        self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long))\n        self.conv_encoder = TransformerEncoder(\n            n_heads=self.n_heads,\n            n_layers=self.n_layers,\n            embedding_size=self.token_emb_dim,\n            ffn_size=self.ffn_size,\n            vocabulary_size=self.vocab_size,\n            embedding=self.token_embedding,\n            dropout=self.dropout,\n            attention_dropout=self.attention_dropout,\n            relu_dropout=self.relu_dropout,\n            padding_idx=self.pad_token_idx,\n            learn_positional_embeddings=self.learn_positional_embeddings,\n            embeddings_scale=self.embeddings_scale,\n            reduction=self.reduction,\n            n_positions=self.n_positions,\n        )\n\n        self.conv_entity_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n        self.conv_entity_attn_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n        self.conv_word_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n        self.conv_word_attn_norm = nn.Linear(self.kg_emb_dim, self.ffn_size)\n\n        self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim)\n        self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size)\n        copy_mask = np.load(os.path.join(self.dpath, \"copy_mask.npy\")).astype(bool)\n        if self.replace_token:\n            if self.replace_token_idx < len(copy_mask):\n                copy_mask[self.replace_token_idx] = False\n            else:\n                copy_mask = np.insert(copy_mask,len(copy_mask),False)\n        self.copy_mask = torch.as_tensor(copy_mask).to(self.device)\n        \n\n        self.conv_decoder = TransformerDecoderKG(\n            self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size,\n            embedding=self.token_embedding,\n            dropout=self.dropout,\n            attention_dropout=self.attention_dropout,\n            relu_dropout=self.relu_dropout,\n            embeddings_scale=self.embeddings_scale,\n            learn_positional_embeddings=self.learn_positional_embeddings,\n            padding_idx=self.pad_token_idx,\n            n_positions=self.n_positions\n        )\n        self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)\n\n        logger.debug('[Finish build conv layer]')\n\n    def pretrain_infomax(self, batch):\n        \"\"\"\n        words: (batch_size, word_length)\n        entity_labels: (batch_size, n_entity)\n        \"\"\"\n        words, entity_labels = batch\n\n        loss_mask = torch.sum(entity_labels)\n        if loss_mask.item() == 0:\n            return None\n\n        entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type)\n        word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges)\n\n        word_representations = word_graph_representations[words]\n        word_padding_mask = words.eq(self.pad_word_idx)  # (bs, seq_len)\n\n        word_attn_rep = self.word_self_attn(word_representations, word_padding_mask)\n        word_info_rep = self.infomax_norm(word_attn_rep)  # (bs, dim)\n        info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias)  # (bs, #entity)\n        loss = self.infomax_loss(info_predict, entity_labels) / loss_mask\n        return loss\n\n    def _build_movie_selector(self):\n        self.movie_selector = TransformerDecoderSelection(\n            n_heads=self.n_heads,\n            n_layers=self.n_layers,\n            embedding_size=self.token_emb_dim,\n            ffn_size=self.ffn_size,\n            vocabulary_size=self.n_movies_label,\n            # embedding=self.token_embedding,\n            dropout=self.dropout,\n            attention_dropout=self.attention_dropout,\n            relu_dropout=self.relu_dropout,\n            padding_idx=self.pad_token_idx,\n            learn_positional_embeddings=self.learn_positional_embeddings,\n            embeddings_scale=self.embeddings_scale,\n            n_positions=self.n_positions,\n        )\n        self.matching_linear = nn.Linear(self.token_emb_dim,self.n_movies_label)\n        self.sel_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)\n\n    def recommend(self, batch, mode):\n        \"\"\"\n        context_entities: (batch_size, entity_length)\n        context_words: (batch_size, word_length)\n        movie: (batch_size)\n        \"\"\"\n        context_entities, context_words, entities, movie = batch\n\n        entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type)\n        word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges)\n\n        entity_padding_mask = context_entities.eq(self.pad_entity_idx)  # (bs, entity_len)\n        word_padding_mask = context_words.eq(self.pad_word_idx)  # (bs, word_len)\n\n        entity_representations = entity_graph_representations[context_entities]\n        word_representations = word_graph_representations[context_words]\n\n        entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask)\n        word_attn_rep = self.word_self_attn(word_representations, word_padding_mask)\n\n        user_rep = self.gate_layer(entity_attn_rep, word_attn_rep)\n        rec_scores = F.linear(user_rep, entity_graph_representations, self.rec_bias.bias)  # (bs, #entity)\n\n        rec_loss = self.rec_loss(rec_scores, movie)\n\n        info_loss_mask = torch.sum(entities)\n        if info_loss_mask.item() == 0:\n            info_loss = None\n        else:\n            word_info_rep = self.infomax_norm(word_attn_rep)  # (bs, dim)\n            info_predict = F.linear(word_info_rep, entity_graph_representations,\n                                    self.infomax_bias.bias)  # (bs, #entity)\n            info_loss = self.infomax_loss(info_predict, entities) / info_loss_mask\n\n        return rec_loss, info_loss, rec_scores\n\n    def freeze_parameters(self):\n        freeze_models = [self.word_kg_embedding, self.entity_encoder, self.entity_self_attn, self.word_encoder,\n                         self.word_self_attn, self.gate_layer, self.infomax_bias, self.infomax_norm, self.rec_bias]\n        for model in freeze_models:\n            for p in model.parameters():\n                p.requires_grad = False\n\n    def _starts(self, batch_size):\n        \"\"\"Return bsz start tokens.\"\"\"\n        return self.START.detach().expand(batch_size, 1)\n    \n    def converse(self, batch, mode):\n        context_tokens, context_entities, context_words, response, all_movies = batch\n\n        entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type)\n        word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges)\n\n        entity_padding_mask = context_entities.eq(self.pad_entity_idx)  # (bs, entity_len)\n        word_padding_mask = context_words.eq(self.pad_word_idx)  # (bs, seq_len)\n\n        entity_representations = entity_graph_representations[context_entities]\n        word_representations = word_graph_representations[context_words]\n\n        entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask)\n        word_attn_rep = self.word_self_attn(word_representations, word_padding_mask)\n\n        # encoder-decoder\n        tokens_encoding = self.conv_encoder(context_tokens)\n        conv_entity_emb = self.conv_entity_attn_norm(entity_attn_rep)\n        conv_word_emb = self.conv_word_attn_norm(word_attn_rep)\n        conv_entity_reps = self.conv_entity_norm(entity_representations)\n        conv_word_reps = self.conv_word_norm(word_representations)\n\n        if mode != 'test':\n            logits, preds,latent = self._decode_forced_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb,\n                                                        entity_padding_mask,\n                                                        conv_word_reps, conv_word_emb, word_padding_mask,\n                                                        response)\n\n            logits_ = logits.view(-1, logits.shape[-1])\n            response_ = response.view(-1)\n            gen_loss = self.conv_loss(logits_, response_)\n\n            assert torch.sum(all_movies!=0, dim=(0,1)) == torch.sum((response == 30000), dim=(0,1)) #30000 means the idx of [ITEM]\n            masked_for_selection_token = (response == self.replace_token_idx) \n\n            matching_tensor,_ = self.movie_selector(latent,tokens_encoding,conv_word_reps,word_padding_mask)\n            matching_logits_ = self.matching_linear(matching_tensor)\n\n            matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1])\n\n            all_movies = torch.masked_select(all_movies,(all_movies != 0)) \n            matching_logits = matching_logits.view(-1,matching_logits.shape[-1])\n            all_movies = all_movies.view(-1)\n            selection_loss = self.sel_loss(matching_logits,all_movies)\n            return gen_loss,selection_loss, preds\n        else:\n            logits, preds,latent = self._decode_greedy_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb,\n                                                        entity_padding_mask,\n                                                        conv_word_reps, conv_word_emb, word_padding_mask)\n            \n            preds_for_selection = preds[:, 1:] # skip the start_ind\n            masked_for_selection_token = (preds_for_selection == self.replace_token_idx)\n\n            matching_tensor,_ = self.movie_selector(latent,tokens_encoding,conv_word_reps,word_padding_mask)\n            matching_logits_ = self.matching_linear(matching_tensor)\n            matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1])\n\n            if matching_logits.shape[0] is not 0:\n                    #W1: greedy\n                    _, matching_pred = matching_logits.max(dim=-1) # [bsz * dynamic_movie_nums] \n            else:\n                matching_pred = None\n            return preds,matching_pred,matching_logits_\n    \n    def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask,\n                               word_reps, word_emb_attn, word_mask):\n        batch_size = token_encoding[0].shape[0]\n        inputs = self._starts(batch_size).long()\n        incr_state = None\n        logits = []\n        latents = []\n        for _ in range(self.response_truncate):\n            dialog_latent, incr_state = self.conv_decoder(inputs, token_encoding, word_reps, word_mask,\n                                                          entity_reps, entity_mask, incr_state)\n            dialog_latent = dialog_latent[:, -1:, :]  # (bs, 1, dim)\n            latents.append(dialog_latent)\n            db_latent = entity_emb_attn.unsqueeze(1)\n            concept_latent = word_emb_attn.unsqueeze(1)\n            copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1))\n\n            copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0)\n            gen_logits = F.linear(dialog_latent, self.token_embedding.weight)\n            sum_logits = copy_logits + gen_logits\n            preds = sum_logits.argmax(dim=-1).long()\n            logits.append(sum_logits)\n            inputs = torch.cat((inputs, preds), dim=1)\n\n            finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size\n            if finished:\n                break\n        logits = torch.cat(logits, dim=1)\n        latents = torch.cat(latents, dim=1)\n        return logits, inputs, latents\n\n    def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask,\n                               word_reps, word_emb_attn, word_mask, response):\n        batch_size, seq_len = response.shape\n        start = self._starts(batch_size)\n        inputs = torch.cat((start, response[:, :-1]), dim=-1).long()\n\n        dialog_latent, _ = self.conv_decoder(inputs, token_encoding, word_reps, word_mask,\n                                             entity_reps, entity_mask)  # (bs, seq_len, dim)\n        \n        entity_latent = entity_emb_attn.unsqueeze(1).expand(-1, seq_len, -1)\n        word_latent = word_emb_attn.unsqueeze(1).expand(-1, seq_len, -1)\n        copy_latent = self.copy_norm(\n            torch.cat((entity_latent, word_latent, dialog_latent), dim=-1))  # (bs, seq_len, dim)\n\n        copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(\n            0)  # (bs, seq_len, vocab_size)\n        gen_logits = F.linear(dialog_latent, self.token_embedding.weight)  # (bs, seq_len, vocab_size)\n        sum_logits = copy_logits + gen_logits\n        preds = sum_logits.argmax(dim=-1)\n        return sum_logits, preds, dialog_latent\n        \n\n    \n    def forward(self, batch, stage, mode):\n        if len(self.gpu) >= 2:\n            #  forward function operates on different gpus, the weight of graph network need to be copied to other gpu\n            self.entity_edge_idx = self.entity_edge_idx.cuda(torch.cuda.current_device())\n            self.entity_edge_type = self.entity_edge_type.cuda(torch.cuda.current_device())\n            self.word_edges = self.word_edges.cuda(torch.cuda.current_device())\n            self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, \"copy_mask.npy\")).astype(bool),\n                                             ).cuda(torch.cuda.current_device())\n        if stage == \"pretrain\":\n            loss = self.pretrain_infomax(batch)\n        elif stage == \"rec\":\n            loss = self.recommend(batch, mode)\n        elif stage == \"conv\":\n            loss = self.converse(batch, mode)\n        return loss"
  },
  {
    "path": "crslab/model/crs/ntrd/resources.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/13\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/12/15\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\nresources = {\n    'ReDial': {\n        'version': '0.2',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1',\n            'kgsf_redial.zip',\n            'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548',\n        ),\n    },\n    'TGReDial': {\n        'version': '0.2',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1',\n            'kgsf_tgredial.zip',\n            'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1',\n        ),\n    },\n    'GoRecDial': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ER5u2yMmgDNFvHuW6lKZLEkBKZkOkxMtZGK0bBQ-jvfLNw?download=1',\n            'kgsf_gorecdial.zip',\n            'f2f57ebb8f688f38a98ee41fe3a87e9362aed945ec9078869407f799da322633',\n        )\n    },\n    'OpenDialKG': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1',\n            'kgsf_opendialkg.zip',\n            '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61'\n        )\n    },\n    'Inspired': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1',\n            'kgsf_inspired.zip',\n            '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d'\n        )\n    },\n    'DuRecDial': {\n        'version': '0.1',\n        'file': DownloadableFile(\n            'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1',\n            'kgsf_durecdial.zip',\n            'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef'\n        )\n    }\n}\n"
  },
  {
    "path": "crslab/model/crs/redial/__init__.py",
    "content": "from .redial_conv import ReDialConvModel\nfrom .redial_rec import ReDialRecModel\n"
  },
  {
    "path": "crslab/model/crs/redial/modules.py",
    "content": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE:\n# @Time   : 2020/12/16\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence\n\nfrom crslab.model.utils.functions import sort_for_packed_sequence\n\n\nclass HRNN(nn.Module):\n    def __init__(self,\n                 utterance_encoder_hidden_size,\n                 dialog_encoder_hidden_size,\n                 dialog_encoder_num_layers,\n                 pad_token_idx,\n                 embedding=None,\n                 use_dropout=False,\n                 dropout=0.3):\n        super(HRNN, self).__init__()\n        self.pad_token_idx = pad_token_idx\n        # embedding\n        self.embedding_size = embedding.weight.shape[1]\n        self.embedding = embedding\n        # utterance encoder\n        self.utterance_encoder_hidden_size = utterance_encoder_hidden_size\n        self.utterance_encoder = nn.GRU(\n            input_size=self.embedding_size,\n            hidden_size=self.utterance_encoder_hidden_size,\n            batch_first=True,\n            bidirectional=True\n        )\n        # conversation encoder\n        self.dialog_encoder = nn.GRU(\n            input_size=(2 * self.utterance_encoder_hidden_size),\n            hidden_size=dialog_encoder_hidden_size,\n            num_layers=dialog_encoder_num_layers,\n            batch_first=True\n        )\n        # dropout\n        self.use_dropout = use_dropout\n        if self.use_dropout:\n            self.dropout = nn.Dropout(p=dropout)\n\n    def get_utterance_encoding(self, context, utterance_lengths):\n        \"\"\"\n        :param context: (batch_size, max_conversation_length, max_utterance_length)\n        :param utterance_lengths: (batch_size, max_conversation_length)\n        :return utterance_encoding: (batch_size, max_conversation_length, 2 * utterance_encoder_hidden_size)\n        \"\"\"\n        batch_size, max_conv_length = context.shape[:2]\n        utterance_lengths = utterance_lengths.reshape(-1)  # (bs * conv_len)\n        sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(utterance_lengths)\n\n        # reshape and reorder\n        sorted_utterances = context.view(batch_size * max_conv_length, -1).index_select(0, sorted_idx)\n\n        # consider valid sequences only(length > 0)\n        num_positive_lengths = torch.sum(utterance_lengths > 0)\n        sorted_utterances = sorted_utterances[:num_positive_lengths]\n        sorted_lengths = sorted_lengths[:num_positive_lengths]\n\n        embedded = self.embedding(sorted_utterances)\n        if self.use_dropout:\n            embedded = self.dropout(embedded)\n\n        packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True)\n        _, utterance_encoding = self.utterance_encoder(packed_utterances)\n\n        # concat the hidden states of the last layer (two directions of the GRU)\n        utterance_encoding = torch.cat((utterance_encoding[-1], utterance_encoding[-2]), 1)\n        if self.use_dropout:\n            utterance_encoding = self.dropout(utterance_encoding)\n\n        # complete the missing sequences (of length 0)\n        if num_positive_lengths < batch_size * max_conv_length:\n            pad_tensor = utterance_encoding.new_full(\n                (batch_size * max_conv_length - num_positive_lengths, 2 * self.utterance_encoder_hidden_size),\n                self.pad_token_idx)\n            utterance_encoding = torch.cat((utterance_encoding, pad_tensor), 0)\n\n        # retrieve original utterance order and Reshape to separate contexts\n        utterance_encoding = utterance_encoding.index_select(0, rev_idx)\n        utterance_encoding = utterance_encoding.view(batch_size, max_conv_length,\n                                                     2 * self.utterance_encoder_hidden_size)\n        return utterance_encoding\n\n    def forward(self, context, utterance_lengths, dialog_lengths):\n        \"\"\"\n        :param context: (batch_size, max_context_length, max_utterance_length)\n        :param utterance_lengths: (batch_size, max_context_length)\n        :param dialog_lengths: (batch_size)\n        :return context_state: (batch_size, context_encoder_hidden_size)\n        \"\"\"\n        utterance_encoding = self.get_utterance_encoding(context, utterance_lengths)  # (bs, conv_len, 2 * utt_dim)\n        sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(dialog_lengths)\n\n        # reorder in decreasing sequence length\n        sorted_representations = utterance_encoding.index_select(0, sorted_idx)\n        packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True)\n\n        _, context_state = self.dialog_encoder(packed_sequences)\n        context_state = context_state.index_select(1, rev_idx)\n        if self.use_dropout:\n            context_state = self.dropout(context_state)\n        return context_state[-1]\n\n\nclass SwitchingDecoder(nn.Module):\n    def __init__(self, hidden_size, context_size, num_layers, vocab_size, embedding, pad_token_idx):\n        super(SwitchingDecoder, self).__init__()\n        self.pad_token_idx = pad_token_idx\n        self.hidden_size = hidden_size\n        self.context_size = context_size\n        self.num_layers = num_layers\n        if context_size != hidden_size:\n            raise ValueError(\"The context size {} must match the hidden size {} in DecoderGRU\".format(\n                context_size, hidden_size))\n\n        self.embedding = embedding\n        embedding_dim = embedding.weight.shape[1]\n        self.decoder = nn.GRU(input_size=embedding_dim, hidden_size=hidden_size,\n                              num_layers=num_layers, batch_first=True)\n        self.out = nn.Linear(hidden_size, vocab_size)\n        self.switch = nn.Linear(hidden_size + context_size, 1)\n\n    def forward(self, request, request_lengths, context_state):\n        \"\"\"\n        :param request: (batch_size, max_utterance_length)\n        :param request_lengths: (batch_size)\n        :param context_state: (batch_size, context_encoder_hidden_size)\n        :return log_probabilities: (batch_size, max_utterance_length, vocab_size + 1)\n        \"\"\"\n        batch_size, max_utterance_length = request.shape\n\n        # sort for pack\n        sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths)\n        sorted_request = request.index_select(0, sorted_idx)\n        embedded_request = self.embedding(sorted_request)  # (batch_size, max_utterance_length, embed_dim)\n        packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True)\n\n        sorted_context_state = context_state.index_select(0, sorted_idx)\n        h_0 = sorted_context_state.unsqueeze(0).expand(\n            self.num_layers, batch_size, self.hidden_size\n        ).contiguous()  # require context_size == hidden_size\n\n        sorted_vocab_state, _ = self.decoder(packed_request, h_0)\n        sorted_vocab_state, _ = pad_packed_sequence(sorted_vocab_state,\n                                                    batch_first=True)  # (batch_size, max_request_length, decoder_hidden_size)\n\n        _, max_request_length, decoder_hidden_size = sorted_vocab_state.shape\n        pad_tensor = sorted_vocab_state.new_full(\n            (batch_size, max_utterance_length - max_request_length, decoder_hidden_size), self.pad_token_idx)\n        sorted_vocab_state = torch.cat((sorted_vocab_state, pad_tensor),\n                                       dim=1)  # (batch_size, max_utterance_length, decoder_hidden_size)\n        sorted_language_output = self.out(sorted_vocab_state)  # (batch_size, max_utterance_length, vocab_size)\n\n        # expand context to each time step\n        expanded_sorted_context_state = sorted_context_state.unsqueeze(1).expand(\n            batch_size, max_utterance_length, self.context_size\n        ).contiguous()  # (batch_size, max_utterance_length, context_size)\n        # compute switch\n        switch_input = torch.cat((expanded_sorted_context_state, sorted_vocab_state),\n                                 dim=2)  # (batch_size, max_utterance_length, context_size + decoder_hidden_size)\n        switch = self.switch(switch_input)  # (batch_size, max_utterance_length, 1)\n\n        sorted_output = torch.cat((\n            F.logsigmoid(switch) + F.log_softmax(sorted_language_output, dim=2),\n            F.logsigmoid(-switch)  # for item\n        ), dim=2)\n        output = sorted_output.index_select(0, rev_idx)  # (batch_size, max_utterance_length, vocab_size + 1)\n\n        return output\n"
  },
  {
    "path": "crslab/model/crs/redial/redial_conv.py",
    "content": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nReDial_Conv\n===========\nReferences:\n    Li, Raymond, et al. `\"Towards deep conversational recommendations.\"`_ in NeurIPS.\n\n.. _`\"Towards deep conversational recommendations.\"`:\n   https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html\n\n\"\"\"\n\nimport torch\nfrom torch import nn\n\nfrom crslab.model.base import BaseModel\nfrom .modules import HRNN, SwitchingDecoder\n\n\nclass ReDialConvModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        vocab_size: A integer indicating the vocabulary size.\n        pad_token_idx: A integer indicating the id of padding token.\n        start_token_idx: A integer indicating the id of start token.\n        end_token_idx: A integer indicating the id of end token.\n        unk_token_idx: A integer indicating the id of unk token.\n        pretrained_embedding: A string indicating the path of pretrained embedding.\n        embedding_dim: A integer indicating the dimension of item embedding.\n        utterance_encoder_hidden_size: A integer indicating the size of hidden size in utterance encoder.\n        dialog_encoder_hidden_size: A integer indicating the size of hidden size in dialog encoder.\n        dialog_encoder_num_layers: A integer indicating the number of layers in dialog encoder.\n        use_dropout: A boolean indicating if we use the dropout.\n        dropout: A float indicating the dropout rate.\n        decoder_hidden_size: A integer indicating the size of hidden size in decoder.\n        decoder_num_layers: A integer indicating the number of layer in decoder.\n        decoder_embedding_dim: A integer indicating the dimension of embedding in decoder.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        # dataset\n        self.vocab_size = vocab['vocab_size']\n        self.pad_token_idx = vocab['pad']\n        self.start_token_idx = vocab['start']\n        self.end_token_idx = vocab['end']\n        self.unk_token_idx = vocab['unk']\n        self.pretrained_embedding = side_data.get('embedding', None)\n        self.embedding_dim = opt.get('embedding_dim', None)\n        if opt.get('embedding', None) and self.embedding_dim is None:\n            raise\n        # HRNN\n        self.utterance_encoder_hidden_size = opt['utterance_encoder_hidden_size']\n        self.dialog_encoder_hidden_size = opt['dialog_encoder_hidden_size']\n        self.dialog_encoder_num_layers = opt['dialog_encoder_num_layers']\n        self.use_dropout = opt['use_dropout']\n        self.dropout = opt['dropout']\n        # SwitchingDecoder\n        self.decoder_hidden_size = opt['decoder_hidden_size']\n        self.decoder_num_layers = opt['decoder_num_layers']\n        self.decoder_embedding_dim = opt['decoder_embedding_dim']\n\n        super(ReDialConvModel, self).__init__(opt, device)\n\n    def build_model(self):\n        if self.opt.get('embedding', None) and self.pretrained_embedding is not None:\n            embedding = nn.Embedding.from_pretrained(\n                torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False,\n                padding_idx=self.pad_token_idx)\n        else:\n            embedding = nn.Embedding(self.vocab_size, self.embedding_dim)\n\n        self.encoder = HRNN(\n            embedding=embedding,\n            utterance_encoder_hidden_size=self.utterance_encoder_hidden_size,\n            dialog_encoder_hidden_size=self.dialog_encoder_hidden_size,\n            dialog_encoder_num_layers=self.dialog_encoder_num_layers,\n            use_dropout=self.use_dropout,\n            dropout=self.dropout,\n            pad_token_idx=self.pad_token_idx\n        )\n\n        self.decoder = SwitchingDecoder(\n            hidden_size=self.decoder_hidden_size,\n            context_size=self.dialog_encoder_hidden_size,\n            num_layers=self.decoder_num_layers,\n            vocab_size=self.vocab_size,\n            embedding=embedding,\n            pad_token_idx=self.pad_token_idx\n        )\n        self.loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)\n\n    def forward(self, batch, mode):\n        \"\"\"\n        Args:\n            batch: ::\n\n                {\n                    'context': (batch_size, max_context_length, max_utterance_length),\n                    'context_lengths': (batch_size),\n                    'utterance_lengths': (batch_size, max_context_length),\n                    'request': (batch_size, max_utterance_length),\n                    'request_lengths': (batch_size),\n                    'response': (batch_size, max_utterance_length)\n                }\n\n        \"\"\"\n        assert mode in ('train', 'valid', 'test')\n        if mode == 'train':\n            self.train()\n        else:\n            self.eval()\n\n        context = batch['context']\n        utterance_lengths = batch['utterance_lengths']\n        context_lengths = batch['context_lengths']\n        context_state = self.encoder(context, utterance_lengths,\n                                     context_lengths)  # (batch_size, context_encoder_hidden_size)\n\n        request = batch['request']\n        request_lengths = batch['request_lengths']\n        log_probs = self.decoder(request, request_lengths,\n                                 context_state)  # (batch_size, max_utterance_length, vocab_size + 1)\n        preds = log_probs.argmax(dim=-1)  # (batch_size, max_utterance_length)\n\n        log_probs = log_probs.view(-1, log_probs.shape[-1])\n        response = batch['response'].view(-1)\n        loss = self.loss(log_probs, response)\n\n        return loss, preds\n"
  },
  {
    "path": "crslab/model/crs/redial/redial_rec.py",
    "content": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nReDial_Rec\n==========\nReferences:\n    Li, Raymond, et al. `\"Towards deep conversational recommendations.\"`_ in NeurIPS.\n\n.. _`\"Towards deep conversational recommendations.\"`:\n   https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html\n\n\"\"\"\n\nimport torch.nn as nn\n\nfrom crslab.model.base import BaseModel\n\n\nclass ReDialRecModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        n_entity: A integer indicating the number of entities.\n        layer_sizes: A integer indicating the size of layer in autorec.\n        pad_entity_idx: A integer indicating the id of entity padding.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.n_entity = vocab['n_entity']\n        self.layer_sizes = opt['autorec_layer_sizes']\n        self.pad_entity_idx = vocab['pad_entity']\n\n        super(ReDialRecModel, self).__init__(opt, device)\n\n    def build_model(self):\n        # AutoRec\n        if self.opt['autorec_f'] == 'identity':\n            self.f = lambda x: x\n        elif self.opt['autorec_f'] == 'sigmoid':\n            self.f = nn.Sigmoid()\n        elif self.opt['autorec_f'] == 'relu':\n            self.f = nn.ReLU()\n        else:\n            raise ValueError(\"Got invalid function name for f : {}\".format(self.opt['autorec_f']))\n\n        if self.opt['autorec_g'] == 'identity':\n            self.g = lambda x: x\n        elif self.opt['autorec_g'] == 'sigmoid':\n            self.g = nn.Sigmoid()\n        elif self.opt['autorec_g'] == 'relu':\n            self.g = nn.ReLU()\n        else:\n            raise ValueError(\"Got invalid function name for g : {}\".format(self.opt['autorec_g']))\n\n        self.encoder = nn.ModuleList([nn.Linear(self.n_entity, self.layer_sizes[0]) if i == 0\n                                      else nn.Linear(self.layer_sizes[i - 1], self.layer_sizes[i])\n                                      for i in range(len(self.layer_sizes))])\n        self.user_repr_dim = self.layer_sizes[-1]\n        self.decoder = nn.Linear(self.user_repr_dim, self.n_entity)\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, batch, mode):\n        \"\"\"\n\n        Args:\n            batch: ::\n\n                {\n                    'context_entities': (batch_size, n_entity),\n                    'item': (batch_size)\n                }\n\n            mode (str)\n\n        \"\"\"\n        context_entities = batch['context_entities']\n        for i, layer in enumerate(self.encoder):\n            context_entities = self.f(layer(context_entities))\n        scores = self.g(self.decoder(context_entities))\n        loss = self.loss(scores, batch['item'])\n\n        return loss, scores\n"
  },
  {
    "path": "crslab/model/crs/tgredial/__init__.py",
    "content": "from .tg_conv import TGConvModel\nfrom .tg_policy import TGPolicyModel\nfrom .tg_rec import TGRecModel\n"
  },
  {
    "path": "crslab/model/crs/tgredial/tg_conv.py",
    "content": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/7, 2020/12/15, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou\n# @Email  : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com\n\nr\"\"\"\nTGReDial_Conv\n=============\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport os\n\nimport torch\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import GPT2LMHeadModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\n\n\nclass TGConvModel(BaseModel):\n    \"\"\"\n        \n    Attributes:\n        context_truncate: A integer indicating the length of dialogue context.\n        response_truncate: A integer indicating the length of dialogue response.\n        pad_id: A integer indicating the id of padding token.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.context_truncate = opt['context_truncate']\n        self.response_truncate = opt['response_truncate']\n        self.pad_id = vocab['pad']\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['gpt2'][language]\n        dpath = os.path.join(PRETRAIN_PATH, 'gpt2', language)\n        super(TGConvModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self):\n        \"\"\"build model\"\"\"\n        self.model = GPT2LMHeadModel.from_pretrained(self.dpath)\n        self.loss = CrossEntropyLoss(ignore_index=self.pad_id)\n\n    def forward(self, batch, mode):\n        if mode == 'test' or mode == 'infer':\n            enhanced_context = batch[1]\n            return self.generate(enhanced_context)\n        else:\n            enhanced_input_ids = batch[0]\n            # torch.tensor's shape = (bs, seq_len, v_s); tuple's length = 12\n            lm_logits = self.model(enhanced_input_ids).logits\n\n            # index from 1 to self.reponse_truncate is valid response\n            loss = self.calculate_loss(\n                lm_logits[:, -self.response_truncate:-1, :],\n                enhanced_input_ids[:, -self.response_truncate + 1:])\n\n            pred = torch.max(lm_logits, dim=2)[1]  # [bs, seq_len]\n            pred = pred[:, -self.response_truncate:]\n\n            return loss, pred\n\n    def generate(self, context):\n        \"\"\"\n        Args:\n            context: torch.tensor, shape=(bs, context_turncate)\n\n        Returns:\n            generated_response: torch.tensor, shape=(bs, reponse_turncate-1)\n        \"\"\"\n        generated_response = []\n        former_hidden_state = None\n        context = context[..., -self.response_truncate + 1:]\n\n        for i in range(self.response_truncate - 1):\n            outputs = self.model(context, former_hidden_state)  # (bs, c_t, v_s),\n            last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values\n\n            next_token_logits = last_hidden_state[:, -1, :]  # (bs, v_s)\n            preds = next_token_logits.argmax(dim=-1).long()  # (bs)\n\n            context = preds.unsqueeze(1)\n            generated_response.append(preds)\n\n        generated_response = torch.stack(generated_response).T\n\n        return generated_response\n\n    def generate_bs(self, context, beam=4):\n        context = context[..., -self.response_truncate + 1:]\n        context_former = context\n        batch_size = context.shape[0]\n        sequences = [[[list(), 1.0]]] * batch_size\n        for i in range(self.response_truncate - 1):\n            if sequences != [[[list(), 1.0]]] * batch_size:\n                context = []\n                for i in range(batch_size):\n                    for cand in sequences[i]:\n                        text = torch.cat(\n                            (context_former[i], torch.tensor(cand[0]).to(self.device)))  # 由于取消了state，与之前的context拼接\n                        context.append(text)\n                context = torch.stack(context)\n            with torch.no_grad():\n                outputs = self.model(context)\n            last_hidden_state, state = outputs.logits, outputs.past_key_values\n            next_token_logits = last_hidden_state[:, -1, :]\n            next_token_probs = torch.nn.functional.softmax(next_token_logits)\n            topk = torch.topk(next_token_probs, beam, dim=-1)\n            probs = topk.values.reshape([batch_size, -1, beam])  # (bs, candidate, beam)\n            preds = topk.indices.reshape([batch_size, -1, beam])  # (bs, candidate, beam)\n\n            for j in range(batch_size):\n                all_candidates = []\n                for n in range(len(sequences[j])):\n                    for k in range(beam):\n                        seq = sequences[j][n][0]\n                        prob = sequences[j][n][1]\n                        seq_tmp = seq.copy()\n                        seq_tmp.append(preds[j][n][k])\n                        candidate = [seq_tmp, prob * probs[j][n][k]]\n                        all_candidates.append(candidate)\n                ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)\n                sequences[j] = ordered[:beam]\n\n        res = []\n        for i in range(batch_size):\n            res.append(torch.stack(sequences[i][0][0]))\n        res = torch.stack(res)\n        return res\n\n    def calculate_loss(self, logit, labels):\n        \"\"\"\n        Args:\n            preds: torch.FloatTensor, shape=(bs, response_truncate, vocab_size)\n            labels: torch.LongTensor, shape=(bs, response_truncate)\n\n        \"\"\"\n\n        loss = self.loss(logit.reshape(-1, logit.size(-1)), labels.reshape(-1))\n        return loss\n"
  },
  {
    "path": "crslab/model/crs/tgredial/tg_policy.py",
    "content": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/7, 2020/12/15, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou\n# @Email  : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com\n\nr\"\"\"\nTGReDial_Policy\n===============\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport os\n\nimport torch\nfrom torch import nn\nfrom transformers import BertModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\n\n\nclass TGPolicyModel(BaseModel):\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n        \n        \"\"\"\n        self.topic_class_num = vocab['n_topic']\n        self.n_sent = opt.get('n_sent', 10)\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['bert'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"bert\", language)\n        super(TGPolicyModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self, *args, **kwargs):\n        \"\"\"build model\"\"\"\n        self.context_bert = BertModel.from_pretrained(self.dpath)\n        self.topic_bert = BertModel.from_pretrained(self.dpath)\n        self.profile_bert = BertModel.from_pretrained(self.dpath)\n\n        self.bert_hidden_size = self.context_bert.config.hidden_size\n        self.state2topic_id = nn.Linear(self.bert_hidden_size * 3,\n                                        self.topic_class_num)\n\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, batch, mode):\n        # conv_id, message_id, context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n        context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n\n        context_rep = self.context_bert(\n            context,\n            context_mask).pooler_output  # (bs, hidden_size)\n\n        topic_rep = self.topic_bert(\n            topic_path_kw,\n            tp_mask).pooler_output  # (bs, hidden_size)\n\n        bs = user_profile.shape[0] // self.n_sent\n        profile_rep = self.profile_bert(user_profile, profile_mask).pooler_output  # (bs, word_num, hidden)\n        profile_rep = profile_rep.view(bs, self.n_sent, -1)\n        profile_rep = torch.mean(profile_rep, dim=1)  # (bs, hidden)\n\n        state_rep = torch.cat((context_rep, topic_rep, profile_rep), dim=1)  # [bs, hidden_size*3]\n        topic_scores = self.state2topic_id(state_rep)\n        topic_loss = self.loss(topic_scores, y)\n\n        return topic_loss, topic_scores\n"
  },
  {
    "path": "crslab/model/crs/tgredial/tg_rec.py",
    "content": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/7, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @Email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nTGReDial_Rec\n============\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport os\n\nimport torch\nfrom loguru import logger\nfrom torch import nn\nfrom transformers import BertModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\nfrom crslab.model.recommendation.sasrec.modules import SASRec\n\n\nclass TGRecModel(BaseModel):\n    \"\"\"\n        \n    Attributes:\n        hidden_dropout_prob: A float indicating the dropout rate to dropout hidden state in SASRec.\n        initializer_range: A float indicating the range of parameters initization in SASRec.\n        hidden_size: A integer indicating the size of hidden state in SASRec.\n        max_seq_length: A integer indicating the max interaction history length.\n        item_size: A integer indicating the number of items.\n        num_attention_heads: A integer indicating the head number in SASRec.\n        attention_probs_dropout_prob: A float indicating the dropout rate in attention layers.\n        hidden_act: A string indicating the activation function type in SASRec.\n        num_hidden_layers: A integer indicating the number of hidden layers in SASRec.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.hidden_dropout_prob = opt['hidden_dropout_prob']\n        self.initializer_range = opt['initializer_range']\n        self.hidden_size = opt['hidden_size']\n        self.max_seq_length = opt['max_history_items']\n        self.item_size = vocab['n_entity'] + 1\n        self.num_attention_heads = opt['num_attention_heads']\n        self.attention_probs_dropout_prob = opt['attention_probs_dropout_prob']\n        self.hidden_act = opt['hidden_act']\n        self.num_hidden_layers = opt['num_hidden_layers']\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['bert'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"bert\", language)\n        super(TGRecModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self):\n        # build BERT layer, give the architecture, load pretrained parameters\n        self.bert = BertModel.from_pretrained(self.dpath)\n        self.bert_hidden_size = self.bert.config.hidden_size\n        self.concat_embed_size = self.bert_hidden_size + self.hidden_size\n        self.fusion = nn.Linear(self.concat_embed_size, self.item_size)\n        self.SASREC = SASRec(self.hidden_dropout_prob, self.device,\n                             self.initializer_range, self.hidden_size,\n                             self.max_seq_length, self.item_size,\n                             self.num_attention_heads,\n                             self.attention_probs_dropout_prob,\n                             self.hidden_act, self.num_hidden_layers)\n\n        # this loss may conduct to some weakness\n        self.rec_loss = nn.CrossEntropyLoss()\n\n        logger.debug('[Finish build rec layer]')\n\n    def forward(self, batch, mode):\n        context, mask, input_ids, target_pos, input_mask, sample_negs, y = batch\n\n        bert_embed = self.bert(context, attention_mask=mask).pooler_output\n\n        sequence_output = self.SASREC(input_ids, input_mask)  # bs, max_len, hidden_size2\n        sas_embed = sequence_output[:, -1, :]  # bs, hidden_size2\n\n        embed = torch.cat((sas_embed, bert_embed), dim=1)\n        rec_scores = self.fusion(embed)  # bs, item_size\n\n        if mode == 'infer':\n            return rec_scores\n        else:\n            rec_loss = self.rec_loss(rec_scores, y)\n            return rec_loss, rec_scores\n"
  },
  {
    "path": "crslab/model/policy/__init__.py",
    "content": "from .conv_bert import ConvBERTModel\nfrom .mgcg import MGCGModel\nfrom .pmi import PMIModel\nfrom .profile_bert import ProfileBERTModel\nfrom .topic_bert import TopicBERTModel\n"
  },
  {
    "path": "crslab/model/policy/conv_bert/__init__.py",
    "content": "from .conv_bert import ConvBERTModel\n"
  },
  {
    "path": "crslab/model/policy/conv_bert/conv_bert.py",
    "content": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nConv_BERT\n=========\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport os\n\nfrom torch import nn\nfrom transformers import BertModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom ...pretrained_models import resources\n\n\nclass ConvBERTModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        topic_class_num: the number of topic.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n        \n        \"\"\"\n        self.topic_class_num = vocab['n_topic']\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['bert'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"bert\", language)\n        super(ConvBERTModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self, *args, **kwargs):\n        \"\"\"build model\"\"\"\n        self.context_bert = BertModel.from_pretrained(self.dpath)\n\n        self.bert_hidden_size = self.context_bert.config.hidden_size\n        self.state2topic_id = nn.Linear(self.bert_hidden_size,\n                                        self.topic_class_num)\n\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, batch, mode):\n        # conv_id, message_id, context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n        context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n\n        context_rep = self.context_bert(\n            context,\n            context_mask).pooler_output  # [bs, hidden_size]\n\n        topic_scores = self.state2topic_id(context_rep)\n\n        topic_loss = self.loss(topic_scores, y)\n\n        return topic_loss, topic_scores\n"
  },
  {
    "path": "crslab/model/policy/mgcg/__init__.py",
    "content": "from .mgcg import MGCGModel\n"
  },
  {
    "path": "crslab/model/policy/mgcg/mgcg.py",
    "content": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nMGCG\n====\nReferences:\n    Liu, Zeming, et al. `\"Towards Conversational Recommendation over Multi-Type Dialogs.\"`_ in ACL 2020.\n\n.. _\"Towards Conversational Recommendation over Multi-Type Dialogs.\":\n   https://www.aclweb.org/anthology/2020.acl-main.98/\n\n\"\"\"\n\nimport torch\nfrom torch import nn\nfrom torch.nn.utils.rnn import pack_padded_sequence\n\nfrom crslab.model.base import BaseModel\n\n\nclass MGCGModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        topic_class_num: A integer indicating the number of topic.\n        vocab_size: A integer indicating the size of vocabulary.\n        embedding_dim: A integer indicating the dimension of embedding layer.\n        hidden_size: A integer indicating the size of hidden state.\n        num_layers: A integer indicating the number of layers in GRU.\n        dropout_hidden: A float indicating the dropout rate of hidden state.\n        n_sent: A integer indicating sequence length in user profile.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n        \n        \"\"\"\n        self.topic_class_num = vocab['n_topic']\n        self.vocab_size = vocab['vocab_size']\n        self.embedding_dim = opt['embedding_dim']\n        self.hidden_size = opt['hidden_size']\n        self.num_layers = opt['num_layers']\n        self.dropout_hidden = opt['dropout_hidden']\n        self.n_sent = opt.get('n_sent', 10)\n\n        super(MGCGModel, self).__init__(opt, device)\n\n    def build_model(self, *args, **kwargs):\n        \"\"\"build model\"\"\"\n        self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)\n        self.context_lstm = nn.LSTM(self.embedding_dim,\n                                    self.hidden_size,\n                                    self.num_layers,\n                                    dropout=self.dropout_hidden,\n                                    batch_first=True)\n\n        self.topic_lstm = nn.LSTM(self.embedding_dim,\n                                  self.hidden_size,\n                                  self.num_layers,\n                                  dropout=self.dropout_hidden,\n                                  batch_first=True)\n\n        self.profile_lstm = nn.LSTM(self.embedding_dim,\n                                    self.hidden_size,\n                                    self.num_layers,\n                                    dropout=self.dropout_hidden,\n                                    batch_first=True)\n\n        self.state2topic_id = nn.Linear(self.hidden_size * 3,\n                                        self.topic_class_num)\n        self.loss = nn.CrossEntropyLoss()\n\n    def get_length(self, input):\n        return [torch.sum((ids != 0).long()).item() for ids in input]\n\n    def forward(self, batch, mode):\n        # conv_id, message_id, context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n        context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n\n        len_context = self.get_length(context)\n        len_tp = self.get_length(topic_path_kw)\n        len_profile = self.get_length(user_profile)\n\n        bs_, word_num = user_profile.shape\n        bs = bs_ // self.n_sent\n\n        context = self.embeddings(context)\n        topic_path_kw = self.embeddings(topic_path_kw)\n        user_profile = self.embeddings(user_profile)\n\n        context = pack_padded_sequence(context,\n                                       len_context,\n                                       enforce_sorted=False,\n                                       batch_first=True)\n        topic_path_kw = pack_padded_sequence(topic_path_kw,\n                                             len_tp,\n                                             enforce_sorted=False,\n                                             batch_first=True)\n        user_profile = pack_padded_sequence(user_profile,\n                                            len_profile,\n                                            enforce_sorted=False,\n                                            batch_first=True)\n\n        init_h0 = (torch.zeros(self.num_layers, bs,\n                               self.hidden_size).to(self.device),\n                   torch.zeros(self.num_layers, bs,\n                               self.hidden_size).to(self.device))\n\n        # batch, seq_len, num_directions * hidden_size        # num_layers * num_directions, batch, hidden_size\n        context_output, (context_h, _) = self.context_lstm(context, init_h0)\n        topic_output, (topic_h, _) = self.topic_lstm(topic_path_kw, init_h0)\n        # batch*sent_num, seq_len, num_directions * hidden_size\n        init_h0 = (torch.zeros(self.num_layers, bs * self.n_sent,\n                               self.hidden_size).to(self.device),\n                   torch.zeros(self.num_layers, bs * self.n_sent,\n                               self.hidden_size).to(self.device))\n        profile_output, (profile_h,\n                         _) = self.profile_lstm(user_profile, init_h0)\n\n        # batch, hidden_size\n        context_rep = context_h[-1]\n        topic_rep = topic_h[-1]\n\n        profile_rep = profile_h[-1]\n        profile_rep = profile_rep.view(bs, self.n_sent, -1)\n        # batch, hidden_size\n        profile_rep = torch.mean(profile_rep, dim=1)\n\n        state_rep = torch.cat((context_rep, topic_rep, profile_rep), 1)\n\n        topic_scores = self.state2topic_id(state_rep)\n\n        topic_loss = self.loss(topic_scores, y)\n\n        return topic_loss, topic_scores\n"
  },
  {
    "path": "crslab/model/policy/pmi/__init__.py",
    "content": "from .pmi import PMIModel\n"
  },
  {
    "path": "crslab/model/policy/pmi/pmi.py",
    "content": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nPMI\n===\n\"\"\"\n\nfrom collections import defaultdict\n\nimport torch\n\nfrom crslab.model.base import BaseModel\n\n\nclass PMIModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        topic_class_num: A integer indicating the number of topic.\n        pad_topic: A integer indicating the id of topic padding.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.topic_class_num = vocab['n_topic']\n        self.pad_topic = vocab['pad_topic']\n        super(PMIModel, self).__init__(opt, device)\n\n    def build_model(self, *args, **kwargs):\n        \"\"\"build model\"\"\"\n        self.topic_to_num = defaultdict(int)\n        self.t2gram_to_num = defaultdict(int)\n        self.last_topic_to_target_topic = defaultdict(int)\n\n    def forward(self, batch, mode):\n        # conv_id, message_id, context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n        context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, target = batch\n\n        if mode == 'train':\n            for topic_path in topic_path_kw:\n                topic_path = [topic_id.item() for topic_id in topic_path if topic_id.item() != self.pad_topic]\n                for topic in topic_path:\n                    self.topic_to_num[topic] += 1\n                for i in range(1, len(topic_path)):\n                    self.t2gram_to_num[(topic_path[i - 1], topic_path[i])] += 1\n                self.last_topic_to_target_topic[(topic_path[-1], target[0])] += 1\n\n        test_last_topic_to_target_topic = defaultdict(int)\n        for topic_path in topic_path_kw:\n            topic_path = [topic_id.item() for topic_id in topic_path if topic_id.item() != self.pad_topic]\n            test_last_topic_to_target_topic[(topic_path[-1], target[0])] += 1\n\n        total_1_gram = sum(self.topic_to_num.values())\n        total_2_gram = sum(self.t2gram_to_num.values())\n        p_1_gram = {topic: num / total_1_gram for topic, num in self.topic_to_num.items()}\n        p_2_gram = {topic_tuple: num / total_2_gram for topic_tuple, num in self.t2gram_to_num.items()}\n\n        topic_scores = []\n        for (last_topic, target_topic), num in test_last_topic_to_target_topic.items():\n            candidate_topic_to_PMI = {}\n            for cnad_topic in self.topic_to_num:\n                if (last_topic, cnad_topic) in p_2_gram:\n                    candidate_topic_to_PMI[cnad_topic] = p_2_gram.get((last_topic, cnad_topic), 0) / (\n                            p_1_gram.get(last_topic, 0) * p_1_gram.get(cnad_topic, 0))\n            top_cand = dict(sorted(candidate_topic_to_PMI.items(), key=lambda x: x[1], reverse=True))\n            # top_cand = [topic for topic, num in top_cand]\n            topic_scores.append([top_cand.get(topic_id, 0) for topic_id in range(self.topic_class_num)])\n\n        return None, torch.tensor(topic_scores, dtype=torch.long)\n"
  },
  {
    "path": "crslab/model/policy/profile_bert/__init__.py",
    "content": "from .profile_bert import ProfileBERTModel\n"
  },
  {
    "path": "crslab/model/policy/profile_bert/profile_bert.py",
    "content": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nProfile_BERT\n============\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport os\n\nimport torch\nfrom torch import nn\nfrom transformers import BertModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\n\n\nclass ProfileBERTModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        topic_class_num: A integer indicating the number of topic.\n        n_sent: A integer indicating sequence length in user profile.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n        \n        \"\"\"\n        self.topic_class_num = vocab['n_topic']\n        self.n_sent = opt.get('n_sent', 10)\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['bert'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"bert\", language)\n        super(ProfileBERTModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self, *args, **kwargs):\n        \"\"\"build model\"\"\"\n        self.profile_bert = BertModel.from_pretrained(self.dpath)\n\n        self.bert_hidden_size = self.profile_bert.config.hidden_size\n        self.state2topic_id = nn.Linear(self.bert_hidden_size,\n                                        self.topic_class_num)\n\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, batch, mode):\n        # conv_id, message_id, context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n        context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n\n        bs = user_profile.size(0) // self.n_sent\n        profile_rep = self.profile_bert(\n            user_profile, profile_mask).pooler_output  # (bs, word_num, hidden)\n        profile_rep = profile_rep.view(bs, self.n_sent, -1)\n        profile_rep = torch.mean(profile_rep, dim=1)  # (bs, hidden)\n\n        topic_scores = self.state2topic_id(profile_rep)\n\n        topic_loss = self.loss(topic_scores, y)\n\n        return topic_loss, topic_scores\n"
  },
  {
    "path": "crslab/model/policy/topic_bert/__init__.py",
    "content": "from .topic_bert import TopicBERTModel\n"
  },
  {
    "path": "crslab/model/policy/topic_bert/topic_bert.py",
    "content": "# @Time   : 2020/12/17\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nTopic_BERT\n==========\nReferences:\n    Zhou, Kun, et al. `\"Towards Topic-Guided Conversational Recommender System.\"`_ in COLING 2020.\n\n.. _`\"Towards Topic-Guided Conversational Recommender System.\"`:\n   https://www.aclweb.org/anthology/2020.coling-main.365/\n\n\"\"\"\n\nimport os\n\nfrom torch import nn\nfrom transformers import BertModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\n\n\nclass TopicBERTModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        topic_class_num: A integer indicating the number of topic.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n        \n        \"\"\"\n        self.topic_class_num = vocab['n_topic']\n\n        language = dataset_language_map[opt['dataset']]\n        dpath = os.path.join(PRETRAIN_PATH, \"bert\", language)\n        resource = resources['bert'][language]\n        super(TopicBERTModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self, *args, **kwargs):\n        \"\"\"build model\"\"\"\n        self.topic_bert = BertModel.from_pretrained(self.dpath)\n\n        self.bert_hidden_size = self.topic_bert.config.hidden_size\n        self.state2topic_id = nn.Linear(self.bert_hidden_size,\n                                        self.topic_class_num)\n\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, batch, mode):\n        # conv_id, message_id, context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n        context, context_mask, topic_path_kw, tp_mask, user_profile, profile_mask, y = batch\n\n        topic_rep = self.topic_bert(\n            topic_path_kw,\n            tp_mask).pooler_output  # (bs, hidden_size)\n\n        topic_scores = self.state2topic_id(topic_rep)\n\n        topic_loss = self.loss(topic_scores, y)\n\n        return topic_loss, topic_scores\n"
  },
  {
    "path": "crslab/model/pretrained_models.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2021/1/6\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2021/1/7\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.download import DownloadableFile\n\n\"\"\"Download links of pretrain models.\n\nNow we provide the following models:\n\n- `BERT`_: zh, en\n- `GPT2`_: zh, en\n\n.. _BERT:\n   https://www.aclweb.org/anthology/N19-1423/\n.. _GPT2:\n   https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf\n    \n\"\"\"\n\nresources = {\n    'bert': {\n        'zh': {\n            'version': '0.1',\n            'file': DownloadableFile(\n                'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1',\n                'bert_zh.zip',\n                'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3'\n            )\n        },\n        'en': {\n            'version': '0.1',\n            'file': DownloadableFile(\n                'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1',\n                'bert_en.zip',\n                '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3'\n            )\n        },\n    },\n    'gpt2': {\n        'zh': {\n            'version': '0.1',\n            'file': DownloadableFile(\n                'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1',\n                'gpt2_zh.zip',\n                '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43'\n            )\n        },\n        'en': {\n            'version': '0.1',\n            'file': DownloadableFile(\n                'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1',\n                'gpt2_en.zip',\n                '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e'\n            )\n        }\n    }\n}\n"
  },
  {
    "path": "crslab/model/recommendation/__init__.py",
    "content": "from .bert import BERTModel\nfrom .gru4rec import GRU4RECModel\nfrom .popularity import PopularityModel\nfrom .sasrec import SASRECModel\nfrom .textcnn import TextCNNModel\n"
  },
  {
    "path": "crslab/model/recommendation/bert/__init__.py",
    "content": "from .bert import BERTModel\n"
  },
  {
    "path": "crslab/model/recommendation/bert/bert.py",
    "content": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2021/1/7, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nBERT\n====\nReferences:\n    Devlin, Jacob, et al. `\"BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.\"`_ in NAACL 2019.\n\n.. _`\"BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.\"`:\n   https://www.aclweb.org/anthology/N19-1423/\n\n\"\"\"\n\nimport os\n\nfrom loguru import logger\nfrom torch import nn\nfrom transformers import BertModel\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import dataset_language_map\nfrom crslab.model.base import BaseModel\nfrom crslab.model.pretrained_models import resources\n\n\nclass BERTModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        item_size: A integer indicating the number of items.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.item_size = vocab['n_entity']\n\n        language = dataset_language_map[opt['dataset']]\n        resource = resources['bert'][language]\n        dpath = os.path.join(PRETRAIN_PATH, \"bert\", language)\n        super(BERTModel, self).__init__(opt, device, dpath, resource)\n\n    def build_model(self):\n        # build BERT layer, give the architecture, load pretrained parameters\n        self.bert = BertModel.from_pretrained(self.dpath)\n        # print(self.item_size)\n        self.bert_hidden_size = self.bert.config.hidden_size\n        self.mlp = nn.Linear(self.bert_hidden_size, self.item_size)\n\n        # this loss may conduct to some weakness\n        self.rec_loss = nn.CrossEntropyLoss()\n\n        logger.debug('[Finish build rec layer]')\n\n    def forward(self, batch, mode='train'):\n        context, mask, input_ids, target_pos, input_mask, sample_negs, y = batch\n\n        bert_embed = self.bert(context, attention_mask=mask).pooler_output\n\n        rec_scores = self.mlp(bert_embed)  # bs, item_size\n\n        rec_loss = self.rec_loss(rec_scores, y)\n\n        return rec_loss, rec_scores\n"
  },
  {
    "path": "crslab/model/recommendation/gru4rec/__init__.py",
    "content": "from .gru4rec import GRU4RECModel\n"
  },
  {
    "path": "crslab/model/recommendation/gru4rec/gru4rec.py",
    "content": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nGRU4REC\n=======\nReferences:\n    Hidasi, Balázs, et al. `\"Session-Based Recommendations with Recurrent Neural Networks.\"`_ in ICLR 2016.\n\n.. _`\"Session-Based Recommendations with Recurrent Neural Networks.\"`:\n   https://arxiv.org/abs/1511.06939\n\n\"\"\"\n\nimport torch\nfrom loguru import logger\nfrom torch import nn\nfrom torch.nn.utils.rnn import pack_padded_sequence\nfrom torch.nn.utils.rnn import pad_packed_sequence\n\nfrom crslab.model.base import BaseModel\n\n\nclass GRU4RECModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        item_size: A integer indicating the number of items.\n        hidden_size: A integer indicating the hidden state size in GRU.\n        num_layers: A integer indicating the number of GRU layers.\n        dropout_hidden: A float indicating the dropout rate to dropout hidden state.\n        dropout_input: A integer indicating if we dropout the input of model.\n        embedding_dim: A integer indicating the dimension of item embedding.\n        batch_size: A integer indicating the batch size.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.item_size = vocab['n_entity'] + 1\n        self.hidden_size = opt['gru_hidden_size']\n        self.num_layers = opt['num_layers']\n        self.dropout_hidden = opt['dropout_hidden']\n        self.dropout_input = opt['dropout_input']\n        self.embedding_dim = opt['embedding_dim']\n        self.batch_size = opt['batch_size']\n\n        super(GRU4RECModel, self).__init__(opt, device)\n\n    def build_model(self):\n        self.item_embeddings = nn.Embedding(self.item_size, self.embedding_dim)\n        self.gru = nn.GRU(self.embedding_dim,\n                          self.hidden_size,\n                          self.num_layers,\n                          dropout=self.dropout_hidden,\n                          batch_first=True)\n\n        logger.debug('[Finish build rec layer]')\n\n    def reconstruct_input(self, input_ids):\n        \"\"\"\n        convert the padding from left to right\n        \"\"\"\n\n        def reverse_padding(ids):\n            ans = [0] * len(ids)\n            idx = 0\n            for m_id in ids:\n                m_id = m_id.item()\n                if m_id != 0:\n                    ans[idx] = m_id\n                    idx += 1\n            return ans\n\n        input_len = [torch.sum((ids != 0).long()).item() for ids in input_ids]\n        input_ids = [reverse_padding(ids) for ids in input_ids]\n        input_ids = torch.tensor(input_ids, dtype=torch.long)\n        input_mask = (input_ids != 0).long()\n\n        return input_ids.to(self.device), input_len, input_mask.to(self.device)\n\n    def cross_entropy(self, seq_out, pos_ids, neg_ids, input_mask):\n        # [batch seq_len hidden_size]\n        pos_emb = self.item_embeddings(pos_ids)\n        neg_emb = self.item_embeddings(neg_ids)\n\n        # [batch*seq_len hidden_size]\n        pos = pos_emb.view(-1, pos_emb.size(2))\n        neg = neg_emb.view(-1, neg_emb.size(2))\n\n        # [batch*seq_len hidden_size]\n        seq_emb = seq_out.contiguous().view(-1, self.hidden_size)\n\n        # [batch*seq_len]\n        pos_logits = torch.sum(pos * seq_emb, -1)\n        neg_logits = torch.sum(neg * seq_emb, -1)\n\n        # [batch*seq_len]\n        istarget = (pos_ids > 0).view(pos_ids.size(0) * pos_ids.size(1)).float()\n        loss = torch.sum(\n            - torch.log(torch.sigmoid(pos_logits) + 1e-24) * istarget -\n            torch.log(1 - torch.sigmoid(neg_logits) + 1e-24) * istarget\n        ) / torch.sum(istarget)\n\n        return loss\n\n    def forward(self, batch, mode):\n        \"\"\"\n        Args:\n            input_ids: padding in left, [pad, pad, id1, id2, ..., idn]\n            target_ids: padding in left, [pad, pad, id2, id3, ..., y]\n        \"\"\"\n        context, mask, input_ids, target_pos, input_mask, sample_negs, y = batch\n\n        input_ids, input_len, input_mask = self.reconstruct_input(input_ids)\n        target_pos, _, _ = self.reconstruct_input(target_pos)\n        sample_negs, _, _ = self.reconstruct_input(sample_negs)\n        embedded = self.item_embeddings(input_ids)  # (batch, seq_len, hidden_size)\n        input_len = [len_ if len_ > 0 else 1 for len_ in input_len]\n        embedded = pack_padded_sequence(\n            embedded, input_len, enforce_sorted=False,\n            batch_first=True)  # (num_layers , batch, hidden_size)\n\n        output, hidden = self.gru(embedded)\n        output, output_len = pad_packed_sequence(output, batch_first=True)\n\n        batch, seq_len, hidden_size = output.size()\n        logit = output.view(batch, seq_len, hidden_size)\n\n        last_logit = logit[:, -1, :]\n        rec_scores = torch.matmul(last_logit, self.item_embeddings.weight.data.T)\n        rec_scores = rec_scores.squeeze(1)\n\n        max_out_len = max([len_ for len_ in output_len])\n        rec_loss = self.cross_entropy(logit, target_pos[:, :max_out_len],\n                                      sample_negs[:, :max_out_len], input_mask)\n\n        return rec_loss, rec_scores\n"
  },
  {
    "path": "crslab/model/recommendation/gru4rec/modules.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass Embedding(nn.Module):\n    def __init__(self, item_size, embedding_dim):\n        super(Embedding, self).__init__()\n        self.embedding = nn.Embedding(item_size, embedding_dim)\n\n    def forward(self, input: torch.Tensor):\n        return self.embedding(input)\n\n\nclass GRU4REC(nn.Module):\n    def __init__(self, item_size, embedding_dim, hidden_size, num_layers, dropout_hidden):\n        super(GRU4REC, self).__init__()\n        self.module_dict = nn.ModuleDict({\n            'gru': nn.GRU(embedding_dim,\n                          hidden_size,\n                          num_layers,\n                          dropout=dropout_hidden,\n                          batch_first=True),\n            'item_embeddings': Embedding(item_size, embedding_dim),\n        })\n        # self.param = nn.ParameterDict({\n        #     'hidden_size': hidden_size\n        # })\n        self.hidden_size = hidden_size\n        # self.item_embeddings = Embedding(item_size, embedding_dim)\n        # self.gru = nn.GRU(embedding_dim,\n        #                   hidden_size,\n        #                   num_layers,\n        #                   dropout=dropout_hidden,\n        #                   batch_first=True)\n        # self.rec_loss = self.cross_entropy\n\n    def cross_entropy(self, seq_out, pos_ids, neg_ids):\n        # [batch seq_len hidden_size]\n        pos_emb = self.module_dict['item_embeddings'](pos_ids)\n        neg_emb = self.module_dict['item_embeddings'](neg_ids)\n\n        # [batch*seq_len hidden_size]\n        pos = pos_emb.view(-1, pos_emb.size(2))\n        neg = neg_emb.view(-1, neg_emb.size(2))\n\n        # [batch*seq_len hidden_size]\n        seq_emb = seq_out.contiguous().view(-1, self.hidden_size)\n\n        # [batch*seq_len]\n        pos_logits = torch.sum(pos * seq_emb, -1)\n        neg_logits = torch.sum(neg * seq_emb, -1)\n\n        # [batch*seq_len]\n        istarget = (pos_ids > 0).view(pos_ids.size(0) * pos_ids.size(1)).float()\n        loss = torch.sum(\n            - torch.log(torch.sigmoid(pos_logits) + 1e-24) * istarget -\n            torch.log(1 - torch.sigmoid(neg_logits) + 1e-24) * istarget\n        ) / torch.sum(istarget)\n\n        return loss\n\n    def forward(self, input: torch.Tensor):\n        return self.module_dict['gru'](input)\n"
  },
  {
    "path": "crslab/model/recommendation/popularity/__init__.py",
    "content": "from .popularity import PopularityModel\n"
  },
  {
    "path": "crslab/model/recommendation/popularity/popularity.py",
    "content": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nPopularity\n==========\n\"\"\"\n\nfrom collections import defaultdict\n\nimport torch\nfrom loguru import logger\n\nfrom crslab.model.base import BaseModel\n\n\nclass PopularityModel(BaseModel):\n    \"\"\"\n\n    Attributes:\n        item_size: A integer indicating the number of items.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.item_size = vocab['n_entity']\n        super(PopularityModel, self).__init__(opt, device)\n\n    def build_model(self):\n        self.item_frequency = defaultdict(int)\n        logger.debug('[Finish build rec layer]')\n\n    def forward(self, batch, mode):\n        context, mask, input_ids, target_pos, input_mask, sample_negs, y = batch\n        if mode == 'train':\n            for ids in input_ids:\n                for id in ids:\n                    self.item_frequency[id.item()] += 1\n\n        bs = input_ids.shape[0]\n        rec_score = [self.item_frequency.get(item_id, 0) for item_id in range(self.item_size)]\n        rec_scores = torch.tensor([rec_score] * bs, dtype=torch.long)\n        loss = torch.zeros(1, requires_grad=True)\n        return loss, rec_scores\n"
  },
  {
    "path": "crslab/model/recommendation/sasrec/__init__.py",
    "content": "from .sasrec import SASRECModel\n"
  },
  {
    "path": "crslab/model/recommendation/sasrec/modules.py",
    "content": "# @Time   : 2020/12/13\n# @Author : Kun Zhou\n# @Email  : wxl1999@foxmail.com\n\n# UPDATE\n# @Time   : 2020/12/13, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nimport copy\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SASRec(nn.Module):\n    def __init__(self, hidden_dropout_prob, device, initializer_range,\n                 hidden_size, max_seq_length, item_size, num_attention_heads,\n                 attention_probs_dropout_prob, hidden_act, num_hidden_layers):\n        super(SASRec, self).__init__()\n        self.hidden_dropout_prob = hidden_dropout_prob\n        self.device = device\n        self.initializer_range = initializer_range\n        self.hidden_size = hidden_size\n        self.max_seq_length = max_seq_length\n        self.item_size = item_size\n        self.num_attention_heads = num_attention_heads\n        self.attention_probs_dropout_prob = attention_probs_dropout_prob\n        self.hidden_act = hidden_act\n        self.num_hidden_layers = num_hidden_layers\n\n        self.build_model()\n        self.init_model()\n\n    def build_model(self):\n        self.embeddings = Embeddings(self.item_size, self.hidden_size,\n                                     self.max_seq_length,\n                                     self.hidden_dropout_prob)\n        self.encoder = Encoder(self.num_hidden_layers, self.hidden_size,\n                               self.num_attention_heads,\n                               self.hidden_dropout_prob, self.hidden_act,\n                               self.attention_probs_dropout_prob)\n\n        self.act = nn.Tanh()\n        self.dropout = nn.Dropout(p=self.hidden_dropout_prob)\n\n    def init_model(self):\n        self.apply(self.init_sas_weights)\n\n    def forward(self,\n                input_ids,\n                attention_mask=None,\n                output_all_encoded_layers=True):\n        if attention_mask is None:\n            attention_mask = torch.ones_like(input_ids)  # (bs, seq_len)\n        # We create a 3D attention mask from a 2D tensor mask.\n        # Sizes are [batch_size, 1, 1, to_seq_length]\n        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]\n        # this attention mask is more simple than the triangular masking of causal attention\n        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.\n        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(\n            2)  # torch.int64, (bs, 1, 1, seq_len)\n        # 添加mask 只关注前几个物品进行推荐\n        max_len = attention_mask.size(-1)\n        attn_shape = (1, max_len, max_len)\n        subsequent_mask = torch.triu(torch.ones(attn_shape),\n                                     diagonal=1)  # torch.uint8\n        subsequent_mask = (subsequent_mask == 0).unsqueeze(1)\n        subsequent_mask = subsequent_mask.long().to(self.device)\n        extended_attention_mask = extended_attention_mask * subsequent_mask\n\n        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n        # masked positions, this operation will create a tensor which is 0.0 for\n        # positions we want to attend and -10000.0 for masked positions.\n        # Since we are adding it to the raw scores before the softmax, this is\n        # effectively the same as removing these entirely.\n        # extended_attention_mask = extended_attention_mask.to(\n        #   dtype=next(self.parameters()).dtype)  # fp16 compatibility\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        embedding = self.embeddings(input_ids)\n\n        encoded_layers = self.encoder(\n            embedding,\n            extended_attention_mask,\n            output_all_encoded_layers=output_all_encoded_layers)\n        # [B L H]\n        sequence_output = encoded_layers[-1]\n        return sequence_output\n\n    def init_sas_weights(self, module):\n        \"\"\" Initialize the weights.\n        \"\"\"\n        if isinstance(module, (nn.Linear, nn.Embedding)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.initializer_range)\n        elif isinstance(module, LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        if isinstance(module, nn.Linear) and module.bias is not None:\n            module.bias.data.zero_()\n\n    def save_model(self, file_name):\n        torch.save(self.cpu().state_dict(), file_name)\n        self.to(self.device)\n\n    def load_model(self, path):\n        load_states = torch.load(path, map_location=self.device)\n        load_states_keys = set(load_states.keys())\n        this_states_keys = set(self.state_dict().keys())\n        assert this_states_keys.issubset(this_states_keys)\n        key_not_used = load_states_keys - this_states_keys\n        for key in key_not_used:\n            del load_states[key]\n\n        self.load_state_dict(load_states)\n\n    def compute_loss(self, y_pred, y, subset='test'):\n        pass\n\n    def cross_entropy(self, seq_out, pos_ids, neg_ids):\n        # [batch seq_len hidden_size]\n        pos_emb = self.embeddings.item_embeddings(pos_ids)\n        neg_emb = self.embeddings.item_embeddings(neg_ids)\n\n        # [batch*seq_len hidden_size]\n        pos = pos_emb.view(-1, pos_emb.size(2))\n        neg = neg_emb.view(-1, neg_emb.size(2))\n\n        # [batch*seq_len hidden_size]\n        seq_emb = seq_out.view(-1, self.hidden_size)\n\n        # [batch*seq_len]\n        pos_logits = torch.sum(pos * seq_emb, -1)\n        neg_logits = torch.sum(neg * seq_emb, -1)\n\n        # [batch*seq_len]\n        istarget = (pos_ids > 0).view(-1).float()\n        loss = torch.sum(-torch.log(torch.sigmoid(pos_logits) + 1e-24) *\n                         istarget -\n                         torch.log(1 - torch.sigmoid(neg_logits) + 1e-24) *\n                         istarget) / torch.sum(istarget)\n\n        return loss\n\n\ndef gelu(x):\n    \"\"\"Implementation of the gelu activation function.\n\n    For information: OpenAI GPT's gelu is slightly different\n    (and gives slightly different results):\n    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) *\n    (x + 0.044715 * torch.pow(x, 3))))\n    Also see https://arxiv.org/abs/1606.08415\n\n    \"\"\"\n    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))\n\n\ndef swish(x):\n    return x * torch.sigmoid(x)\n\n\nACT2FN = {\"gelu\": gelu, \"relu\": F.relu, \"swish\": swish}\n\n\nclass LayerNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-12):\n        \"\"\"Construct a layernorm module in the TF style (epsilon inside the square root).\"\"\"\n        super(LayerNorm, self).__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size), requires_grad=True)\n        self.bias = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)\n        self.variance_epsilon = eps\n\n    def forward(self, x):\n        u = x.mean(-1, keepdim=True)\n        s = (x - u).pow(2).mean(-1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.variance_epsilon)\n        return self.weight * x + self.bias\n\n\nclass Embeddings(nn.Module):\n    \"\"\"Construct the embeddings from item, position, attribute.\"\"\"\n\n    def __init__(self, item_size, hidden_size, max_seq_length,\n                 hidden_dropout_prob):\n        super(Embeddings, self).__init__()\n\n        self.item_embeddings = nn.Embedding(item_size, hidden_size)\n        self.position_embeddings = nn.Embedding(max_seq_length, hidden_size)\n\n        self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)\n        self.dropout = nn.Dropout(hidden_dropout_prob)\n\n    def forward(self, input_ids):\n        seq_length = input_ids.size(1)\n\n        position_ids = torch.arange(seq_length,\n                                    dtype=torch.long,\n                                    device=input_ids.device)\n        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n        items_embeddings = self.item_embeddings(input_ids)\n        position_embeddings = self.position_embeddings(position_ids)\n\n        embeddings = items_embeddings + position_embeddings\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n\n        return embeddings\n\n\nclass SelfAttention(nn.Module):\n    def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob,\n                 attention_probs_dropout_prob):\n        super(SelfAttention, self).__init__()\n        if hidden_size % num_attention_heads != 0:\n            raise ValueError(\n                \"The hidden size (%d) is not a multiple of the number of attention \"\n                \"heads (%d)\" % (hidden_size, num_attention_heads))\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_size = int(hidden_size / num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(hidden_size, self.all_head_size)\n        self.key = nn.Linear(hidden_size, self.all_head_size)\n        self.value = nn.Linear(hidden_size, self.all_head_size)\n\n        self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)\n\n        self.dense = nn.Linear(hidden_size, hidden_size)\n        self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)\n        self.out_dropout = nn.Dropout(hidden_dropout_prob)\n\n    def transpose_for_scores(self, x):\n        \"\"\"\n        Args:\n            x: (bs, seq_len, all_head_size)\n\n        Returns:\n            x.permute(0, 2, 1, 3), (bs, num_heads, seq_len, head_size)\n\n        \"\"\"\n        new_x_shape = x.size()[:-1] + (self.num_attention_heads,\n                                       self.attention_head_size)\n        x = x.view(*new_x_shape)\n\n        return x.permute(0, 2, 1, 3)\n\n    def forward(self, input_tensor, attention_mask):\n        mixed_query_layer = self.query(input_tensor)\n        mixed_key_layer = self.key(input_tensor)\n        mixed_value_layer = self.value(input_tensor)\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n        key_layer = self.transpose_for_scores(mixed_key_layer)\n        value_layer = self.transpose_for_scores(mixed_value_layer)\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(\n            -1, -2))  # (bs, num_heads, seq_len, seq_len)\n\n        attention_scores = attention_scores / math.sqrt(\n            self.attention_head_size)\n        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n\n        # [batch_size heads seq_len seq_len] scores\n        # [batch_size 1 1 seq_len]\n        attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs = self.attn_dropout(attention_probs)\n\n        context_layer = torch.matmul(attention_probs, value_layer)\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (\n            self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        hidden_states = self.dense(context_layer)\n        hidden_states = self.out_dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n\n        return hidden_states\n\n\nclass Intermediate(nn.Module):\n    def __init__(self, hidden_size, hidden_act, hidden_dropout_prob):\n        super(Intermediate, self).__init__()\n        self.dense_1 = nn.Linear(hidden_size, hidden_size * 4)\n        if isinstance(hidden_act, str):\n            self.intermediate_act_fn = ACT2FN[hidden_act]\n        else:\n            self.intermediate_act_fn = hidden_act\n\n        self.dense_2 = nn.Linear(hidden_size * 4, hidden_size)\n        self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)\n        self.dropout = nn.Dropout(hidden_dropout_prob)\n\n    def forward(self, input_tensor):\n\n        hidden_states = self.dense_1(input_tensor)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n\n        hidden_states = self.dense_2(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass Layer(nn.Module):\n    def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob,\n                 hidden_act, attention_probs_dropout_prob):\n        super(Layer, self).__init__()\n        self.attention = SelfAttention(hidden_size, num_attention_heads,\n                                       hidden_dropout_prob,\n                                       attention_probs_dropout_prob)\n        self.intermediate = Intermediate(hidden_size, hidden_act, hidden_dropout_prob)\n\n    def forward(self, hidden_states, attention_mask):\n        attention_output = self.attention(hidden_states, attention_mask)\n        intermediate_output = self.intermediate(attention_output)\n        return intermediate_output\n\n\nclass Encoder(nn.Module):\n    def __init__(self, num_hidden_layers, hidden_size, num_attention_heads,\n                 hidden_dropout_prob, hidden_act,\n                 attention_probs_dropout_prob):\n        super(Encoder, self).__init__()\n        layer = Layer(hidden_size, num_attention_heads, hidden_dropout_prob,\n                      hidden_act, attention_probs_dropout_prob)\n        self.layer = nn.ModuleList(\n            [copy.deepcopy(layer) for _ in range(num_hidden_layers)])\n\n    def forward(self,\n                hidden_states,\n                attention_mask,\n                output_all_encoded_layers=True):\n        all_encoder_layers = []\n        for layer_module in self.layer:\n            hidden_states = layer_module(hidden_states, attention_mask)\n            if output_all_encoded_layers:\n                all_encoder_layers.append(hidden_states)\n        if not output_all_encoded_layers:\n            all_encoder_layers.append(hidden_states)\n        return all_encoder_layers\n"
  },
  {
    "path": "crslab/model/recommendation/sasrec/sasrec.py",
    "content": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nSASREC\n======\nReferences:\n    Kang, Wang-Cheng, and Julian McAuley. `\"Self-attentive sequential recommendation.\"`_ in ICDM 2018.\n\n.. _`\"Self-attentive sequential recommendation.\"`:\n   https://ieeexplore.ieee.org/abstract/document/8594844\n\n\"\"\"\n\nimport torch\nfrom loguru import logger\nfrom torch import nn\n\nfrom crslab.model.base import BaseModel\nfrom crslab.model.recommendation.sasrec.modules import SASRec\n\n\nclass SASRECModel(BaseModel):\n    \"\"\"\n        \n    Attributes:\n        hidden_dropout_prob: A float indicating the dropout rate to dropout hidden state in SASRec.\n        initializer_range: A float indicating the range of parameters initiation in SASRec.\n        hidden_size: A integer indicating the size of hidden state in SASRec.\n        max_seq_length: A integer indicating the max interaction history length.\n        item_size: A integer indicating the number of items.\n        num_attention_heads: A integer indicating the head number in SASRec.\n        attention_probs_dropout_prob: A float indicating the dropout rate in attention layers.\n        hidden_act: A string indicating the activation function type in SASRec.\n        num_hidden_layers: A integer indicating the number of hidden layers in SASRec.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.hidden_dropout_prob = opt['hidden_dropout_prob']\n        self.initializer_range = opt['initializer_range']\n        self.hidden_size = opt['hidden_size']\n        self.max_seq_length = opt['max_history_items']\n        self.item_size = vocab['n_entity'] + 1\n        self.num_attention_heads = opt['num_attention_heads']\n        self.attention_probs_dropout_prob = opt['attention_probs_dropout_prob']\n        self.hidden_act = opt['hidden_act']\n        self.num_hidden_layers = opt['num_hidden_layers']\n\n        super(SASRECModel, self).__init__(opt, device)\n\n    def build_model(self):\n        # build BERT layer, give the architecture, load pretrained parameters\n        self.SASREC = SASRec(self.hidden_dropout_prob, self.device,\n                             self.initializer_range, self.hidden_size,\n                             self.max_seq_length, self.item_size,\n                             self.num_attention_heads,\n                             self.attention_probs_dropout_prob,\n                             self.hidden_act, self.num_hidden_layers)\n\n        # this loss may conduct to some weakness\n        self.rec_loss = nn.CrossEntropyLoss()\n\n        logger.debug('[Finish build rec layer]')\n\n    def forward(self, batch, mode):\n        context, mask, input_ids, target_pos, input_mask, sample_negs, y = batch\n        # print(input_ids.shape)\n        sequence_output = self.SASREC(input_ids, input_mask)  # bs, max_len, hidden_size2\n\n        logit = sequence_output[:, -1:, :]\n        rec_scores = torch.matmul(logit, self.SASREC.embeddings.item_embeddings.weight.data.T)\n        rec_scores = rec_scores.squeeze(1)\n        # print('rec_scores.shape', rec_scores.shape)\n\n        rec_loss = self.SASREC.cross_entropy(sequence_output, target_pos,\n                                             sample_negs)\n\n        return rec_loss, rec_scores\n"
  },
  {
    "path": "crslab/model/recommendation/textcnn/__init__.py",
    "content": "from .textcnn import TextCNNModel\n"
  },
  {
    "path": "crslab/model/recommendation/textcnn/textcnn.py",
    "content": "# @Time   : 2020/12/16\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE\n# @Time   : 2020/12/29, 2021/1/4\n# @Author : Xiaolei Wang, Yuanhang Zhou\n# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com\n\nr\"\"\"\nTextCNN\n=======\nReferences:\n    Kim, Yoon. `\"Convolutional Neural Networks for Sentence Classification.\"`_ in EMNLP 2014.\n\n.. _`\"Convolutional Neural Networks for Sentence Classification.\"`:\n   https://www.aclweb.org/anthology/D14-1181/\n\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger\nfrom torch import nn\n\nfrom crslab.model.base import BaseModel\n\n\nclass TextCNNModel(BaseModel):\n    \"\"\"\n        \n    Attributes:\n        movie_num: A integer indicating the number of items.\n        num_filters: A string indicating the number of filter in CNN.\n        embed: A integer indicating the size of embedding layer.\n        filter_sizes: A string indicating the size of filter in CNN.\n        dropout: A float indicating the dropout rate.\n\n    \"\"\"\n\n    def __init__(self, opt, device, vocab, side_data):\n        \"\"\"\n\n        Args:\n            opt (dict): A dictionary record the hyper parameters.\n            device (torch.device): A variable indicating which device to place the data and model.\n            vocab (dict): A dictionary record the vocabulary information.\n            side_data (dict): A dictionary record the side data.\n\n        \"\"\"\n        self.movie_num = vocab['n_entity']\n        self.num_filters = opt['num_filters']\n        self.embed = opt['embed']\n        self.filter_sizes = eval(opt['filter_sizes'])\n        self.dropout = opt['dropout']\n        super(TextCNNModel, self).__init__(opt, device)\n\n    def conv_and_pool(self, x, conv):\n        x = F.relu(conv(x)).squeeze(3)\n        x = F.max_pool1d(x, x.size(2)).squeeze(2)\n        return x\n\n    def build_model(self):\n        self.embedding = nn.Embedding(self.movie_num, self.embed)\n\n        self.convs = nn.ModuleList(\n            [nn.Conv2d(1, self.num_filters, (k, self.embed)) for k in self.filter_sizes])\n        self.dropout = nn.Dropout(self.dropout)\n        self.fc = nn.Linear(self.num_filters * len(self.filter_sizes), self.movie_num)\n\n        # this loss may conduct to some weakness\n        self.rec_loss = nn.CrossEntropyLoss()\n\n        logger.debug('[Finish build rec layer]')\n\n    def forward(self, batch, mode):\n        context, mask, input_ids, target_pos, input_mask, sample_negs, y = batch\n\n        out = self.embedding(context)\n        out = out.unsqueeze(1)\n        out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)\n        out = self.dropout(out)\n        out = self.fc(out)\n\n        rec_scores = out\n        rec_loss = self.rec_loss(out, y)\n\n        return rec_loss, rec_scores\n"
  },
  {
    "path": "crslab/model/utils/__init__.py",
    "content": ""
  },
  {
    "path": "crslab/model/utils/functions.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/11/26\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2020/11/16\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nimport torch\n\n\ndef edge_to_pyg_format(edge, type='RGCN'):\n    if type == 'RGCN':\n        edge_sets = torch.as_tensor(edge, dtype=torch.long)\n        edge_idx = edge_sets[:, :2].t()\n        edge_type = edge_sets[:, 2]\n        return edge_idx, edge_type\n    elif type == 'GCN':\n        edge_set = [[co[0] for co in edge], [co[1] for co in edge]]\n        return torch.as_tensor(edge_set, dtype=torch.long)\n    else:\n        raise NotImplementedError('type {} has not been implemented', type)\n\n\ndef sort_for_packed_sequence(lengths: torch.Tensor):\n    \"\"\"\n    :param lengths: 1D array of lengths\n    :return: sorted_lengths (lengths in descending order), sorted_idx (indices to sort), rev_idx (indices to retrieve original order)\n\n    \"\"\"\n    sorted_idx = torch.argsort(lengths, descending=True)  # idx to sort by length\n    rev_idx = torch.argsort(sorted_idx)  # idx to retrieve original order\n    sorted_lengths = lengths[sorted_idx]\n\n    return sorted_lengths, sorted_idx, rev_idx\n"
  },
  {
    "path": "crslab/model/utils/modules/__init__.py",
    "content": ""
  },
  {
    "path": "crslab/model/utils/modules/attention.py",
    "content": "# -*- coding: utf-8 -*-\n# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SelfAttentionBatch(nn.Module):\n    def __init__(self, dim, da, alpha=0.2, dropout=0.5):\n        super(SelfAttentionBatch, self).__init__()\n        self.dim = dim\n        self.da = da\n        self.alpha = alpha\n        self.dropout = dropout\n        self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)\n        self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)\n        nn.init.xavier_uniform_(self.a.data, gain=1.414)\n        nn.init.xavier_uniform_(self.b.data, gain=1.414)\n\n    def forward(self, h):\n        # h: (N, dim)\n        e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1)\n        attention = F.softmax(e, dim=0)  # (N)\n        return torch.matmul(attention, h)  # (dim)\n\n\nclass SelfAttentionSeq(nn.Module):\n    def __init__(self, dim, da, alpha=0.2, dropout=0.5):\n        super(SelfAttentionSeq, self).__init__()\n        self.dim = dim\n        self.da = da\n        self.alpha = alpha\n        self.dropout = dropout\n        self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)\n        self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)\n        nn.init.xavier_uniform_(self.a.data, gain=1.414)\n        nn.init.xavier_uniform_(self.b.data, gain=1.414)\n\n    def forward(self, h, mask=None, return_logits=False):\n        \"\"\"\n        For the padding tokens, its corresponding mask is True\n        if mask==[1, 1, 1, ...]\n        \"\"\"\n        # h: (batch, seq_len, dim), mask: (batch, seq_len)\n        e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b)  # (batch, seq_len, 1)\n        if mask is not None:\n            full_mask = -1e30 * mask.float()\n            batch_mask = torch.sum((mask == False), -1).bool().float().unsqueeze(-1)  # for all padding one, the mask=0\n            mask = full_mask * batch_mask\n            e += mask.unsqueeze(-1)\n        attention = F.softmax(e, dim=1)  # (batch, seq_len, 1)\n        # (batch, dim)\n        if return_logits:\n            return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1), attention.squeeze(-1)\n        else:\n            return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1)\n"
  },
  {
    "path": "crslab/model/utils/modules/transformer.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\"\"\"Near infinity, useful as a large penalty for scoring when inf is bad.\"\"\"\nNEAR_INF = 1e20\nNEAR_INF_FP16 = 65504\n\n\ndef neginf(dtype):\n    \"\"\"Returns a representable finite number near -inf for a dtype.\"\"\"\n    if dtype is torch.float16:\n        return -NEAR_INF_FP16\n    else:\n        return -NEAR_INF\n\n\ndef _create_selfattn_mask(x):\n    # figure out how many timestamps we need\n    bsz = x.size(0)\n    time = x.size(1)\n    # make sure that we don't look into the future\n    mask = torch.tril(x.new(time, time).fill_(1))\n    # broadcast across batch\n    mask = mask.unsqueeze(0).expand(bsz, -1, -1)\n    return mask\n\n\ndef create_position_codes(n_pos, dim, out):\n    position_enc = np.array([\n        [pos / np.power(10000, 2 * j / dim) for j in range(dim // 2)]\n        for pos in range(n_pos)\n    ])\n\n    out.data[:, 0::2] = torch.as_tensor(np.sin(position_enc))\n    out.data[:, 1::2] = torch.as_tensor(np.cos(position_enc))\n    out.detach_()\n    out.requires_grad = False\n\n\ndef _normalize(tensor, norm_layer):\n    \"\"\"Broadcast layer norm\"\"\"\n    size = tensor.size()\n    return norm_layer(tensor.view(-1, size[-1])).view(size)\n\n\nclass MultiHeadAttention(nn.Module):\n    def __init__(self, n_heads, dim, dropout=.0):\n        super(MultiHeadAttention, self).__init__()\n        self.n_heads = n_heads\n        self.dim = dim\n\n        self.attn_dropout = nn.Dropout(p=dropout)  # --attention-dropout\n        self.q_lin = nn.Linear(dim, dim)\n        self.k_lin = nn.Linear(dim, dim)\n        self.v_lin = nn.Linear(dim, dim)\n        # TODO: merge for the initialization step\n        nn.init.xavier_normal_(self.q_lin.weight)\n        nn.init.xavier_normal_(self.k_lin.weight)\n        nn.init.xavier_normal_(self.v_lin.weight)\n        # and set biases to 0\n        self.out_lin = nn.Linear(dim, dim)\n\n        nn.init.xavier_normal_(self.out_lin.weight)\n\n    def forward(self, query, key=None, value=None, mask=None):\n        # Input is [B, query_len, dim]\n        # Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn)\n        batch_size, query_len, dim = query.size()\n        assert dim == self.dim, \\\n            f'Dimensions do not match: {dim} query vs {self.dim} configured'\n        assert mask is not None, 'Mask is None, please specify a mask'\n        n_heads = self.n_heads\n        dim_per_head = dim // n_heads\n        scale = math.sqrt(dim_per_head)\n\n        def prepare_head(tensor):\n            # input is [batch_size, seq_len, n_heads * dim_per_head]\n            # output is [batch_size * n_heads, seq_len, dim_per_head]\n            bsz, seq_len, _ = tensor.size()\n            tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head)\n            tensor = tensor.transpose(1, 2).contiguous().view(\n                batch_size * n_heads,\n                seq_len,\n                dim_per_head\n            )\n            return tensor\n\n        # q, k, v are the transformed values\n        if key is None and value is None:\n            # self attention\n            key = value = query\n        elif value is None:\n            # key and value are the same, but query differs\n            # self attention\n            value = key\n        _, key_len, dim = key.size()\n\n        q = prepare_head(self.q_lin(query))\n        k = prepare_head(self.k_lin(key))\n        v = prepare_head(self.v_lin(value))\n\n        dot_prod = q.div_(scale).bmm(k.transpose(1, 2))\n        # [B * n_heads, query_len, key_len]\n        attn_mask = (\n            (mask == 0)\n                .view(batch_size, 1, -1, key_len)\n                .repeat(1, n_heads, 1, 1)\n                .expand(batch_size, n_heads, query_len, key_len)\n                .view(batch_size * n_heads, query_len, key_len)\n        )\n        assert attn_mask.shape == dot_prod.shape\n        dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype))\n\n        attn_weights = F.softmax(dot_prod, dim=-1).type_as(query)\n        attn_weights = self.attn_dropout(attn_weights)  # --attention-dropout\n\n        attentioned = attn_weights.bmm(v)\n        attentioned = (\n            attentioned.type_as(query)\n                .view(batch_size, n_heads, query_len, dim_per_head)\n                .transpose(1, 2).contiguous()\n                .view(batch_size, query_len, dim)\n        )\n\n        out = self.out_lin(attentioned)\n\n        return out\n\n\nclass TransformerFFN(nn.Module):\n    def __init__(self, dim, dim_hidden, relu_dropout=.0):\n        super(TransformerFFN, self).__init__()\n        self.relu_dropout = nn.Dropout(p=relu_dropout)\n        self.lin1 = nn.Linear(dim, dim_hidden)\n        self.lin2 = nn.Linear(dim_hidden, dim)\n        nn.init.xavier_uniform_(self.lin1.weight)\n        nn.init.xavier_uniform_(self.lin2.weight)\n        # TODO: initialize biases to 0\n\n    def forward(self, x):\n        x = F.relu(self.lin1(x))\n        x = self.relu_dropout(x)  # --relu-dropout\n        x = self.lin2(x)\n        return x\n\n\nclass TransformerEncoderLayer(nn.Module):\n    def __init__(\n            self,\n            n_heads,\n            embedding_size,\n            ffn_size,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = embedding_size\n        self.ffn_dim = ffn_size\n        self.attention = MultiHeadAttention(\n            n_heads, embedding_size,\n            dropout=attention_dropout,  # --attention-dropout\n        )\n        self.norm1 = nn.LayerNorm(embedding_size)\n        self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)\n        self.norm2 = nn.LayerNorm(embedding_size)\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, tensor, mask):\n        tensor = tensor + self.dropout(self.attention(tensor, mask=mask))\n        tensor = _normalize(tensor, self.norm1)\n        tensor = tensor + self.dropout(self.ffn(tensor))\n        tensor = _normalize(tensor, self.norm2)\n        tensor *= mask.unsqueeze(-1).type_as(tensor)\n        return tensor\n\n\nclass TransformerEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder module.\n\n    :param int n_heads: the number of multihead attention heads.\n    :param int n_layers: number of transformer layers.\n    :param int embedding_size: the embedding sizes. Must be a multiple of n_heads.\n    :param int ffn_size: the size of the hidden layer in the FFN\n    :param embedding: an embedding matrix for the bottom layer of the transformer.\n        If none, one is created for this encoder.\n    :param float dropout: Dropout used around embeddings and before layer\n        layer normalizations. This is used in Vaswani 2017 and works well on\n        large datasets.\n    :param float attention_dropout: Dropout performed after the multhead attention\n        softmax. This is not used in Vaswani 2017.\n    :param float relu_dropout: Dropout used after the ReLU in the FFN. Not used\n        in Vaswani 2017, but used in Tensor2Tensor.\n    :param int padding_idx: Reserved padding index in the embeddings matrix.\n    :param bool learn_positional_embeddings: If off, sinusoidal embeddings are\n        used. If on, position embeddings are learned from scratch.\n    :param bool embeddings_scale: Scale embeddings relative to their dimensionality.\n        Found useful in fairseq.\n    :param bool reduction: If true, returns the mean vector for the entire encoding\n        sequence.\n    :param int n_positions: Size of the position embeddings matrix.\n    \"\"\"\n\n    def __init__(\n            self,\n            n_heads,\n            n_layers,\n            embedding_size,\n            ffn_size,\n            vocabulary_size,\n            embedding=None,\n            dropout=0.0,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            padding_idx=0,\n            learn_positional_embeddings=False,\n            embeddings_scale=False,\n            reduction=True,\n            n_positions=1024\n    ):\n        super(TransformerEncoder, self).__init__()\n\n        self.embedding_size = embedding_size\n        self.ffn_size = ffn_size\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.dim = embedding_size\n        self.embeddings_scale = embeddings_scale\n        self.reduction = reduction\n        self.padding_idx = padding_idx\n        # this is --dropout, not --relu-dropout or --attention-dropout\n        self.dropout = nn.Dropout(dropout)\n        self.out_dim = embedding_size\n        assert embedding_size % n_heads == 0, \\\n            'Transformer embedding size must be a multiple of n_heads'\n\n        # check input formats:\n        if embedding is not None:\n            assert (\n                    embedding_size is None or embedding_size == embedding.weight.shape[1]\n            ), \"Embedding dim must match the embedding size.\"\n\n        if embedding is not None:\n            self.embeddings = embedding\n        else:\n            assert False\n            assert padding_idx is not None\n            self.embeddings = nn.Embedding(\n                vocabulary_size, embedding_size, padding_idx=padding_idx\n            )\n            nn.init.normal_(self.embeddings.weight, 0, embedding_size ** -0.5)\n\n        # create the positional embeddings\n        self.position_embeddings = nn.Embedding(n_positions, embedding_size)\n        if not learn_positional_embeddings:\n            create_position_codes(\n                n_positions, embedding_size, out=self.position_embeddings.weight\n            )\n        else:\n            nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)\n\n        # build the model\n        self.layers = nn.ModuleList()\n        for _ in range(self.n_layers):\n            self.layers.append(TransformerEncoderLayer(\n                n_heads, embedding_size, ffn_size,\n                attention_dropout=attention_dropout,\n                relu_dropout=relu_dropout,\n                dropout=dropout,\n            ))\n\n    def forward(self, input):\n        \"\"\"\n            input data is a FloatTensor of shape [batch, seq_len, dim]\n            mask is a ByteTensor of shape [batch, seq_len], filled with 1 when\n            inside the sequence and 0 outside.\n        \"\"\"\n        mask = input != self.padding_idx\n        positions = (mask.cumsum(dim=1, dtype=torch.int64) - 1).clamp_(min=0)\n        tensor = self.embeddings(input)\n        if self.embeddings_scale:\n            tensor = tensor * np.sqrt(self.dim)\n        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)\n        # --dropout on the embeddings\n        tensor = self.dropout(tensor)\n\n        tensor *= mask.unsqueeze(-1).type_as(tensor)\n        for i in range(self.n_layers):\n            tensor = self.layers[i](tensor, mask)\n\n        if self.reduction:\n            divisor = mask.type_as(tensor).sum(dim=1).unsqueeze(-1).clamp(min=1e-7)\n            output = tensor.sum(dim=1) / divisor\n            return output\n        else:\n            output = tensor\n            return output, mask\n\n\nclass TransformerDecoderLayer(nn.Module):\n    def __init__(\n            self,\n            n_heads,\n            embedding_size,\n            ffn_size,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = embedding_size\n        self.ffn_dim = ffn_size\n        self.dropout = nn.Dropout(p=dropout)\n\n        self.self_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm1 = nn.LayerNorm(embedding_size)\n\n        self.encoder_attention = MultiHeadAttention(\n            n_heads, embedding_size, dropout=attention_dropout\n        )\n        self.norm2 = nn.LayerNorm(embedding_size)\n\n        self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)\n        self.norm3 = nn.LayerNorm(embedding_size)\n\n    def forward(self, x, encoder_output, encoder_mask):\n        decoder_mask = self._create_selfattn_mask(x)\n        # first self attn\n        residual = x\n        # don't peak into the future!\n        x = self.self_attention(query=x, mask=decoder_mask)\n        x = self.dropout(x)  # --dropout\n        x = x + residual\n        x = _normalize(x, self.norm1)\n\n        residual = x\n        x = self.encoder_attention(\n            query=x,\n            key=encoder_output,\n            value=encoder_output,\n            mask=encoder_mask\n        )\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm2)\n\n        # finally the ffn\n        residual = x\n        x = self.ffn(x)\n        x = self.dropout(x)  # --dropout\n        x = residual + x\n        x = _normalize(x, self.norm3)\n\n        return x\n\n    def _create_selfattn_mask(self, x):\n        # figure out how many timestamps we need\n        bsz = x.size(0)\n        time = x.size(1)\n        # make sure that we don't look into the future\n        mask = torch.tril(x.new(time, time).fill_(1))\n        # broadcast across batch\n        mask = mask.unsqueeze(0).expand(bsz, -1, -1)\n        return mask\n\n\nclass TransformerDecoder(nn.Module):\n    \"\"\"\n    Transformer Decoder layer.\n\n    :param int n_heads: the number of multihead attention heads.\n    :param int n_layers: number of transformer layers.\n    :param int embedding_size: the embedding sizes. Must be a multiple of n_heads.\n    :param int ffn_size: the size of the hidden layer in the FFN\n    :param embedding: an embedding matrix for the bottom layer of the transformer.\n        If none, one is created for this encoder.\n    :param float dropout: Dropout used around embeddings and before layer\n        layer normalizations. This is used in Vaswani 2017 and works well on\n        large datasets.\n    :param float attention_dropout: Dropout performed after the multhead attention\n        softmax. This is not used in Vaswani 2017.\n    :param int padding_idx: Reserved padding index in the embeddings matrix.\n    :param bool learn_positional_embeddings: If off, sinusoidal embeddings are\n        used. If on, position embeddings are learned from scratch.\n    :param bool embeddings_scale: Scale embeddings relative to their dimensionality.\n        Found useful in fairseq.\n    :param int n_positions: Size of the position embeddings matrix.\n    \"\"\"\n\n    def __init__(\n            self,\n            n_heads,\n            n_layers,\n            embedding_size,\n            ffn_size,\n            vocabulary_size,\n            embedding=None,\n            dropout=0.0,\n            attention_dropout=0.0,\n            relu_dropout=0.0,\n            embeddings_scale=True,\n            learn_positional_embeddings=False,\n            padding_idx=None,\n            n_positions=1024,\n    ):\n        super().__init__()\n        self.embedding_size = embedding_size\n        self.ffn_size = ffn_size\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.dim = embedding_size\n        self.embeddings_scale = embeddings_scale\n        self.dropout = nn.Dropout(p=dropout)  # --dropout\n\n        self.out_dim = embedding_size\n        assert embedding_size % n_heads == 0, \\\n            'Transformer embedding size must be a multiple of n_heads'\n\n        self.embeddings = embedding\n\n        # create the positional embeddings\n        self.position_embeddings = nn.Embedding(n_positions, embedding_size)\n        if not learn_positional_embeddings:\n            create_position_codes(\n                n_positions, embedding_size, out=self.position_embeddings.weight\n            )\n        else:\n            nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)\n\n        # build the model\n        self.layers = nn.ModuleList()\n        for _ in range(self.n_layers):\n            self.layers.append(TransformerDecoderLayer(\n                n_heads, embedding_size, ffn_size,\n                attention_dropout=attention_dropout,\n                relu_dropout=relu_dropout,\n                dropout=dropout,\n            ))\n\n    def forward(self, input, encoder_state, incr_state=None):\n        encoder_output, encoder_mask = encoder_state\n\n        seq_len = input.shape[1]\n        positions = input.new_empty(seq_len).long()\n        positions = torch.arange(seq_len, out=positions).unsqueeze(0)  # (batch, seq_len)\n        tensor = self.embeddings(input)\n        if self.embeddings_scale:\n            tensor = tensor * np.sqrt(self.dim)\n        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)\n        tensor = self.dropout(tensor)  # --dropout\n\n        for layer in self.layers:\n            tensor = layer(tensor, encoder_output, encoder_mask)\n\n        return tensor, None\n"
  },
  {
    "path": "crslab/quick_start/__init__.py",
    "content": "from .quick_start import run_crslab\n"
  },
  {
    "path": "crslab/quick_start/quick_start.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2021/1/8\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2021/1/9\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nfrom crslab.config import Config\nfrom crslab.data import get_dataset, get_dataloader\nfrom crslab.system import get_system\n\n\ndef run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False,\n               interact=False, debug=False, tensorboard=False):\n    \"\"\"A fast running api, which includes the complete process of training and testing models on specified datasets.\n\n    Args:\n        config (Config or str): an instance of ``Config`` or path to the config file,\n            which should be in ``yaml`` format. You can use default config provided in the `Github repo`_,\n            or write it by yourself.\n        save_data (bool): whether to save data. Defaults to False.\n        restore_data (bool): whether to restore data. Defaults to False.\n        save_system (bool): whether to save system. Defaults to False.\n        restore_system (bool): whether to restore system. Defaults to False.\n        interact (bool): whether to interact with the system. Defaults to False.\n        debug (bool): whether to debug the system. Defaults to False.\n\n    .. _Github repo:\n       https://github.com/RUCAIBox/CRSLab\n\n    \"\"\"\n    # dataset & dataloader\n    if isinstance(config['tokenize'], str):\n        CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data)\n        side_data = CRS_dataset.side_data\n        vocab = CRS_dataset.vocab\n\n        train_dataloader = get_dataloader(config, CRS_dataset.train_data, vocab)\n        valid_dataloader = get_dataloader(config, CRS_dataset.valid_data, vocab)\n        test_dataloader = get_dataloader(config, CRS_dataset.test_data, vocab)\n    else:\n        tokenized_dataset = {}\n        train_dataloader = {}\n        valid_dataloader = {}\n        test_dataloader = {}\n        vocab = {}\n        side_data = {}\n\n        for task, tokenize in config['tokenize'].items():\n            if tokenize in tokenized_dataset:\n                dataset = tokenized_dataset[tokenize]\n            else:\n                dataset = get_dataset(config, tokenize, restore_data, save_data)\n                tokenized_dataset[tokenize] = dataset\n            train_data = dataset.train_data\n            valid_data = dataset.valid_data\n            test_data = dataset.test_data\n            side_data[task] = dataset.side_data\n            vocab[task] = dataset.vocab\n\n            train_dataloader[task] = get_dataloader(config, train_data, vocab[task])\n            valid_dataloader[task] = get_dataloader(config, valid_data, vocab[task])\n            test_dataloader[task] = get_dataloader(config, test_data, vocab[task])\n    # system\n    CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system,\n                     interact, debug, tensorboard)\n    if interact:\n        CRS.interact()\n    else:\n        CRS.fit()\n        if save_system:\n            CRS.save_model()\n"
  },
  {
    "path": "crslab/system/__init__.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020/12/29\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\n# @Time    :   2021/10/6\n# @Author  :   Zhipeng Zhao\n# @email   :   oran_official@outlook.com\n\n\n\nfrom loguru import logger\n\nfrom .inspired import InspiredSystem\nfrom .kbrd import KBRDSystem\nfrom .kgsf import KGSFSystem\nfrom .redial import ReDialSystem\nfrom .ntrd import NTRDSystem\nfrom .tgredial import TGReDialSystem\n\nsystem_register_table = {\n    'ReDialRec_ReDialConv': ReDialSystem,\n    'KBRD': KBRDSystem,\n    'KGSF': KGSFSystem,\n    'TGRec_TGConv': TGReDialSystem,\n    'TGRec_TGConv_TGPolicy': TGReDialSystem,\n    'InspiredRec_InspiredConv': InspiredSystem,\n    'GPT2': TGReDialSystem,\n    'Transformer': TGReDialSystem,\n    'ConvBERT': TGReDialSystem,\n    'ProfileBERT': TGReDialSystem,\n    'TopicBERT': TGReDialSystem,\n    'PMI': TGReDialSystem,\n    'MGCG': TGReDialSystem,\n    'BERT': TGReDialSystem,\n    'SASREC': TGReDialSystem,\n    'GRU4REC': TGReDialSystem,\n    'Popularity': TGReDialSystem,\n    'TextCNN': TGReDialSystem,\n    'NTRD': NTRDSystem\n}\n\n\ndef get_system(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n               interact=False, debug=False, tensorboard=False):\n    \"\"\"\n    return the system class\n    \"\"\"\n    model_name = opt['model_name']\n    if model_name in system_register_table:\n        system = system_register_table[model_name](opt, train_dataloader, valid_dataloader, test_dataloader, vocab,\n                                                   side_data, restore_system, interact, debug, tensorboard)\n        logger.info(f'[Build system {model_name}]')\n        return system\n    else:\n        raise NotImplementedError('The system with model [{}] in dataset [{}] has not been implemented'.\n                                  format(model_name, opt['dataset']))\n"
  },
  {
    "path": "crslab/system/base.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2021/1/9\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2021/11/5\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nimport os\nfrom abc import ABC, abstractmethod\nimport numpy as np\nimport random\nimport nltk\nimport torch\nfrom fuzzywuzzy.process import extractOne\nfrom loguru import logger\nfrom nltk import word_tokenize\nfrom torch import optim\nfrom transformers import AdamW, Adafactor\n\nfrom crslab.config import SAVE_PATH\nfrom crslab.evaluator import get_evaluator\nfrom crslab.evaluator.metrics.base import AverageMetric\nfrom crslab.model import get_model\nfrom crslab.system.utils import lr_scheduler\nfrom crslab.system.utils.functions import compute_grad_norm\n\noptim_class = {}\noptim_class.update({k: v for k, v in optim.__dict__.items() if not k.startswith('__') and k[0].isupper()})\noptim_class.update({'AdamW': AdamW, 'Adafactor': Adafactor})\nlr_scheduler_class = {k: v for k, v in lr_scheduler.__dict__.items() if not k.startswith('__') and k[0].isupper()}\ntransformers_tokenizer = ('bert', 'gpt2')\n\n\nclass BaseSystem(ABC):\n    \"\"\"Base class for all system\"\"\"\n\n    def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n                 interact=False, debug=False, tensorboard=False):\n        \"\"\"\n\n        Args:\n            opt (dict): Indicating the hyper parameters.\n            train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.\n            valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.\n            test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.\n            vocab (dict): Indicating the vocabulary.\n            side_data (dict): Indicating the side data.\n            restore_system (bool, optional): Indicating if we store system after training. Defaults to False.\n            interact (bool, optional): Indicating if we interact with system. Defaults to False.\n            debug (bool, optional): Indicating if we train in debug mode. Defaults to False.\n            tensorboard (bool, optional) Indicating if we monitor the training performance in tensorboard. Defaults to False. \n\n        \"\"\"\n        self.opt = opt\n        if opt[\"gpu\"] == [-1]:\n            self.device = torch.device('cpu')\n        elif len(opt[\"gpu\"]) == 1:\n            self.device = torch.device('cuda')\n        else:\n            self.device = torch.device('cuda')\n        # seed\n        if 'seed' in opt:\n            seed = int(opt['seed'])\n            random.seed(seed)\n            np.random.seed(seed)\n            torch.manual_seed(seed)\n            torch.cuda.manual_seed(seed)\n            torch.cuda.manual_seed_all(seed)\n            logger.info(f'[Set seed] {seed}')\n        # data\n        if debug:\n            self.train_dataloader = valid_dataloader\n            self.valid_dataloader = valid_dataloader\n            self.test_dataloader = test_dataloader\n        else:\n            self.train_dataloader = train_dataloader\n            self.valid_dataloader = valid_dataloader\n            self.test_dataloader = test_dataloader\n        self.vocab = vocab\n        self.side_data = side_data\n        # model\n        if 'model' in opt:\n            self.model = get_model(opt, opt['model'], self.device, vocab, side_data).to(self.device)\n        else:\n            if 'rec_model' in opt:\n                self.rec_model = get_model(opt, opt['rec_model'], self.device, vocab['rec'], side_data['rec']).to(\n                    self.device)\n            if 'conv_model' in opt:\n                self.conv_model = get_model(opt, opt['conv_model'], self.device, vocab['conv'], side_data['conv']).to(\n                    self.device)\n            if 'policy_model' in opt:\n                self.policy_model = get_model(opt, opt['policy_model'], self.device, vocab['policy'],\n                                              side_data['policy']).to(self.device)\n        model_file_name = opt.get('model_file', f'{opt[\"model_name\"]}.pth')\n        self.model_file = os.path.join(SAVE_PATH, model_file_name)\n        if restore_system:\n            self.restore_model()\n\n        if not interact:\n            self.evaluator = get_evaluator(opt.get('evaluator', 'standard'), opt['dataset'], tensorboard)\n\n    def init_optim(self, opt, parameters):\n        self.optim_opt = opt\n        parameters = list(parameters)\n        if isinstance(parameters[0], dict):\n            for i, d in enumerate(parameters):\n                parameters[i]['params'] = list(d['params'])\n\n        # gradient acumulation\n        self.update_freq = opt.get('update_freq', 1)\n        self._number_grad_accum = 0\n\n        self.gradient_clip = opt.get('gradient_clip', -1)\n\n        self.build_optimizer(parameters)\n        self.build_lr_scheduler()\n\n        if isinstance(parameters[0], dict):\n            self.parameters = []\n            for d in parameters:\n                self.parameters.extend(d['params'])\n        else:\n            self.parameters = parameters\n\n        # early stop\n        self.need_early_stop = self.optim_opt.get('early_stop', False)\n        if self.need_early_stop:\n            logger.debug('[Enable early stop]')\n            self.reset_early_stop_state()\n\n    def build_optimizer(self, parameters):\n        optimizer_opt = self.optim_opt['optimizer']\n        optimizer = optimizer_opt.pop('name')\n        self.optimizer = optim_class[optimizer](parameters, **optimizer_opt)\n        logger.info(f\"[Build optimizer: {optimizer}]\")\n\n    def build_lr_scheduler(self):\n        \"\"\"\n        Create the learning rate scheduler, and assign it to self.scheduler. This\n        scheduler will be updated upon a call to receive_metrics. May also create\n        self.warmup_scheduler, if appropriate.\n\n        :param state_dict states: Possible state_dict provided by model\n            checkpoint, for restoring LR state\n        :param bool hard_reset: If true, the LR scheduler should ignore the\n            state dictionary.\n        \"\"\"\n        if self.optim_opt.get('lr_scheduler', None):\n            lr_scheduler_opt = self.optim_opt['lr_scheduler']\n            lr_scheduler = lr_scheduler_opt.pop('name')\n            self.scheduler = lr_scheduler_class[lr_scheduler](self.optimizer, **lr_scheduler_opt)\n            logger.info(f\"[Build scheduler {lr_scheduler}]\")\n\n    def reset_early_stop_state(self):\n        self.best_valid = None\n        self.drop_cnt = 0\n        self.impatience = self.optim_opt.get('impatience', 3)\n        if self.optim_opt['stop_mode'] == 'max':\n            self.stop_mode = 1\n        elif self.optim_opt['stop_mode'] == 'min':\n            self.stop_mode = -1\n        else:\n            raise\n        logger.debug('[Reset early stop state]')\n\n    @abstractmethod\n    def fit(self):\n        \"\"\"fit the whole system\"\"\"\n        pass\n\n    @abstractmethod\n    def step(self, batch, stage, mode):\n        \"\"\"calculate loss and prediction for batch data under certrain stage and mode\n\n        Args:\n            batch (dict or tuple): batch data\n            stage (str): recommendation/policy/conversation etc.\n            mode (str): train/valid/test\n        \"\"\"\n        pass\n\n    def backward(self, loss):\n        \"\"\"empty grad, backward loss and update params\n\n        Args:\n            loss (torch.Tensor):\n        \"\"\"\n        self._zero_grad()\n\n        if self.update_freq > 1:\n            self._number_grad_accum = (self._number_grad_accum + 1) % self.update_freq\n            loss /= self.update_freq\n        loss.backward(loss.clone().detach())\n\n        self._update_params()\n\n    def _zero_grad(self):\n        if self._number_grad_accum != 0:\n            # if we're accumulating gradients, don't actually zero things out yet.\n            return\n        self.optimizer.zero_grad()\n\n    def _update_params(self):\n        if self.update_freq > 1:\n            # we're doing gradient accumulation, so we don't only want to step\n            # every N updates instead\n            # self._number_grad_accum is updated in backward function\n            if self._number_grad_accum != 0:\n                return\n\n        if self.gradient_clip > 0:\n            grad_norm = torch.nn.utils.clip_grad_norm_(\n                self.parameters, self.gradient_clip\n            )\n            self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm))\n            self.evaluator.optim_metrics.add(\n                'grad clip ratio',\n                AverageMetric(float(grad_norm > self.gradient_clip)),\n            )\n        else:\n            grad_norm = compute_grad_norm(self.parameters)\n            self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm))\n\n        self.optimizer.step()\n\n        if hasattr(self, 'scheduler'):\n            self.scheduler.train_step()\n\n    def adjust_lr(self, metric=None):\n        \"\"\"adjust learning rate w/o metric by scheduler\n\n        Args:\n            metric (optional): Defaults to None.\n        \"\"\"\n        if not hasattr(self, 'scheduler') or self.scheduler is None:\n            return\n        self.scheduler.valid_step(metric)\n        logger.debug('[Adjust learning rate after valid epoch]')\n\n    def early_stop(self, metric):\n        if not self.need_early_stop:\n            return False\n        if self.best_valid is None or metric * self.stop_mode > self.best_valid * self.stop_mode:\n            self.best_valid = metric\n            self.drop_cnt = 0\n            logger.info('[Get new best model]')\n            return False\n        else:\n            self.drop_cnt += 1\n            if self.drop_cnt >= self.impatience:\n                logger.info('[Early stop]')\n                return True\n\n    def save_model(self):\n        r\"\"\"Store the model parameters.\"\"\"\n        state = {}\n        if hasattr(self, 'model'):\n            state['model_state_dict'] = self.model.state_dict()\n        if hasattr(self, 'rec_model'):\n            state['rec_state_dict'] = self.rec_model.state_dict()\n        if hasattr(self, 'conv_model'):\n            state['conv_state_dict'] = self.conv_model.state_dict()\n        if hasattr(self, 'policy_model'):\n            state['policy_state_dict'] = self.policy_model.state_dict()\n\n        os.makedirs(SAVE_PATH, exist_ok=True)\n        torch.save(state, self.model_file)\n        logger.info(f'[Save model into {self.model_file}]')\n\n    def restore_model(self):\n        r\"\"\"Store the model parameters.\"\"\"\n        if not os.path.exists(self.model_file):\n            raise ValueError(f'Saved model [{self.model_file}] does not exist')\n        checkpoint = torch.load(self.model_file, map_location=self.device)\n        if hasattr(self, 'model'):\n            self.model.load_state_dict(checkpoint['model_state_dict'])\n        if hasattr(self, 'rec_model'):\n            self.rec_model.load_state_dict(checkpoint['rec_state_dict'])\n        if hasattr(self, 'conv_model'):\n            self.conv_model.load_state_dict(checkpoint['conv_state_dict'])\n        if hasattr(self, 'policy_model'):\n            self.policy_model.load_state_dict(checkpoint['policy_state_dict'])\n        logger.info(f'[Restore model from {self.model_file}]')\n\n    @abstractmethod\n    def interact(self):\n        pass\n\n    def init_interact(self):\n        self.finished = False\n        self.context = {\n            'rec': {},\n            'conv': {}\n        }\n        for key in self.context:\n            self.context[key]['context_tokens'] = []\n            self.context[key]['context_entities'] = []\n            self.context[key]['context_words'] = []\n            self.context[key]['context_items'] = []\n            self.context[key]['user_profile'] = []\n            self.context[key]['interaction_history'] = []\n            self.context[key]['entity_set'] = set()\n            self.context[key]['word_set'] = set()\n\n    def update_context(self, stage, token_ids=None, entity_ids=None, item_ids=None, word_ids=None):\n        if token_ids is not None:\n            self.context[stage]['context_tokens'].append(token_ids)\n        if item_ids is not None:\n            self.context[stage]['context_items'] += item_ids\n        if entity_ids is not None:\n            for entity_id in entity_ids:\n                if entity_id not in self.context[stage]['entity_set']:\n                    self.context[stage]['entity_set'].add(entity_id)\n                    self.context[stage]['context_entities'].append(entity_id)\n        if word_ids is not None:\n            for word_id in word_ids:\n                if word_id not in self.context[stage]['word_set']:\n                    self.context[stage]['word_set'].add(word_id)\n                    self.context[stage]['context_words'].append(word_id)\n\n    def get_input(self, language):\n        print(\"Enter [EXIT] if you want to quit.\")\n\n        if language == 'zh':\n            language = 'chinese'\n        elif language == 'en':\n            language = 'english'\n        else:\n            raise\n        text = input(f\"Enter Your Message in {language}: \")\n\n        if '[EXIT]' in text:\n            self.finished = True\n        return text\n\n    def tokenize(self, text, tokenizer, path=None):\n        tokenize_fun = getattr(self, tokenizer + '_tokenize')\n        if path is not None:\n            return tokenize_fun(text, path)\n        else:\n            return tokenize_fun(text)\n\n    def nltk_tokenize(self, text):\n        nltk.download('punkt')\n        return word_tokenize(text)\n\n    def bert_tokenize(self, text, path):\n        if not hasattr(self, 'bert_tokenizer'):\n            from transformers import AutoTokenizer\n            self.bert_tokenizer = AutoTokenizer.from_pretrained(path)\n        return self.bert_tokenizer.tokenize(text)\n\n    def gpt2_tokenize(self, text, path):\n        if not hasattr(self, 'gpt2_tokenizer'):\n            from transformers import AutoTokenizer\n            self.gpt2_tokenizer = AutoTokenizer.from_pretrained(path)\n        return self.gpt2_tokenizer.tokenize(text)\n\n    def pkuseg_tokenize(self, text):\n        if not hasattr(self, 'pkuseg_tokenizer'):\n            import pkuseg\n            self.pkuseg_tokenizer = pkuseg.pkuseg()\n        return self.pkuseg_tokenizer.cut(text)\n\n    def link(self, tokens, entities):\n        linked_entities = []\n        for token in tokens:\n            entity = extractOne(token, entities, score_cutoff=90)\n            if entity:\n                linked_entities.append(entity[0])\n        return linked_entities\n"
  },
  {
    "path": "crslab/system/inspired.py",
    "content": "# @Time   : 2021/3/1\n# @Author : Beichen Zhang\n# @Email  : zhangbeichen724@gmail.com\n\nimport torch\nfrom loguru import logger\nfrom math import floor\n\nfrom crslab.data import dataset_language_map\nfrom crslab.evaluator.metrics.base import AverageMetric\nfrom crslab.evaluator.metrics.gen import PPLMetric\nfrom crslab.system.base import BaseSystem\nfrom crslab.system.utils.functions import ind2txt\n\n\nclass InspiredSystem(BaseSystem):\n    \"\"\"This is the system for Inspired model\"\"\"\n\n    def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n                 interact=False, debug=False, tensorboard=False):\n        \"\"\"\n\n        Args:\n            opt (dict): Indicating the hyper parameters.\n            train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.\n            valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.\n            test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.\n            vocab (dict): Indicating the vocabulary.\n            side_data (dict): Indicating the side data.\n            restore_system (bool, optional): Indicating if we store system after training. Defaults to False.\n            interact (bool, optional): Indicating if we interact with system. Defaults to False.\n            debug (bool, optional): Indicating if we train in debug mode. Defaults to False.\n            tensorboard (bool, optional) Indicating if we monitor the training performance in tensorboard. Defaults to False. \n\n        \"\"\"\n        super(InspiredSystem, self).__init__(opt, train_dataloader, valid_dataloader,\n                                             test_dataloader, vocab, side_data, restore_system, interact, debug,\n                                             tensorboard)\n\n        if hasattr(self, 'conv_model'):\n            self.ind2tok = vocab['conv']['ind2tok']\n            self.end_token_idx = vocab['conv']['end']\n        if hasattr(self, 'rec_model'):\n            self.item_ids = side_data['rec']['item_entity_ids']\n            self.id2entity = vocab['rec']['id2entity']\n\n        if hasattr(self, 'rec_model'):\n            self.rec_optim_opt = self.opt['rec']\n            self.rec_epoch = self.rec_optim_opt['epoch']\n            self.rec_batch_size = self.rec_optim_opt['batch_size']\n\n        if hasattr(self, 'conv_model'):\n            self.conv_optim_opt = self.opt['conv']\n            self.conv_epoch = self.conv_optim_opt['epoch']\n            self.conv_batch_size = self.conv_optim_opt['batch_size']\n            if self.conv_optim_opt.get('lr_scheduler', None) and 'Transformers' in self.conv_optim_opt['lr_scheduler'][\n                'name']:\n                batch_num = 0\n                for _ in self.train_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                    batch_num += 1\n                conv_training_steps = self.conv_epoch * floor(batch_num / self.conv_optim_opt.get('update_freq', 1))\n                self.conv_optim_opt['lr_scheduler']['training_steps'] = conv_training_steps\n\n        self.language = dataset_language_map[self.opt['dataset']]\n\n    def rec_evaluate(self, rec_predict, item_label):\n        rec_predict = rec_predict.cpu()\n        rec_predict = rec_predict[:, self.item_ids]\n        _, rec_ranks = torch.topk(rec_predict, 50, dim=-1)\n        rec_ranks = rec_ranks.tolist()\n        item_label = item_label.tolist()\n        for rec_rank, item in zip(rec_ranks, item_label):\n            item = self.item_ids.index(item)\n            self.evaluator.rec_evaluate(rec_rank, item)\n\n    def conv_evaluate(self, prediction, response):\n        \"\"\"\n        Args:\n            prediction: torch.LongTensor, shape=(bs, response_truncate-1)\n            response: (torch.LongTensor, torch.LongTensor), shape=((bs, response_truncate), (bs, response_truncate))\n\n            the first token in response is <|endoftext|>,  it is not in prediction\n        \"\"\"\n        prediction = prediction.tolist()\n        response = response.tolist()\n        for p, r in zip(prediction, response):\n            p_str = ind2txt(p, self.ind2tok, self.end_token_idx)\n            r_str = ind2txt(r[1:], self.ind2tok, self.end_token_idx)\n            self.evaluator.gen_evaluate(p_str, [r_str])\n\n    def step(self, batch, stage, mode):\n        \"\"\"\n        stage: ['policy', 'rec', 'conv']\n        mode: ['train', 'val', 'test]\n        \"\"\"\n        batch = [ele.to(self.device) for ele in batch]\n        if stage == 'rec':\n            if mode == 'train':\n                self.rec_model.train()\n            else:\n                self.rec_model.eval()\n\n            rec_loss, rec_predict = self.rec_model.recommend(batch, mode)\n            if mode == \"train\":\n                self.backward(rec_loss)\n            else:\n                self.rec_evaluate(rec_predict, batch[-1])\n            rec_loss = rec_loss.item()\n            self.evaluator.optim_metrics.add(\"rec_loss\",\n                                             AverageMetric(rec_loss))\n        elif stage == \"conv\":\n            if mode != \"test\":\n                # train + valid: need to compute ppl\n                gen_loss, pred = self.conv_model.converse(batch, mode)\n                if mode == 'train':\n                    self.conv_model.train()\n                    self.backward(gen_loss)\n                else:\n                    self.conv_model.eval()\n                    self.conv_evaluate(pred, batch[-1])\n                gen_loss = gen_loss.item()\n                self.evaluator.optim_metrics.add(\"gen_loss\",\n                                                 AverageMetric(gen_loss))\n                self.evaluator.gen_metrics.add(\"ppl\", PPLMetric(gen_loss))\n            else:\n                # generate response in conv_model.step\n                pred = self.conv_model.converse(batch, mode)\n                self.conv_evaluate(pred, batch[-1])\n        else:\n            raise\n\n    def train_recommender(self):\n        if hasattr(self.rec_model, 'bert'):\n            bert_param = list(self.rec_model.bert.named_parameters())\n            bert_param_name = ['bert.' + n for n, p in bert_param]\n        else:\n            bert_param = []\n            bert_param_name = []\n        other_param = [\n            name_param for name_param in self.rec_model.named_parameters()\n            if name_param[0] not in bert_param_name\n        ]\n        params = [{'params': [p for n, p in bert_param], 'lr': self.rec_optim_opt['lr_bert']},\n                  {'params': [p for n, p in other_param]}]\n        self.init_optim(self.rec_optim_opt, params)\n\n        for epoch in range(self.rec_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Recommendation epoch {str(epoch)}]')\n            for batch in self.train_dataloader['rec'].get_rec_data(self.rec_batch_size,\n                                                                   shuffle=True):\n                self.step(batch, stage='rec', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader['rec'].get_rec_data(\n                        self.rec_batch_size, shuffle=False):\n                    self.step(batch, stage='rec', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n                # early stop\n                metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50']\n                if self.early_stop(metric):\n                    break\n        # test\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader['rec'].get_rec_data(self.rec_batch_size,\n                                                                  shuffle=False):\n                self.step(batch, stage='rec', mode='test')\n            self.evaluator.report(mode='test')\n\n    def train_conversation(self):\n        self.init_optim(self.conv_optim_opt, self.conv_model.parameters())\n\n        for epoch in range(self.conv_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Conversation epoch {str(epoch)}]')\n            for batch in self.train_dataloader['conv'].get_conv_data(\n                    batch_size=self.conv_batch_size, shuffle=True):\n                self.step(batch, stage='conv', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader['conv'].get_conv_data(\n                        batch_size=self.conv_batch_size, shuffle=False):\n                    self.step((batch), stage='conv', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n                # early stop\n                metric = self.evaluator.gen_metrics['ppl']\n                if self.early_stop(metric):\n                    break\n        # test\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader['conv'].get_conv_data(\n                    batch_size=self.conv_batch_size, shuffle=False):\n                self.step((batch), stage='conv', mode='test')\n            self.evaluator.report(mode='test')\n\n    def fit(self):\n        if hasattr(self, 'rec_model'):\n            self.train_recommender()\n        if hasattr(self, 'conv_model'):\n            self.train_conversation()\n\n    def interact(self):\n        pass\n"
  },
  {
    "path": "crslab/system/kbrd.py",
    "content": "# -*- encoding: utf-8 -*-\n# @Time    :   2020/12/4\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\n# UPDATE\n# @Time    :   2021/1/3\n# @Author  :   Xiaolei Wang\n# @email   :   wxl1999@foxmail.com\n\nimport os\n\nimport torch\nfrom loguru import logger\n\nfrom crslab.evaluator.metrics.base import AverageMetric\nfrom crslab.evaluator.metrics.gen import PPLMetric\nfrom crslab.system.base import BaseSystem\nfrom crslab.system.utils.functions import ind2txt\n\n\nclass KBRDSystem(BaseSystem):\n    \"\"\"This is the system for KBRD model\"\"\"\n\n    def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n                 interact=False, debug=False, tensorboard=False):\n        \"\"\"\n\n        Args:\n            opt (dict): Indicating the hyper parameters.\n            train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.\n            valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.\n            test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.\n            vocab (dict): Indicating the vocabulary.\n            side_data (dict): Indicating the side data.\n            restore_system (bool, optional): Indicating if we store system after training. Defaults to False.\n            interact (bool, optional): Indicating if we interact with system. Defaults to False.\n            debug (bool, optional): Indicating if we train in debug mode. Defaults to False.\n            tensorboard (bool, optional) Indicating if we monitor the training performance in tensorboard. Defaults to False. \n\n        \"\"\"\n        super(KBRDSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data,\n                                         restore_system, interact, debug, tensorboard)\n\n        self.ind2tok = vocab['ind2tok']\n        self.end_token_idx = vocab['end']\n        self.item_ids = side_data['item_entity_ids']\n\n        self.rec_optim_opt = opt['rec']\n        self.conv_optim_opt = opt['conv']\n        self.rec_epoch = self.rec_optim_opt['epoch']\n        self.conv_epoch = self.conv_optim_opt['epoch']\n        self.rec_batch_size = self.rec_optim_opt['batch_size']\n        self.conv_batch_size = self.conv_optim_opt['batch_size']\n\n    def rec_evaluate(self, rec_predict, item_label):\n        rec_predict = rec_predict.cpu()\n        rec_predict = rec_predict[:, self.item_ids]\n        _, rec_ranks = torch.topk(rec_predict, 50, dim=-1)\n        rec_ranks = rec_ranks.tolist()\n        item_label = item_label.tolist()\n        for rec_rank, label in zip(rec_ranks, item_label):\n            label = self.item_ids.index(label)\n            self.evaluator.rec_evaluate(rec_rank, label)\n\n    def conv_evaluate(self, prediction, response):\n        prediction = prediction.tolist()\n        response = response.tolist()\n        for p, r in zip(prediction, response):\n            p_str = ind2txt(p, self.ind2tok, self.end_token_idx)\n            r_str = ind2txt(r, self.ind2tok, self.end_token_idx)\n            self.evaluator.gen_evaluate(p_str, [r_str])\n\n    def step(self, batch, stage, mode):\n        assert stage in ('rec', 'conv')\n        assert mode in ('train', 'valid', 'test')\n\n        for k, v in batch.items():\n            if isinstance(v, torch.Tensor):\n                batch[k] = v.to(self.device)\n\n        if stage == 'rec':\n            rec_loss, rec_scores = self.model.forward(batch, mode, stage)\n            rec_loss = rec_loss.sum()\n            if mode == 'train':\n                self.backward(rec_loss)\n            else:\n                self.rec_evaluate(rec_scores, batch['item'])\n            rec_loss = rec_loss.item()\n            self.evaluator.optim_metrics.add(\"rec_loss\", AverageMetric(rec_loss))\n        else:\n            if mode != 'test':\n                gen_loss, preds = self.model.forward(batch, mode, stage)\n                if mode == 'train':\n                    self.backward(gen_loss)\n                else:\n                    self.conv_evaluate(preds, batch['response'])\n                gen_loss = gen_loss.item()\n                self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss))\n                self.evaluator.gen_metrics.add(\"ppl\", PPLMetric(gen_loss))\n            else:\n                preds = self.model.forward(batch, mode, stage)\n                self.conv_evaluate(preds, batch['response'])\n\n    def train_recommender(self):\n        self.init_optim(self.rec_optim_opt, self.model.parameters())\n\n        for epoch in range(self.rec_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Recommendation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader.get_rec_data(self.rec_batch_size):\n                self.step(batch, stage='rec', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                    self.step(batch, stage='rec', mode='valid')\n                self.evaluator.report(epoch=epoch, mode='valid')\n                # early stop\n                metric = self.evaluator.optim_metrics['rec_loss']\n                if self.early_stop(metric):\n                    break\n        # test\n        logger.info('[Test]')\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                self.step(batch, stage='rec', mode='test')\n            self.evaluator.report(mode='test')\n\n    def train_conversation(self):\n        if os.environ[\"CUDA_VISIBLE_DEVICES\"] == '-1':\n            self.model.freeze_parameters()\n        elif len(os.environ[\"CUDA_VISIBLE_DEVICES\"]) == 1:\n            self.model.freeze_parameters()\n        else:\n            self.model.module.freeze_parameters()\n        self.init_optim(self.conv_optim_opt, self.model.parameters())\n\n        for epoch in range(self.conv_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Conversation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader.get_conv_data(batch_size=self.conv_batch_size):\n                self.step(batch, stage='conv', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                    self.step(batch, stage='conv', mode='valid')\n                self.evaluator.report(epoch=epoch, mode='valid')\n                # early stop\n                metric = self.evaluator.optim_metrics['gen_loss']\n                if self.early_stop(metric):\n                    break\n        # test\n        logger.info('[Test]')\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                self.step(batch, stage='conv', mode='test')\n            self.evaluator.report(mode='test')\n\n    def fit(self):\n        self.train_recommender()\n        self.train_conversation()\n\n    def interact(self):\n        pass\n"
  },
  {
    "path": "crslab/system/kgsf.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2021/1/3\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nimport os\n\nimport torch\nfrom loguru import logger\n\nfrom crslab.evaluator.metrics.base import AverageMetric\nfrom crslab.evaluator.metrics.gen import PPLMetric\nfrom crslab.system.base import BaseSystem\nfrom crslab.system.utils.functions import ind2txt\n\n\nclass KGSFSystem(BaseSystem):\n    \"\"\"This is the system for KGSF model\"\"\"\n\n    def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n                 interact=False, debug=False, tensorboard=False):\n        \"\"\"\n\n        Args:\n            opt (dict): Indicating the hyper parameters.\n            train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.\n            valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.\n            test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.\n            vocab (dict): Indicating the vocabulary.\n            side_data (dict): Indicating the side data.\n            restore_system (bool, optional): Indicating if we store system after training. Defaults to False.\n            interact (bool, optional): Indicating if we interact with system. Defaults to False.\n            debug (bool, optional): Indicating if we train in debug mode. Defaults to False.\n            tensorboard (bool, optional) Indicating if we monitor the training performance in tensorboard. Defaults to False. \n\n        \"\"\"\n        super(KGSFSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data,\n                                         restore_system, interact, debug, tensorboard)\n\n        self.ind2tok = vocab['ind2tok']\n        self.end_token_idx = vocab['end']\n        self.item_ids = side_data['item_entity_ids']\n\n        self.pretrain_optim_opt = self.opt['pretrain']\n        self.rec_optim_opt = self.opt['rec']\n        self.conv_optim_opt = self.opt['conv']\n        self.pretrain_epoch = self.pretrain_optim_opt['epoch']\n        self.rec_epoch = self.rec_optim_opt['epoch']\n        self.conv_epoch = self.conv_optim_opt['epoch']\n        self.pretrain_batch_size = self.pretrain_optim_opt['batch_size']\n        self.rec_batch_size = self.rec_optim_opt['batch_size']\n        self.conv_batch_size = self.conv_optim_opt['batch_size']\n\n    def rec_evaluate(self, rec_predict, item_label):\n        rec_predict = rec_predict.cpu()\n        rec_predict = rec_predict[:, self.item_ids]\n        _, rec_ranks = torch.topk(rec_predict, 50, dim=-1)\n        rec_ranks = rec_ranks.tolist()\n        item_label = item_label.tolist()\n        for rec_rank, item in zip(rec_ranks, item_label):\n            item = self.item_ids.index(item)\n            self.evaluator.rec_evaluate(rec_rank, item)\n\n    def conv_evaluate(self, prediction, response):\n        prediction = prediction.tolist()\n        response = response.tolist()\n        for p, r in zip(prediction, response):\n            p_str = ind2txt(p, self.ind2tok, self.end_token_idx)\n            r_str = ind2txt(r, self.ind2tok, self.end_token_idx)\n            self.evaluator.gen_evaluate(p_str, [r_str])\n\n    def step(self, batch, stage, mode):\n        batch = [ele.to(self.device) for ele in batch]\n        if stage == 'pretrain':\n            info_loss = self.model.forward(batch, stage, mode)\n            if info_loss is not None:\n                self.backward(info_loss.sum())\n                info_loss = info_loss.sum().item()\n                self.evaluator.optim_metrics.add(\"info_loss\", AverageMetric(info_loss))\n        elif stage == 'rec':\n            rec_loss, info_loss, rec_predict = self.model.forward(batch, stage, mode)\n            if info_loss:\n                loss = rec_loss + 0.025 * info_loss\n            else:\n                loss = rec_loss\n            if mode == \"train\":\n                self.backward(loss.sum())\n            else:\n                self.rec_evaluate(rec_predict, batch[-1])\n            rec_loss = rec_loss.sum().item()\n            self.evaluator.optim_metrics.add(\"rec_loss\", AverageMetric(rec_loss))\n            if info_loss:\n                info_loss = info_loss.sum().item()\n                self.evaluator.optim_metrics.add(\"info_loss\", AverageMetric(info_loss))\n        elif stage == \"conv\":\n            if mode != \"test\":\n                gen_loss, pred = self.model.forward(batch, stage, mode)\n                if mode == 'train':\n                    self.backward(gen_loss.sum())\n                else:\n                    self.conv_evaluate(pred, batch[-1])\n                gen_loss = gen_loss.sum().item()\n                self.evaluator.optim_metrics.add(\"gen_loss\", AverageMetric(gen_loss))\n                self.evaluator.gen_metrics.add(\"ppl\", PPLMetric(gen_loss))\n            else:\n                pred = self.model.forward(batch, stage, mode)\n                self.conv_evaluate(pred, batch[-1])\n        else:\n            raise\n\n    def pretrain(self):\n        self.init_optim(self.pretrain_optim_opt, self.model.parameters())\n\n        for epoch in range(self.pretrain_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Pretrain epoch {str(epoch)}]')\n            for batch in self.train_dataloader.get_pretrain_data(self.pretrain_batch_size, shuffle=False):\n                self.step(batch, stage=\"pretrain\", mode='train')\n            self.evaluator.report()\n\n    def train_recommender(self):\n        self.init_optim(self.rec_optim_opt, self.model.parameters())\n\n        for epoch in range(self.rec_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Recommendation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                self.step(batch, stage='rec', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                    self.step(batch, stage='rec', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n                # early stop\n                metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50']\n                if self.early_stop(metric):\n                    break\n        # test\n        logger.info('[Test]')\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                self.step(batch, stage='rec', mode='test')\n            self.evaluator.report(mode='test')\n\n    def train_conversation(self):\n        if os.environ[\"CUDA_VISIBLE_DEVICES\"] == '-1':\n            self.model.freeze_parameters()\n        else:\n            self.model.module.freeze_parameters()\n        self.init_optim(self.conv_optim_opt, self.model.parameters())\n\n        for epoch in range(self.conv_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Conversation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                self.step(batch, stage='conv', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                    self.step(batch, stage='conv', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n        # test\n        logger.info('[Test]')\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                self.step(batch, stage='conv', mode='test')\n            self.evaluator.report(mode='test')\n\n    def fit(self):\n        self.pretrain()\n        self.train_recommender()\n        self.train_conversation()\n\n    def interact(self):\n        pass\n"
  },
  {
    "path": "crslab/system/ntrd.py",
    "content": "# @Time   : 2021/10/05\n# @Author : Zhipeng Zhao\n# @Email  : oran_official@outlook.com\n\nimport os\nfrom crslab.evaluator.metrics import gen\nfrom numpy.core.numeric import NaN\n\nimport torch\nfrom loguru import logger\n\nfrom crslab.evaluator.metrics.base import AverageMetric\nfrom crslab.evaluator.metrics.gen import PPLMetric\nfrom crslab.system.base import BaseSystem\nfrom crslab.system.utils.functions import ind2slot,ind2txt_with_slots\n\n\nclass NTRDSystem(BaseSystem):\n    \"\"\"This is the system for NTRD model\"\"\"\n    def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n                 interact=False, debug=False, tensorboard=False):\n        \n        super(NTRDSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data,\n                                         restore_system, interact, debug, tensorboard)\n\n        self.ind2tok = vocab['ind2tok']\n        self.ind2movie = vocab['id2entity']\n        self.end_token_idx = vocab['end']\n        self.item_ids = side_data['item_entity_ids']\n\n        self.pretrain_optim_opt = self.opt['pretrain']\n        self.rec_optim_opt = self.opt['rec']\n        self.conv_optim_opt = self.opt['conv']\n        self.pretrain_epoch = self.pretrain_optim_opt['epoch']\n        self.rec_epoch = self.rec_optim_opt['epoch']\n        self.conv_epoch = self.conv_optim_opt['epoch']\n        self.pretrain_batch_size = self.pretrain_optim_opt['batch_size']\n        self.rec_batch_size = self.rec_optim_opt['batch_size']\n        self.conv_batch_size = self.conv_optim_opt['batch_size']\n\n        # loss weight\n        self.gen_loss_weight = self.opt['gen_loss_weight']\n    def rec_evaluate(self, rec_predict, item_label):\n        rec_predict = rec_predict.cpu()\n        rec_predict = rec_predict[:, self.item_ids]\n        _, rec_ranks = torch.topk(rec_predict, 50, dim=-1)\n        rec_ranks = rec_ranks.tolist()\n        item_label = item_label.tolist()\n        for rec_rank, item in zip(rec_ranks, item_label):\n            item = self.item_ids.index(item)\n            self.evaluator.rec_evaluate(rec_rank, item)\n\n    def conv_evaluate(self, prediction,movie_prediction,response,movie_response):\n        prediction = prediction.tolist()\n        response = response.tolist()\n        if movie_prediction != None:\n            movie_prediction = movie_prediction * (movie_prediction!=-1) \n            movie_prediction = torch.masked_select(movie_prediction,(movie_prediction!=0)) \n            movie_prediction = movie_prediction.tolist()\n            movie_prediction = ind2slot(movie_prediction,self.ind2movie)\n        if movie_response != None:\n            movie_response = movie_response * (movie_response!=-1)\n            movie_response = torch.masked_select(movie_response,(movie_response!=0))\n            movie_response = movie_response.tolist()\n            movie_response = ind2slot(movie_response,self.ind2movie)\n\n        for p, r in zip(prediction,response):\n            p_str = ind2txt_with_slots(p, movie_prediction, self.ind2tok, self.end_token_idx)\n            p_str = p_str[1:]\n            r_str = ind2txt_with_slots(r, movie_response, self.ind2tok, self.end_token_idx)\n            self.evaluator.gen_evaluate(p_str, [r_str])\n    \n    def step(self, batch, stage, mode):\n        '''\n        converse:\n        context_tokens, context_entities, context_words, response,all_movies = batch\n\n        recommend\n        context_entities, context_words, entities, movie = batch\n        '''\n        batch = [ele.to(self.device) for ele in batch]\n        if stage == 'pretrain':\n            info_loss = self.model.forward(batch, stage, mode)\n            if info_loss is not None:\n                self.backward(info_loss.sum())\n                info_loss = info_loss.sum().item()\n                self.evaluator.optim_metrics.add(\"info_loss\", AverageMetric(info_loss))\n        elif stage == 'rec':\n            rec_loss, info_loss, rec_predict = self.model.forward(batch, stage, mode)\n            if info_loss:\n                loss = rec_loss + 0.025 * info_loss\n            else:\n                loss = rec_loss\n            if mode == \"train\":\n                self.backward(loss.sum())\n            else:\n                self.rec_evaluate(rec_predict, batch[-1])\n            rec_loss = rec_loss.sum().item()\n            self.evaluator.optim_metrics.add(\"rec_loss\", AverageMetric(rec_loss))\n            if info_loss:\n                info_loss = info_loss.sum().item()\n                self.evaluator.optim_metrics.add(\"info_loss\", AverageMetric(info_loss))\n        elif stage == \"conv\":\n            if mode != \"test\":\n                gen_loss,selection_loss,pred = self.model.forward(batch, stage, mode)\n                if mode == 'train':\n                    loss = self.gen_loss_weight * gen_loss + selection_loss\n                    self.backward(loss.sum())\n                    loss = loss.sum().item()\n                    self.evaluator.optim_metrics.add(\"gen_total_loss\", AverageMetric(loss))\n                gen_loss = gen_loss.sum().item()\n                \n\n                self.evaluator.optim_metrics.add(\"gen_loss\", AverageMetric(gen_loss))\n                self.evaluator.gen_metrics.add(\"ppl\", PPLMetric(gen_loss))\n                selection_loss = selection_loss.sum().item()\n                self.evaluator.optim_metrics.add('sel_loss',AverageMetric(selection_loss))\n\n            else:\n                pred,matching_pred,matching_logist = self.model.forward(batch, stage, mode)\n                self.conv_evaluate(pred,matching_pred,batch[-2],batch[-1])\n        else:\n            raise\n\n\n\n    def pretrain(self):\n        self.init_optim(self.pretrain_optim_opt, self.model.parameters())\n\n        for epoch in range(self.pretrain_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Pretrain epoch {str(epoch)}]')\n            for batch in self.train_dataloader.get_pretrain_data(self.pretrain_batch_size, shuffle=False):\n                self.step(batch, stage=\"pretrain\", mode='train')\n            self.evaluator.report()\n\n    def train_recommender(self):\n        self.init_optim(self.rec_optim_opt, self.model.parameters())\n\n        for epoch in range(self.rec_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Recommendation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                self.step(batch, stage='rec', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                    self.step(batch, stage='rec', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n                # early stop\n                metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50']\n                if self.early_stop(metric):\n                    break\n        # test\n        logger.info('[Test]')\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):\n                self.step(batch, stage='rec', mode='test')\n            self.evaluator.report(mode='test')\n    \n    def train_conversation(self):\n        if os.environ[\"CUDA_VISIBLE_DEVICES\"] == '-1':\n            self.model.freeze_parameters()\n        else:\n            self.model.module.freeze_parameters()\n        self.init_optim(self.conv_optim_opt, self.model.parameters())\n\n        for epoch in range(self.conv_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Conversation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                self.step(batch, stage='conv', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                    self.step(batch, stage='conv', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n            # test\n            logger.info('[Test]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                    self.step(batch, stage='conv', mode='test')\n                self.evaluator.report(mode='test')\n\n    def fit(self):\n        self.pretrain()\n        self.train_recommender()\n        self.train_conversation()\n\n    def interact(self):\n        pass\n    \n\n"
  },
  {
    "path": "crslab/system/redial.py",
    "content": "# @Time   : 2020/12/4\n# @Author : Chenzhan Shang\n# @Email  : czshang@outlook.com\n\n# UPDATE\n# @Time   : 2021/1/3\n# @Author : Xiaolei Wang\n# @email  : wxl1999@foxmail.com\n\nimport torch\nfrom loguru import logger\n\nfrom crslab.data import dataset_language_map\nfrom crslab.evaluator.metrics.base import AverageMetric\nfrom crslab.evaluator.metrics.gen import PPLMetric\nfrom crslab.system.base import BaseSystem\nfrom crslab.system.utils.functions import ind2txt\n\n\nclass ReDialSystem(BaseSystem):\n    \"\"\"This is the system for KGSF model\"\"\"\n\n    def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n                 interact=False, debug=False, tensorboard=False):\n        \"\"\"\n\n        Args:\n            opt (dict): Indicating the hyper parameters.\n            train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.\n            valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.\n            test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.\n            vocab (dict): Indicating the vocabulary.\n            side_data (dict): Indicating the side data.\n            restore_system (bool, optional): Indicating if we store system after training. Defaults to False.\n            interact (bool, optional): Indicating if we interact with system. Defaults to False.\n            debug (bool, optional): Indicating if we train in debug mode. Defaults to False.\n            tensorboard (bool, optional) Indicating if we monitor the training performance in tensorboard. Defaults to False. \n\n        \"\"\"\n        super(ReDialSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data,\n                                           restore_system, interact, debug, tensorboard)\n        self.ind2tok = vocab['conv']['ind2tok']\n        self.end_token_idx = vocab['conv']['end']\n        self.item_ids = side_data['rec']['item_entity_ids']\n        self.id2entity = vocab['rec']['id2entity']\n\n        self.rec_optim_opt = opt['rec']\n        self.conv_optim_opt = opt['conv']\n        self.rec_epoch = self.rec_optim_opt['epoch']\n        self.conv_epoch = self.conv_optim_opt['epoch']\n        self.rec_batch_size = self.rec_optim_opt['batch_size']\n        self.conv_batch_size = self.conv_optim_opt['batch_size']\n\n        self.language = dataset_language_map[self.opt['dataset']]\n\n    def rec_evaluate(self, rec_predict, item_label):\n        rec_predict = rec_predict.cpu()\n        rec_predict = rec_predict[:, self.item_ids]\n        _, rec_ranks = torch.topk(rec_predict, 50, dim=-1)\n        rec_ranks = rec_ranks.tolist()\n        item_label = item_label.tolist()\n        for rec_rank, item in zip(rec_ranks, item_label):\n            item = self.item_ids.index(item)\n            self.evaluator.rec_evaluate(rec_rank, item)\n\n    def conv_evaluate(self, prediction, response):\n        prediction = prediction.tolist()\n        response = response.tolist()\n        for p, r in zip(prediction, response):\n            p_str = ind2txt(p, self.ind2tok, self.end_token_idx)\n            r_str = ind2txt(r, self.ind2tok, self.end_token_idx)\n            self.evaluator.gen_evaluate(p_str, [r_str])\n\n    def step(self, batch, stage, mode):\n        assert stage in ('rec', 'conv')\n        assert mode in ('train', 'valid', 'test')\n\n        for k, v in batch.items():\n            if isinstance(v, torch.Tensor):\n                batch[k] = v.to(self.device)\n\n        if stage == 'rec':\n            rec_loss, rec_scores = self.rec_model.forward(batch, mode=mode)\n            rec_loss = rec_loss.sum()\n            if mode == 'train':\n                self.backward(rec_loss)\n            else:\n                self.rec_evaluate(rec_scores, batch['item'])\n            rec_loss = rec_loss.item()\n            self.evaluator.optim_metrics.add(\"rec_loss\", AverageMetric(rec_loss))\n        else:\n            gen_loss, preds = self.conv_model.forward(batch, mode=mode)\n            gen_loss = gen_loss.sum()\n            if mode == 'train':\n                self.backward(gen_loss)\n            else:\n                self.conv_evaluate(preds, batch['response'])\n            gen_loss = gen_loss.item()\n            self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss))\n            self.evaluator.gen_metrics.add('ppl', PPLMetric(gen_loss))\n\n    def train_recommender(self):\n        self.init_optim(self.rec_optim_opt, self.rec_model.parameters())\n\n        for epoch in range(self.rec_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Recommendation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader['rec'].get_rec_data(batch_size=self.rec_batch_size):\n                self.step(batch, stage='rec', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')  # report train loss\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader['rec'].get_rec_data(batch_size=self.rec_batch_size, shuffle=False):\n                    self.step(batch, stage='rec', mode='valid')\n                self.evaluator.report(epoch=epoch, mode='valid')  # report valid loss\n                # early stop\n                metric = self.evaluator.optim_metrics['rec_loss']\n                if self.early_stop(metric):\n                    break\n        # test\n        logger.info('[Test]')\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader['rec'].get_rec_data(batch_size=self.rec_batch_size, shuffle=False):\n                self.step(batch, stage='rec', mode='test')\n            self.evaluator.report(mode='test')\n\n    def train_conversation(self):\n        self.init_optim(self.conv_optim_opt, self.conv_model.parameters())\n\n        for epoch in range(self.conv_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Conversation epoch {str(epoch)}]')\n            logger.info('[Train]')\n            for batch in self.train_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size):\n                self.step(batch, stage='conv', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            logger.info('[Valid]')\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size,\n                                                                         shuffle=False):\n                    self.step(batch, stage='conv', mode='valid')\n                self.evaluator.report(epoch=epoch, mode='valid')\n                metric = self.evaluator.optim_metrics['gen_loss']\n                if self.early_stop(metric):\n                    break\n        # test\n        logger.info('[Test]')\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                self.step(batch, stage='conv', mode='test')\n            self.evaluator.report(mode='test')\n\n    def fit(self):\n        self.train_recommender()\n        self.train_conversation()\n\n    def interact(self):\n        pass\n"
  },
  {
    "path": "crslab/system/tgredial.py",
    "content": "# @Time   : 2020/12/9\n# @Author : Yuanhang Zhou\n# @Email  : sdzyh002@gmail.com\n\n# UPDATE:\n# @Time   : 2021/1/3\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nimport os\n\nimport torch\nfrom loguru import logger\nfrom math import floor\n\nfrom crslab.config import PRETRAIN_PATH\nfrom crslab.data import get_dataloader, dataset_language_map\nfrom crslab.evaluator.metrics.base import AverageMetric\nfrom crslab.evaluator.metrics.gen import PPLMetric\nfrom crslab.system.base import BaseSystem\nfrom crslab.system.utils.functions import ind2txt\n\n\nclass TGReDialSystem(BaseSystem):\n    \"\"\"This is the system for TGReDial model\"\"\"\n\n    def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,\n                 interact=False, debug=False, tensorboard=False):\n        \"\"\"\n\n        Args:\n            opt (dict): Indicating the hyper parameters.\n            train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.\n            valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.\n            test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.\n            vocab (dict): Indicating the vocabulary.\n            side_data (dict): Indicating the side data.\n            restore_system (bool, optional): Indicating if we store system after training. Defaults to False.\n            interact (bool, optional): Indicating if we interact with system. Defaults to False.\n            debug (bool, optional): Indicating if we train in debug mode. Defaults to False.\n            tensorboard (bool, optional) Indicating if we monitor the training performance in tensorboard. Defaults to False. \n\n        \"\"\"\n        super(TGReDialSystem, self).__init__(opt, train_dataloader, valid_dataloader,\n                                             test_dataloader, vocab, side_data, restore_system, interact, debug,\n                                             tensorboard)\n\n        if hasattr(self, 'conv_model'):\n            self.ind2tok = vocab['conv']['ind2tok']\n            self.end_token_idx = vocab['conv']['end']\n        if hasattr(self, 'rec_model'):\n            self.item_ids = side_data['rec']['item_entity_ids']\n            self.id2entity = vocab['rec']['id2entity']\n\n        if hasattr(self, 'rec_model'):\n            self.rec_optim_opt = self.opt['rec']\n            self.rec_epoch = self.rec_optim_opt['epoch']\n            self.rec_batch_size = self.rec_optim_opt['batch_size']\n\n        if hasattr(self, 'conv_model'):\n            self.conv_optim_opt = self.opt['conv']\n            self.conv_epoch = self.conv_optim_opt['epoch']\n            self.conv_batch_size = self.conv_optim_opt['batch_size']\n            if self.conv_optim_opt.get('lr_scheduler', None) and 'Transformers' in self.conv_optim_opt['lr_scheduler'][\n                'name']:\n                batch_num = 0\n                for _ in self.train_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size, shuffle=False):\n                    batch_num += 1\n                conv_training_steps = self.conv_epoch * floor(batch_num / self.conv_optim_opt.get('update_freq', 1))\n                self.conv_optim_opt['lr_scheduler']['training_steps'] = conv_training_steps\n\n        if hasattr(self, 'policy_model'):\n            self.policy_optim_opt = self.opt['policy']\n            self.policy_epoch = self.policy_optim_opt['epoch']\n            self.policy_batch_size = self.policy_optim_opt['batch_size']\n\n        self.language = dataset_language_map[self.opt['dataset']]\n\n    def rec_evaluate(self, rec_predict, item_label):\n        rec_predict = rec_predict.cpu()\n        rec_predict = rec_predict[:, self.item_ids]\n        _, rec_ranks = torch.topk(rec_predict, 50, dim=-1)\n        rec_ranks = rec_ranks.tolist()\n        item_label = item_label.tolist()\n        for rec_rank, item in zip(rec_ranks, item_label):\n            item = self.item_ids.index(item)\n            self.evaluator.rec_evaluate(rec_rank, item)\n\n    def policy_evaluate(self, rec_predict, movie_label):\n        rec_predict = rec_predict.cpu()\n        _, rec_ranks = torch.topk(rec_predict, 50, dim=-1)\n        rec_ranks = rec_ranks.tolist()\n        movie_label = movie_label.tolist()\n        for rec_rank, movie in zip(rec_ranks, movie_label):\n            self.evaluator.rec_evaluate(rec_rank, movie)\n\n    def conv_evaluate(self, prediction, response):\n        \"\"\"\n        Args:\n            prediction: torch.LongTensor, shape=(bs, response_truncate-1)\n            response: torch.LongTensor, shape=(bs, response_truncate)\n\n            the first token in response is <|endoftext|>,  it is not in prediction\n        \"\"\"\n        prediction = prediction.tolist()\n        response = response.tolist()\n        for p, r in zip(prediction, response):\n            p_str = ind2txt(p, self.ind2tok, self.end_token_idx)\n            r_str = ind2txt(r[1:], self.ind2tok, self.end_token_idx)\n            self.evaluator.gen_evaluate(p_str, [r_str])\n\n    def step(self, batch, stage, mode):\n        \"\"\"\n        stage: ['policy', 'rec', 'conv']\n        mode: ['train', 'val', 'test]\n        \"\"\"\n        batch = [ele.to(self.device) for ele in batch]\n        if stage == 'policy':\n            if mode == 'train':\n                self.policy_model.train()\n            else:\n                self.policy_model.eval()\n\n            policy_loss, policy_predict = self.policy_model.forward(batch, mode)\n            if mode == \"train\" and policy_loss is not None:\n                policy_loss = policy_loss.sum()\n                self.backward(policy_loss)\n            else:\n                self.policy_evaluate(policy_predict, batch[-1])\n            if isinstance(policy_loss, torch.Tensor):\n                policy_loss = policy_loss.item()\n                self.evaluator.optim_metrics.add(\"policy_loss\",\n                                                 AverageMetric(policy_loss))\n        elif stage == 'rec':\n            if mode == 'train':\n                self.rec_model.train()\n            else:\n                self.rec_model.eval()\n            rec_loss, rec_predict = self.rec_model.forward(batch, mode)\n            rec_loss = rec_loss.sum()\n            if mode == \"train\":\n                self.backward(rec_loss)\n            else:\n                self.rec_evaluate(rec_predict, batch[-1])\n            rec_loss = rec_loss.item()\n            self.evaluator.optim_metrics.add(\"rec_loss\",\n                                             AverageMetric(rec_loss))\n        elif stage == \"conv\":\n            if mode != \"test\":\n                # train + valid: need to compute ppl\n                gen_loss, pred = self.conv_model.forward(batch, mode)\n                gen_loss = gen_loss.sum()\n                if mode == 'train':\n                    self.backward(gen_loss)\n                else:\n                    self.conv_evaluate(pred, batch[-1])\n                gen_loss = gen_loss.item()\n                self.evaluator.optim_metrics.add(\"gen_loss\",\n                                                 AverageMetric(gen_loss))\n                self.evaluator.gen_metrics.add(\"ppl\", PPLMetric(gen_loss))\n            else:\n                # generate response in conv_model.step\n                pred = self.conv_model.forward(batch, mode)\n                self.conv_evaluate(pred, batch[-1])\n        else:\n            raise\n\n    def train_recommender(self):\n        if hasattr(self.rec_model, 'bert'):\n            if os.environ[\"CUDA_VISIBLE_DEVICES\"] == '-1':\n                bert_param = list(self.rec_model.bert.named_parameters())\n            else:\n                bert_param = list(self.rec_model.module.bert.named_parameters())\n            bert_param_name = ['bert.' + n for n, p in bert_param]\n        else:\n            bert_param = []\n            bert_param_name = []\n        other_param = [\n            name_param for name_param in self.rec_model.named_parameters()\n            if name_param[0] not in bert_param_name\n        ]\n        params = [{'params': [p for n, p in bert_param], 'lr': self.rec_optim_opt['lr_bert']},\n                  {'params': [p for n, p in other_param]}]\n        self.init_optim(self.rec_optim_opt, params)\n\n        for epoch in range(self.rec_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Recommendation epoch {str(epoch)}]')\n            for batch in self.train_dataloader['rec'].get_rec_data(self.rec_batch_size,\n                                                                   shuffle=True):\n                self.step(batch, stage='rec', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader['rec'].get_rec_data(\n                        self.rec_batch_size, shuffle=False):\n                    self.step(batch, stage='rec', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n                # early stop\n                metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50']\n                if self.early_stop(metric):\n                    break\n        # test\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader['rec'].get_rec_data(self.rec_batch_size,\n                                                                  shuffle=False):\n                self.step(batch, stage='rec', mode='test')\n            self.evaluator.report(mode='test')\n\n    def train_conversation(self):\n        self.init_optim(self.conv_optim_opt, self.conv_model.parameters())\n\n        for epoch in range(self.conv_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Conversation epoch {str(epoch)}]')\n            for batch in self.train_dataloader['conv'].get_conv_data(\n                    batch_size=self.conv_batch_size, shuffle=True):\n                self.step(batch, stage='conv', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader['conv'].get_conv_data(\n                        batch_size=self.conv_batch_size, shuffle=False):\n                    self.step(batch, stage='conv', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n                # early stop\n                metric = self.evaluator.gen_metrics['ppl']\n                if self.early_stop(metric):\n                    break\n        # test\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader['conv'].get_conv_data(\n                    batch_size=self.conv_batch_size, shuffle=False):\n                self.step(batch, stage='conv', mode='test')\n            self.evaluator.report(mode='test')\n\n    def train_policy(self):\n        policy_params = list(self.policy_model.named_parameters())\n        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n        params = [{\n            'params': [\n                p for n, p in policy_params\n                if not any(nd in n for nd in no_decay)\n            ],\n            'weight_decay':\n                self.policy_optim_opt['weight_decay']\n        }, {\n            'params': [\n                p for n, p in policy_params\n                if any(nd in n for nd in no_decay)\n            ],\n        }]\n        self.init_optim(self.policy_optim_opt, params)\n\n        for epoch in range(self.policy_epoch):\n            self.evaluator.reset_metrics()\n            logger.info(f'[Policy epoch {str(epoch)}]')\n            # change the shuffle to True\n            for batch in self.train_dataloader['policy'].get_policy_data(\n                    self.policy_batch_size, shuffle=True):\n                self.step(batch, stage='policy', mode='train')\n            self.evaluator.report(epoch=epoch, mode='train')\n            # val\n            with torch.no_grad():\n                self.evaluator.reset_metrics()\n                for batch in self.valid_dataloader['policy'].get_policy_data(\n                        self.policy_batch_size, shuffle=False):\n                    self.step(batch, stage='policy', mode='val')\n                self.evaluator.report(epoch=epoch, mode='val')\n                # early stop\n                metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50']\n                if self.early_stop(metric):\n                    break\n        # test\n        with torch.no_grad():\n            self.evaluator.reset_metrics()\n            for batch in self.test_dataloader['policy'].get_policy_data(\n                    self.policy_batch_size, shuffle=False):\n                self.step(batch, stage='policy', mode='test')\n            self.evaluator.report(mode='test')\n\n    def fit(self):\n        if hasattr(self, 'rec_model'):\n            self.train_recommender()\n        if hasattr(self, 'policy_model'):\n            self.train_policy()\n        if hasattr(self, 'conv_model'):\n            self.train_conversation()\n\n    def interact(self):\n        self.init_interact()\n        input_text = self.get_input(self.language)\n        while not self.finished:\n            # rec\n            if hasattr(self, 'rec_model'):\n                rec_input = self.process_input(input_text, 'rec')\n                scores = self.rec_model.forward(rec_input, 'infer')\n\n                scores = scores.cpu()[0]\n                scores = scores[self.item_ids]\n                _, rank = torch.topk(scores, 10, dim=-1)\n                item_ids = []\n                for r in rank.tolist():\n                    item_ids.append(self.item_ids[r])\n                first_item_id = item_ids[:1]\n                self.update_context('rec', entity_ids=first_item_id, item_ids=first_item_id)\n\n                print(f\"[Recommend]:\")\n                for item_id in item_ids:\n                    if item_id in self.id2entity:\n                        print(self.id2entity[item_id])\n            # conv\n            if hasattr(self, 'conv_model'):\n                conv_input = self.process_input(input_text, 'conv')\n                preds = self.conv_model.forward(conv_input, 'infer').tolist()[0]\n                p_str = ind2txt(preds, self.ind2tok, self.end_token_idx)\n\n                token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(p_str, 'conv')\n                self.update_context('conv', token_ids, entity_ids, movie_ids, word_ids)\n\n                print(f\"[Response]:\\n{p_str}\")\n            # input\n            input_text = self.get_input(self.language)\n\n    def process_input(self, input_text, stage):\n        token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(input_text, stage)\n        self.update_context(stage, token_ids, entity_ids, movie_ids, word_ids)\n\n        data = {'role': 'Seeker', 'context_tokens': self.context[stage]['context_tokens'],\n                'context_entities': self.context[stage]['context_entities'],\n                'context_words': self.context[stage]['context_words'],\n                'context_items': self.context[stage]['context_items'],\n                'user_profile': self.context[stage]['user_profile'],\n                'interaction_history': self.context[stage]['interaction_history']}\n        dataloader = get_dataloader(self.opt, data, self.vocab[stage])\n        if stage == 'rec':\n            data = dataloader.rec_interact(data)\n        elif stage == 'conv':\n            data = dataloader.conv_interact(data)\n\n        data = [ele.to(self.device) if isinstance(ele, torch.Tensor) else ele for ele in data]\n        return data\n\n    def convert_to_id(self, text, stage):\n        if self.language == 'zh':\n            tokens = self.tokenize(text, 'pkuseg')\n        elif self.language == 'en':\n            tokens = self.tokenize(text, 'nltk')\n        else:\n            raise\n\n        entities = self.link(tokens, self.side_data[stage]['entity_kg']['entity'])\n        words = self.link(tokens, self.side_data[stage]['word_kg']['entity'])\n\n        if self.opt['tokenize'][stage] in ('gpt2', 'bert'):\n            language = dataset_language_map[self.opt['dataset']]\n            path = os.path.join(PRETRAIN_PATH, self.opt['tokenize'][stage], language)\n            tokens = self.tokenize(text, 'bert', path)\n\n        token_ids = [self.vocab[stage]['tok2ind'].get(token, self.vocab[stage]['unk']) for token in tokens]\n        entity_ids = [self.vocab[stage]['entity2id'][entity] for entity in entities if\n                      entity in self.vocab[stage]['entity2id']]\n        movie_ids = [entity_id for entity_id in entity_ids if entity_id in self.item_ids]\n        word_ids = [self.vocab[stage]['word2id'][word] for word in words if word in self.vocab[stage]['word2id']]\n\n        return token_ids, entity_ids, movie_ids, word_ids\n"
  },
  {
    "path": "crslab/system/utils/__init__.py",
    "content": ""
  },
  {
    "path": "crslab/system/utils/functions.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2020/12/18\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\n# UPDATE:\n# @Time   : 2021/10/05\n# @Author  :   Zhipeng Zhao\n# @email   :   oran_official@outlook.com\n\nimport torch\n\n\ndef compute_grad_norm(parameters, norm_type=2.0):\n    \"\"\"\n    Compute norm over gradients of model parameters.\n\n    :param parameters:\n        the model parameters for gradient norm calculation. Iterable of\n        Tensors or single Tensor\n    :param norm_type:\n        type of p-norm to use\n\n    :returns:\n        the computed gradient norm\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p is not None and p.grad is not None]\n    total_norm = 0\n    for p in parameters:\n        param_norm = p.grad.data.norm(norm_type)\n        total_norm += param_norm.item() ** norm_type\n    return total_norm ** (1.0 / norm_type)\n\n\ndef ind2txt(inds, ind2tok, end_token_idx=None, unk_token='unk'):\n    sentence = []\n    for ind in inds:\n        if isinstance(ind, torch.Tensor):\n            ind = ind.item()\n        if end_token_idx and ind == end_token_idx:\n            break\n        sentence.append(ind2tok.get(ind, unk_token))\n    return ' '.join(sentence)\n\ndef ind2txt_with_slots(inds,slots,ind2tok, end_token_idx=None, unk_token='unk',slot_token='[ITEM]'):\n    sentence = []\n    for ind in inds:\n        if isinstance(ind, torch.Tensor):\n            ind = ind.item()\n        if end_token_idx and ind == end_token_idx:\n            break\n        token = ind2tok.get(ind, unk_token)\n        if token == slot_token:\n            token = slots[0]\n            slots = slots[1:] \n        sentence.append(token)\n    return ' '.join(sentence)\n\ndef ind2slot(inds,ind2slot):\n    return [ ind2slot[ind] for ind in inds]\n"
  },
  {
    "path": "crslab/system/utils/lr_scheduler.py",
    "content": "# @Time   : 2020/12/1\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\n\nfrom abc import abstractmethod, ABC\n\n# UPDATE:\n# @Time   : 2020/12/14\n# @Author : Xiaolei Wang\n# @Email  : wxl1999@foxmail.com\nimport math\nimport numpy as np\nimport torch\nfrom loguru import logger\nfrom torch import optim\n\n\nclass LRScheduler(ABC):\n    \"\"\"\n    Class for LR Schedulers.\n\n    Includes some basic functionality by default - setting up the warmup\n    scheduler, passing the correct number of steps to train_step, loading and\n    saving states.\n    Subclasses must implement abstract methods train_step() and valid_step().\n    Schedulers should be initialized with lr_scheduler_factory().\n    __init__() should not be called directly.\n    \"\"\"\n\n    def __init__(self, optimizer, warmup_steps: int = 0):\n        \"\"\"\n        Initialize warmup scheduler. Specific main schedulers should be initialized in\n        the subclasses. Do not invoke this method diretly.\n\n        :param optimizer optimizer:\n            Optimizer being used for training. May be wrapped in\n            fp16_optimizer_wrapper depending on whether fp16 is used.\n        :param int warmup_steps:\n            Number of training step updates warmup scheduler should take.\n        \"\"\"\n        self._number_training_updates = 0\n        self.warmup_steps = warmup_steps\n        self._init_warmup_scheduler(optimizer)\n\n    def _warmup_lr(self, step):\n        \"\"\"\n        Return lr multiplier (on initial lr) for warmup scheduler.\n        \"\"\"\n        return float(step) / float(max(1, self.warmup_steps))\n\n    def _init_warmup_scheduler(self, optimizer):\n        if self.warmup_steps > 0:\n            self.warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._warmup_lr)\n        else:\n            self.warmup_scheduler = None\n\n    def _is_lr_warming_up(self):\n        \"\"\"\n        Check if we're warming up the learning rate.\n        \"\"\"\n        return (\n                hasattr(self, 'warmup_scheduler')\n                and self.warmup_scheduler is not None\n                and self._number_training_updates <= self.warmup_steps\n        )\n\n    def train_step(self):\n        \"\"\"\n        Use the number of train steps to adjust the warmup scheduler or the main\n        scheduler, depending on where in training we are.\n\n        Override this method to override the behavior for training schedulers.\n        \"\"\"\n        self._number_training_updates += 1\n        if self._is_lr_warming_up():\n            self.warmup_scheduler.step()\n        else:\n            self.train_adjust()\n\n    def valid_step(self, metric=None):\n        if self._is_lr_warming_up():\n            # we're not done warming up, so don't start using validation\n            # metrics to adjust schedule\n            return\n        self.valid_adjust(metric)\n\n    @abstractmethod\n    def train_adjust(self):\n        \"\"\"\n        Use the number of train steps to decide when to adjust LR schedule.\n\n        Override this method to override the behavior for training schedulers.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def valid_adjust(self, metric):\n        \"\"\"\n        Use the metrics to decide when to adjust LR schedule.\n\n        This uses the loss as the validation metric if present, if not this\n        function does nothing. Note that the model must be reporting loss for\n        this to work.\n\n        Override this method to override the behavior for validation schedulers.\n        \"\"\"\n        pass\n\n\nclass ReduceLROnPlateau(LRScheduler):\n    \"\"\"\n    Scheduler that decays by a multiplicative rate when valid loss plateaus.\n    \"\"\"\n\n    def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001,\n                 threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, warmup_steps=0):\n        super(ReduceLROnPlateau, self).__init__(optimizer, warmup_steps)\n        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode=mode, factor=factor,\n                                                              patience=patience, verbose=verbose, threshold=threshold,\n                                                              threshold_mode=threshold_mode, cooldown=cooldown,\n                                                              min_lr=min_lr, eps=eps)\n\n    def train_adjust(self):\n        pass\n\n    def valid_adjust(self, metric):\n        self.scheduler.step(metric)\n\n\nclass StepLR(LRScheduler):\n    \"\"\"\n    Scheduler that decays by a fixed multiplicative rate at each valid step.\n    \"\"\"\n\n    def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, warmup_steps=0):\n        super(StepLR, self).__init__(optimizer, warmup_steps)\n        self.scheduler = optim.lr_scheduler.StepLR(optimizer, step_size, gamma, last_epoch)\n\n    def train_adjust(self):\n        pass\n\n    def valid_adjust(self, metric=None):\n        self.scheduler.step()\n\n\nclass ConstantLR(LRScheduler):\n    def __init__(self, optimizer, warmup_steps=0):\n        super(ConstantLR, self).__init__(optimizer, warmup_steps)\n\n    def train_adjust(self):\n        pass\n\n    def valid_adjust(self, metric):\n        pass\n\n\nclass InvSqrtLR(LRScheduler):\n    \"\"\"\n    Scheduler that decays at an inverse square root rate.\n    \"\"\"\n\n    def __init__(self, optimizer, invsqrt_lr_decay_gamma=-1, last_epoch=-1, warmup_steps=0):\n        \"\"\"\n        invsqrt_lr_decay_gamma determines the cycle length of the inverse square root\n        scheduler.\n\n        When steps taken == invsqrt_lr_decay_gamma, the lr multiplier is 1\n        \"\"\"\n        super(InvSqrtLR, self).__init__(optimizer, warmup_steps)\n        self.invsqrt_lr_decay_gamma = invsqrt_lr_decay_gamma\n        if invsqrt_lr_decay_gamma <= 0:\n            logger.warning(\n                '--lr-scheduler invsqrt requires a value for '\n                '--invsqrt-lr-decay-gamma. Defaulting to set gamma to '\n                '--warmup-updates value for backwards compatibility.'\n            )\n            self.invsqrt_lr_decay_gamma = self.warmup_steps\n\n        self.decay_factor = np.sqrt(max(1, self.invsqrt_lr_decay_gamma))\n        self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._invsqrt_lr, last_epoch)\n\n    def _invsqrt_lr(self, step):\n        return self.decay_factor / np.sqrt(max(1, self.invsqrt_lr_decay_gamma + step))\n\n    def train_adjust(self):\n        self.scheduler.step()\n\n    def valid_adjust(self, metric):\n        # this is a training step lr scheduler, nothing to adjust in validation\n        pass\n\n\nclass CosineAnnealingLR(LRScheduler):\n    \"\"\"\n    Scheduler that decays by a cosine function.\n    \"\"\"\n\n    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, warmup_steps=0):\n        \"\"\"\n        training_steps determines the cycle length of the cosine annealing.\n\n        It indicates the number of steps from 1.0 multiplier to 0.0, which corresponds\n        to going from cos(0) to cos(pi)\n        \"\"\"\n        super(CosineAnnealingLR, self).__init__(optimizer, warmup_steps)\n        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min, last_epoch)\n\n    def train_adjust(self):\n        self.scheduler.step()\n\n    def valid_adjust(self, metric):\n        pass\n\n\nclass CosineAnnealingWarmRestartsLR(LRScheduler):\n    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, warmup_steps=0):\n        super(CosineAnnealingWarmRestartsLR, self).__init__(optimizer, warmup_steps)\n        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult, eta_min, last_epoch)\n\n    def train_adjust(self):\n        self.scheduler.step()\n\n    def valid_adjust(self, metric):\n        pass\n\n\nclass TransformersLinearLR(LRScheduler):\n    \"\"\"\n    Scheduler that decays linearly.\n    \"\"\"\n\n    def __init__(self, optimizer, training_steps, warmup_steps=0):\n        \"\"\"\n        training_steps determines the cycle length of the linear annealing.\n\n        It indicates the number of steps from 1.0 multiplier to 0.0\n        \"\"\"\n        super(TransformersLinearLR, self).__init__(optimizer, warmup_steps)\n        self.training_steps = training_steps - warmup_steps\n        self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._linear_lr)\n\n    def _linear_lr(self, step):\n        return max(0.0, float(self.training_steps - step) / float(max(1, self.training_steps)))\n\n    def train_adjust(self):\n        self.scheduler.step()\n\n    def valid_adjust(self, metric):\n        pass\n\n\nclass TransformersCosineLR(LRScheduler):\n    def __init__(self, optimizer, training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1,\n                 warmup_steps: int = 0):\n        super(TransformersCosineLR, self).__init__(optimizer, warmup_steps)\n        self.training_steps = training_steps - warmup_steps\n        self.num_cycles = num_cycles\n        self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_lr, last_epoch)\n\n    def _cosine_lr(self, step):\n        progress = float(step) / float(max(1, self.training_steps))\n        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))\n\n    def train_adjust(self):\n        self.scheduler.step()\n\n    def valid_adjust(self, metric):\n        pass\n\n\nclass TransformersCosineWithHardRestartsLR(LRScheduler):\n    def __init__(self, optimizer, training_steps: int, num_cycles: int = 1, last_epoch: int = -1,\n                 warmup_steps: int = 0):\n        super(TransformersCosineWithHardRestartsLR, self).__init__(optimizer, warmup_steps)\n        self.training_steps = training_steps - warmup_steps\n        self.num_cycles = num_cycles\n        self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_with_hard_restarts_lr, last_epoch)\n\n    def _cosine_with_hard_restarts_lr(self, step):\n        progress = float(step) / float(max(1, self.training_steps))\n        if progress >= 1.0:\n            return 0.0\n        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(self.num_cycles) * progress) % 1.0))))\n\n    def train_adjust(self):\n        self.scheduler.step()\n\n    def valid_adjust(self, metric):\n        pass\n\n\nclass TransformersPolynomialDecayLR(LRScheduler):\n    def __init__(self, optimizer, training_steps, lr_end=1e-7, power=1.0, last_epoch=-1, warmup_steps=0):\n        super(TransformersPolynomialDecayLR, self).__init__(optimizer, warmup_steps)\n        self.training_steps = training_steps - warmup_steps\n        self.lr_init = optimizer.defaults[\"lr\"]\n        self.lr_end = lr_end\n        assert self.lr_init > lr_end, f\"lr_end ({lr_end}) must be be smaller than initial lr ({self.lr_init})\"\n        self.power = power\n        self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._polynomial_decay_lr, last_epoch)\n\n    def _polynomial_decay_lr(self, step):\n        if step > self.training_steps:\n            return self.lr_end / self.lr_init  # as LambdaLR multiplies by lr_init\n        else:\n            lr_range = self.lr_init - self.lr_end\n            decay_steps = self.training_steps\n            pct_remaining = 1 - step / decay_steps\n            decay = lr_range * pct_remaining ** self.power + self.lr_end\n            return decay / self.lr_init  # as LambdaLR multiplies by lr_init\n\n    def train_adjust(self):\n        self.scheduler.step()\n\n    def valid_adjust(self, metric):\n        pass\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=source\nset BUILDDIR=build\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.http://sphinx-doc.org/\n\texit /b 1\n)\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\n\n:end\npopd\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "numpy~=1.19.4\nsentencepiece<0.1.92\ndataclasses~=0.7\ntransformers~=4.1.1\nfasttext~=0.9.2\npkuseg~=0.0.25\npyyaml~=5.4\ntqdm~=4.55.0\nloguru~=0.5.3\nnltk~=3.4.4\nrequests~=2.25.1\nscikit-learn~=0.24.0\nfuzzywuzzy~=0.18.0\n"
  },
  {
    "path": "docs/requirements_geometric.txt",
    "content": "-f https://pytorch-geometric.com/whl/torch-1.4.0+cpu.html\ntorch-cluster==1.5.4\ntorch-scatter==2.0.4\ntorch-sparse==0.6.1\ntorch-spline-conv==1.2.0\ntorch-geometric~=1.6.3"
  },
  {
    "path": "docs/requirements_sphinx.txt",
    "content": "sphinx~=3.4.1\nsphinx_rtd_theme~=0.5.0\nrecommonmark~=0.7.1"
  },
  {
    "path": "docs/requirements_torch.txt",
    "content": "-f https://download.pytorch.org/whl/torch_stable.html\ntorch==1.4.0+cpu\ntorchvision==0.5.0+cpu"
  },
  {
    "path": "docs/source/api/crslab.config.rst",
    "content": "crslab.config package\n=====================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.config.config\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.config\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataloader.rst",
    "content": "crslab.data.dataloader package\n==============================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataloader.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataloader.kbrd\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataloader.kgsf\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataloader.redial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataloader.tgredial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataloader.utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataloader\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataset.durecdial.rst",
    "content": "crslab.data.dataset.durecdial package\n=====================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataset.durecdial.durecdial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataset.durecdial.resources\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataset.durecdial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataset.gorecdial.rst",
    "content": "crslab.data.dataset.gorecdial package\n=====================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataset.gorecdial.gorecdial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataset.gorecdial.resources\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataset.gorecdial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataset.inspired.rst",
    "content": "crslab.data.dataset.inspired package\n====================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataset.inspired.inspired\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataset.inspired.resources\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataset.inspired\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataset.opendialkg.rst",
    "content": "crslab.data.dataset.opendialkg package\n======================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataset.opendialkg.opendialkg\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataset.opendialkg.resources\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataset.opendialkg\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataset.redial.rst",
    "content": "crslab.data.dataset.redial package\n==================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataset.redial.redial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataset.redial.resources\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataset.redial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataset.rst",
    "content": "crslab.data.dataset package\n===========================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.data.dataset.durecdial\n   crslab.data.dataset.gorecdial\n   crslab.data.dataset.inspired\n   crslab.data.dataset.opendialkg\n   crslab.data.dataset.redial\n   crslab.data.dataset.tgredial\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataset.base_dataset\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataset\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.dataset.tgredial.rst",
    "content": "crslab.data.dataset.tgredial package\n====================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.data.dataset.tgredial.resources\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.data.dataset.tgredial.tgredial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.data.dataset.tgredial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.data.rst",
    "content": "crslab.data package\n===================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.data.dataloader\n   crslab.data.dataset\n\nModule contents\n---------------\n\n.. automodule:: crslab.data\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.evaluator.metrics.rst",
    "content": "crslab.evaluator.metrics package\n================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.evaluator.metrics.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.metrics.gen\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.metrics.rec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.evaluator.metrics\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.evaluator.rst",
    "content": "crslab.evaluator package\n========================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.evaluator.metrics\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.evaluator.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.conv\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.embeddings\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.end2end\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.rec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.standard\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.evaluator.utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.evaluator\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.conversation.gpt2.rst",
    "content": "crslab.model.conversation.gpt2 package\n======================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.conversation.gpt2.gpt2\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.conversation.gpt2\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.conversation.rst",
    "content": "crslab.model.conversation package\n=================================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.model.conversation.gpt2\n   crslab.model.conversation.transformer\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.conversation\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.conversation.transformer.rst",
    "content": "crslab.model.conversation.transformer package\n=============================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.conversation.transformer.transformer\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.conversation.transformer\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.crs.kbrd.rst",
    "content": "crslab.model.crs.kbrd package\n=============================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.crs.kbrd.kbrd\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.crs.kbrd\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.crs.kgsf.rst",
    "content": "crslab.model.crs.kgsf package\n=============================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.crs.kgsf.kgsf\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.crs.kgsf.modules\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.crs.kgsf.resources\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.crs.kgsf\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.crs.redial.rst",
    "content": "crslab.model.crs.redial package\n===============================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.crs.redial.modules\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.crs.redial.redial_conv\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.crs.redial.redial_rec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.crs.redial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.crs.rst",
    "content": "crslab.model.crs package\n========================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.model.crs.kbrd\n   crslab.model.crs.kgsf\n   crslab.model.crs.redial\n   crslab.model.crs.tgredial\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.crs\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.crs.tgredial.rst",
    "content": "crslab.model.crs.tgredial package\n=================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.crs.tgredial.tg_conv\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.crs.tgredial.tg_policy\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.crs.tgredial.tg_rec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.crs.tgredial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.policy.conv_bert.rst",
    "content": "crslab.model.policy.conv\\_bert package\n======================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.policy.conv_bert.conv_bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.policy.conv_bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.policy.mgcg.rst",
    "content": "crslab.model.policy.mgcg package\n================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.policy.mgcg.mgcg\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.policy.mgcg\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.policy.pmi.rst",
    "content": "crslab.model.policy.pmi package\n===============================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.policy.pmi.pmi\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.policy.pmi\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.policy.profile_bert.rst",
    "content": "crslab.model.policy.profile\\_bert package\n=========================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.policy.profile_bert.profile_bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.policy.profile_bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.policy.rst",
    "content": "crslab.model.policy package\n===========================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.model.policy.conv_bert\n   crslab.model.policy.mgcg\n   crslab.model.policy.pmi\n   crslab.model.policy.profile_bert\n   crslab.model.policy.topic_bert\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.policy\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.policy.topic_bert.rst",
    "content": "crslab.model.policy.topic\\_bert package\n=======================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.policy.topic_bert.topic_bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.policy.topic_bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.recommendation.bert.rst",
    "content": "crslab.model.recommendation.bert package\n========================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.recommendation.bert.bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.recommendation.bert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.recommendation.gru4rec.rst",
    "content": "crslab.model.recommendation.gru4rec package\n===========================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.recommendation.gru4rec.gru4rec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.recommendation.gru4rec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.recommendation.popularity.rst",
    "content": "crslab.model.recommendation.popularity package\n==============================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.recommendation.popularity.popularity\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.recommendation.popularity\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.recommendation.rst",
    "content": "crslab.model.recommendation package\n===================================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.model.recommendation.bert\n   crslab.model.recommendation.gru4rec\n   crslab.model.recommendation.popularity\n   crslab.model.recommendation.sasrec\n   crslab.model.recommendation.textcnn\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.recommendation\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.recommendation.sasrec.rst",
    "content": "crslab.model.recommendation.sasrec package\n==========================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.recommendation.sasrec.modules\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.recommendation.sasrec.sasrec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.recommendation.sasrec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.recommendation.textcnn.rst",
    "content": "crslab.model.recommendation.textcnn package\n===========================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.recommendation.textcnn.textcnn\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.recommendation.textcnn\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.rst",
    "content": "crslab.model package\n====================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.model.conversation\n   crslab.model.crs\n   crslab.model.policy\n   crslab.model.recommendation\n   crslab.model.utils\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.pretrain_models\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.utils.modules.rst",
    "content": "crslab.model.utils.modules package\n==================================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.utils.modules.attention\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.model.utils.modules.transformer\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.utils.modules\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.model.utils.rst",
    "content": "crslab.model.utils package\n==========================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.model.utils.modules\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.model.utils.functions\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.model.utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.quick_start.rst",
    "content": "crslab.quick\\_start package\n===========================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.quick_start.quick_start\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.quick_start\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.rst",
    "content": "crslab package\n==============\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab.config\n   crslab.data\n   crslab.evaluator\n   crslab.model\n   crslab.quick_start\n   crslab.system\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.download\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/crslab.system.rst",
    "content": "crslab.system package\n=====================\n\nSubmodules\n----------\n\n\n.. automodule:: crslab.system.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.system.kbrd\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.system.kgsf\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.system.lr_scheduler\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.system.redial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.system.tgredial\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n\n.. automodule:: crslab.system.utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: crslab.system\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api/modules.rst",
    "content": "crslab\n======\n\n.. toctree::\n   :maxdepth: 1\n\n   crslab\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\n\nsys.path.insert(0, os.path.abspath('../../'))\n\nfrom recommonmark.transform import AutoStructify\n\n# -- Project information -----------------------------------------------------\n\nproject = 'CRSLab'\ncopyright = '2021, RUC AIBox'\nauthor = 'RUC AIBox'\n\n# The full version, including alpha/beta/rc tags\nrelease = '0.1.1'\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.napoleon',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.doctest',\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.todo',\n    'sphinx.ext.coverage',\n    'sphinx.ext.mathjax',\n    'sphinx.ext.viewcode',\n    'recommonmark'\n]\n\nsource_suffix = ['.rst', '.md']\nautoclass_content = \"both\"\n\n# napoleon\nnapoleon_include_private_with_doc = True\nnapoleon_use_admonition_for_examples = True\nnapoleon_use_admonition_for_notes = True\nnapoleon_use_admonition_for_references = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = []\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_rtd_theme'\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n\n\ndef setup(app):\n    app.add_config_value('recommonmark_config', {\n        'auto_toc_tree_section': 'Contents',\n    }, True)\n    app.add_transform(AutoStructify)\n"
  },
  {
    "path": "docs/source/index.md",
    "content": "# CRSLab\n\n```eval_rst\n\n.. image:: https://img.shields.io/pypi/v/crslab\n    :target: https://pypi.org/project/crslab\n    \n.. image:: https://img.shields.io/github/v/release/rucaibox/crslab.svg\n    :target: https://github.com/rucaibox/crslab/releases\n    \n.. image:: https://img.shields.io/badge/License-MIT-blue.svg\n    :target: ../../../LICENSE\n    \n.. image:: https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B\n    :target: https://arxiv.org/abs/2101.00939\n\n.. toctree::\n   :maxdepth: 1\n   :caption: API REFERENCE\n\n   api/crslab.quick_start\n   api/crslab.config\n   api/crslab.data\n   api/crslab.evaluator\n   api/crslab.model\n   api/crslab.system\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n```"
  },
  {
    "path": "requirements.txt",
    "content": "numpy~=1.19.4\nsentencepiece<0.1.92\ndataclasses~=0.7; python_version<'3.7'\ntransformers~=4.1.1\nfasttext~=0.9.2\npkuseg~=0.0.25\npyyaml~=5.4\ntqdm~=4.55.0\nloguru~=0.5.3\nnltk~=3.4.4\nrequests~=2.25.1\nscikit-learn~=0.24.0\nfuzzywuzzy~=0.18.0\ntensorboard~=2.4.1\n"
  },
  {
    "path": "run_crslab.py",
    "content": "# @Time   : 2020/11/22\n# @Author : Kun Zhou\n# @Email  : francis_kun_zhou@163.com\n\n# UPDATE:\n# @Time   : 2020/11/24, 2021/1/9\n# @Author : Kun Zhou, Xiaolei Wang\n# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com\n\nimport argparse\nimport warnings\n\nfrom crslab.config import Config\n\nwarnings.filterwarnings('ignore')\n\nif __name__ == '__main__':\n    # parse args\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-c', '--config', type=str,\n                        default='config/crs/tgredial/tgredial.yaml', help='config file(yaml) path')\n    parser.add_argument('-g', '--gpu', type=str, default='-1',\n                        help='specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1).')\n    parser.add_argument('-sd', '--save_data', action='store_true',\n                        help='save processed dataset')\n    parser.add_argument('-rd', '--restore_data', action='store_true',\n                        help='restore processed dataset')\n    parser.add_argument('-ss', '--save_system', action='store_true',\n                        help='save trained system')\n    parser.add_argument('-rs', '--restore_system', action='store_true',\n                        help='restore trained system')\n    parser.add_argument('-d', '--debug', action='store_true',\n                        help='use valid dataset to debug your system')\n    parser.add_argument('-i', '--interact', action='store_true',\n                        help='interact with your system instead of training')\n    parser.add_argument('-tb', '--tensorboard', action='store_true',\n                        help='enable tensorboard to monitor train performance')\n    args, _ = parser.parse_known_args()\n    config = Config(args.config, args.gpu, args.debug)\n\n    from crslab.quick_start import run_crslab\n\n    run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact,\n               args.debug, args.tensorboard)\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\ntry:\n    import torch\n    import torch_geometric\nexcept Exception:\n    raise Exception('Please install PyTorch and PyTorch Geometric manually first.\\n' +\n                    'View CRSLab GitHub page for more information: https://github.com/RUCAIBox/CRSLab')\n\nsetup_requires = []\n\ninstall_requires = [\n    'numpy~=1.19.4',\n    'sentencepiece<0.1.92',\n    \"dataclasses~=0.7;python_version<'3.7'\",\n    'transformers~=4.1.1',\n    'fasttext~=0.9.2',\n    'pkuseg~=0.0.25',\n    'pyyaml~=5.4',\n    'tqdm~=4.55.0',\n    'loguru~=0.5.3',\n    'nltk~=3.4.4',\n    'requests~=2.25.1',\n    'scikit-learn~=0.24.0',\n    'fuzzywuzzy~=0.18.0',\n    'tensorboard~=2.4.1',\n]\n\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: MIT License\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"Topic :: Scientific/Engineering :: Human Machine Interfaces\"\n]\n\nwith open(\"README.md\", \"r\", encoding=\"utf-8\") as f:\n    long_description = f.read()\n\nsetup(\n    name='crslab',\n    version='0.1.1',  # please remember to edit crslab/__init__.py in response, once updating the version\n    author='CRSLabTeam',\n    author_email='francis_kun_zhou@163.com',\n    description='An Open-Source Toolkit for Building Conversational Recommender System(CRS)',\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    url='https://github.com/RUCAIBox/CRSLab',\n    packages=[\n        package for package in find_packages()\n        if package.startswith('crslab')\n    ],\n    classifiers=classifiers,\n    install_requires=install_requires,\n    setup_requires=setup_requires,\n    python_requires='>=3.6',\n)\n"
  }
]